add demo
This commit is contained in:
parent
ee7d8798af
commit
4051332732
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
data
|
212
FED.py
Normal file
212
FED.py
Normal file
@ -0,0 +1,212 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import datasets, transforms
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
# 配置参数
|
||||
NUM_CLIENTS = 10
|
||||
NUM_ROUNDS = 3
|
||||
CLIENT_EPOCHS = 2
|
||||
BATCH_SIZE = 32
|
||||
TEMP = 2.0 # 蒸馏温度
|
||||
|
||||
# 设备配置
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 定义中心大模型
|
||||
class ServerModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(784, 512)
|
||||
self.fc2 = nn.Linear(512, 256)
|
||||
self.fc3 = nn.Linear(256, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 784)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
return self.fc3(x)
|
||||
|
||||
# 定义端侧小模型
|
||||
class ClientModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(784, 64)
|
||||
self.fc2 = nn.Linear(64, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 784)
|
||||
x = F.relu(self.fc1(x))
|
||||
return self.fc2(x)
|
||||
|
||||
# 数据准备
|
||||
def prepare_data(num_clients):
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
train_set = datasets.MNIST("./data", train=True, download=True, transform=transform)
|
||||
|
||||
# 非IID数据划分(每个客户端2个类别)
|
||||
client_data = {i: [] for i in range(num_clients)}
|
||||
labels = train_set.targets.numpy()
|
||||
for label in range(10):
|
||||
label_idx = np.where(labels == label)[0]
|
||||
np.random.shuffle(label_idx)
|
||||
split = np.array_split(label_idx, num_clients//2)
|
||||
for i, idx in enumerate(split):
|
||||
client_data[i*2 + label%2].extend(idx)
|
||||
|
||||
return [Subset(train_set, ids) for ids in client_data.values()]
|
||||
|
||||
# 客户端训练函数
|
||||
def client_train(client_model, server_model, dataset):
|
||||
client_model.train()
|
||||
server_model.eval()
|
||||
|
||||
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
|
||||
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
for _ in range(CLIENT_EPOCHS):
|
||||
for data, target in loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 获取小模型输出
|
||||
client_output = client_model(data)
|
||||
|
||||
# 获取大模型输出(知识蒸馏)
|
||||
with torch.no_grad():
|
||||
server_output = server_model(data)
|
||||
|
||||
# 计算联合损失
|
||||
loss_task = F.cross_entropy(client_output, target)
|
||||
loss_distill = F.kl_div(
|
||||
F.log_softmax(client_output/TEMP, dim=1),
|
||||
F.softmax(server_output/TEMP, dim=1),
|
||||
reduction="batchmean"
|
||||
) * (TEMP**2)
|
||||
|
||||
total_loss = loss_task + loss_distill
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return client_model.state_dict()
|
||||
|
||||
# 模型参数聚合(FedAvg)
|
||||
def aggregate(client_params):
|
||||
global_params = {}
|
||||
for key in client_params[0].keys():
|
||||
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
|
||||
return global_params
|
||||
|
||||
# 服务器知识更新
|
||||
def server_update(server_model, client_models, public_loader):
|
||||
server_model.train()
|
||||
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
||||
|
||||
for data, _ in public_loader:
|
||||
data = data.to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 获取客户端模型的平均输出
|
||||
with torch.no_grad():
|
||||
client_outputs = [model(data).mean(dim=0, keepdim=True) for model in client_models]
|
||||
soft_targets = torch.stack(client_outputs).mean(dim=0)
|
||||
|
||||
# 蒸馏学习
|
||||
server_output = server_model(data)
|
||||
loss = F.kl_div(
|
||||
F.log_softmax(server_output, dim=1),
|
||||
F.softmax(soft_targets, dim=1),
|
||||
reduction="batchmean"
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
def test_model(model, test_loader):
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
total += target.size(0)
|
||||
correct += (predicted == target).sum().item()
|
||||
accuracy = 100 * correct / total
|
||||
return accuracy
|
||||
|
||||
# 主训练流程
|
||||
def main():
|
||||
# 初始化模型
|
||||
global_server_model = ServerModel().to(device)
|
||||
client_models = [ClientModel().to(device) for _ in range(NUM_CLIENTS)]
|
||||
|
||||
# 准备数据
|
||||
client_datasets = prepare_data(NUM_CLIENTS)
|
||||
public_loader = DataLoader(
|
||||
datasets.MNIST("./data", train=False, download=True,
|
||||
transform=transforms.ToTensor()),
|
||||
batch_size=100, shuffle=True)
|
||||
|
||||
for round in range(NUM_ROUNDS):
|
||||
# 客户端选择
|
||||
selected_clients = np.random.choice(NUM_CLIENTS, 5, replace=False)
|
||||
|
||||
# 客户端本地训练
|
||||
client_params = []
|
||||
for cid in selected_clients:
|
||||
# 下载全局模型
|
||||
local_model = copy.deepcopy(client_models[cid])
|
||||
local_model.load_state_dict(client_models[cid].state_dict())
|
||||
|
||||
# 本地训练
|
||||
updated_params = client_train(
|
||||
local_model,
|
||||
global_server_model,
|
||||
client_datasets[cid]
|
||||
)
|
||||
client_params.append(updated_params)
|
||||
|
||||
# 模型聚合
|
||||
global_client_params = aggregate(client_params)
|
||||
for model in client_models:
|
||||
model.load_state_dict(global_client_params)
|
||||
|
||||
# 服务器知识更新
|
||||
server_update(global_server_model, client_models, public_loader)
|
||||
|
||||
print(f"Round {round+1} completed")
|
||||
|
||||
print("Training completed!")
|
||||
|
||||
# 保存训练好的模型
|
||||
torch.save(global_server_model.state_dict(), "server_model.pth")
|
||||
torch.save(client_models[0].state_dict(), "client_model.pth")
|
||||
print("Models saved successfully.")
|
||||
|
||||
# 创建测试数据加载器
|
||||
test_dataset = datasets.MNIST(
|
||||
"./data",
|
||||
train=False,
|
||||
transform=transforms.ToTensor()
|
||||
)
|
||||
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
||||
|
||||
# 测试服务器模型
|
||||
server_model = ServerModel().to(device)
|
||||
server_model.load_state_dict(torch.load("server_model.pth"))
|
||||
server_acc = test_model(server_model, test_loader)
|
||||
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
||||
|
||||
# 测试客户端模型
|
||||
client_model = ClientModel().to(device)
|
||||
client_model.load_state_dict(torch.load("client_model.pth"))
|
||||
client_acc = test_model(client_model, test_loader)
|
||||
print(f"Client Model Test Accuracy: {client_acc:.2f}%")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
140
main.py
140
main.py
@ -0,0 +1,140 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
# 设备配置
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 超参数设置
|
||||
num_epochs_teacher = 10 # 教师模型训练轮数
|
||||
num_epochs_student = 20 # 学生模型训练轮数
|
||||
batch_size = 64
|
||||
learning_rate = 0.001
|
||||
temperature = 5 # 蒸馏温度
|
||||
alpha = 0.3 # 蒸馏损失权重
|
||||
|
||||
# 数据集准备(示例使用CIFAR-10)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
# 假设:
|
||||
# 数据集A是CIFAR-10训练集的前25000张
|
||||
# 数据集B是CIFAR-10训练集的后25000张
|
||||
dataset_A = torchvision.datasets.CIFAR10(
|
||||
root='./data', train=True, download=True, transform=transform)
|
||||
dataset_A = torch.utils.data.Subset(dataset_A, range(25000))
|
||||
|
||||
dataset_B = torchvision.datasets.CIFAR10(
|
||||
root='./data', train=True, download=True, transform=transform)
|
||||
dataset_B = torch.utils.data.Subset(dataset_B, range(25000, 50000))
|
||||
|
||||
train_loader_A = DataLoader(dataset_A, batch_size=batch_size, shuffle=True)
|
||||
train_loader_B = DataLoader(dataset_B, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# 教师模型定义(ResNet18)
|
||||
class TeacherModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(TeacherModel, self).__init__()
|
||||
self.resnet = torchvision.models.resnet18(pretrained=False)
|
||||
self.resnet.fc = nn.Linear(512, 10) # CIFAR-10有10个类别
|
||||
|
||||
def forward(self, x):
|
||||
return self.resnet(x)
|
||||
|
||||
# 学生模型定义(更小的CNN)
|
||||
class StudentModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(StudentModel, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 16, 3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Conv2d(16, 32, 3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2)
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(32 * 8 * 8, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 10)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
# 训练教师模型
|
||||
teacher = TeacherModel().to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(teacher.parameters(), lr=learning_rate)
|
||||
|
||||
print("Training Teacher Model...")
|
||||
for epoch in range(num_epochs_teacher):
|
||||
teacher.train()
|
||||
for images, labels in train_loader_A:
|
||||
images = images.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
outputs = teacher(images)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print(f"Teacher Epoch [{epoch+1}/{num_epochs_teacher}]")
|
||||
|
||||
# 知识蒸馏训练学生模型
|
||||
student = StudentModel().to(device)
|
||||
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
|
||||
criterion_hard = nn.CrossEntropyLoss()
|
||||
criterion_soft = nn.KLDivLoss(reduction="batchmean")
|
||||
|
||||
print("\nDistilling Knowledge to Student...")
|
||||
teacher.eval() # 设置教师模型为评估模式
|
||||
|
||||
for epoch in range(num_epochs_student):
|
||||
student.train()
|
||||
total_loss = 0
|
||||
|
||||
for images, labels in train_loader_B:
|
||||
images = images.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
# 教师模型预测(不计算梯度)
|
||||
with torch.no_grad():
|
||||
teacher_logits = teacher(images)
|
||||
|
||||
# 学生模型预测
|
||||
student_logits = student(images)
|
||||
|
||||
# 计算硬标签损失
|
||||
hard_loss = criterion_hard(student_logits, labels)
|
||||
|
||||
# 计算软标签损失(带温度缩放)
|
||||
soft_loss = criterion_soft(
|
||||
nn.functional.log_softmax(student_logits / temperature, dim=1),
|
||||
nn.functional.softmax(teacher_logits / temperature, dim=1)
|
||||
) * (temperature ** 2) # 缩放梯度
|
||||
|
||||
# 组合损失
|
||||
loss = alpha * hard_loss + (1 - alpha) * soft_loss
|
||||
|
||||
# 反向传播
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
avg_loss = total_loss / len(train_loader_B)
|
||||
print(f"Student Epoch [{epoch+1}/{num_epochs_student}], Loss: {avg_loss:.4f}")
|
||||
|
||||
print("Knowledge distillation complete!")
|
Loading…
Reference in New Issue
Block a user