更新训练代码

This commit is contained in:
yoiannis 2025-03-10 19:42:47 +08:00
parent 59ae44bc92
commit 30eeff4b1d
4 changed files with 9 additions and 16 deletions

View File

@ -8,13 +8,13 @@ class Config:
# 训练参数 # 训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64 batch_size = 32
epochs = 100 epochs = 150
learning_rate = 0.001 learning_rate = 0.001
save_path = "checkpoints/best_model.pth" save_path = "checkpoints/best_model.pth"
# 日志参数 # 日志参数
log_dir = "logs" log_dir = "runs/logs"
log_level = "INFO" log_level = "INFO"
# 断点续训 # 断点续训

View File

@ -18,7 +18,7 @@ def main():
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27') train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size)
# 初始化训练器 # 初始化训练器
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion) trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)

View File

@ -37,13 +37,6 @@ class Trainer:
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}" "Avg Loss": f"{total_loss/(batch_idx+1):.4f}"
}) })
# 每100个batch记录一次日志
if batch_idx % 100 == 0:
logger.log("info",
f"Train Epoch: {epoch} [{batch_idx}/{len(self.train_loader)}] "
f"Loss: {loss.item():.4f}"
)
return total_loss / len(self.train_loader) return total_loss / len(self.train_loader)
def validate(self): def validate(self):

View File

@ -1,16 +1,17 @@
from os import path from os import path
import os
import torch import torch
from config import config from config import config
import logger from logger import logger
weightdir = path.join(config.output_path, "weights") weightdir = path.join(config.output_path, "weights")
def initialize(): def initialize():
if not path.exists(config.output_path): if not path.exists(config.output_path):
path.mkdir(config.output_path) os.mkdir(config.output_path)
if not path.exists(config.log_dir): if not path.exists(config.log_dir):
path.mkdir(config.log_dir) os.mkdir(config.log_dir)
if not path.exists(weightdir): if not path.exists(weightdir):
path.mkdir(weightdir) os.mkdir(weightdir)
def save_checkpoint(model, optimizer, epoch, is_best=False): def save_checkpoint(model, optimizer, epoch, is_best=False):
checkpoint = { checkpoint = {
@ -22,7 +23,6 @@ def save_checkpoint(model, optimizer, epoch, is_best=False):
if is_best: if is_best:
torch.save(checkpoint, path.join(weightdir,"best_model.pth")) torch.save(checkpoint, path.join(weightdir,"best_model.pth"))
torch.save(checkpoint, path.join(weightdir, "last_model.pth")) torch.save(checkpoint, path.join(weightdir, "last_model.pth"))
logger.log("info", f"Checkpoint saved at epoch {epoch}")
def load_checkpoint(model, optimizer=None): def load_checkpoint(model, optimizer=None):
try: try: