from os import path import torch from config import config import logger def save_checkpoint(model, optimizer, epoch, is_best=False): checkpoint = { "epoch": epoch, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict() } if is_best: torch.save(checkpoint, config.save_path) torch.save(checkpoint, path.join(config.output_path,"checkpoints/last_checkpoint.pth")) logger.log("info", f"Checkpoint saved at epoch {epoch}") def load_checkpoint(model, optimizer=None): try: checkpoint = torch.load(config.checkpoint_path) model.load_state_dict(checkpoint["model_state"]) if optimizer: optimizer.load_state_dict(checkpoint["optimizer_state"]) logger.log("info", f"Resuming training from epoch {checkpoint['epoch']}") return checkpoint["epoch"] except FileNotFoundError: logger.log("warning", "No checkpoint found, starting from scratch") return 0