import os from PIL import Image import numpy as np import torch from torchvision import datasets, transforms from torch.utils.data import Dataset, DataLoader class ClassifyDataset(Dataset): def __init__(self, data_dir,transforms = None): self.data_dir = data_dir # Assume the dataset is structured with subdirectories for each class self.transform = transforms self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform) self.image_size = (3, 224, 224) def __len__(self): return len(self.dataset) def __getitem__(self, idx): try: image, label = self.dataset[idx] return image, label except Exception as e: black_image = np.zeros((224, 224, 3), dtype=np.uint8) return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签 def create_data_loaders(data_dir,batch_size=64): # Define transformations for training data augmentation and normalization train_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Define transformations for validation and test data (only normalization) valid_test_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load the datasets with ImageFolder train_dir = data_dir + '/train' valid_dir = data_dir + '/val' test_dir = data_dir + '/test' train_data = ClassifyDataset(train_dir, transforms=train_transforms) valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms) test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms) # Create the DataLoaders with batch size 64 train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size) test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size) return train_loader, valid_loader, test_loader