TA_EC/config.py

25 lines
525 B
Python
Raw Normal View History

2025-03-09 16:31:37 +00:00
import torch
class Config:
# 模型参数
input_dim = 784
hidden_dim = 256
output_dim = 10
# 训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
epochs = 100
learning_rate = 0.001
save_path = "checkpoints/best_model.pth"
# 日志参数
log_dir = "logs"
log_level = "INFO"
# 断点续训
resume = False
checkpoint_path = "checkpoints/last_checkpoint.pth"
output_path = "runs/"
config = Config()