TA_EC/utils.py
2025-03-10 19:42:47 +08:00

37 lines
1.2 KiB
Python

from os import path
import os
import torch
from config import config
from logger import logger
weightdir = path.join(config.output_path, "weights")
def initialize():
if not path.exists(config.output_path):
os.mkdir(config.output_path)
if not path.exists(config.log_dir):
os.mkdir(config.log_dir)
if not path.exists(weightdir):
os.mkdir(weightdir)
def save_checkpoint(model, optimizer, epoch, is_best=False):
checkpoint = {
"epoch": epoch,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict()
}
if is_best:
torch.save(checkpoint, path.join(weightdir,"best_model.pth"))
torch.save(checkpoint, path.join(weightdir, "last_model.pth"))
def load_checkpoint(model, optimizer=None):
try:
checkpoint = torch.load(config.checkpoint_path)
model.load_state_dict(checkpoint["model_state"])
if optimizer:
optimizer.load_state_dict(checkpoint["optimizer_state"])
logger.log("info", f"Resuming training from epoch {checkpoint['epoch']}")
return checkpoint["epoch"]
except FileNotFoundError:
logger.log("warning", "No checkpoint found, starting from scratch")
return 0