27 lines
550 B
Python
27 lines
550 B
Python
import torch
|
|
|
|
class Config:
|
|
# 模型参数
|
|
input_dim = 784
|
|
hidden_dim = 256
|
|
output_dim = 10
|
|
|
|
# 训练参数
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
batch_size = 128
|
|
epochs = 150
|
|
learning_rate = 0.001
|
|
save_path = "checkpoints/best_model.pth"
|
|
|
|
# 日志参数
|
|
log_dir = "runs/logs"
|
|
log_level = "INFO"
|
|
|
|
# 断点续训
|
|
resume = False
|
|
checkpoint_path = "checkpoints/last_checkpoint.pth"
|
|
output_path = "runs/"
|
|
|
|
cache = 'RAM'
|
|
|
|
config = Config() |