37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
from os import path
|
|
import os
|
|
import torch
|
|
from config import config
|
|
from logger import logger
|
|
|
|
weightdir = path.join(config.output_path, "weights")
|
|
def initialize():
|
|
if not path.exists(config.output_path):
|
|
os.mkdir(config.output_path)
|
|
if not path.exists(config.log_dir):
|
|
os.mkdir(config.log_dir)
|
|
if not path.exists(weightdir):
|
|
os.mkdir(weightdir)
|
|
|
|
def save_checkpoint(model, optimizer, epoch, is_best=False):
|
|
checkpoint = {
|
|
"epoch": epoch,
|
|
"model_state": model.state_dict(),
|
|
"optimizer_state": optimizer.state_dict()
|
|
}
|
|
|
|
if is_best:
|
|
torch.save(checkpoint, path.join(weightdir,"best_model.pth"))
|
|
torch.save(checkpoint, path.join(weightdir, "last_model.pth"))
|
|
|
|
def load_checkpoint(model, optimizer=None):
|
|
try:
|
|
checkpoint = torch.load(config.checkpoint_path)
|
|
model.load_state_dict(checkpoint["model_state"])
|
|
if optimizer:
|
|
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
|
logger.log("info", f"Resuming training from epoch {checkpoint['epoch']}")
|
|
return checkpoint["epoch"]
|
|
except FileNotFoundError:
|
|
logger.log("warning", "No checkpoint found, starting from scratch")
|
|
return 0 |