更新缓存
This commit is contained in:
parent
30eeff4b1d
commit
f43e21c09d
@ -8,7 +8,7 @@ class Config:
|
|||||||
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
batch_size = 32
|
batch_size = 128
|
||||||
epochs = 150
|
epochs = 150
|
||||||
learning_rate = 0.001
|
learning_rate = 0.001
|
||||||
save_path = "checkpoints/best_model.pth"
|
save_path = "checkpoints/best_model.pth"
|
||||||
@ -22,4 +22,6 @@ class Config:
|
|||||||
checkpoint_path = "checkpoints/last_checkpoint.pth"
|
checkpoint_path = "checkpoints/last_checkpoint.pth"
|
||||||
output_path = "runs/"
|
output_path = "runs/"
|
||||||
|
|
||||||
|
cache = 'RAM'
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
112
data_loader.py
112
data_loader.py
@ -1,30 +1,67 @@
|
|||||||
import os
|
import os
|
||||||
|
from logger import logger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision import datasets, transforms
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
class ClassifyDataset(Dataset):
|
class ImageClassificationDataset(Dataset):
|
||||||
def __init__(self, data_dir,transforms = None):
|
def __init__(self, root_dir, transform=None,Cache=False):
|
||||||
self.data_dir = data_dir
|
self.root_dir = root_dir
|
||||||
# Assume the dataset is structured with subdirectories for each class
|
self.transform = transform
|
||||||
self.transform = transforms
|
self.classes = sorted(os.listdir(root_dir))
|
||||||
self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform)
|
self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
|
||||||
self.image_size = (3, 224, 224)
|
self.image_paths = []
|
||||||
|
self.image = []
|
||||||
|
self.labels = []
|
||||||
|
self.Cache = Cache
|
||||||
|
|
||||||
|
logger.log("info",
|
||||||
|
"init the dataloader"
|
||||||
|
)
|
||||||
|
|
||||||
|
for cls_name in self.classes:
|
||||||
|
cls_dir = os.path.join(root_dir, cls_name)
|
||||||
|
for img_name in os.listdir(cls_dir):
|
||||||
|
try:
|
||||||
|
img_path = os.path.join(cls_dir, img_name)
|
||||||
|
imgs = Image.open(img_path).convert('RGB')
|
||||||
|
if Cache == 'RAM':
|
||||||
|
if self.transform:
|
||||||
|
imgs = self.transform(imgs)
|
||||||
|
self.image.append(imgs)
|
||||||
|
else:
|
||||||
|
self.image_paths.append(img_path)
|
||||||
|
self.labels.append(self.class_to_idx[cls_name])
|
||||||
|
except:
|
||||||
|
logger.log("info",
|
||||||
|
"read image error " +
|
||||||
|
img_path
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.dataset)
|
return len(self.labels)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
try:
|
label = self.labels[idx]
|
||||||
image, label = self.dataset[idx]
|
if self.Cache == 'RAM':
|
||||||
return image, label
|
image = self.image[idx]
|
||||||
except Exception as e:
|
else:
|
||||||
black_image = np.zeros((224, 224, 3), dtype=np.uint8)
|
img_path = self.image_paths[idx]
|
||||||
return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签
|
image = Image.open(img_path).convert('RGB')
|
||||||
|
if self.transform:
|
||||||
|
image = self.transform(image)
|
||||||
|
|
||||||
|
return image, label
|
||||||
|
|
||||||
|
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):
|
||||||
|
# Define the transform for the training data and for the validation data
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)), # Resize images to 224x224
|
||||||
|
transforms.ToTensor(), # Convert PIL Image to Tensor
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize the images
|
||||||
|
])
|
||||||
|
|
||||||
def create_data_loaders(data_dir,batch_size=64):
|
|
||||||
# Define transformations for training data augmentation and normalization
|
# Define transformations for training data augmentation and normalization
|
||||||
train_transforms = transforms.Compose([
|
train_transforms = transforms.Compose([
|
||||||
transforms.Resize((224, 224)),
|
transforms.Resize((224, 224)),
|
||||||
@ -40,17 +77,38 @@ def create_data_loaders(data_dir,batch_size=64):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# Load the datasets with ImageFolder
|
# Load the datasets with ImageFolder
|
||||||
train_dir = data_dir + '/train'
|
train_dir = root_dir + '/train'
|
||||||
valid_dir = data_dir + '/val'
|
valid_dir = root_dir + '/val'
|
||||||
test_dir = data_dir + '/test'
|
test_dir = root_dir + '/test'
|
||||||
|
|
||||||
train_data = ClassifyDataset(train_dir, transforms=train_transforms)
|
train_data = ImageClassificationDataset(train_dir, transform=train_transforms,Cache=Cache)
|
||||||
valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms)
|
valid_data = ImageClassificationDataset(valid_dir, transform=valid_test_transforms,Cache=Cache)
|
||||||
test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms)
|
test_data = ImageClassificationDataset(test_dir, transform=valid_test_transforms,Cache=Cache)
|
||||||
|
|
||||||
# Create the DataLoaders with batch size 64
|
|
||||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
|
# Create the data loader
|
||||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
|
train_loader = DataLoader(
|
||||||
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
|
train_data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=pin_memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the data loader
|
||||||
|
valid_loader = DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=pin_memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the data loader
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=pin_memory
|
||||||
|
)
|
||||||
|
|
||||||
return train_loader, valid_loader, test_loader
|
return train_loader, valid_loader, test_loader
|
@ -106,15 +106,15 @@ def process_images(input_folder, background_image_path, output_base):
|
|||||||
递归处理所有子文件夹并保持目录结构
|
递归处理所有子文件夹并保持目录结构
|
||||||
"""
|
"""
|
||||||
# 预处理背景路径(只需执行一次)
|
# 预处理背景路径(只需执行一次)
|
||||||
if os.path.isfile(background_image_path):
|
# if os.path.isfile(background_image_path):
|
||||||
background_paths = [background_image_path]
|
# background_paths = [background_image_path]
|
||||||
else:
|
# else:
|
||||||
valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
|
# valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
|
||||||
background_paths = [
|
# background_paths = [
|
||||||
os.path.join(background_image_path, f)
|
# os.path.join(background_image_path, f)
|
||||||
for f in os.listdir(background_image_path)
|
# for f in os.listdir(background_image_path)
|
||||||
if os.path.splitext(f)[1].lower() in valid_ext
|
# if os.path.splitext(f)[1].lower() in valid_ext
|
||||||
]
|
# ]
|
||||||
|
|
||||||
# 递归遍历输入目录
|
# 递归遍历输入目录
|
||||||
for root, dirs, files in os.walk(input_folder):
|
for root, dirs, files in os.walk(input_folder):
|
||||||
@ -136,10 +136,10 @@ def process_images(input_folder, background_image_path, output_base):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 去背景处理
|
# 去背景处理
|
||||||
foreground = remove_background(input_path)
|
result = remove_background(input_path)
|
||||||
|
|
||||||
|
|
||||||
result = edge_fill2(foreground)
|
# result = edge_fill2(result)
|
||||||
|
|
||||||
# 保存结果
|
# 保存结果
|
||||||
cv2.imwrite(output_path, result)
|
cv2.imwrite(output_path, result)
|
||||||
@ -150,8 +150,8 @@ def process_images(input_folder, background_image_path, output_base):
|
|||||||
|
|
||||||
|
|
||||||
# 使用示例
|
# 使用示例
|
||||||
input_directory = 'L:/Tobacco/2023_JY/20230821/SOURCE'
|
input_directory = 'L:/Grade_datasets/JY_A'
|
||||||
background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B'
|
background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B'
|
||||||
output_directory = 'L:/Test'
|
output_directory = 'L:/Grade_datasets/MOVE_BACKGROUND'
|
||||||
|
|
||||||
process_images(input_directory, background_image_path, output_directory)
|
process_images(input_directory, background_image_path, output_directory)
|
12
dataset/test.py
Normal file
12
dataset/test.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
def debug_walk_with_links(input_folder):
|
||||||
|
for root, dirs, files in os.walk(input_folder):
|
||||||
|
print(f'Root: {root}')
|
||||||
|
print(f'Dirs: {dirs}')
|
||||||
|
print(f'Files: {files}')
|
||||||
|
print('-' * 40)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
input_folder = 'L:/Grade_datasets'
|
||||||
|
debug_walk_with_links(input_folder)
|
5
main.py
5
main.py
@ -7,6 +7,7 @@ from torchvision.datasets import MNIST
|
|||||||
from torchvision.transforms import ToTensor
|
from torchvision.transforms import ToTensor
|
||||||
|
|
||||||
from model.repvit import *
|
from model.repvit import *
|
||||||
|
from model.mobilenetv3 import *
|
||||||
from data_loader import *
|
from data_loader import *
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
@ -14,11 +15,11 @@ def main():
|
|||||||
# 初始化组件
|
# 初始化组件
|
||||||
initialize()
|
initialize()
|
||||||
|
|
||||||
model = repvit_m1_1(num_classes=10).to(config.device)
|
model = repvit_m1_0(num_classes=9).to(config.device)
|
||||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size)
|
train_loader, valid_loader, test_loader = get_data_loader('/home/yoiannis/deep_learning/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size,Cache='RAM')
|
||||||
|
|
||||||
# 初始化训练器
|
# 初始化训练器
|
||||||
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
|
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
|
||||||
|
@ -4,6 +4,7 @@ from torch.utils.data import DataLoader
|
|||||||
from config import config
|
from config import config
|
||||||
from logger import logger
|
from logger import logger
|
||||||
from utils import save_checkpoint, load_checkpoint
|
from utils import save_checkpoint, load_checkpoint
|
||||||
|
import time
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(self, model, train_loader, val_loader, optimizer, criterion):
|
def __init__(self, model, train_loader, val_loader, optimizer, criterion):
|
||||||
@ -21,7 +22,7 @@ class Trainer:
|
|||||||
self.model.train()
|
self.model.train()
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
|
||||||
|
time_start = time.time()
|
||||||
for batch_idx, (data, target) in enumerate(progress_bar):
|
for batch_idx, (data, target) in enumerate(progress_bar):
|
||||||
data, target = data.to(config.device), target.to(config.device)
|
data, target = data.to(config.device), target.to(config.device)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user