import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoader, Subset import numpy as np import copy from tqdm import tqdm from model.repvit import repvit_m1_1 from model.mobilenetv3 import MobileNetV3 # 配置参数 NUM_CLIENTS = 2 NUM_ROUNDS = 10 CLIENT_EPOCHS = 2 BATCH_SIZE = 32 TEMP = 2.0 # 蒸馏温度 CLASS_NUM = [3, 3, 3] # 设备配置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 数据准备 import os from torchvision.datasets import ImageFolder def prepare_data(): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # Load datasets dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform) dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform) dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform) # Assign datasets to clients client_datasets = [dataset_B, dataset_C] # Server dataset (A) for public updates public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True) return client_datasets, public_loader # 客户端训练函数 def client_train(client_model, server_model, dataset): client_model.train() server_model.eval() optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) # 训练进度条 progress_bar = tqdm(total=CLIENT_EPOCHS*len(loader), desc="Client Training", unit="batch") for epoch in range(CLIENT_EPOCHS): epoch_loss = 0.0 task_loss = 0.0 distill_loss = 0.0 correct = 0 total = 0 for batch_idx, (data, target) in enumerate(loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() # 前向传播 client_output = client_model(data) # 获取教师模型输出 with torch.no_grad(): server_output = server_model(data) # 计算损失 loss_task = F.cross_entropy(client_output, target) loss_distill = F.kl_div( F.log_softmax(client_output/TEMP, dim=1), F.softmax(server_output/TEMP, dim=1), reduction="batchmean" ) * (TEMP**2) total_loss = loss_task + loss_distill # 反向传播 total_loss.backward() optimizer.step() # 统计指标 epoch_loss += total_loss.item() task_loss += loss_task.item() distill_loss += loss_distill.item() _, predicted = torch.max(client_output.data, 1) correct += (predicted == target).sum().item() total += target.size(0) # 实时更新进度条 progress_bar.set_postfix({ "Epoch": f"{epoch+1}/{CLIENT_EPOCHS}", "Batch": f"{batch_idx+1}/{len(loader)}", "Loss": f"{total_loss.item():.4f}", "Acc": f"{100*correct/total:.2f}%\n", }) progress_bar.update(1) # 每个epoch结束打印汇总信息 avg_loss = epoch_loss / len(loader) avg_task = task_loss / len(loader) avg_distill = distill_loss / len(loader) epoch_acc = 100 * correct / total print(f"\n{'='*40}") print(f"Epoch {epoch+1} Summary:") print(f"Average Loss: {avg_loss:.4f}") print(f"Task Loss: {avg_task:.4f}") print(f"Distill Loss: {avg_distill:.4f}") print(f"Training Accuracy: {epoch_acc:.2f}%") print(f"{'='*40}\n") progress_bar.close() return client_model.state_dict() # 模型参数聚合(FedAvg) def aggregate(client_params): global_params = {} for key in client_params[0].keys(): global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0) return global_params def server_aggregate(server_model, client_models, public_loader): server_model.train() optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) for data, _ in public_loader: data = data.to(device) # 获取客户端模型特征 client_features = [] with torch.no_grad(): for model in client_models: features = model.extract_features(data) # 需要实现特征提取方法 client_features.append(features) # 计算特征蒸馏目标 target_features = torch.stack(client_features).mean(dim=0) # 服务器前向 server_features = server_model.extract_features(data) # 特征对齐损失 loss = F.mse_loss(server_features, target_features) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() # 更新统计信息 total_loss += loss.item() # 服务器知识更新 def server_update(server_model, client_models, public_loader): server_model.train() optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) total_loss = 0.0 progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch") for batch_idx, (data, _) in enumerate(progress_bar): data = data.to(device) optimizer.zero_grad() # 获取客户端模型的平均输出 with torch.no_grad(): client_outputs = [model(data).mean(dim=0, keepdim=True) for model in client_models] soft_targets = torch.stack(client_outputs).mean(dim=0) # 蒸馏学习 server_output = server_model(data) loss = F.kl_div( F.log_softmax(server_output, dim=1), F.softmax(soft_targets, dim=1), reduction="batchmean" ) # 反向传播 loss.backward() optimizer.step() # 更新统计信息 total_loss += loss.item() progress_bar.set_postfix({ "Avg Loss": f"{total_loss/(batch_idx+1):.4f}", "Current Loss": f"{loss.item():.4f}" }) print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n") def test_model(model, test_loader): model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) _, predicted = torch.max(output.data, 1) total += target.size(0) correct += (predicted == target).sum().item() accuracy = 100 * correct / total return accuracy # 主训练流程 def main(): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # Initialize models global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)] # Prepare data client_datasets, public_loader = prepare_data() # Test dataset (using dataset A's test set for simplicity) test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform) test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round") for round in range(NUM_ROUNDS): print(f"\n{'#'*50}") print(f"Federated Round {round+1}/{NUM_ROUNDS}") print(f"{'#'*50}") # Client selection (only 2 clients) selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False) print(f"Selected Clients: {selected_clients}") # Client local training client_params = [] for cid in selected_clients: print(f"\nTraining Client {cid}") local_model = copy.deepcopy(client_models[cid]) local_model.load_state_dict(client_models[cid].state_dict()) updated_params = client_train(local_model, global_server_model, client_datasets[cid]) client_params.append(updated_params) # Model aggregation global_client_params = aggregate(client_params) for model in client_models: model.load_state_dict(global_client_params) # Server knowledge update print("\nServer Updating...") server_update(global_server_model, client_models, public_loader) # Test model performance server_acc = test_model(global_server_model, test_loader) client_acc = test_model(client_models[0], test_loader) print(f"\nRound {round+1} Performance:") print(f"Global Model Accuracy: {server_acc:.2f}%") print(f"Client Model Accuracy: {client_acc:.2f}%") round_progress.update(1) print(f"Round {round+1} completed") print("Training completed!") # Save trained models torch.save(global_server_model.state_dict(), "server_model.pth") for i in range(NUM_CLIENTS): torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth") print("Models saved successfully.") # Test server model server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) server_model.load_state_dict(torch.load("server_model.pth",weights_only=True)) server_acc = test_model(server_model, test_loader) print(f"Server Model Test Accuracy: {server_acc:.2f}%") # Test client model for i in range(NUM_CLIENTS): client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True)) client_acc = test_model(client_model, test_loader) print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%") if __name__ == "__main__": main()