TA_EC/config.py
2025-03-11 23:12:23 +08:00

27 lines
550 B
Python

import torch
class Config:
# 模型参数
input_dim = 784
hidden_dim = 256
output_dim = 10
# 训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
epochs = 150
learning_rate = 0.001
save_path = "checkpoints/best_model.pth"
# 日志参数
log_dir = "runs/logs"
log_level = "INFO"
# 断点续训
resume = False
checkpoint_path = "checkpoints/last_checkpoint.pth"
output_path = "runs/"
cache = 'RAM'
config = Config()