75 lines
2.6 KiB
Python
75 lines
2.6 KiB
Python
import torch
|
|
from tqdm import tqdm
|
|
from torch.utils.data import DataLoader
|
|
from config import config
|
|
from logger import logger
|
|
from utils import save_checkpoint, load_checkpoint
|
|
import time
|
|
|
|
class Trainer:
|
|
def __init__(self, model, train_loader, val_loader, optimizer, criterion):
|
|
self.model = model
|
|
self.train_loader = train_loader
|
|
self.val_loader = val_loader
|
|
self.optimizer = optimizer
|
|
self.criterion = criterion
|
|
self.start_epoch = 0
|
|
|
|
if config.resume:
|
|
self.start_epoch = load_checkpoint(model, optimizer)
|
|
|
|
def train_epoch(self, epoch):
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
|
|
time_start = time.time()
|
|
for batch_idx, (data, target) in enumerate(progress_bar):
|
|
data, target = data.to(config.device), target.to(config.device)
|
|
|
|
self.optimizer.zero_grad()
|
|
output = self.model(data)
|
|
loss = self.criterion(output, target)
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
progress_bar.set_postfix({
|
|
"Loss": f"{loss.item():.4f}",
|
|
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}"
|
|
})
|
|
|
|
return total_loss / len(self.train_loader)
|
|
|
|
def validate(self):
|
|
self.model.eval()
|
|
total_loss = 0.0
|
|
correct = 0
|
|
|
|
with torch.no_grad():
|
|
for data, target in self.val_loader:
|
|
data, target = data.to(config.device), target.to(config.device)
|
|
output = self.model(data)
|
|
total_loss += self.criterion(output, target).item()
|
|
pred = output.argmax(dim=1)
|
|
correct += pred.eq(target).sum().item()
|
|
|
|
avg_loss = total_loss / len(self.val_loader)
|
|
accuracy = 100. * correct / len(self.val_loader.dataset)
|
|
logger.log("info",
|
|
f"Validation - Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%"
|
|
)
|
|
return avg_loss, accuracy
|
|
|
|
def train(self):
|
|
best_acc = 0.0
|
|
for epoch in range(self.start_epoch, config.epochs):
|
|
train_loss = self.train_epoch(epoch)
|
|
val_loss, val_acc = self.validate()
|
|
|
|
# 保存最佳模型
|
|
if val_acc > best_acc:
|
|
best_acc = val_acc
|
|
save_checkpoint(self.model, self.optimizer, epoch, is_best=True)
|
|
|
|
# 保存检查点
|
|
save_checkpoint(self.model, self.optimizer, epoch) |