from torch.utils.data import Dataset, DataLoader from trainner import Trainer from config import config from logger import logger from torch import optim, nn from torchvision.datasets import MNIST from torchvision.transforms import ToTensor from model.repvit import * from data_loader import * from utils import * def main(): # 初始化组件 initialize() model = repvit_m1_1(num_classes=10).to(config.device) optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) criterion = nn.CrossEntropyLoss() train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size) # 初始化训练器 trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion) trainer.train() if __name__ == "__main__": main()