修改文件夹
This commit is contained in:
parent
049155a8b9
commit
59ae44bc92
3
main.py
3
main.py
@ -8,9 +8,12 @@ from torchvision.transforms import ToTensor
|
||||
|
||||
from model.repvit import *
|
||||
from data_loader import *
|
||||
from utils import *
|
||||
|
||||
def main():
|
||||
# 初始化组件
|
||||
initialize()
|
||||
|
||||
model = repvit_m1_1(num_classes=10).to(config.device)
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
13
utils.py
13
utils.py
@ -3,6 +3,15 @@ import torch
|
||||
from config import config
|
||||
import logger
|
||||
|
||||
weightdir = path.join(config.output_path, "weights")
|
||||
def initialize():
|
||||
if not path.exists(config.output_path):
|
||||
path.mkdir(config.output_path)
|
||||
if not path.exists(config.log_dir):
|
||||
path.mkdir(config.log_dir)
|
||||
if not path.exists(weightdir):
|
||||
path.mkdir(weightdir)
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch, is_best=False):
|
||||
checkpoint = {
|
||||
"epoch": epoch,
|
||||
@ -11,8 +20,8 @@ def save_checkpoint(model, optimizer, epoch, is_best=False):
|
||||
}
|
||||
|
||||
if is_best:
|
||||
torch.save(checkpoint, config.save_path)
|
||||
torch.save(checkpoint, path.join(config.output_path,"checkpoints/last_checkpoint.pth"))
|
||||
torch.save(checkpoint, path.join(weightdir,"best_model.pth"))
|
||||
torch.save(checkpoint, path.join(weightdir, "last_model.pth"))
|
||||
logger.log("info", f"Checkpoint saved at epoch {epoch}")
|
||||
|
||||
def load_checkpoint(model, optimizer=None):
|
||||
|
Loading…
Reference in New Issue
Block a user