同步代码

This commit is contained in:
yoiannis 2025-03-10 00:31:37 +08:00
parent 49b2110b5f
commit 049155a8b9
9 changed files with 274 additions and 1 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
data
*.pth
*.pyc
logs

25
config.py Normal file
View File

@ -0,0 +1,25 @@
import torch
class Config:
# 模型参数
input_dim = 784
hidden_dim = 256
output_dim = 10
# 训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
epochs = 100
learning_rate = 0.001
save_path = "checkpoints/best_model.pth"
# 日志参数
log_dir = "logs"
log_level = "INFO"
# 断点续训
resume = False
checkpoint_path = "checkpoints/last_checkpoint.pth"
output_path = "runs/"
config = Config()

56
data_loader.py Normal file
View File

@ -0,0 +1,56 @@
import os
from PIL import Image
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
class ClassifyDataset(Dataset):
def __init__(self, data_dir,transforms = None):
self.data_dir = data_dir
# Assume the dataset is structured with subdirectories for each class
self.transform = transforms
self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform)
self.image_size = (3, 224, 224)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
try:
image, label = self.dataset[idx]
return image, label
except Exception as e:
black_image = np.zeros((224, 224, 3), dtype=np.uint8)
return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签
def create_data_loaders(data_dir,batch_size=64):
# Define transformations for training data augmentation and normalization
train_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Define transformations for validation and test data (only normalization)
valid_test_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the datasets with ImageFolder
train_dir = data_dir + '/train'
valid_dir = data_dir + '/val'
test_dir = data_dir + '/test'
train_data = ClassifyDataset(train_dir, transforms=train_transforms)
valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms)
test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms)
# Create the DataLoaders with batch size 64
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
return train_loader, valid_loader, test_loader

35
logger.py Normal file
View File

@ -0,0 +1,35 @@
import logging
import os
from datetime import datetime
from config import config
class Logger:
def __init__(self):
os.makedirs(config.log_dir, exist_ok=True)
log_file = os.path.join(
config.log_dir,
f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
)
self.logger = logging.getLogger(__name__)
self.logger.setLevel(config.log_level)
# 文件输出
file_handler = logging.FileHandler(log_file)
file_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(file_formatter)
# 控制台输出
console_handler = logging.StreamHandler()
console_formatter = logging.Formatter("%(message)s")
console_handler.setFormatter(console_formatter)
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
def log(self, level, message):
getattr(self.logger, level)(message)
logger = Logger()

25
main.py Normal file
View File

@ -0,0 +1,25 @@
from torch.utils.data import Dataset, DataLoader
from trainner import Trainer
from config import config
from logger import logger
from torch import optim, nn
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from model.repvit import *
from data_loader import *
def main():
# 初始化组件
model = repvit_m1_1(num_classes=10).to(config.device)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss()
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27')
# 初始化训练器
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
trainer.train()
if __name__ == "__main__":
main()

View File

@ -19,7 +19,7 @@ def _make_divisible(v, divisor, min_value=None):
new_v += divisor
return new_v
from timm.models.layers import SqueezeExcite
from timm.layers import SqueezeExcite
import torch

22
predict.py Normal file
View File

@ -0,0 +1,22 @@
import torch
from model import create_model
from config import config
from utils import load_checkpoint
class Predictor:
def __init__(self):
self.model = create_model()
load_checkpoint(self.model) # 加载最佳模型
self.model.eval()
def predict(self, input_data):
with torch.no_grad():
input_tensor = torch.tensor(input_data).float().to(config.device)
output = self.model(input_tensor)
return output.argmax(dim=1).cpu().numpy()
# 使用示例
if __name__ == "__main__":
predictor = Predictor()
sample_data = [...] # 输入数据
print("Prediction:", predictor.predict(sample_data))

81
trainner.py Normal file
View File

@ -0,0 +1,81 @@
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from config import config
from logger import logger
from utils import save_checkpoint, load_checkpoint
class Trainer:
def __init__(self, model, train_loader, val_loader, optimizer, criterion):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.optimizer = optimizer
self.criterion = criterion
self.start_epoch = 0
if config.resume:
self.start_epoch = load_checkpoint(model, optimizer)
def train_epoch(self, epoch):
self.model.train()
total_loss = 0.0
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(config.device), target.to(config.device)
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix({
"Loss": f"{loss.item():.4f}",
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}"
})
# 每100个batch记录一次日志
if batch_idx % 100 == 0:
logger.log("info",
f"Train Epoch: {epoch} [{batch_idx}/{len(self.train_loader)}] "
f"Loss: {loss.item():.4f}"
)
return total_loss / len(self.train_loader)
def validate(self):
self.model.eval()
total_loss = 0.0
correct = 0
with torch.no_grad():
for data, target in self.val_loader:
data, target = data.to(config.device), target.to(config.device)
output = self.model(data)
total_loss += self.criterion(output, target).item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
avg_loss = total_loss / len(self.val_loader)
accuracy = 100. * correct / len(self.val_loader.dataset)
logger.log("info",
f"Validation - Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%"
)
return avg_loss, accuracy
def train(self):
best_acc = 0.0
for epoch in range(self.start_epoch, config.epochs):
train_loss = self.train_epoch(epoch)
val_loss, val_acc = self.validate()
# 保存最佳模型
if val_acc > best_acc:
best_acc = val_acc
save_checkpoint(self.model, self.optimizer, epoch, is_best=True)
# 保存检查点
save_checkpoint(self.model, self.optimizer, epoch)

28
utils.py Normal file
View File

@ -0,0 +1,28 @@
from os import path
import torch
from config import config
import logger
def save_checkpoint(model, optimizer, epoch, is_best=False):
checkpoint = {
"epoch": epoch,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict()
}
if is_best:
torch.save(checkpoint, config.save_path)
torch.save(checkpoint, path.join(config.output_path,"checkpoints/last_checkpoint.pth"))
logger.log("info", f"Checkpoint saved at epoch {epoch}")
def load_checkpoint(model, optimizer=None):
try:
checkpoint = torch.load(config.checkpoint_path)
model.load_state_dict(checkpoint["model_state"])
if optimizer:
optimizer.load_state_dict(checkpoint["optimizer_state"])
logger.log("info", f"Resuming training from epoch {checkpoint['epoch']}")
return checkpoint["epoch"]
except FileNotFoundError:
logger.log("warning", "No checkpoint found, starting from scratch")
return 0