TA_EC/trainner.py
2025-03-11 23:12:23 +08:00

75 lines
2.6 KiB
Python

import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from config import config
from logger import logger
from utils import save_checkpoint, load_checkpoint
import time
class Trainer:
def __init__(self, model, train_loader, val_loader, optimizer, criterion):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.optimizer = optimizer
self.criterion = criterion
self.start_epoch = 0
if config.resume:
self.start_epoch = load_checkpoint(model, optimizer)
def train_epoch(self, epoch):
self.model.train()
total_loss = 0.0
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
time_start = time.time()
for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(config.device), target.to(config.device)
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix({
"Loss": f"{loss.item():.4f}",
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}"
})
return total_loss / len(self.train_loader)
def validate(self):
self.model.eval()
total_loss = 0.0
correct = 0
with torch.no_grad():
for data, target in self.val_loader:
data, target = data.to(config.device), target.to(config.device)
output = self.model(data)
total_loss += self.criterion(output, target).item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
avg_loss = total_loss / len(self.val_loader)
accuracy = 100. * correct / len(self.val_loader.dataset)
logger.log("info",
f"Validation - Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%"
)
return avg_loss, accuracy
def train(self):
best_acc = 0.0
for epoch in range(self.start_epoch, config.epochs):
train_loss = self.train_epoch(epoch)
val_loss, val_acc = self.validate()
# 保存最佳模型
if val_acc > best_acc:
best_acc = val_acc
save_checkpoint(self.model, self.optimizer, epoch, is_best=True)
# 保存检查点
save_checkpoint(self.model, self.optimizer, epoch)