2025-03-09 16:31:37 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
# 模型参数
|
|
|
|
input_dim = 784
|
|
|
|
hidden_dim = 256
|
|
|
|
output_dim = 10
|
|
|
|
|
|
|
|
# 训练参数
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
2025-03-10 11:42:47 +00:00
|
|
|
batch_size = 32
|
|
|
|
epochs = 150
|
2025-03-09 16:31:37 +00:00
|
|
|
learning_rate = 0.001
|
|
|
|
save_path = "checkpoints/best_model.pth"
|
|
|
|
|
|
|
|
# 日志参数
|
2025-03-10 11:42:47 +00:00
|
|
|
log_dir = "runs/logs"
|
2025-03-09 16:31:37 +00:00
|
|
|
log_level = "INFO"
|
|
|
|
|
|
|
|
# 断点续训
|
|
|
|
resume = False
|
|
|
|
checkpoint_path = "checkpoints/last_checkpoint.pth"
|
|
|
|
output_path = "runs/"
|
|
|
|
|
|
|
|
config = Config()
|