From f43e21c09d050396d3ad58ddc67ed89aea351c46 Mon Sep 17 00:00:00 2001 From: yoiannis <13330431063> Date: Tue, 11 Mar 2025 23:12:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 ++ config.py | 4 +- data_loader.py | 118 +++++++++++++++++++++++++++++++++------------ dataset/recover.py | 26 +++++----- dataset/test.py | 12 +++++ main.py | 5 +- trainner.py | 3 +- 7 files changed, 125 insertions(+), 47 deletions(-) create mode 100644 README.md create mode 100644 dataset/test.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..08d1c83 --- /dev/null +++ b/README.md @@ -0,0 +1,4 @@ +# 1.相关知识 +``` + https://github.com/Hao840/OFAKD +``` \ No newline at end of file diff --git a/config.py b/config.py index 0e7d471..190fba4 100644 --- a/config.py +++ b/config.py @@ -8,7 +8,7 @@ class Config: # 训练参数 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - batch_size = 32 + batch_size = 128 epochs = 150 learning_rate = 0.001 save_path = "checkpoints/best_model.pth" @@ -22,4 +22,6 @@ class Config: checkpoint_path = "checkpoints/last_checkpoint.pth" output_path = "runs/" + cache = 'RAM' + config = Config() \ No newline at end of file diff --git a/data_loader.py b/data_loader.py index de98721..a06ac22 100644 --- a/data_loader.py +++ b/data_loader.py @@ -1,31 +1,68 @@ import os +from logger import logger from PIL import Image -import numpy as np import torch -from torchvision import datasets, transforms from torch.utils.data import Dataset, DataLoader +from torchvision import transforms -class ClassifyDataset(Dataset): - def __init__(self, data_dir,transforms = None): - self.data_dir = data_dir - # Assume the dataset is structured with subdirectories for each class - self.transform = transforms - self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform) - self.image_size = (3, 224, 224) +class ImageClassificationDataset(Dataset): + def __init__(self, root_dir, transform=None,Cache=False): + self.root_dir = root_dir + self.transform = transform + self.classes = sorted(os.listdir(root_dir)) + self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)} + self.image_paths = [] + self.image = [] + self.labels = [] + self.Cache = Cache + + logger.log("info", + "init the dataloader" + ) + + for cls_name in self.classes: + cls_dir = os.path.join(root_dir, cls_name) + for img_name in os.listdir(cls_dir): + try: + img_path = os.path.join(cls_dir, img_name) + imgs = Image.open(img_path).convert('RGB') + if Cache == 'RAM': + if self.transform: + imgs = self.transform(imgs) + self.image.append(imgs) + else: + self.image_paths.append(img_path) + self.labels.append(self.class_to_idx[cls_name]) + except: + logger.log("info", + "read image error " + + img_path + ) def __len__(self): - return len(self.dataset) + return len(self.labels) def __getitem__(self, idx): - try: - image, label = self.dataset[idx] - return image, label - except Exception as e: - black_image = np.zeros((224, 224, 3), dtype=np.uint8) - return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签 - -def create_data_loaders(data_dir,batch_size=64): - # Define transformations for training data augmentation and normalization + label = self.labels[idx] + if self.Cache == 'RAM': + image = self.image[idx] + else: + img_path = self.image_paths[idx] + image = Image.open(img_path).convert('RGB') + if self.transform: + image = self.transform(image) + + return image, label + +def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False): + # Define the transform for the training data and for the validation data + transform = transforms.Compose([ + transforms.Resize((224, 224)), # Resize images to 224x224 + transforms.ToTensor(), # Convert PIL Image to Tensor + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize the images + ]) + + # Define transformations for training data augmentation and normalization train_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), @@ -40,17 +77,38 @@ def create_data_loaders(data_dir,batch_size=64): ]) # Load the datasets with ImageFolder - train_dir = data_dir + '/train' - valid_dir = data_dir + '/val' - test_dir = data_dir + '/test' + train_dir = root_dir + '/train' + valid_dir = root_dir + '/val' + test_dir = root_dir + '/test' - train_data = ClassifyDataset(train_dir, transforms=train_transforms) - valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms) - test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms) + train_data = ImageClassificationDataset(train_dir, transform=train_transforms,Cache=Cache) + valid_data = ImageClassificationDataset(valid_dir, transform=valid_test_transforms,Cache=Cache) + test_data = ImageClassificationDataset(test_dir, transform=valid_test_transforms,Cache=Cache) - # Create the DataLoaders with batch size 64 - train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size) - test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size) - return train_loader, valid_loader, test_loader \ No newline at end of file + # Create the data loader + train_loader = DataLoader( + train_data, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=pin_memory + ) + + # Create the data loader + valid_loader = DataLoader( + valid_data, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory + ) + + # Create the data loader + test_loader = DataLoader( + test_data, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory + ) + + return train_loader, valid_loader, test_loader diff --git a/dataset/recover.py b/dataset/recover.py index 58eb430..19aab9a 100644 --- a/dataset/recover.py +++ b/dataset/recover.py @@ -106,15 +106,15 @@ def process_images(input_folder, background_image_path, output_base): 递归处理所有子文件夹并保持目录结构 """ # 预处理背景路径(只需执行一次) - if os.path.isfile(background_image_path): - background_paths = [background_image_path] - else: - valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp'] - background_paths = [ - os.path.join(background_image_path, f) - for f in os.listdir(background_image_path) - if os.path.splitext(f)[1].lower() in valid_ext - ] + # if os.path.isfile(background_image_path): + # background_paths = [background_image_path] + # else: + # valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp'] + # background_paths = [ + # os.path.join(background_image_path, f) + # for f in os.listdir(background_image_path) + # if os.path.splitext(f)[1].lower() in valid_ext + # ] # 递归遍历输入目录 for root, dirs, files in os.walk(input_folder): @@ -136,10 +136,10 @@ def process_images(input_folder, background_image_path, output_base): try: # 去背景处理 - foreground = remove_background(input_path) + result = remove_background(input_path) - result = edge_fill2(foreground) + # result = edge_fill2(result) # 保存结果 cv2.imwrite(output_path, result) @@ -150,8 +150,8 @@ def process_images(input_folder, background_image_path, output_base): # 使用示例 -input_directory = 'L:/Tobacco/2023_JY/20230821/SOURCE' +input_directory = 'L:/Grade_datasets/JY_A' background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B' -output_directory = 'L:/Test' +output_directory = 'L:/Grade_datasets/MOVE_BACKGROUND' process_images(input_directory, background_image_path, output_directory) \ No newline at end of file diff --git a/dataset/test.py b/dataset/test.py new file mode 100644 index 0000000..687fd25 --- /dev/null +++ b/dataset/test.py @@ -0,0 +1,12 @@ +import os + +def debug_walk_with_links(input_folder): + for root, dirs, files in os.walk(input_folder): + print(f'Root: {root}') + print(f'Dirs: {dirs}') + print(f'Files: {files}') + print('-' * 40) + +if __name__ == "__main__": + input_folder = 'L:/Grade_datasets' + debug_walk_with_links(input_folder) \ No newline at end of file diff --git a/main.py b/main.py index 2f55803..e1c9afd 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ from torchvision.datasets import MNIST from torchvision.transforms import ToTensor from model.repvit import * +from model.mobilenetv3 import * from data_loader import * from utils import * @@ -14,11 +15,11 @@ def main(): # 初始化组件 initialize() - model = repvit_m1_1(num_classes=10).to(config.device) + model = repvit_m1_0(num_classes=9).to(config.device) optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) criterion = nn.CrossEntropyLoss() - train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size) + train_loader, valid_loader, test_loader = get_data_loader('/home/yoiannis/deep_learning/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size,Cache='RAM') # 初始化训练器 trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion) diff --git a/trainner.py b/trainner.py index e3d44cc..2a3bcc8 100644 --- a/trainner.py +++ b/trainner.py @@ -4,6 +4,7 @@ from torch.utils.data import DataLoader from config import config from logger import logger from utils import save_checkpoint, load_checkpoint +import time class Trainer: def __init__(self, model, train_loader, val_loader, optimizer, criterion): @@ -21,7 +22,7 @@ class Trainer: self.model.train() total_loss = 0.0 progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}") - + time_start = time.time() for batch_idx, (data, target) in enumerate(progress_bar): data, target = data.to(config.device), target.to(config.device)