TA_EC/data_loader.py

56 lines
2.2 KiB
Python
Raw Normal View History

2025-03-09 16:31:37 +00:00
import os
from PIL import Image
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
class ClassifyDataset(Dataset):
def __init__(self, data_dir,transforms = None):
self.data_dir = data_dir
# Assume the dataset is structured with subdirectories for each class
self.transform = transforms
self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform)
self.image_size = (3, 224, 224)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
try:
image, label = self.dataset[idx]
return image, label
except Exception as e:
black_image = np.zeros((224, 224, 3), dtype=np.uint8)
return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签
def create_data_loaders(data_dir,batch_size=64):
# Define transformations for training data augmentation and normalization
train_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Define transformations for validation and test data (only normalization)
valid_test_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the datasets with ImageFolder
train_dir = data_dir + '/train'
valid_dir = data_dir + '/val'
test_dir = data_dir + '/test'
train_data = ClassifyDataset(train_dir, transforms=train_transforms)
valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms)
test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms)
# Create the DataLoaders with batch size 64
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
return train_loader, valid_loader, test_loader