From f43e21c09d050396d3ad58ddc67ed89aea351c46 Mon Sep 17 00:00:00 2001 From: yoiannis <13330431063> Date: Tue, 11 Mar 2025 23:12:23 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 ++ config.py | 4 +- data_loader.py | 118 +++++++++++++++++++++++++++++++++------------ dataset/recover.py | 26 +++++----- dataset/test.py | 12 +++++ main.py | 5 +- trainner.py | 3 +- 7 files changed, 125 insertions(+), 47 deletions(-) create mode 100644 README.md create mode 100644 dataset/test.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..08d1c83 --- /dev/null +++ b/README.md @@ -0,0 +1,4 @@ +# 1.相关知识 +``` + https://github.com/Hao840/OFAKD +``` \ No newline at end of file diff --git a/config.py b/config.py index 0e7d471..190fba4 100644 --- a/config.py +++ b/config.py @@ -8,7 +8,7 @@ class Config: # 训练参数 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - batch_size = 32 + batch_size = 128 epochs = 150 learning_rate = 0.001 save_path = "checkpoints/best_model.pth" @@ -22,4 +22,6 @@ class Config: checkpoint_path = "checkpoints/last_checkpoint.pth" output_path = "runs/" + cache = 'RAM' + config = Config() \ No newline at end of file diff --git a/data_loader.py b/data_loader.py index de98721..a06ac22 100644 --- a/data_loader.py +++ b/data_loader.py @@ -1,31 +1,68 @@ import os +from logger import logger from PIL import Image -import numpy as np import torch -from torchvision import datasets, transforms from torch.utils.data import Dataset, DataLoader +from torchvision import transforms -class ClassifyDataset(Dataset): - def __init__(self, data_dir,transforms = None): - self.data_dir = data_dir - # Assume the dataset is structured with subdirectories for each class - self.transform = transforms - self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform) - self.image_size = (3, 224, 224) +class ImageClassificationDataset(Dataset): + def __init__(self, root_dir, transform=None,Cache=False): + self.root_dir = root_dir + self.transform = transform + self.classes = sorted(os.listdir(root_dir)) + self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)} + self.image_paths = [] + self.image = [] + self.labels = [] + self.Cache = Cache + + logger.log("info", + "init the dataloader" + ) + + for cls_name in self.classes: + cls_dir = os.path.join(root_dir, cls_name) + for img_name in os.listdir(cls_dir): + try: + img_path = os.path.join(cls_dir, img_name) + imgs = Image.open(img_path).convert('RGB') + if Cache == 'RAM': + if self.transform: + imgs = self.transform(imgs) + self.image.append(imgs) + else: + self.image_paths.append(img_path) + self.labels.append(self.class_to_idx[cls_name]) + except: + logger.log("info", + "read image error " + + img_path + ) def __len__(self): - return len(self.dataset) + return len(self.labels) def __getitem__(self, idx): - try: - image, label = self.dataset[idx] - return image, label - except Exception as e: - black_image = np.zeros((224, 224, 3), dtype=np.uint8) - return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签 - -def create_data_loaders(data_dir,batch_size=64): - # Define transformations for training data augmentation and normalization + label = self.labels[idx] + if self.Cache == 'RAM': + image = self.image[idx] + else: + img_path = self.image_paths[idx] + image = Image.open(img_path).convert('RGB') + if self.transform: + image = self.transform(image) + + return image, label + +def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False): + # Define the transform for the training data and for the validation data + transform = transforms.Compose([ + transforms.Resize((224, 224)), # Resize images to 224x224 + transforms.ToTensor(), # Convert PIL Image to Tensor + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize the images + ]) + + # Define transformations for training data augmentation and normalization train_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), @@ -40,17 +77,38 @@ def create_data_loaders(data_dir,batch_size=64): ]) # Load the datasets with ImageFolder - train_dir = data_dir + '/train' - valid_dir = data_dir + '/val' - test_dir = data_dir + '/test' + train_dir = root_dir + '/train' + valid_dir = root_dir + '/val' + test_dir = root_dir + '/test' - train_data = ClassifyDataset(train_dir, transforms=train_transforms) - valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms) - test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms) + train_data = ImageClassificationDataset(train_dir, transform=train_transforms,Cache=Cache) + valid_data = ImageClassificationDataset(valid_dir, transform=valid_test_transforms,Cache=Cache) + test_data = ImageClassificationDataset(test_dir, transform=valid_test_transforms,Cache=Cache) - # Create the DataLoaders with batch size 64 - train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size) - test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size) - return train_loader, valid_loader, test_loader \ No newline at end of file + # Create the data loader + train_loader = DataLoader( + train_data, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=pin_memory + ) + + # Create the data loader + valid_loader = DataLoader( + valid_data, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory + ) + + # Create the data loader + test_loader = DataLoader( + test_data, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory + ) + + return train_loader, valid_loader, test_loader diff --git a/dataset/recover.py b/dataset/recover.py index 58eb430..19aab9a 100644 --- a/dataset/recover.py +++ b/dataset/recover.py @@ -106,15 +106,15 @@ def process_images(input_folder, background_image_path, output_base): 递归处理所有子文件夹并保持目录结构 """ # 预处理背景路径(只需执行一次) - if os.path.isfile(background_image_path): - background_paths = [background_image_path] - else: - valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp'] - background_paths = [ - os.path.join(background_image_path, f) - for f in os.listdir(background_image_path) - if os.path.splitext(f)[1].lower() in valid_ext - ] + # if os.path.isfile(background_image_path): + # background_paths = [background_image_path] + # else: + # valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp'] + # background_paths = [ + # os.path.join(background_image_path, f) + # for f in os.listdir(background_image_path) + # if os.path.splitext(f)[1].lower() in valid_ext + # ] # 递归遍历输入目录 for root, dirs, files in os.walk(input_folder): @@ -136,10 +136,10 @@ def process_images(input_folder, background_image_path, output_base): try: # 去背景处理 - foreground = remove_background(input_path) + result = remove_background(input_path) - result = edge_fill2(foreground) + # result = edge_fill2(result) # 保存结果 cv2.imwrite(output_path, result) @@ -150,8 +150,8 @@ def process_images(input_folder, background_image_path, output_base): # 使用示例 -input_directory = 'L:/Tobacco/2023_JY/20230821/SOURCE' +input_directory = 'L:/Grade_datasets/JY_A' background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B' -output_directory = 'L:/Test' +output_directory = 'L:/Grade_datasets/MOVE_BACKGROUND' process_images(input_directory, background_image_path, output_directory) \ No newline at end of file diff --git a/dataset/test.py b/dataset/test.py new file mode 100644 index 0000000..687fd25 --- /dev/null +++ b/dataset/test.py @@ -0,0 +1,12 @@ +import os + +def debug_walk_with_links(input_folder): + for root, dirs, files in os.walk(input_folder): + print(f'Root: {root}') + print(f'Dirs: {dirs}') + print(f'Files: {files}') + print('-' * 40) + +if __name__ == "__main__": + input_folder = 'L:/Grade_datasets' + debug_walk_with_links(input_folder) \ No newline at end of file diff --git a/main.py b/main.py index 2f55803..e1c9afd 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ from torchvision.datasets import MNIST from torchvision.transforms import ToTensor from model.repvit import * +from model.mobilenetv3 import * from data_loader import * from utils import * @@ -14,11 +15,11 @@ def main(): # 初始化组件 initialize() - model = repvit_m1_1(num_classes=10).to(config.device) + model = repvit_m1_0(num_classes=9).to(config.device) optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) criterion = nn.CrossEntropyLoss() - train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size) + train_loader, valid_loader, test_loader = get_data_loader('/home/yoiannis/deep_learning/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size,Cache='RAM') # 初始化训练器 trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion) diff --git a/trainner.py b/trainner.py index e3d44cc..2a3bcc8 100644 --- a/trainner.py +++ b/trainner.py @@ -4,6 +4,7 @@ from torch.utils.data import DataLoader from config import config from logger import logger from utils import save_checkpoint, load_checkpoint +import time class Trainer: def __init__(self, model, train_loader, val_loader, optimizer, criterion): @@ -21,7 +22,7 @@ class Trainer: self.model.train() total_loss = 0.0 progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}") - + time_start = time.time() for batch_idx, (data, target) in enumerate(progress_bar): data, target = data.to(config.device), target.to(config.device) From 617230e296b4ad040a3fc615707a0413f7cdbd38 Mon Sep 17 00:00:00 2001 From: yoiannis <13330431063> Date: Tue, 11 Mar 2025 23:12:34 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E8=81=94=E9=82=A6?= =?UTF-8?q?=E5=AD=A6=E4=B9=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- FED.py | 91 ++++++++++++++++++++++++---------------------------------- 1 file changed, 37 insertions(+), 54 deletions(-) diff --git a/FED.py b/FED.py index df6688a..b1b0d13 100644 --- a/FED.py +++ b/FED.py @@ -12,7 +12,7 @@ from model.repvit import repvit_m1_1 from model.mobilenetv3 import MobileNetV3 # 配置参数 -NUM_CLIENTS = 4 +NUM_CLIENTS = 2 NUM_ROUNDS = 3 CLIENT_EPOCHS = 5 BATCH_SIZE = 32 @@ -22,25 +22,27 @@ TEMP = 2.0 # 蒸馏温度 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 数据准备 -def prepare_data(num_clients): +import os +from torchvision.datasets import ImageFolder + +def prepare_data(): transform = transforms.Compose([ - transforms.Resize((224, 224)), # 将图像调整为 224x224 - transforms.Grayscale(num_output_channels=3), - transforms.ToTensor() - ]) - train_set = datasets.MNIST("./data", train=True, download=True, transform=transform) + transforms.Resize((224, 224)), + transforms.ToTensor() + ]) - # 非IID数据划分(每个客户端2个类别) - client_data = {i: [] for i in range(num_clients)} - labels = train_set.targets.numpy() - for label in range(10): - label_idx = np.where(labels == label)[0] - np.random.shuffle(label_idx) - split = np.array_split(label_idx, num_clients//2) - for i, idx in enumerate(split): - client_data[i*2 + label%2].extend(idx) + # Load datasets + dataset_A = ImageFolder(root='./dataset_A/train', transform=transform) + dataset_B = ImageFolder(root='./dataset_B/train', transform=transform) + dataset_C = ImageFolder(root='./dataset_C/train', transform=transform) - return [Subset(train_set, ids) for ids in client_data.values()] + # Assign datasets to clients + client_datasets = [dataset_B, dataset_C] + + # Server dataset (A) for public updates + public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True) + + return client_datasets, public_loader # 客户端训练函数 def client_train(client_model, server_model, dataset): @@ -189,63 +191,47 @@ def test_model(model, test_loader): # 主训练流程 def main(): - # 初始化模型 + # Initialize models global_server_model = repvit_m1_1(num_classes=10).to(device) client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)] - - round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round") - # 准备数据 - client_datasets = prepare_data(NUM_CLIENTS) - public_loader = DataLoader( - datasets.MNIST("./data", train=False, download=True, - transform= transforms.Compose([ - transforms.Resize((224, 224)), # 将图像调整为 224x224 - transforms.Grayscale(num_output_channels=3), - transforms.ToTensor() # 将图像转换为张量 - ])), - batch_size=100, shuffle=True) + # Prepare data + client_datasets, public_loader = prepare_data() - test_dataset = datasets.MNIST( - "./data", - train=False, - transform= transforms.Compose([ - transforms.Resize((224, 224)), # 将图像调整为 224x224 - transforms.Grayscale(num_output_channels=3), - transforms.ToTensor() # 将图像转换为张量 - ]) - ) + # Test dataset (using dataset A's test set for simplicity) + test_dataset = ImageFolder(root='./dataset_A/test', transform=transform) test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) + round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round") + for round in range(NUM_ROUNDS): print(f"\n{'#'*50}") print(f"Federated Round {round+1}/{NUM_ROUNDS}") print(f"{'#'*50}") - # 客户端选择 + # Client selection (only 2 clients) selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False) print(f"Selected Clients: {selected_clients}") - # 客户端本地训练 + # Client local training client_params = [] for cid in selected_clients: print(f"\nTraining Client {cid}") local_model = copy.deepcopy(client_models[cid]) local_model.load_state_dict(client_models[cid].state_dict()) - updated_params = client_train(local_model, global_server_model, client_datasets[cid]) client_params.append(updated_params) - # 模型聚合 + # Model aggregation global_client_params = aggregate(client_params) for model in client_models: model.load_state_dict(global_client_params) - # 服务器知识更新 + # Server knowledge update print("\nServer Updating...") server_update(global_server_model, client_models, public_loader) - # 测试模型性能 + # Test model performance server_acc = test_model(global_server_model, test_loader) client_acc = test_model(client_models[0], test_loader) print(f"\nRound {round+1} Performance:") @@ -253,25 +239,22 @@ def main(): print(f"Client Model Accuracy: {client_acc:.2f}%") round_progress.update(1) - print(f"Round {round+1} completed") print("Training completed!") - - # 保存训练好的模型 + + # Save trained models torch.save(global_server_model.state_dict(), "server_model.pth") torch.save(client_models[0].state_dict(), "client_model.pth") print("Models saved successfully.") - - # 创建测试数据加载器 - - # 测试服务器模型 + + # Test server model server_model = repvit_m1_1(num_classes=10).to(device) server_model.load_state_dict(torch.load("server_model.pth")) server_acc = test_model(server_model, test_loader) print(f"Server Model Test Accuracy: {server_acc:.2f}%") - - # 测试客户端模型 + + # Test client model client_model = MobileNetV3(n_class=10).to(device) client_model.load_state_dict(torch.load("client_model.pth")) client_acc = test_model(client_model, test_loader) From 382916643bd41e1ec202f19d8088ab5a91a30872 Mon Sep 17 00:00:00 2001 From: yoiannis Date: Wed, 12 Mar 2025 00:21:31 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E5=AE=8C=E6=88=90=E8=81=94=E9=82=A6?= =?UTF-8?q?=E8=92=B8=E9=A6=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- FED.py | 75 +++++++++++++++++++++++++++++++------------- model/mobilenetv3.py | 5 +++ model/repvit.py | 4 +++ 3 files changed, 62 insertions(+), 22 deletions(-) diff --git a/FED.py b/FED.py index b1b0d13..a727ca0 100644 --- a/FED.py +++ b/FED.py @@ -13,10 +13,11 @@ from model.mobilenetv3 import MobileNetV3 # 配置参数 NUM_CLIENTS = 2 -NUM_ROUNDS = 3 -CLIENT_EPOCHS = 5 +NUM_ROUNDS = 10 +CLIENT_EPOCHS = 2 BATCH_SIZE = 32 TEMP = 2.0 # 蒸馏温度 +CLASS_NUM = [3, 3, 3] # 设备配置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -32,9 +33,9 @@ def prepare_data(): ]) # Load datasets - dataset_A = ImageFolder(root='./dataset_A/train', transform=transform) - dataset_B = ImageFolder(root='./dataset_B/train', transform=transform) - dataset_C = ImageFolder(root='./dataset_C/train', transform=transform) + dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform) + dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform) + dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform) # Assign datasets to clients client_datasets = [dataset_B, dataset_C] @@ -105,13 +106,6 @@ def client_train(client_model, server_model, dataset): }) progress_bar.update(1) - # 每10个batch打印详细信息 - if (batch_idx + 1) % 10 == 0: - progress_bar.write(f"\nEpoch {epoch+1} | Batch {batch_idx+1}") - progress_bar.write(f"Task Loss: {loss_task:.4f}") - progress_bar.write(f"Distill Loss: {loss_distill:.4f}") - progress_bar.write(f"Total Loss: {total_loss:.4f}") - progress_bar.write(f"Batch Accuracy: {100*correct/total:.2f}%\n") # 每个epoch结束打印汇总信息 avg_loss = epoch_loss / len(loader) avg_task = task_loss / len(loader) @@ -135,6 +129,37 @@ def aggregate(client_params): global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0) return global_params +def server_aggregate(server_model, client_models, public_loader): + server_model.train() + optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) + + for data, _ in public_loader: + data = data.to(device) + + # 获取客户端模型特征 + client_features = [] + with torch.no_grad(): + for model in client_models: + features = model.extract_features(data) # 需要实现特征提取方法 + client_features.append(features) + + # 计算特征蒸馏目标 + target_features = torch.stack(client_features).mean(dim=0) + + # 服务器前向 + server_features = server_model.extract_features(data) + + # 特征对齐损失 + loss = F.mse_loss(server_features, target_features) + + # 反向传播 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # 更新统计信息 + total_loss += loss.item() + # 服务器知识更新 def server_update(server_model, client_models, public_loader): server_model.train() @@ -191,15 +216,19 @@ def test_model(model, test_loader): # 主训练流程 def main(): + transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor() + ]) # Initialize models - global_server_model = repvit_m1_1(num_classes=10).to(device) - client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)] + global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) + client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)] # Prepare data client_datasets, public_loader = prepare_data() # Test dataset (using dataset A's test set for simplicity) - test_dataset = ImageFolder(root='./dataset_A/test', transform=transform) + test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform) test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round") @@ -245,20 +274,22 @@ def main(): # Save trained models torch.save(global_server_model.state_dict(), "server_model.pth") - torch.save(client_models[0].state_dict(), "client_model.pth") + for i in range(NUM_CLIENTS): + torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth") print("Models saved successfully.") # Test server model - server_model = repvit_m1_1(num_classes=10).to(device) - server_model.load_state_dict(torch.load("server_model.pth")) + server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) + server_model.load_state_dict(torch.load("server_model.pth",weights_only=True)) server_acc = test_model(server_model, test_loader) print(f"Server Model Test Accuracy: {server_acc:.2f}%") # Test client model - client_model = MobileNetV3(n_class=10).to(device) - client_model.load_state_dict(torch.load("client_model.pth")) - client_acc = test_model(client_model, test_loader) - print(f"Client Model Test Accuracy: {client_acc:.2f}%") + for i in range(NUM_CLIENTS): + client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) + client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True)) + client_acc = test_model(client_model, test_loader) + print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%") if __name__ == "__main__": main() \ No newline at end of file diff --git a/model/mobilenetv3.py b/model/mobilenetv3.py index 4692cf9..2de909a 100644 --- a/model/mobilenetv3.py +++ b/model/mobilenetv3.py @@ -200,6 +200,11 @@ class MobileNetV3(nn.Module): self._initialize_weights() + + def extract_features(self, x): + x = self.features(x) + return x + def forward(self, x): x = self.features(x) x = x.mean(3).mean(2) diff --git a/model/repvit.py b/model/repvit.py index 78197f8..27b07b9 100644 --- a/model/repvit.py +++ b/model/repvit.py @@ -236,6 +236,10 @@ class RepViT(nn.Module): self.features = nn.ModuleList(layers) self.classifier = Classfier(output_channel, num_classes, distillation) + def extract_features(self, x): + for f in self.features: + x = f(x) + return x def forward(self, x): # x = self.features(x) for f in self.features: