更新缓存
This commit is contained in:
parent
30eeff4b1d
commit
f43e21c09d
@ -8,7 +8,7 @@ class Config:
|
||||
|
||||
# 训练参数
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
batch_size = 32
|
||||
batch_size = 128
|
||||
epochs = 150
|
||||
learning_rate = 0.001
|
||||
save_path = "checkpoints/best_model.pth"
|
||||
@ -22,4 +22,6 @@ class Config:
|
||||
checkpoint_path = "checkpoints/last_checkpoint.pth"
|
||||
output_path = "runs/"
|
||||
|
||||
cache = 'RAM'
|
||||
|
||||
config = Config()
|
112
data_loader.py
112
data_loader.py
@ -1,30 +1,67 @@
|
||||
import os
|
||||
from logger import logger
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import datasets, transforms
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
class ClassifyDataset(Dataset):
|
||||
def __init__(self, data_dir,transforms = None):
|
||||
self.data_dir = data_dir
|
||||
# Assume the dataset is structured with subdirectories for each class
|
||||
self.transform = transforms
|
||||
self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform)
|
||||
self.image_size = (3, 224, 224)
|
||||
class ImageClassificationDataset(Dataset):
|
||||
def __init__(self, root_dir, transform=None,Cache=False):
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.classes = sorted(os.listdir(root_dir))
|
||||
self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
|
||||
self.image_paths = []
|
||||
self.image = []
|
||||
self.labels = []
|
||||
self.Cache = Cache
|
||||
|
||||
logger.log("info",
|
||||
"init the dataloader"
|
||||
)
|
||||
|
||||
for cls_name in self.classes:
|
||||
cls_dir = os.path.join(root_dir, cls_name)
|
||||
for img_name in os.listdir(cls_dir):
|
||||
try:
|
||||
img_path = os.path.join(cls_dir, img_name)
|
||||
imgs = Image.open(img_path).convert('RGB')
|
||||
if Cache == 'RAM':
|
||||
if self.transform:
|
||||
imgs = self.transform(imgs)
|
||||
self.image.append(imgs)
|
||||
else:
|
||||
self.image_paths.append(img_path)
|
||||
self.labels.append(self.class_to_idx[cls_name])
|
||||
except:
|
||||
logger.log("info",
|
||||
"read image error " +
|
||||
img_path
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
return len(self.labels)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
image, label = self.dataset[idx]
|
||||
return image, label
|
||||
except Exception as e:
|
||||
black_image = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||
return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签
|
||||
label = self.labels[idx]
|
||||
if self.Cache == 'RAM':
|
||||
image = self.image[idx]
|
||||
else:
|
||||
img_path = self.image_paths[idx]
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
return image, label
|
||||
|
||||
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):
|
||||
# Define the transform for the training data and for the validation data
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)), # Resize images to 224x224
|
||||
transforms.ToTensor(), # Convert PIL Image to Tensor
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize the images
|
||||
])
|
||||
|
||||
def create_data_loaders(data_dir,batch_size=64):
|
||||
# Define transformations for training data augmentation and normalization
|
||||
train_transforms = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
@ -40,17 +77,38 @@ def create_data_loaders(data_dir,batch_size=64):
|
||||
])
|
||||
|
||||
# Load the datasets with ImageFolder
|
||||
train_dir = data_dir + '/train'
|
||||
valid_dir = data_dir + '/val'
|
||||
test_dir = data_dir + '/test'
|
||||
train_dir = root_dir + '/train'
|
||||
valid_dir = root_dir + '/val'
|
||||
test_dir = root_dir + '/test'
|
||||
|
||||
train_data = ClassifyDataset(train_dir, transforms=train_transforms)
|
||||
valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms)
|
||||
test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms)
|
||||
train_data = ImageClassificationDataset(train_dir, transform=train_transforms,Cache=Cache)
|
||||
valid_data = ImageClassificationDataset(valid_dir, transform=valid_test_transforms,Cache=Cache)
|
||||
test_data = ImageClassificationDataset(test_dir, transform=valid_test_transforms,Cache=Cache)
|
||||
|
||||
# Create the DataLoaders with batch size 64
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
|
||||
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
|
||||
|
||||
# Create the data loader
|
||||
train_loader = DataLoader(
|
||||
train_data,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory
|
||||
)
|
||||
|
||||
# Create the data loader
|
||||
valid_loader = DataLoader(
|
||||
valid_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory
|
||||
)
|
||||
|
||||
# Create the data loader
|
||||
test_loader = DataLoader(
|
||||
test_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory
|
||||
)
|
||||
|
||||
return train_loader, valid_loader, test_loader
|
@ -106,15 +106,15 @@ def process_images(input_folder, background_image_path, output_base):
|
||||
递归处理所有子文件夹并保持目录结构
|
||||
"""
|
||||
# 预处理背景路径(只需执行一次)
|
||||
if os.path.isfile(background_image_path):
|
||||
background_paths = [background_image_path]
|
||||
else:
|
||||
valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
|
||||
background_paths = [
|
||||
os.path.join(background_image_path, f)
|
||||
for f in os.listdir(background_image_path)
|
||||
if os.path.splitext(f)[1].lower() in valid_ext
|
||||
]
|
||||
# if os.path.isfile(background_image_path):
|
||||
# background_paths = [background_image_path]
|
||||
# else:
|
||||
# valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
|
||||
# background_paths = [
|
||||
# os.path.join(background_image_path, f)
|
||||
# for f in os.listdir(background_image_path)
|
||||
# if os.path.splitext(f)[1].lower() in valid_ext
|
||||
# ]
|
||||
|
||||
# 递归遍历输入目录
|
||||
for root, dirs, files in os.walk(input_folder):
|
||||
@ -136,10 +136,10 @@ def process_images(input_folder, background_image_path, output_base):
|
||||
|
||||
try:
|
||||
# 去背景处理
|
||||
foreground = remove_background(input_path)
|
||||
result = remove_background(input_path)
|
||||
|
||||
|
||||
result = edge_fill2(foreground)
|
||||
# result = edge_fill2(result)
|
||||
|
||||
# 保存结果
|
||||
cv2.imwrite(output_path, result)
|
||||
@ -150,8 +150,8 @@ def process_images(input_folder, background_image_path, output_base):
|
||||
|
||||
|
||||
# 使用示例
|
||||
input_directory = 'L:/Tobacco/2023_JY/20230821/SOURCE'
|
||||
input_directory = 'L:/Grade_datasets/JY_A'
|
||||
background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B'
|
||||
output_directory = 'L:/Test'
|
||||
output_directory = 'L:/Grade_datasets/MOVE_BACKGROUND'
|
||||
|
||||
process_images(input_directory, background_image_path, output_directory)
|
12
dataset/test.py
Normal file
12
dataset/test.py
Normal file
@ -0,0 +1,12 @@
|
||||
import os
|
||||
|
||||
def debug_walk_with_links(input_folder):
|
||||
for root, dirs, files in os.walk(input_folder):
|
||||
print(f'Root: {root}')
|
||||
print(f'Dirs: {dirs}')
|
||||
print(f'Files: {files}')
|
||||
print('-' * 40)
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_folder = 'L:/Grade_datasets'
|
||||
debug_walk_with_links(input_folder)
|
5
main.py
5
main.py
@ -7,6 +7,7 @@ from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from model.repvit import *
|
||||
from model.mobilenetv3 import *
|
||||
from data_loader import *
|
||||
from utils import *
|
||||
|
||||
@ -14,11 +15,11 @@ def main():
|
||||
# 初始化组件
|
||||
initialize()
|
||||
|
||||
model = repvit_m1_1(num_classes=10).to(config.device)
|
||||
model = repvit_m1_0(num_classes=9).to(config.device)
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size)
|
||||
train_loader, valid_loader, test_loader = get_data_loader('/home/yoiannis/deep_learning/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size,Cache='RAM')
|
||||
|
||||
# 初始化训练器
|
||||
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
|
||||
|
@ -4,6 +4,7 @@ from torch.utils.data import DataLoader
|
||||
from config import config
|
||||
from logger import logger
|
||||
from utils import save_checkpoint, load_checkpoint
|
||||
import time
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, model, train_loader, val_loader, optimizer, criterion):
|
||||
@ -21,7 +22,7 @@ class Trainer:
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
|
||||
|
||||
time_start = time.time()
|
||||
for batch_idx, (data, target) in enumerate(progress_bar):
|
||||
data, target = data.to(config.device), target.to(config.device)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user