from os import path import os import torch from config import config from logger import logger weightdir = path.join(config.output_path, "weights") def initialize(): if not path.exists(config.output_path): os.mkdir(config.output_path) if not path.exists(config.log_dir): os.mkdir(config.log_dir) if not path.exists(weightdir): os.mkdir(weightdir) def save_checkpoint(model, optimizer, epoch, is_best=False): checkpoint = { "epoch": epoch, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict() } if is_best: torch.save(checkpoint, path.join(weightdir,"best_model.pth")) torch.save(checkpoint, path.join(weightdir, "last_model.pth")) def load_checkpoint(model, optimizer=None): try: checkpoint = torch.load(config.checkpoint_path) model.load_state_dict(checkpoint["model_state"]) if optimizer: optimizer.load_state_dict(checkpoint["optimizer_state"]) logger.log("info", f"Resuming training from epoch {checkpoint['epoch']}") return checkpoint["epoch"] except FileNotFoundError: logger.log("warning", "No checkpoint found, starting from scratch") return 0