更新缓存

This commit is contained in:
yoiannis 2025-03-11 23:12:23 +08:00
parent 30eeff4b1d
commit f43e21c09d
7 changed files with 125 additions and 47 deletions

4
README.md Normal file
View File

@ -0,0 +1,4 @@
# 1.相关知识
```
https://github.com/Hao840/OFAKD
```

View File

@ -8,7 +8,7 @@ class Config:
# 训练参数 # 训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32 batch_size = 128
epochs = 150 epochs = 150
learning_rate = 0.001 learning_rate = 0.001
save_path = "checkpoints/best_model.pth" save_path = "checkpoints/best_model.pth"
@ -22,4 +22,6 @@ class Config:
checkpoint_path = "checkpoints/last_checkpoint.pth" checkpoint_path = "checkpoints/last_checkpoint.pth"
output_path = "runs/" output_path = "runs/"
cache = 'RAM'
config = Config() config = Config()

View File

@ -1,31 +1,68 @@
import os import os
from logger import logger
from PIL import Image from PIL import Image
import numpy as np
import torch import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class ClassifyDataset(Dataset): class ImageClassificationDataset(Dataset):
def __init__(self, data_dir,transforms = None): def __init__(self, root_dir, transform=None,Cache=False):
self.data_dir = data_dir self.root_dir = root_dir
# Assume the dataset is structured with subdirectories for each class self.transform = transform
self.transform = transforms self.classes = sorted(os.listdir(root_dir))
self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform) self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
self.image_size = (3, 224, 224) self.image_paths = []
self.image = []
self.labels = []
self.Cache = Cache
logger.log("info",
"init the dataloader"
)
for cls_name in self.classes:
cls_dir = os.path.join(root_dir, cls_name)
for img_name in os.listdir(cls_dir):
try:
img_path = os.path.join(cls_dir, img_name)
imgs = Image.open(img_path).convert('RGB')
if Cache == 'RAM':
if self.transform:
imgs = self.transform(imgs)
self.image.append(imgs)
else:
self.image_paths.append(img_path)
self.labels.append(self.class_to_idx[cls_name])
except:
logger.log("info",
"read image error " +
img_path
)
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.labels)
def __getitem__(self, idx): def __getitem__(self, idx):
try: label = self.labels[idx]
image, label = self.dataset[idx] if self.Cache == 'RAM':
return image, label image = self.image[idx]
except Exception as e: else:
black_image = np.zeros((224, 224, 3), dtype=np.uint8) img_path = self.image_paths[idx]
return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签 image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
def create_data_loaders(data_dir,batch_size=64): return image, label
# Define transformations for training data augmentation and normalization
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):
# Define the transform for the training data and for the validation data
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to 224x224
transforms.ToTensor(), # Convert PIL Image to Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize the images
])
# Define transformations for training data augmentation and normalization
train_transforms = transforms.Compose([ train_transforms = transforms.Compose([
transforms.Resize((224, 224)), transforms.Resize((224, 224)),
transforms.ToTensor(), transforms.ToTensor(),
@ -40,17 +77,38 @@ def create_data_loaders(data_dir,batch_size=64):
]) ])
# Load the datasets with ImageFolder # Load the datasets with ImageFolder
train_dir = data_dir + '/train' train_dir = root_dir + '/train'
valid_dir = data_dir + '/val' valid_dir = root_dir + '/val'
test_dir = data_dir + '/test' test_dir = root_dir + '/test'
train_data = ClassifyDataset(train_dir, transforms=train_transforms) train_data = ImageClassificationDataset(train_dir, transform=train_transforms,Cache=Cache)
valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms) valid_data = ImageClassificationDataset(valid_dir, transform=valid_test_transforms,Cache=Cache)
test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms) test_data = ImageClassificationDataset(test_dir, transform=valid_test_transforms,Cache=Cache)
# Create the DataLoaders with batch size 64
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) # Create the data loader
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size) train_loader = DataLoader(
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size) train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory
)
# Create the data loader
valid_loader = DataLoader(
valid_data,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory
)
# Create the data loader
test_loader = DataLoader(
test_data,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory
)
return train_loader, valid_loader, test_loader return train_loader, valid_loader, test_loader

View File

@ -106,15 +106,15 @@ def process_images(input_folder, background_image_path, output_base):
递归处理所有子文件夹并保持目录结构 递归处理所有子文件夹并保持目录结构
""" """
# 预处理背景路径(只需执行一次) # 预处理背景路径(只需执行一次)
if os.path.isfile(background_image_path): # if os.path.isfile(background_image_path):
background_paths = [background_image_path] # background_paths = [background_image_path]
else: # else:
valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp'] # valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
background_paths = [ # background_paths = [
os.path.join(background_image_path, f) # os.path.join(background_image_path, f)
for f in os.listdir(background_image_path) # for f in os.listdir(background_image_path)
if os.path.splitext(f)[1].lower() in valid_ext # if os.path.splitext(f)[1].lower() in valid_ext
] # ]
# 递归遍历输入目录 # 递归遍历输入目录
for root, dirs, files in os.walk(input_folder): for root, dirs, files in os.walk(input_folder):
@ -136,10 +136,10 @@ def process_images(input_folder, background_image_path, output_base):
try: try:
# 去背景处理 # 去背景处理
foreground = remove_background(input_path) result = remove_background(input_path)
result = edge_fill2(foreground) # result = edge_fill2(result)
# 保存结果 # 保存结果
cv2.imwrite(output_path, result) cv2.imwrite(output_path, result)
@ -150,8 +150,8 @@ def process_images(input_folder, background_image_path, output_base):
# 使用示例 # 使用示例
input_directory = 'L:/Tobacco/2023_JY/20230821/SOURCE' input_directory = 'L:/Grade_datasets/JY_A'
background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B' background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B'
output_directory = 'L:/Test' output_directory = 'L:/Grade_datasets/MOVE_BACKGROUND'
process_images(input_directory, background_image_path, output_directory) process_images(input_directory, background_image_path, output_directory)

12
dataset/test.py Normal file
View File

@ -0,0 +1,12 @@
import os
def debug_walk_with_links(input_folder):
for root, dirs, files in os.walk(input_folder):
print(f'Root: {root}')
print(f'Dirs: {dirs}')
print(f'Files: {files}')
print('-' * 40)
if __name__ == "__main__":
input_folder = 'L:/Grade_datasets'
debug_walk_with_links(input_folder)

View File

@ -7,6 +7,7 @@ from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor from torchvision.transforms import ToTensor
from model.repvit import * from model.repvit import *
from model.mobilenetv3 import *
from data_loader import * from data_loader import *
from utils import * from utils import *
@ -14,11 +15,11 @@ def main():
# 初始化组件 # 初始化组件
initialize() initialize()
model = repvit_m1_1(num_classes=10).to(config.device) model = repvit_m1_0(num_classes=9).to(config.device)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size) train_loader, valid_loader, test_loader = get_data_loader('/home/yoiannis/deep_learning/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size,Cache='RAM')
# 初始化训练器 # 初始化训练器
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion) trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)

View File

@ -4,6 +4,7 @@ from torch.utils.data import DataLoader
from config import config from config import config
from logger import logger from logger import logger
from utils import save_checkpoint, load_checkpoint from utils import save_checkpoint, load_checkpoint
import time
class Trainer: class Trainer:
def __init__(self, model, train_loader, val_loader, optimizer, criterion): def __init__(self, model, train_loader, val_loader, optimizer, criterion):
@ -21,7 +22,7 @@ class Trainer:
self.model.train() self.model.train()
total_loss = 0.0 total_loss = 0.0
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}") progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
time_start = time.time()
for batch_idx, (data, target) in enumerate(progress_bar): for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(config.device), target.to(config.device) data, target = data.to(config.device), target.to(config.device)