115 lines
3.8 KiB
Python
115 lines
3.8 KiB
Python
import os
|
|
from logger import logger
|
|
from PIL import Image
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision import transforms
|
|
|
|
class ImageClassificationDataset(Dataset):
|
|
def __init__(self, root_dir, transform=None,Cache=False):
|
|
self.root_dir = root_dir
|
|
self.transform = transform
|
|
self.classes = sorted(os.listdir(root_dir))
|
|
self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
|
|
self.image_paths = []
|
|
self.image = []
|
|
self.labels = []
|
|
self.Cache = Cache
|
|
|
|
logger.log("info",
|
|
"init the dataloader"
|
|
)
|
|
|
|
for cls_name in self.classes:
|
|
cls_dir = os.path.join(root_dir, cls_name)
|
|
for img_name in os.listdir(cls_dir):
|
|
try:
|
|
img_path = os.path.join(cls_dir, img_name)
|
|
imgs = Image.open(img_path).convert('RGB')
|
|
if Cache == 'RAM':
|
|
if self.transform:
|
|
imgs = self.transform(imgs)
|
|
self.image.append(imgs)
|
|
else:
|
|
self.image_paths.append(img_path)
|
|
self.labels.append(self.class_to_idx[cls_name])
|
|
except:
|
|
logger.log("info",
|
|
"read image error " +
|
|
img_path
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.labels)
|
|
|
|
def __getitem__(self, idx):
|
|
label = self.labels[idx]
|
|
if self.Cache == 'RAM':
|
|
image = self.image[idx]
|
|
else:
|
|
img_path = self.image_paths[idx]
|
|
image = Image.open(img_path).convert('RGB')
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
|
|
return image, label
|
|
|
|
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):
|
|
# Define the transform for the training data and for the validation data
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)), # Resize images to 224x224
|
|
transforms.ToTensor(), # Convert PIL Image to Tensor
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize the images
|
|
])
|
|
|
|
# Define transformations for training data augmentation and normalization
|
|
train_transforms = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
# Define transformations for validation and test data (only normalization)
|
|
valid_test_transforms = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
# Load the datasets with ImageFolder
|
|
train_dir = root_dir + '/train'
|
|
valid_dir = root_dir + '/val'
|
|
test_dir = root_dir + '/test'
|
|
|
|
train_data = ImageClassificationDataset(train_dir, transform=train_transforms,Cache=Cache)
|
|
valid_data = ImageClassificationDataset(valid_dir, transform=valid_test_transforms,Cache=Cache)
|
|
test_data = ImageClassificationDataset(test_dir, transform=valid_test_transforms,Cache=Cache)
|
|
|
|
|
|
# Create the data loader
|
|
train_loader = DataLoader(
|
|
train_data,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=num_workers,
|
|
pin_memory=pin_memory
|
|
)
|
|
|
|
# Create the data loader
|
|
valid_loader = DataLoader(
|
|
valid_data,
|
|
batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
pin_memory=pin_memory
|
|
)
|
|
|
|
# Create the data loader
|
|
test_loader = DataLoader(
|
|
test_data,
|
|
batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
pin_memory=pin_memory
|
|
)
|
|
|
|
return train_loader, valid_loader, test_loader
|