From a5ca9d04d7d730fc728ba691430b533d65167c1d Mon Sep 17 00:00:00 2001 From: yoiannis <13330431063> Date: Wed, 12 Mar 2025 14:00:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E8=81=94=E9=82=A6=E5=AD=A6?= =?UTF-8?q?=E4=B9=A0=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- FED.py | 141 ++++++++++++++++++++++++++--------------------- data_loader.py | 18 +++--- dataset/split.py | 7 +-- 3 files changed, 89 insertions(+), 77 deletions(-) diff --git a/FED.py b/FED.py index a727ca0..8ec2835 100644 --- a/FED.py +++ b/FED.py @@ -8,6 +8,7 @@ import copy from tqdm import tqdm +from data_loader import get_data_loader from model.repvit import repvit_m1_1 from model.mobilenetv3 import MobileNetV3 @@ -17,7 +18,7 @@ NUM_ROUNDS = 10 CLIENT_EPOCHS = 2 BATCH_SIZE = 32 TEMP = 2.0 # 蒸馏温度 -CLASS_NUM = [3, 3, 3] +CLASS_NUM = [9, 9, 9] # 设备配置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -32,49 +33,60 @@ def prepare_data(): transforms.ToTensor() ]) - # Load datasets - dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform) - dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform) - dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform) + # 加载所有数据集(训练、验证、测试) + dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/JY_A',Cache='RAM') + dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_A',Cache='RAM') + dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_B',Cache='RAM') - # Assign datasets to clients - client_datasets = [dataset_B, dataset_C] + # 组织客户端数据集 + client_datasets = [ + { # Client 0 + 'train': dataset_B_train, + 'val': dataset_B_val, + 'test': dataset_B_test + }, + { # Client 1 + 'train': dataset_C_train, + 'val': dataset_C_val, + 'test': dataset_C_test + } + ] - # Server dataset (A) for public updates - public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True) + # 公共数据集(使用A的训练集) + public_loader = dataset_A_train - return client_datasets, public_loader + # 服务器测试集(使用A的测试集) + server_test_loader = dataset_A_test + + return client_datasets, public_loader, server_test_loader # 客户端训练函数 -def client_train(client_model, server_model, dataset): +def client_train(client_model, server_model, loader): client_model.train() server_model.eval() optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1) - loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - - # 训练进度条 - progress_bar = tqdm(total=CLIENT_EPOCHS*len(loader), - desc="Client Training", - unit="batch") - + for epoch in range(CLIENT_EPOCHS): epoch_loss = 0.0 task_loss = 0.0 distill_loss = 0.0 correct = 0 total = 0 + + # 训练进度条 + progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{CLIENT_EPOCHS}") - for batch_idx, (data, target) in enumerate(loader): + for batch_idx, (data, target) in enumerate(progress_bar): data, target = data.to(device), target.to(device) optimizer.zero_grad() # 前向传播 - client_output = client_model(data) + client_output = client_model(data).to(device) # 获取教师模型输出 with torch.no_grad(): - server_output = server_model(data) + server_output = server_model(data).to(device) # 计算损失 loss_task = F.cross_entropy(client_output, target) @@ -166,9 +178,10 @@ def server_update(server_model, client_models, public_loader): optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) total_loss = 0.0 + progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch") - for batch_idx, (data, _) in enumerate(progress_bar): + for batch_idx, (data, target) in enumerate(progress_bar): data = data.to(device) optimizer.zero_grad() @@ -199,97 +212,97 @@ def server_update(server_model, client_models, public_loader): print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n") -def test_model(model, test_loader): +def test_model(model, test_loader): # 添加对DataLoader的支持 model.eval() correct = 0 total = 0 with torch.no_grad(): - for data, target in test_loader: + progress_bar = tqdm(test_loader, desc="Server Updating", unit="batch") + + for batch_idx, (data, target) in enumerate(progress_bar): data, target = data.to(device), target.to(device) output = model(data) _, predicted = torch.max(output.data, 1) total += target.size(0) correct += (predicted == target).sum().item() - accuracy = 100 * correct / total - return accuracy + return 100 * correct / total # 主训练流程 def main(): - transform = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor() - ]) - # Initialize models + # 初始化模型(保持不变) global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)] - # Prepare data - client_datasets, public_loader = prepare_data() - - # Test dataset (using dataset A's test set for simplicity) - test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform) - test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) + # 加载数据集 + client_datasets, public_loader, server_test_loader = prepare_data() round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round") for round in range(NUM_ROUNDS): print(f"\n{'#'*50}") - print(f"Federated Round {round+1}/{NUM_ROUNDS}") + print(f"Round {round+1}/{NUM_ROUNDS}") print(f"{'#'*50}") - # Client selection (only 2 clients) + # 客户端选择 selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False) - print(f"Selected Clients: {selected_clients}") + print(f"Selected clients: {selected_clients}") - # Client local training + # 客户端训练 client_params = [] for cid in selected_clients: print(f"\nTraining Client {cid}") local_model = copy.deepcopy(client_models[cid]) local_model.load_state_dict(client_models[cid].state_dict()) - updated_params = client_train(local_model, global_server_model, client_datasets[cid]) + + # 传入客户端的训练集 + updated_params = client_train( + local_model, + global_server_model, + client_datasets[cid]['train'] # 使用训练集 + ) client_params.append(updated_params) - # Model aggregation + # 模型聚合 global_client_params = aggregate(client_params) for model in client_models: model.load_state_dict(global_client_params) - # Server knowledge update - print("\nServer Updating...") + # 服务器更新 + print("\nServer updating...") server_update(global_server_model, client_models, public_loader) - # Test model performance - server_acc = test_model(global_server_model, test_loader) - client_acc = test_model(client_models[0], test_loader) - print(f"\nRound {round+1} Performance:") - print(f"Global Model Accuracy: {server_acc:.2f}%") - print(f"Client Model Accuracy: {client_acc:.2f}%") + # 测试性能 + server_acc = test_model(global_server_model, server_test_loader) + client_accuracies = [ + test_model(client_models[i], + client_datasets[i]['test']) # 动态创建测试loader + for i in range(NUM_CLIENTS) + ] + + print(f"\nRound {round+1} Results:") + print(f"Server Accuracy: {server_acc:.2f}%") + for i, acc in enumerate(client_accuracies): + print(f"Client {i} Accuracy: {acc:.2f}%") round_progress.update(1) - print(f"Round {round+1} completed") - print("Training completed!") - - # Save trained models + # 保存模型 torch.save(global_server_model.state_dict(), "server_model.pth") for i in range(NUM_CLIENTS): - torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth") - print("Models saved successfully.") + torch.save(client_models[i].state_dict(), f"client{i}_model.pth") - # Test server model + # 最终测试 + print("\nFinal Evaluation:") server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) server_model.load_state_dict(torch.load("server_model.pth",weights_only=True)) - server_acc = test_model(server_model, test_loader) - print(f"Server Model Test Accuracy: {server_acc:.2f}%") + print(f"Server Accuracy: {test_model(server_model, server_test_loader):.2f}%") - # Test client model for i in range(NUM_CLIENTS): client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) - client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True)) - client_acc = test_model(client_model, test_loader) - print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%") + client_model.load_state_dict(torch.load(f"client{i}_model.pth",weights_only=True)) + test_loader = client_datasets[i]['test'] + print(f"Client {i} Accuracy: {test_model(client_model, test_loader):.2f}%") if __name__ == "__main__": main() \ No newline at end of file diff --git a/data_loader.py b/data_loader.py index a06ac22..b11b718 100644 --- a/data_loader.py +++ b/data_loader.py @@ -5,8 +5,11 @@ import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms +from tqdm import tqdm # 导入 tqdm +import logging + class ImageClassificationDataset(Dataset): - def __init__(self, root_dir, transform=None,Cache=False): + def __init__(self, root_dir, transform=None, Cache=False): self.root_dir = root_dir self.transform = transform self.classes = sorted(os.listdir(root_dir)) @@ -20,7 +23,8 @@ class ImageClassificationDataset(Dataset): "init the dataloader" ) - for cls_name in self.classes: + # 使用 tqdm 显示进度条 + for cls_name in tqdm(self.classes, desc="Loading images"): cls_dir = os.path.join(root_dir, cls_name) for img_name in os.listdir(cls_dir): try: @@ -33,11 +37,8 @@ class ImageClassificationDataset(Dataset): else: self.image_paths.append(img_path) self.labels.append(self.class_to_idx[cls_name]) - except: - logger.log("info", - "read image error " + - img_path - ) + except Exception as e: + logger.log("info", f"Read image error: {img_path} - {e}") def __len__(self): return len(self.labels) @@ -46,12 +47,11 @@ class ImageClassificationDataset(Dataset): label = self.labels[idx] if self.Cache == 'RAM': image = self.image[idx] - else: + else: img_path = self.image_paths[idx] image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) - return image, label def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False): diff --git a/dataset/split.py b/dataset/split.py index f58a46f..568f797 100644 --- a/dataset/split.py +++ b/dataset/split.py @@ -2,9 +2,8 @@ import os import shutil import random -def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15): +def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1): # 确保比例总和为1 - assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1" # 创建输出目录 os.makedirs(output_dir, exist_ok=True) @@ -55,6 +54,6 @@ def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15, print("Dataset successfully split into train, validation, and test sets.") # 使用示例 -base_directory = 'F:/dataset/02.TA_EC/EC27/JY_A' -output_directory = 'F:/dataset/02.TA_EC/datasets/EC27' +base_directory = 'L:/Grade_datasets/SPLIT/JY_A' +output_directory = 'L:/Grade_datasets/train/JY_A' create_dataset_splits(base_directory, output_directory) \ No newline at end of file