diff --git a/config.py b/config.py index cdd3336..0e7d471 100644 --- a/config.py +++ b/config.py @@ -8,13 +8,13 @@ class Config: # 训练参数 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - batch_size = 64 - epochs = 100 + batch_size = 32 + epochs = 150 learning_rate = 0.001 save_path = "checkpoints/best_model.pth" # 日志参数 - log_dir = "logs" + log_dir = "runs/logs" log_level = "INFO" # 断点续训 diff --git a/main.py b/main.py index fb9b14e..2f55803 100644 --- a/main.py +++ b/main.py @@ -18,7 +18,7 @@ def main(): optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) criterion = nn.CrossEntropyLoss() - train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27') + train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size) # 初始化训练器 trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion) diff --git a/trainner.py b/trainner.py index 672f516..e3d44cc 100644 --- a/trainner.py +++ b/trainner.py @@ -37,13 +37,6 @@ class Trainer: "Avg Loss": f"{total_loss/(batch_idx+1):.4f}" }) - # 每100个batch记录一次日志 - if batch_idx % 100 == 0: - logger.log("info", - f"Train Epoch: {epoch} [{batch_idx}/{len(self.train_loader)}] " - f"Loss: {loss.item():.4f}" - ) - return total_loss / len(self.train_loader) def validate(self): diff --git a/utils.py b/utils.py index c1bc38c..ffdf250 100644 --- a/utils.py +++ b/utils.py @@ -1,16 +1,17 @@ from os import path +import os import torch from config import config -import logger +from logger import logger weightdir = path.join(config.output_path, "weights") def initialize(): if not path.exists(config.output_path): - path.mkdir(config.output_path) + os.mkdir(config.output_path) if not path.exists(config.log_dir): - path.mkdir(config.log_dir) + os.mkdir(config.log_dir) if not path.exists(weightdir): - path.mkdir(weightdir) + os.mkdir(weightdir) def save_checkpoint(model, optimizer, epoch, is_best=False): checkpoint = { @@ -22,7 +23,6 @@ def save_checkpoint(model, optimizer, epoch, is_best=False): if is_best: torch.save(checkpoint, path.join(weightdir,"best_model.pth")) torch.save(checkpoint, path.join(weightdir, "last_model.pth")) - logger.log("info", f"Checkpoint saved at epoch {epoch}") def load_checkpoint(model, optimizer=None): try: