TA_EC/data_loader.py
2025-03-11 23:12:23 +08:00

115 lines
3.8 KiB
Python

import os
from logger import logger
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class ImageClassificationDataset(Dataset):
def __init__(self, root_dir, transform=None,Cache=False):
self.root_dir = root_dir
self.transform = transform
self.classes = sorted(os.listdir(root_dir))
self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
self.image_paths = []
self.image = []
self.labels = []
self.Cache = Cache
logger.log("info",
"init the dataloader"
)
for cls_name in self.classes:
cls_dir = os.path.join(root_dir, cls_name)
for img_name in os.listdir(cls_dir):
try:
img_path = os.path.join(cls_dir, img_name)
imgs = Image.open(img_path).convert('RGB')
if Cache == 'RAM':
if self.transform:
imgs = self.transform(imgs)
self.image.append(imgs)
else:
self.image_paths.append(img_path)
self.labels.append(self.class_to_idx[cls_name])
except:
logger.log("info",
"read image error " +
img_path
)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
label = self.labels[idx]
if self.Cache == 'RAM':
image = self.image[idx]
else:
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):
# Define the transform for the training data and for the validation data
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to 224x224
transforms.ToTensor(), # Convert PIL Image to Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize the images
])
# Define transformations for training data augmentation and normalization
train_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Define transformations for validation and test data (only normalization)
valid_test_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the datasets with ImageFolder
train_dir = root_dir + '/train'
valid_dir = root_dir + '/val'
test_dir = root_dir + '/test'
train_data = ImageClassificationDataset(train_dir, transform=train_transforms,Cache=Cache)
valid_data = ImageClassificationDataset(valid_dir, transform=valid_test_transforms,Cache=Cache)
test_data = ImageClassificationDataset(test_dir, transform=valid_test_transforms,Cache=Cache)
# Create the data loader
train_loader = DataLoader(
train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory
)
# Create the data loader
valid_loader = DataLoader(
valid_data,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory
)
# Create the data loader
test_loader = DataLoader(
test_data,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory
)
return train_loader, valid_loader, test_loader