TA_EC/FED.py

212 lines
6.7 KiB
Python
Raw Normal View History

2025-03-04 04:55:40 +00:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
# 配置参数
NUM_CLIENTS = 10
NUM_ROUNDS = 3
CLIENT_EPOCHS = 2
BATCH_SIZE = 32
TEMP = 2.0 # 蒸馏温度
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义中心大模型
class ServerModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
# 定义端侧小模型
class ClientModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 64)
self.fc2 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.fc1(x))
return self.fc2(x)
# 数据准备
def prepare_data(num_clients):
transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.MNIST("./data", train=True, download=True, transform=transform)
# 非IID数据划分每个客户端2个类别
client_data = {i: [] for i in range(num_clients)}
labels = train_set.targets.numpy()
for label in range(10):
label_idx = np.where(labels == label)[0]
np.random.shuffle(label_idx)
split = np.array_split(label_idx, num_clients//2)
for i, idx in enumerate(split):
client_data[i*2 + label%2].extend(idx)
return [Subset(train_set, ids) for ids in client_data.values()]
# 客户端训练函数
def client_train(client_model, server_model, dataset):
client_model.train()
server_model.eval()
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
for _ in range(CLIENT_EPOCHS):
for data, target in loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# 获取小模型输出
client_output = client_model(data)
# 获取大模型输出(知识蒸馏)
with torch.no_grad():
server_output = server_model(data)
# 计算联合损失
loss_task = F.cross_entropy(client_output, target)
loss_distill = F.kl_div(
F.log_softmax(client_output/TEMP, dim=1),
F.softmax(server_output/TEMP, dim=1),
reduction="batchmean"
) * (TEMP**2)
total_loss = loss_task + loss_distill
total_loss.backward()
optimizer.step()
return client_model.state_dict()
# 模型参数聚合FedAvg
def aggregate(client_params):
global_params = {}
for key in client_params[0].keys():
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
return global_params
# 服务器知识更新
def server_update(server_model, client_models, public_loader):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
for data, _ in public_loader:
data = data.to(device)
optimizer.zero_grad()
# 获取客户端模型的平均输出
with torch.no_grad():
client_outputs = [model(data).mean(dim=0, keepdim=True) for model in client_models]
soft_targets = torch.stack(client_outputs).mean(dim=0)
# 蒸馏学习
server_output = server_model(data)
loss = F.kl_div(
F.log_softmax(server_output, dim=1),
F.softmax(soft_targets, dim=1),
reduction="batchmean"
)
loss.backward()
optimizer.step()
def test_model(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = 100 * correct / total
return accuracy
# 主训练流程
def main():
# 初始化模型
global_server_model = ServerModel().to(device)
client_models = [ClientModel().to(device) for _ in range(NUM_CLIENTS)]
# 准备数据
client_datasets = prepare_data(NUM_CLIENTS)
public_loader = DataLoader(
datasets.MNIST("./data", train=False, download=True,
transform=transforms.ToTensor()),
batch_size=100, shuffle=True)
for round in range(NUM_ROUNDS):
# 客户端选择
selected_clients = np.random.choice(NUM_CLIENTS, 5, replace=False)
# 客户端本地训练
client_params = []
for cid in selected_clients:
# 下载全局模型
local_model = copy.deepcopy(client_models[cid])
local_model.load_state_dict(client_models[cid].state_dict())
# 本地训练
updated_params = client_train(
local_model,
global_server_model,
client_datasets[cid]
)
client_params.append(updated_params)
# 模型聚合
global_client_params = aggregate(client_params)
for model in client_models:
model.load_state_dict(global_client_params)
# 服务器知识更新
server_update(global_server_model, client_models, public_loader)
print(f"Round {round+1} completed")
print("Training completed!")
# 保存训练好的模型
torch.save(global_server_model.state_dict(), "server_model.pth")
torch.save(client_models[0].state_dict(), "client_model.pth")
print("Models saved successfully.")
# 创建测试数据加载器
test_dataset = datasets.MNIST(
"./data",
train=False,
transform=transforms.ToTensor()
)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
# 测试服务器模型
server_model = ServerModel().to(device)
server_model.load_state_dict(torch.load("server_model.pth"))
server_acc = test_model(server_model, test_loader)
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
# 测试客户端模型
client_model = ClientModel().to(device)
client_model.load_state_dict(torch.load("client_model.pth"))
client_acc = test_model(client_model, test_loader)
print(f"Client Model Test Accuracy: {client_acc:.2f}%")
if __name__ == "__main__":
main()