From 382916643bd41e1ec202f19d8088ab5a91a30872 Mon Sep 17 00:00:00 2001 From: yoiannis Date: Wed, 12 Mar 2025 00:21:31 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E8=81=94=E9=82=A6=E8=92=B8?= =?UTF-8?q?=E9=A6=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- FED.py | 75 +++++++++++++++++++++++++++++++------------- model/mobilenetv3.py | 5 +++ model/repvit.py | 4 +++ 3 files changed, 62 insertions(+), 22 deletions(-) diff --git a/FED.py b/FED.py index b1b0d13..a727ca0 100644 --- a/FED.py +++ b/FED.py @@ -13,10 +13,11 @@ from model.mobilenetv3 import MobileNetV3 # 配置参数 NUM_CLIENTS = 2 -NUM_ROUNDS = 3 -CLIENT_EPOCHS = 5 +NUM_ROUNDS = 10 +CLIENT_EPOCHS = 2 BATCH_SIZE = 32 TEMP = 2.0 # 蒸馏温度 +CLASS_NUM = [3, 3, 3] # 设备配置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -32,9 +33,9 @@ def prepare_data(): ]) # Load datasets - dataset_A = ImageFolder(root='./dataset_A/train', transform=transform) - dataset_B = ImageFolder(root='./dataset_B/train', transform=transform) - dataset_C = ImageFolder(root='./dataset_C/train', transform=transform) + dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform) + dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform) + dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform) # Assign datasets to clients client_datasets = [dataset_B, dataset_C] @@ -105,13 +106,6 @@ def client_train(client_model, server_model, dataset): }) progress_bar.update(1) - # 每10个batch打印详细信息 - if (batch_idx + 1) % 10 == 0: - progress_bar.write(f"\nEpoch {epoch+1} | Batch {batch_idx+1}") - progress_bar.write(f"Task Loss: {loss_task:.4f}") - progress_bar.write(f"Distill Loss: {loss_distill:.4f}") - progress_bar.write(f"Total Loss: {total_loss:.4f}") - progress_bar.write(f"Batch Accuracy: {100*correct/total:.2f}%\n") # 每个epoch结束打印汇总信息 avg_loss = epoch_loss / len(loader) avg_task = task_loss / len(loader) @@ -135,6 +129,37 @@ def aggregate(client_params): global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0) return global_params +def server_aggregate(server_model, client_models, public_loader): + server_model.train() + optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) + + for data, _ in public_loader: + data = data.to(device) + + # 获取客户端模型特征 + client_features = [] + with torch.no_grad(): + for model in client_models: + features = model.extract_features(data) # 需要实现特征提取方法 + client_features.append(features) + + # 计算特征蒸馏目标 + target_features = torch.stack(client_features).mean(dim=0) + + # 服务器前向 + server_features = server_model.extract_features(data) + + # 特征对齐损失 + loss = F.mse_loss(server_features, target_features) + + # 反向传播 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # 更新统计信息 + total_loss += loss.item() + # 服务器知识更新 def server_update(server_model, client_models, public_loader): server_model.train() @@ -191,15 +216,19 @@ def test_model(model, test_loader): # 主训练流程 def main(): + transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor() + ]) # Initialize models - global_server_model = repvit_m1_1(num_classes=10).to(device) - client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)] + global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) + client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)] # Prepare data client_datasets, public_loader = prepare_data() # Test dataset (using dataset A's test set for simplicity) - test_dataset = ImageFolder(root='./dataset_A/test', transform=transform) + test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform) test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round") @@ -245,20 +274,22 @@ def main(): # Save trained models torch.save(global_server_model.state_dict(), "server_model.pth") - torch.save(client_models[0].state_dict(), "client_model.pth") + for i in range(NUM_CLIENTS): + torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth") print("Models saved successfully.") # Test server model - server_model = repvit_m1_1(num_classes=10).to(device) - server_model.load_state_dict(torch.load("server_model.pth")) + server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) + server_model.load_state_dict(torch.load("server_model.pth",weights_only=True)) server_acc = test_model(server_model, test_loader) print(f"Server Model Test Accuracy: {server_acc:.2f}%") # Test client model - client_model = MobileNetV3(n_class=10).to(device) - client_model.load_state_dict(torch.load("client_model.pth")) - client_acc = test_model(client_model, test_loader) - print(f"Client Model Test Accuracy: {client_acc:.2f}%") + for i in range(NUM_CLIENTS): + client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) + client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True)) + client_acc = test_model(client_model, test_loader) + print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%") if __name__ == "__main__": main() \ No newline at end of file diff --git a/model/mobilenetv3.py b/model/mobilenetv3.py index 4692cf9..2de909a 100644 --- a/model/mobilenetv3.py +++ b/model/mobilenetv3.py @@ -200,6 +200,11 @@ class MobileNetV3(nn.Module): self._initialize_weights() + + def extract_features(self, x): + x = self.features(x) + return x + def forward(self, x): x = self.features(x) x = x.mean(3).mean(2) diff --git a/model/repvit.py b/model/repvit.py index 78197f8..27b07b9 100644 --- a/model/repvit.py +++ b/model/repvit.py @@ -236,6 +236,10 @@ class RepViT(nn.Module): self.features = nn.ModuleList(layers) self.classifier = Classfier(output_channel, num_classes, distillation) + def extract_features(self, x): + for f in self.features: + x = f(x) + return x def forward(self, x): # x = self.features(x) for f in self.features: