TA_EC/main.py
2025-03-11 23:12:23 +08:00

29 lines
909 B
Python

from torch.utils.data import Dataset, DataLoader
from trainner import Trainer
from config import config
from logger import logger
from torch import optim, nn
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from model.repvit import *
from model.mobilenetv3 import *
from data_loader import *
from utils import *
def main():
# 初始化组件
initialize()
model = repvit_m1_0(num_classes=9).to(config.device)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss()
train_loader, valid_loader, test_loader = get_data_loader('/home/yoiannis/deep_learning/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size,Cache='RAM')
# 初始化训练器
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
trainer.train()
if __name__ == "__main__":
main()