29 lines
909 B
Python
29 lines
909 B
Python
from torch.utils.data import Dataset, DataLoader
|
|
from trainner import Trainer
|
|
from config import config
|
|
from logger import logger
|
|
from torch import optim, nn
|
|
from torchvision.datasets import MNIST
|
|
from torchvision.transforms import ToTensor
|
|
|
|
from model.repvit import *
|
|
from model.mobilenetv3 import *
|
|
from data_loader import *
|
|
from utils import *
|
|
|
|
def main():
|
|
# 初始化组件
|
|
initialize()
|
|
|
|
model = repvit_m1_0(num_classes=9).to(config.device)
|
|
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
train_loader, valid_loader, test_loader = get_data_loader('/home/yoiannis/deep_learning/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size,Cache='RAM')
|
|
|
|
# 初始化训练器
|
|
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
|
|
trainer.train()
|
|
|
|
if __name__ == "__main__":
|
|
main() |