import torch from tqdm import tqdm from torch.utils.data import DataLoader from config import config from logger import logger from utils import save_checkpoint, load_checkpoint class Trainer: def __init__(self, model, train_loader, val_loader, optimizer, criterion): self.model = model self.train_loader = train_loader self.val_loader = val_loader self.optimizer = optimizer self.criterion = criterion self.start_epoch = 0 if config.resume: self.start_epoch = load_checkpoint(model, optimizer) def train_epoch(self, epoch): self.model.train() total_loss = 0.0 progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}") for batch_idx, (data, target) in enumerate(progress_bar): data, target = data.to(config.device), target.to(config.device) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() total_loss += loss.item() progress_bar.set_postfix({ "Loss": f"{loss.item():.4f}", "Avg Loss": f"{total_loss/(batch_idx+1):.4f}" }) # 每100个batch记录一次日志 if batch_idx % 100 == 0: logger.log("info", f"Train Epoch: {epoch} [{batch_idx}/{len(self.train_loader)}] " f"Loss: {loss.item():.4f}" ) return total_loss / len(self.train_loader) def validate(self): self.model.eval() total_loss = 0.0 correct = 0 with torch.no_grad(): for data, target in self.val_loader: data, target = data.to(config.device), target.to(config.device) output = self.model(data) total_loss += self.criterion(output, target).item() pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() avg_loss = total_loss / len(self.val_loader) accuracy = 100. * correct / len(self.val_loader.dataset) logger.log("info", f"Validation - Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%" ) return avg_loss, accuracy def train(self): best_acc = 0.0 for epoch in range(self.start_epoch, config.epochs): train_loss = self.train_epoch(epoch) val_loss, val_acc = self.validate() # 保存最佳模型 if val_acc > best_acc: best_acc = val_acc save_checkpoint(self.model, self.optimizer, epoch, is_best=True) # 保存检查点 save_checkpoint(self.model, self.optimizer, epoch)