TA_EC/FED.py
2025-03-04 12:55:40 +08:00

212 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
# 配置参数
NUM_CLIENTS = 10
NUM_ROUNDS = 3
CLIENT_EPOCHS = 2
BATCH_SIZE = 32
TEMP = 2.0 # 蒸馏温度
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义中心大模型
class ServerModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
# 定义端侧小模型
class ClientModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 64)
self.fc2 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.fc1(x))
return self.fc2(x)
# 数据准备
def prepare_data(num_clients):
transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.MNIST("./data", train=True, download=True, transform=transform)
# 非IID数据划分每个客户端2个类别
client_data = {i: [] for i in range(num_clients)}
labels = train_set.targets.numpy()
for label in range(10):
label_idx = np.where(labels == label)[0]
np.random.shuffle(label_idx)
split = np.array_split(label_idx, num_clients//2)
for i, idx in enumerate(split):
client_data[i*2 + label%2].extend(idx)
return [Subset(train_set, ids) for ids in client_data.values()]
# 客户端训练函数
def client_train(client_model, server_model, dataset):
client_model.train()
server_model.eval()
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
for _ in range(CLIENT_EPOCHS):
for data, target in loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# 获取小模型输出
client_output = client_model(data)
# 获取大模型输出(知识蒸馏)
with torch.no_grad():
server_output = server_model(data)
# 计算联合损失
loss_task = F.cross_entropy(client_output, target)
loss_distill = F.kl_div(
F.log_softmax(client_output/TEMP, dim=1),
F.softmax(server_output/TEMP, dim=1),
reduction="batchmean"
) * (TEMP**2)
total_loss = loss_task + loss_distill
total_loss.backward()
optimizer.step()
return client_model.state_dict()
# 模型参数聚合FedAvg
def aggregate(client_params):
global_params = {}
for key in client_params[0].keys():
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
return global_params
# 服务器知识更新
def server_update(server_model, client_models, public_loader):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
for data, _ in public_loader:
data = data.to(device)
optimizer.zero_grad()
# 获取客户端模型的平均输出
with torch.no_grad():
client_outputs = [model(data).mean(dim=0, keepdim=True) for model in client_models]
soft_targets = torch.stack(client_outputs).mean(dim=0)
# 蒸馏学习
server_output = server_model(data)
loss = F.kl_div(
F.log_softmax(server_output, dim=1),
F.softmax(soft_targets, dim=1),
reduction="batchmean"
)
loss.backward()
optimizer.step()
def test_model(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = 100 * correct / total
return accuracy
# 主训练流程
def main():
# 初始化模型
global_server_model = ServerModel().to(device)
client_models = [ClientModel().to(device) for _ in range(NUM_CLIENTS)]
# 准备数据
client_datasets = prepare_data(NUM_CLIENTS)
public_loader = DataLoader(
datasets.MNIST("./data", train=False, download=True,
transform=transforms.ToTensor()),
batch_size=100, shuffle=True)
for round in range(NUM_ROUNDS):
# 客户端选择
selected_clients = np.random.choice(NUM_CLIENTS, 5, replace=False)
# 客户端本地训练
client_params = []
for cid in selected_clients:
# 下载全局模型
local_model = copy.deepcopy(client_models[cid])
local_model.load_state_dict(client_models[cid].state_dict())
# 本地训练
updated_params = client_train(
local_model,
global_server_model,
client_datasets[cid]
)
client_params.append(updated_params)
# 模型聚合
global_client_params = aggregate(client_params)
for model in client_models:
model.load_state_dict(global_client_params)
# 服务器知识更新
server_update(global_server_model, client_models, public_loader)
print(f"Round {round+1} completed")
print("Training completed!")
# 保存训练好的模型
torch.save(global_server_model.state_dict(), "server_model.pth")
torch.save(client_models[0].state_dict(), "client_model.pth")
print("Models saved successfully.")
# 创建测试数据加载器
test_dataset = datasets.MNIST(
"./data",
train=False,
transform=transforms.ToTensor()
)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
# 测试服务器模型
server_model = ServerModel().to(device)
server_model.load_state_dict(torch.load("server_model.pth"))
server_acc = test_model(server_model, test_loader)
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
# 测试客户端模型
client_model = ClientModel().to(device)
client_model.load_state_dict(torch.load("client_model.pth"))
client_acc = test_model(client_model, test_loader)
print(f"Client Model Test Accuracy: {client_acc:.2f}%")
if __name__ == "__main__":
main()