TA_EC/utils.py

28 lines
997 B
Python
Raw Normal View History

2025-03-09 16:31:37 +00:00
from os import path
import torch
from config import config
import logger
def save_checkpoint(model, optimizer, epoch, is_best=False):
checkpoint = {
"epoch": epoch,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict()
}
if is_best:
torch.save(checkpoint, config.save_path)
torch.save(checkpoint, path.join(config.output_path,"checkpoints/last_checkpoint.pth"))
logger.log("info", f"Checkpoint saved at epoch {epoch}")
def load_checkpoint(model, optimizer=None):
try:
checkpoint = torch.load(config.checkpoint_path)
model.load_state_dict(checkpoint["model_state"])
if optimizer:
optimizer.load_state_dict(checkpoint["optimizer_state"])
logger.log("info", f"Resuming training from epoch {checkpoint['epoch']}")
return checkpoint["epoch"]
except FileNotFoundError:
logger.log("warning", "No checkpoint found, starting from scratch")
return 0