TA_EC/utils.py

37 lines
1.2 KiB
Python
Raw Normal View History

2025-03-09 16:31:37 +00:00
from os import path
2025-03-10 11:42:47 +00:00
import os
2025-03-09 16:31:37 +00:00
import torch
from config import config
2025-03-10 11:42:47 +00:00
from logger import logger
2025-03-09 16:31:37 +00:00
2025-03-09 16:39:16 +00:00
weightdir = path.join(config.output_path, "weights")
def initialize():
if not path.exists(config.output_path):
2025-03-10 11:42:47 +00:00
os.mkdir(config.output_path)
2025-03-09 16:39:16 +00:00
if not path.exists(config.log_dir):
2025-03-10 11:42:47 +00:00
os.mkdir(config.log_dir)
2025-03-09 16:39:16 +00:00
if not path.exists(weightdir):
2025-03-10 11:42:47 +00:00
os.mkdir(weightdir)
2025-03-09 16:39:16 +00:00
2025-03-09 16:31:37 +00:00
def save_checkpoint(model, optimizer, epoch, is_best=False):
checkpoint = {
"epoch": epoch,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict()
}
if is_best:
2025-03-09 16:39:16 +00:00
torch.save(checkpoint, path.join(weightdir,"best_model.pth"))
torch.save(checkpoint, path.join(weightdir, "last_model.pth"))
2025-03-09 16:31:37 +00:00
def load_checkpoint(model, optimizer=None):
try:
checkpoint = torch.load(config.checkpoint_path)
model.load_state_dict(checkpoint["model_state"])
if optimizer:
optimizer.load_state_dict(checkpoint["optimizer_state"])
logger.log("info", f"Resuming training from epoch {checkpoint['epoch']}")
return checkpoint["epoch"]
except FileNotFoundError:
logger.log("warning", "No checkpoint found, starting from scratch")
return 0