import os from logger import logger from PIL import Image import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms class ImageClassificationDataset(Dataset): def __init__(self, root_dir, transform=None,Cache=False): self.root_dir = root_dir self.transform = transform self.classes = sorted(os.listdir(root_dir)) self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)} self.image_paths = [] self.image = [] self.labels = [] self.Cache = Cache logger.log("info", "init the dataloader" ) for cls_name in self.classes: cls_dir = os.path.join(root_dir, cls_name) for img_name in os.listdir(cls_dir): try: img_path = os.path.join(cls_dir, img_name) imgs = Image.open(img_path).convert('RGB') if Cache == 'RAM': if self.transform: imgs = self.transform(imgs) self.image.append(imgs) else: self.image_paths.append(img_path) self.labels.append(self.class_to_idx[cls_name]) except: logger.log("info", "read image error " + img_path ) def __len__(self): return len(self.labels) def __getitem__(self, idx): label = self.labels[idx] if self.Cache == 'RAM': image = self.image[idx] else: img_path = self.image_paths[idx] image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return image, label def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False): # Define the transform for the training data and for the validation data transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize images to 224x224 transforms.ToTensor(), # Convert PIL Image to Tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize the images ]) # Define transformations for training data augmentation and normalization train_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Define transformations for validation and test data (only normalization) valid_test_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load the datasets with ImageFolder train_dir = root_dir + '/train' valid_dir = root_dir + '/val' test_dir = root_dir + '/test' train_data = ImageClassificationDataset(train_dir, transform=train_transforms,Cache=Cache) valid_data = ImageClassificationDataset(valid_dir, transform=valid_test_transforms,Cache=Cache) test_data = ImageClassificationDataset(test_dir, transform=valid_test_transforms,Cache=Cache) # Create the data loader train_loader = DataLoader( train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory ) # Create the data loader valid_loader = DataLoader( valid_data, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory ) # Create the data loader test_loader = DataLoader( test_data, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory ) return train_loader, valid_loader, test_loader