25 lines
525 B
Python
25 lines
525 B
Python
![]() |
import torch
|
||
|
|
||
|
class Config:
|
||
|
# 模型参数
|
||
|
input_dim = 784
|
||
|
hidden_dim = 256
|
||
|
output_dim = 10
|
||
|
|
||
|
# 训练参数
|
||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
batch_size = 64
|
||
|
epochs = 100
|
||
|
learning_rate = 0.001
|
||
|
save_path = "checkpoints/best_model.pth"
|
||
|
|
||
|
# 日志参数
|
||
|
log_dir = "logs"
|
||
|
log_level = "INFO"
|
||
|
|
||
|
# 断点续训
|
||
|
resume = False
|
||
|
checkpoint_path = "checkpoints/last_checkpoint.pth"
|
||
|
output_path = "runs/"
|
||
|
|
||
|
config = Config()
|