更新训练代码
This commit is contained in:
parent
59ae44bc92
commit
30eeff4b1d
@ -8,13 +8,13 @@ class Config:
|
|||||||
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
batch_size = 64
|
batch_size = 32
|
||||||
epochs = 100
|
epochs = 150
|
||||||
learning_rate = 0.001
|
learning_rate = 0.001
|
||||||
save_path = "checkpoints/best_model.pth"
|
save_path = "checkpoints/best_model.pth"
|
||||||
|
|
||||||
# 日志参数
|
# 日志参数
|
||||||
log_dir = "logs"
|
log_dir = "runs/logs"
|
||||||
log_level = "INFO"
|
log_level = "INFO"
|
||||||
|
|
||||||
# 断点续训
|
# 断点续训
|
||||||
|
2
main.py
2
main.py
@ -18,7 +18,7 @@ def main():
|
|||||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27')
|
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size)
|
||||||
|
|
||||||
# 初始化训练器
|
# 初始化训练器
|
||||||
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
|
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
|
||||||
|
@ -37,13 +37,6 @@ class Trainer:
|
|||||||
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}"
|
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}"
|
||||||
})
|
})
|
||||||
|
|
||||||
# 每100个batch记录一次日志
|
|
||||||
if batch_idx % 100 == 0:
|
|
||||||
logger.log("info",
|
|
||||||
f"Train Epoch: {epoch} [{batch_idx}/{len(self.train_loader)}] "
|
|
||||||
f"Loss: {loss.item():.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return total_loss / len(self.train_loader)
|
return total_loss / len(self.train_loader)
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
|
10
utils.py
10
utils.py
@ -1,16 +1,17 @@
|
|||||||
from os import path
|
from os import path
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from config import config
|
from config import config
|
||||||
import logger
|
from logger import logger
|
||||||
|
|
||||||
weightdir = path.join(config.output_path, "weights")
|
weightdir = path.join(config.output_path, "weights")
|
||||||
def initialize():
|
def initialize():
|
||||||
if not path.exists(config.output_path):
|
if not path.exists(config.output_path):
|
||||||
path.mkdir(config.output_path)
|
os.mkdir(config.output_path)
|
||||||
if not path.exists(config.log_dir):
|
if not path.exists(config.log_dir):
|
||||||
path.mkdir(config.log_dir)
|
os.mkdir(config.log_dir)
|
||||||
if not path.exists(weightdir):
|
if not path.exists(weightdir):
|
||||||
path.mkdir(weightdir)
|
os.mkdir(weightdir)
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, epoch, is_best=False):
|
def save_checkpoint(model, optimizer, epoch, is_best=False):
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
@ -22,7 +23,6 @@ def save_checkpoint(model, optimizer, epoch, is_best=False):
|
|||||||
if is_best:
|
if is_best:
|
||||||
torch.save(checkpoint, path.join(weightdir,"best_model.pth"))
|
torch.save(checkpoint, path.join(weightdir,"best_model.pth"))
|
||||||
torch.save(checkpoint, path.join(weightdir, "last_model.pth"))
|
torch.save(checkpoint, path.join(weightdir, "last_model.pth"))
|
||||||
logger.log("info", f"Checkpoint saved at epoch {epoch}")
|
|
||||||
|
|
||||||
def load_checkpoint(model, optimizer=None):
|
def load_checkpoint(model, optimizer=None):
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user