更新训练代码

This commit is contained in:
yoiannis 2025-03-10 19:42:47 +08:00
parent 59ae44bc92
commit 30eeff4b1d
4 changed files with 9 additions and 16 deletions

View File

@ -8,13 +8,13 @@ class Config:
# 训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
epochs = 100
batch_size = 32
epochs = 150
learning_rate = 0.001
save_path = "checkpoints/best_model.pth"
# 日志参数
log_dir = "logs"
log_dir = "runs/logs"
log_level = "INFO"
# 断点续训

View File

@ -18,7 +18,7 @@ def main():
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss()
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27')
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size)
# 初始化训练器
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)

View File

@ -37,13 +37,6 @@ class Trainer:
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}"
})
# 每100个batch记录一次日志
if batch_idx % 100 == 0:
logger.log("info",
f"Train Epoch: {epoch} [{batch_idx}/{len(self.train_loader)}] "
f"Loss: {loss.item():.4f}"
)
return total_loss / len(self.train_loader)
def validate(self):

View File

@ -1,16 +1,17 @@
from os import path
import os
import torch
from config import config
import logger
from logger import logger
weightdir = path.join(config.output_path, "weights")
def initialize():
if not path.exists(config.output_path):
path.mkdir(config.output_path)
os.mkdir(config.output_path)
if not path.exists(config.log_dir):
path.mkdir(config.log_dir)
os.mkdir(config.log_dir)
if not path.exists(weightdir):
path.mkdir(weightdir)
os.mkdir(weightdir)
def save_checkpoint(model, optimizer, epoch, is_best=False):
checkpoint = {
@ -22,7 +23,6 @@ def save_checkpoint(model, optimizer, epoch, is_best=False):
if is_best:
torch.save(checkpoint, path.join(weightdir,"best_model.pth"))
torch.save(checkpoint, path.join(weightdir, "last_model.pth"))
logger.log("info", f"Checkpoint saved at epoch {epoch}")
def load_checkpoint(model, optimizer=None):
try: