diff --git a/.gitignore b/.gitignore index a06a09e..b365258 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ data *.pth *.pyc +logs \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..cdd3336 --- /dev/null +++ b/config.py @@ -0,0 +1,25 @@ +import torch + +class Config: + # 模型参数 + input_dim = 784 + hidden_dim = 256 + output_dim = 10 + + # 训练参数 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + batch_size = 64 + epochs = 100 + learning_rate = 0.001 + save_path = "checkpoints/best_model.pth" + + # 日志参数 + log_dir = "logs" + log_level = "INFO" + + # 断点续训 + resume = False + checkpoint_path = "checkpoints/last_checkpoint.pth" + output_path = "runs/" + +config = Config() \ No newline at end of file diff --git a/data_loader.py b/data_loader.py new file mode 100644 index 0000000..de98721 --- /dev/null +++ b/data_loader.py @@ -0,0 +1,56 @@ +import os +from PIL import Image +import numpy as np +import torch +from torchvision import datasets, transforms +from torch.utils.data import Dataset, DataLoader + +class ClassifyDataset(Dataset): + def __init__(self, data_dir,transforms = None): + self.data_dir = data_dir + # Assume the dataset is structured with subdirectories for each class + self.transform = transforms + self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform) + self.image_size = (3, 224, 224) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + try: + image, label = self.dataset[idx] + return image, label + except Exception as e: + black_image = np.zeros((224, 224, 3), dtype=np.uint8) + return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签 + +def create_data_loaders(data_dir,batch_size=64): + # Define transformations for training data augmentation and normalization + train_transforms = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + # Define transformations for validation and test data (only normalization) + valid_test_transforms = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + # Load the datasets with ImageFolder + train_dir = data_dir + '/train' + valid_dir = data_dir + '/val' + test_dir = data_dir + '/test' + + train_data = ClassifyDataset(train_dir, transforms=train_transforms) + valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms) + test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms) + + # Create the DataLoaders with batch size 64 + train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) + valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size) + test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size) + + return train_loader, valid_loader, test_loader \ No newline at end of file diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..21257f4 --- /dev/null +++ b/logger.py @@ -0,0 +1,35 @@ +import logging +import os +from datetime import datetime +from config import config + +class Logger: + def __init__(self): + os.makedirs(config.log_dir, exist_ok=True) + log_file = os.path.join( + config.log_dir, + f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" + ) + + self.logger = logging.getLogger(__name__) + self.logger.setLevel(config.log_level) + + # 文件输出 + file_handler = logging.FileHandler(log_file) + file_formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(file_formatter) + + # 控制台输出 + console_handler = logging.StreamHandler() + console_formatter = logging.Formatter("%(message)s") + console_handler.setFormatter(console_formatter) + + self.logger.addHandler(file_handler) + self.logger.addHandler(console_handler) + + def log(self, level, message): + getattr(self.logger, level)(message) + +logger = Logger() \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..f513138 --- /dev/null +++ b/main.py @@ -0,0 +1,25 @@ +from torch.utils.data import Dataset, DataLoader +from trainner import Trainer +from config import config +from logger import logger +from torch import optim, nn +from torchvision.datasets import MNIST +from torchvision.transforms import ToTensor + +from model.repvit import * +from data_loader import * + +def main(): + # 初始化组件 + model = repvit_m1_1(num_classes=10).to(config.device) + optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) + criterion = nn.CrossEntropyLoss() + + train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27') + + # 初始化训练器 + trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion) + trainer.train() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/model/repvit.py b/model/repvit.py index 99f4bd5..78197f8 100644 --- a/model/repvit.py +++ b/model/repvit.py @@ -19,7 +19,7 @@ def _make_divisible(v, divisor, min_value=None): new_v += divisor return new_v -from timm.models.layers import SqueezeExcite +from timm.layers import SqueezeExcite import torch diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..4a61c7c --- /dev/null +++ b/predict.py @@ -0,0 +1,22 @@ +import torch +from model import create_model +from config import config +from utils import load_checkpoint + +class Predictor: + def __init__(self): + self.model = create_model() + load_checkpoint(self.model) # 加载最佳模型 + self.model.eval() + + def predict(self, input_data): + with torch.no_grad(): + input_tensor = torch.tensor(input_data).float().to(config.device) + output = self.model(input_tensor) + return output.argmax(dim=1).cpu().numpy() + +# 使用示例 +if __name__ == "__main__": + predictor = Predictor() + sample_data = [...] # 输入数据 + print("Prediction:", predictor.predict(sample_data)) \ No newline at end of file diff --git a/trainner.py b/trainner.py new file mode 100644 index 0000000..672f516 --- /dev/null +++ b/trainner.py @@ -0,0 +1,81 @@ +import torch +from tqdm import tqdm +from torch.utils.data import DataLoader +from config import config +from logger import logger +from utils import save_checkpoint, load_checkpoint + +class Trainer: + def __init__(self, model, train_loader, val_loader, optimizer, criterion): + self.model = model + self.train_loader = train_loader + self.val_loader = val_loader + self.optimizer = optimizer + self.criterion = criterion + self.start_epoch = 0 + + if config.resume: + self.start_epoch = load_checkpoint(model, optimizer) + + def train_epoch(self, epoch): + self.model.train() + total_loss = 0.0 + progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}") + + for batch_idx, (data, target) in enumerate(progress_bar): + data, target = data.to(config.device), target.to(config.device) + + self.optimizer.zero_grad() + output = self.model(data) + loss = self.criterion(output, target) + loss.backward() + self.optimizer.step() + + total_loss += loss.item() + progress_bar.set_postfix({ + "Loss": f"{loss.item():.4f}", + "Avg Loss": f"{total_loss/(batch_idx+1):.4f}" + }) + + # 每100个batch记录一次日志 + if batch_idx % 100 == 0: + logger.log("info", + f"Train Epoch: {epoch} [{batch_idx}/{len(self.train_loader)}] " + f"Loss: {loss.item():.4f}" + ) + + return total_loss / len(self.train_loader) + + def validate(self): + self.model.eval() + total_loss = 0.0 + correct = 0 + + with torch.no_grad(): + for data, target in self.val_loader: + data, target = data.to(config.device), target.to(config.device) + output = self.model(data) + total_loss += self.criterion(output, target).item() + pred = output.argmax(dim=1) + correct += pred.eq(target).sum().item() + + avg_loss = total_loss / len(self.val_loader) + accuracy = 100. * correct / len(self.val_loader.dataset) + logger.log("info", + f"Validation - Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%" + ) + return avg_loss, accuracy + + def train(self): + best_acc = 0.0 + for epoch in range(self.start_epoch, config.epochs): + train_loss = self.train_epoch(epoch) + val_loss, val_acc = self.validate() + + # 保存最佳模型 + if val_acc > best_acc: + best_acc = val_acc + save_checkpoint(self.model, self.optimizer, epoch, is_best=True) + + # 保存检查点 + save_checkpoint(self.model, self.optimizer, epoch) \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..e0167cb --- /dev/null +++ b/utils.py @@ -0,0 +1,28 @@ +from os import path +import torch +from config import config +import logger + +def save_checkpoint(model, optimizer, epoch, is_best=False): + checkpoint = { + "epoch": epoch, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict() + } + + if is_best: + torch.save(checkpoint, config.save_path) + torch.save(checkpoint, path.join(config.output_path,"checkpoints/last_checkpoint.pth")) + logger.log("info", f"Checkpoint saved at epoch {epoch}") + +def load_checkpoint(model, optimizer=None): + try: + checkpoint = torch.load(config.checkpoint_path) + model.load_state_dict(checkpoint["model_state"]) + if optimizer: + optimizer.load_state_dict(checkpoint["optimizer_state"]) + logger.log("info", f"Resuming training from epoch {checkpoint['epoch']}") + return checkpoint["epoch"] + except FileNotFoundError: + logger.log("warning", "No checkpoint found, starting from scratch") + return 0 \ No newline at end of file