import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoader, Subset import numpy as np import copy # 配置参数 NUM_CLIENTS = 10 NUM_ROUNDS = 3 CLIENT_EPOCHS = 2 BATCH_SIZE = 32 TEMP = 2.0 # 蒸馏温度 # 设备配置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 定义中心大模型 class ServerModel(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(784, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, 10) def forward(self, x): x = x.view(-1, 784) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) return self.fc3(x) # 定义端侧小模型 class ClientModel(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(784, 64) self.fc2 = nn.Linear(64, 10) def forward(self, x): x = x.view(-1, 784) x = F.relu(self.fc1(x)) return self.fc2(x) # 数据准备 def prepare_data(num_clients): transform = transforms.Compose([transforms.ToTensor()]) train_set = datasets.MNIST("./data", train=True, download=True, transform=transform) # 非IID数据划分(每个客户端2个类别) client_data = {i: [] for i in range(num_clients)} labels = train_set.targets.numpy() for label in range(10): label_idx = np.where(labels == label)[0] np.random.shuffle(label_idx) split = np.array_split(label_idx, num_clients//2) for i, idx in enumerate(split): client_data[i*2 + label%2].extend(idx) return [Subset(train_set, ids) for ids in client_data.values()] # 客户端训练函数 def client_train(client_model, server_model, dataset): client_model.train() server_model.eval() optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) for _ in range(CLIENT_EPOCHS): for data, target in loader: data, target = data.to(device), target.to(device) optimizer.zero_grad() # 获取小模型输出 client_output = client_model(data) # 获取大模型输出(知识蒸馏) with torch.no_grad(): server_output = server_model(data) # 计算联合损失 loss_task = F.cross_entropy(client_output, target) loss_distill = F.kl_div( F.log_softmax(client_output/TEMP, dim=1), F.softmax(server_output/TEMP, dim=1), reduction="batchmean" ) * (TEMP**2) total_loss = loss_task + loss_distill total_loss.backward() optimizer.step() return client_model.state_dict() # 模型参数聚合(FedAvg) def aggregate(client_params): global_params = {} for key in client_params[0].keys(): global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0) return global_params # 服务器知识更新 def server_update(server_model, client_models, public_loader): server_model.train() optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) for data, _ in public_loader: data = data.to(device) optimizer.zero_grad() # 获取客户端模型的平均输出 with torch.no_grad(): client_outputs = [model(data).mean(dim=0, keepdim=True) for model in client_models] soft_targets = torch.stack(client_outputs).mean(dim=0) # 蒸馏学习 server_output = server_model(data) loss = F.kl_div( F.log_softmax(server_output, dim=1), F.softmax(soft_targets, dim=1), reduction="batchmean" ) loss.backward() optimizer.step() def test_model(model, test_loader): model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) _, predicted = torch.max(output.data, 1) total += target.size(0) correct += (predicted == target).sum().item() accuracy = 100 * correct / total return accuracy # 主训练流程 def main(): # 初始化模型 global_server_model = ServerModel().to(device) client_models = [ClientModel().to(device) for _ in range(NUM_CLIENTS)] # 准备数据 client_datasets = prepare_data(NUM_CLIENTS) public_loader = DataLoader( datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor()), batch_size=100, shuffle=True) for round in range(NUM_ROUNDS): # 客户端选择 selected_clients = np.random.choice(NUM_CLIENTS, 5, replace=False) # 客户端本地训练 client_params = [] for cid in selected_clients: # 下载全局模型 local_model = copy.deepcopy(client_models[cid]) local_model.load_state_dict(client_models[cid].state_dict()) # 本地训练 updated_params = client_train( local_model, global_server_model, client_datasets[cid] ) client_params.append(updated_params) # 模型聚合 global_client_params = aggregate(client_params) for model in client_models: model.load_state_dict(global_client_params) # 服务器知识更新 server_update(global_server_model, client_models, public_loader) print(f"Round {round+1} completed") print("Training completed!") # 保存训练好的模型 torch.save(global_server_model.state_dict(), "server_model.pth") torch.save(client_models[0].state_dict(), "client_model.pth") print("Models saved successfully.") # 创建测试数据加载器 test_dataset = datasets.MNIST( "./data", train=False, transform=transforms.ToTensor() ) test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) # 测试服务器模型 server_model = ServerModel().to(device) server_model.load_state_dict(torch.load("server_model.pth")) server_acc = test_model(server_model, test_loader) print(f"Server Model Test Accuracy: {server_acc:.2f}%") # 测试客户端模型 client_model = ClientModel().to(device) client_model.load_state_dict(torch.load("client_model.pth")) client_acc = test_model(client_model, test_loader) print(f"Client Model Test Accuracy: {client_acc:.2f}%") if __name__ == "__main__": main()