修改文件夹

This commit is contained in:
yoiannis 2025-03-10 00:39:16 +08:00
parent 049155a8b9
commit 59ae44bc92
2 changed files with 14 additions and 2 deletions

View File

@ -8,9 +8,12 @@ from torchvision.transforms import ToTensor
from model.repvit import *
from data_loader import *
from utils import *
def main():
# 初始化组件
initialize()
model = repvit_m1_1(num_classes=10).to(config.device)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss()

View File

@ -3,6 +3,15 @@ import torch
from config import config
import logger
weightdir = path.join(config.output_path, "weights")
def initialize():
if not path.exists(config.output_path):
path.mkdir(config.output_path)
if not path.exists(config.log_dir):
path.mkdir(config.log_dir)
if not path.exists(weightdir):
path.mkdir(weightdir)
def save_checkpoint(model, optimizer, epoch, is_best=False):
checkpoint = {
"epoch": epoch,
@ -11,8 +20,8 @@ def save_checkpoint(model, optimizer, epoch, is_best=False):
}
if is_best:
torch.save(checkpoint, config.save_path)
torch.save(checkpoint, path.join(config.output_path,"checkpoints/last_checkpoint.pth"))
torch.save(checkpoint, path.join(weightdir,"best_model.pth"))
torch.save(checkpoint, path.join(weightdir, "last_model.pth"))
logger.log("info", f"Checkpoint saved at epoch {epoch}")
def load_checkpoint(model, optimizer=None):