TA_EC/config.py

27 lines
550 B
Python
Raw Normal View History

2025-03-09 16:31:37 +00:00
import torch
class Config:
# 模型参数
input_dim = 784
hidden_dim = 256
output_dim = 10
# 训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2025-03-11 15:12:23 +00:00
batch_size = 128
2025-03-10 11:42:47 +00:00
epochs = 150
2025-03-09 16:31:37 +00:00
learning_rate = 0.001
save_path = "checkpoints/best_model.pth"
# 日志参数
2025-03-10 11:42:47 +00:00
log_dir = "runs/logs"
2025-03-09 16:31:37 +00:00
log_level = "INFO"
# 断点续训
resume = False
checkpoint_path = "checkpoints/last_checkpoint.pth"
output_path = "runs/"
2025-03-11 15:12:23 +00:00
cache = 'RAM'
2025-03-09 16:31:37 +00:00
config = Config()