更新训练代码
This commit is contained in:
parent
59ae44bc92
commit
30eeff4b1d
@ -8,13 +8,13 @@ class Config:
|
||||
|
||||
# 训练参数
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
batch_size = 64
|
||||
epochs = 100
|
||||
batch_size = 32
|
||||
epochs = 150
|
||||
learning_rate = 0.001
|
||||
save_path = "checkpoints/best_model.pth"
|
||||
|
||||
# 日志参数
|
||||
log_dir = "logs"
|
||||
log_dir = "runs/logs"
|
||||
log_level = "INFO"
|
||||
|
||||
# 断点续训
|
||||
|
2
main.py
2
main.py
@ -18,7 +18,7 @@ def main():
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27')
|
||||
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size)
|
||||
|
||||
# 初始化训练器
|
||||
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
|
||||
|
@ -37,13 +37,6 @@ class Trainer:
|
||||
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}"
|
||||
})
|
||||
|
||||
# 每100个batch记录一次日志
|
||||
if batch_idx % 100 == 0:
|
||||
logger.log("info",
|
||||
f"Train Epoch: {epoch} [{batch_idx}/{len(self.train_loader)}] "
|
||||
f"Loss: {loss.item():.4f}"
|
||||
)
|
||||
|
||||
return total_loss / len(self.train_loader)
|
||||
|
||||
def validate(self):
|
||||
|
10
utils.py
10
utils.py
@ -1,16 +1,17 @@
|
||||
from os import path
|
||||
import os
|
||||
import torch
|
||||
from config import config
|
||||
import logger
|
||||
from logger import logger
|
||||
|
||||
weightdir = path.join(config.output_path, "weights")
|
||||
def initialize():
|
||||
if not path.exists(config.output_path):
|
||||
path.mkdir(config.output_path)
|
||||
os.mkdir(config.output_path)
|
||||
if not path.exists(config.log_dir):
|
||||
path.mkdir(config.log_dir)
|
||||
os.mkdir(config.log_dir)
|
||||
if not path.exists(weightdir):
|
||||
path.mkdir(weightdir)
|
||||
os.mkdir(weightdir)
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch, is_best=False):
|
||||
checkpoint = {
|
||||
@ -22,7 +23,6 @@ def save_checkpoint(model, optimizer, epoch, is_best=False):
|
||||
if is_best:
|
||||
torch.save(checkpoint, path.join(weightdir,"best_model.pth"))
|
||||
torch.save(checkpoint, path.join(weightdir, "last_model.pth"))
|
||||
logger.log("info", f"Checkpoint saved at epoch {epoch}")
|
||||
|
||||
def load_checkpoint(model, optimizer=None):
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user