From 59ae44bc92219b1d56eecca9e16c8ae8f434c9ec Mon Sep 17 00:00:00 2001 From: yoiannis Date: Mon, 10 Mar 2025 00:39:16 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=96=87=E4=BB=B6=E5=A4=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 3 +++ utils.py | 13 +++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index f513138..fb9b14e 100644 --- a/main.py +++ b/main.py @@ -8,9 +8,12 @@ from torchvision.transforms import ToTensor from model.repvit import * from data_loader import * +from utils import * def main(): # 初始化组件 + initialize() + model = repvit_m1_1(num_classes=10).to(config.device) optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) criterion = nn.CrossEntropyLoss() diff --git a/utils.py b/utils.py index e0167cb..c1bc38c 100644 --- a/utils.py +++ b/utils.py @@ -3,6 +3,15 @@ import torch from config import config import logger +weightdir = path.join(config.output_path, "weights") +def initialize(): + if not path.exists(config.output_path): + path.mkdir(config.output_path) + if not path.exists(config.log_dir): + path.mkdir(config.log_dir) + if not path.exists(weightdir): + path.mkdir(weightdir) + def save_checkpoint(model, optimizer, epoch, is_best=False): checkpoint = { "epoch": epoch, @@ -11,8 +20,8 @@ def save_checkpoint(model, optimizer, epoch, is_best=False): } if is_best: - torch.save(checkpoint, config.save_path) - torch.save(checkpoint, path.join(config.output_path,"checkpoints/last_checkpoint.pth")) + torch.save(checkpoint, path.join(weightdir,"best_model.pth")) + torch.save(checkpoint, path.join(weightdir, "last_model.pth")) logger.log("info", f"Checkpoint saved at epoch {epoch}") def load_checkpoint(model, optimizer=None):