import torch class Config: # 模型参数 input_dim = 784 hidden_dim = 256 output_dim = 10 # 训练参数 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 32 epochs = 150 learning_rate = 0.001 save_path = "checkpoints/best_model.pth" # 日志参数 log_dir = "runs/logs" log_level = "INFO" # 断点续训 resume = False checkpoint_path = "checkpoints/last_checkpoint.pth" output_path = "runs/" config = Config()