修改文件夹
This commit is contained in:
parent
049155a8b9
commit
59ae44bc92
3
main.py
3
main.py
@ -8,9 +8,12 @@ from torchvision.transforms import ToTensor
|
|||||||
|
|
||||||
from model.repvit import *
|
from model.repvit import *
|
||||||
from data_loader import *
|
from data_loader import *
|
||||||
|
from utils import *
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 初始化组件
|
# 初始化组件
|
||||||
|
initialize()
|
||||||
|
|
||||||
model = repvit_m1_1(num_classes=10).to(config.device)
|
model = repvit_m1_1(num_classes=10).to(config.device)
|
||||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
13
utils.py
13
utils.py
@ -3,6 +3,15 @@ import torch
|
|||||||
from config import config
|
from config import config
|
||||||
import logger
|
import logger
|
||||||
|
|
||||||
|
weightdir = path.join(config.output_path, "weights")
|
||||||
|
def initialize():
|
||||||
|
if not path.exists(config.output_path):
|
||||||
|
path.mkdir(config.output_path)
|
||||||
|
if not path.exists(config.log_dir):
|
||||||
|
path.mkdir(config.log_dir)
|
||||||
|
if not path.exists(weightdir):
|
||||||
|
path.mkdir(weightdir)
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, epoch, is_best=False):
|
def save_checkpoint(model, optimizer, epoch, is_best=False):
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
@ -11,8 +20,8 @@ def save_checkpoint(model, optimizer, epoch, is_best=False):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if is_best:
|
if is_best:
|
||||||
torch.save(checkpoint, config.save_path)
|
torch.save(checkpoint, path.join(weightdir,"best_model.pth"))
|
||||||
torch.save(checkpoint, path.join(config.output_path,"checkpoints/last_checkpoint.pth"))
|
torch.save(checkpoint, path.join(weightdir, "last_model.pth"))
|
||||||
logger.log("info", f"Checkpoint saved at epoch {epoch}")
|
logger.log("info", f"Checkpoint saved at epoch {epoch}")
|
||||||
|
|
||||||
def load_checkpoint(model, optimizer=None):
|
def load_checkpoint(model, optimizer=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user