diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1269488 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +data diff --git a/FED.py b/FED.py new file mode 100644 index 0000000..a62b3e9 --- /dev/null +++ b/FED.py @@ -0,0 +1,212 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import datasets, transforms +from torch.utils.data import DataLoader, Subset +import numpy as np +import copy + +# 配置参数 +NUM_CLIENTS = 10 +NUM_ROUNDS = 3 +CLIENT_EPOCHS = 2 +BATCH_SIZE = 32 +TEMP = 2.0 # 蒸馏温度 + +# 设备配置 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# 定义中心大模型 +class ServerModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(784, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, 10) + + def forward(self, x): + x = x.view(-1, 784) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + +# 定义端侧小模型 +class ClientModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(784, 64) + self.fc2 = nn.Linear(64, 10) + + def forward(self, x): + x = x.view(-1, 784) + x = F.relu(self.fc1(x)) + return self.fc2(x) + +# 数据准备 +def prepare_data(num_clients): + transform = transforms.Compose([transforms.ToTensor()]) + train_set = datasets.MNIST("./data", train=True, download=True, transform=transform) + + # 非IID数据划分(每个客户端2个类别) + client_data = {i: [] for i in range(num_clients)} + labels = train_set.targets.numpy() + for label in range(10): + label_idx = np.where(labels == label)[0] + np.random.shuffle(label_idx) + split = np.array_split(label_idx, num_clients//2) + for i, idx in enumerate(split): + client_data[i*2 + label%2].extend(idx) + + return [Subset(train_set, ids) for ids in client_data.values()] + +# 客户端训练函数 +def client_train(client_model, server_model, dataset): + client_model.train() + server_model.eval() + + optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1) + loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) + + for _ in range(CLIENT_EPOCHS): + for data, target in loader: + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + + # 获取小模型输出 + client_output = client_model(data) + + # 获取大模型输出(知识蒸馏) + with torch.no_grad(): + server_output = server_model(data) + + # 计算联合损失 + loss_task = F.cross_entropy(client_output, target) + loss_distill = F.kl_div( + F.log_softmax(client_output/TEMP, dim=1), + F.softmax(server_output/TEMP, dim=1), + reduction="batchmean" + ) * (TEMP**2) + + total_loss = loss_task + loss_distill + total_loss.backward() + optimizer.step() + + return client_model.state_dict() + +# 模型参数聚合(FedAvg) +def aggregate(client_params): + global_params = {} + for key in client_params[0].keys(): + global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0) + return global_params + +# 服务器知识更新 +def server_update(server_model, client_models, public_loader): + server_model.train() + optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) + + for data, _ in public_loader: + data = data.to(device) + optimizer.zero_grad() + + # 获取客户端模型的平均输出 + with torch.no_grad(): + client_outputs = [model(data).mean(dim=0, keepdim=True) for model in client_models] + soft_targets = torch.stack(client_outputs).mean(dim=0) + + # 蒸馏学习 + server_output = server_model(data) + loss = F.kl_div( + F.log_softmax(server_output, dim=1), + F.softmax(soft_targets, dim=1), + reduction="batchmean" + ) + + loss.backward() + optimizer.step() + +def test_model(model, test_loader): + model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + _, predicted = torch.max(output.data, 1) + total += target.size(0) + correct += (predicted == target).sum().item() + accuracy = 100 * correct / total + return accuracy + +# 主训练流程 +def main(): + # 初始化模型 + global_server_model = ServerModel().to(device) + client_models = [ClientModel().to(device) for _ in range(NUM_CLIENTS)] + + # 准备数据 + client_datasets = prepare_data(NUM_CLIENTS) + public_loader = DataLoader( + datasets.MNIST("./data", train=False, download=True, + transform=transforms.ToTensor()), + batch_size=100, shuffle=True) + + for round in range(NUM_ROUNDS): + # 客户端选择 + selected_clients = np.random.choice(NUM_CLIENTS, 5, replace=False) + + # 客户端本地训练 + client_params = [] + for cid in selected_clients: + # 下载全局模型 + local_model = copy.deepcopy(client_models[cid]) + local_model.load_state_dict(client_models[cid].state_dict()) + + # 本地训练 + updated_params = client_train( + local_model, + global_server_model, + client_datasets[cid] + ) + client_params.append(updated_params) + + # 模型聚合 + global_client_params = aggregate(client_params) + for model in client_models: + model.load_state_dict(global_client_params) + + # 服务器知识更新 + server_update(global_server_model, client_models, public_loader) + + print(f"Round {round+1} completed") + + print("Training completed!") + + # 保存训练好的模型 + torch.save(global_server_model.state_dict(), "server_model.pth") + torch.save(client_models[0].state_dict(), "client_model.pth") + print("Models saved successfully.") + + # 创建测试数据加载器 + test_dataset = datasets.MNIST( + "./data", + train=False, + transform=transforms.ToTensor() + ) + test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) + + # 测试服务器模型 + server_model = ServerModel().to(device) + server_model.load_state_dict(torch.load("server_model.pth")) + server_acc = test_model(server_model, test_loader) + print(f"Server Model Test Accuracy: {server_acc:.2f}%") + + # 测试客户端模型 + client_model = ClientModel().to(device) + client_model.load_state_dict(torch.load("client_model.pth")) + client_acc = test_model(client_model, test_loader) + print(f"Client Model Test Accuracy: {client_acc:.2f}%") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/main.py b/main.py index e69de29..4f88c47 100644 --- a/main.py +++ b/main.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms + +# 设备配置 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# 超参数设置 +num_epochs_teacher = 10 # 教师模型训练轮数 +num_epochs_student = 20 # 学生模型训练轮数 +batch_size = 64 +learning_rate = 0.001 +temperature = 5 # 蒸馏温度 +alpha = 0.3 # 蒸馏损失权重 + +# 数据集准备(示例使用CIFAR-10) +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) +]) + +# 假设: +# 数据集A是CIFAR-10训练集的前25000张 +# 数据集B是CIFAR-10训练集的后25000张 +dataset_A = torchvision.datasets.CIFAR10( + root='./data', train=True, download=True, transform=transform) +dataset_A = torch.utils.data.Subset(dataset_A, range(25000)) + +dataset_B = torchvision.datasets.CIFAR10( + root='./data', train=True, download=True, transform=transform) +dataset_B = torch.utils.data.Subset(dataset_B, range(25000, 50000)) + +train_loader_A = DataLoader(dataset_A, batch_size=batch_size, shuffle=True) +train_loader_B = DataLoader(dataset_B, batch_size=batch_size, shuffle=True) + +# 教师模型定义(ResNet18) +class TeacherModel(nn.Module): + def __init__(self): + super(TeacherModel, self).__init__() + self.resnet = torchvision.models.resnet18(pretrained=False) + self.resnet.fc = nn.Linear(512, 10) # CIFAR-10有10个类别 + + def forward(self, x): + return self.resnet(x) + +# 学生模型定义(更小的CNN) +class StudentModel(nn.Module): + def __init__(self): + super(StudentModel, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, 3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2) + ) + self.classifier = nn.Sequential( + nn.Linear(32 * 8 * 8, 128), + nn.ReLU(), + nn.Linear(128, 10) + ) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x + +# 训练教师模型 +teacher = TeacherModel().to(device) +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(teacher.parameters(), lr=learning_rate) + +print("Training Teacher Model...") +for epoch in range(num_epochs_teacher): + teacher.train() + for images, labels in train_loader_A: + images = images.to(device) + labels = labels.to(device) + + outputs = teacher(images) + loss = criterion(outputs, labels) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + print(f"Teacher Epoch [{epoch+1}/{num_epochs_teacher}]") + +# 知识蒸馏训练学生模型 +student = StudentModel().to(device) +optimizer = optim.Adam(student.parameters(), lr=learning_rate) +criterion_hard = nn.CrossEntropyLoss() +criterion_soft = nn.KLDivLoss(reduction="batchmean") + +print("\nDistilling Knowledge to Student...") +teacher.eval() # 设置教师模型为评估模式 + +for epoch in range(num_epochs_student): + student.train() + total_loss = 0 + + for images, labels in train_loader_B: + images = images.to(device) + labels = labels.to(device) + + # 教师模型预测(不计算梯度) + with torch.no_grad(): + teacher_logits = teacher(images) + + # 学生模型预测 + student_logits = student(images) + + # 计算硬标签损失 + hard_loss = criterion_hard(student_logits, labels) + + # 计算软标签损失(带温度缩放) + soft_loss = criterion_soft( + nn.functional.log_softmax(student_logits / temperature, dim=1), + nn.functional.softmax(teacher_logits / temperature, dim=1) + ) * (temperature ** 2) # 缩放梯度 + + # 组合损失 + loss = alpha * hard_loss + (1 - alpha) * soft_loss + + # 反向传播 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + + avg_loss = total_loss / len(train_loader_B) + print(f"Student Epoch [{epoch+1}/{num_epochs_student}], Loss: {avg_loss:.4f}") + +print("Knowledge distillation complete!") \ No newline at end of file