56 lines
2.2 KiB
Python
56 lines
2.2 KiB
Python
import os
|
|
from PIL import Image
|
|
import numpy as np
|
|
import torch
|
|
from torchvision import datasets, transforms
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
class ClassifyDataset(Dataset):
|
|
def __init__(self, data_dir,transforms = None):
|
|
self.data_dir = data_dir
|
|
# Assume the dataset is structured with subdirectories for each class
|
|
self.transform = transforms
|
|
self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform)
|
|
self.image_size = (3, 224, 224)
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def __getitem__(self, idx):
|
|
try:
|
|
image, label = self.dataset[idx]
|
|
return image, label
|
|
except Exception as e:
|
|
black_image = np.zeros((224, 224, 3), dtype=np.uint8)
|
|
return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签
|
|
|
|
def create_data_loaders(data_dir,batch_size=64):
|
|
# Define transformations for training data augmentation and normalization
|
|
train_transforms = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
# Define transformations for validation and test data (only normalization)
|
|
valid_test_transforms = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
# Load the datasets with ImageFolder
|
|
train_dir = data_dir + '/train'
|
|
valid_dir = data_dir + '/val'
|
|
test_dir = data_dir + '/test'
|
|
|
|
train_data = ClassifyDataset(train_dir, transforms=train_transforms)
|
|
valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms)
|
|
test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms)
|
|
|
|
# Create the DataLoaders with batch size 64
|
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
|
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
|
|
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
|
|
|
|
return train_loader, valid_loader, test_loader |