TA_EC/FED.py
2025-03-12 00:21:31 +08:00

295 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
from tqdm import tqdm
from model.repvit import repvit_m1_1
from model.mobilenetv3 import MobileNetV3
# 配置参数
NUM_CLIENTS = 2
NUM_ROUNDS = 10
CLIENT_EPOCHS = 2
BATCH_SIZE = 32
TEMP = 2.0 # 蒸馏温度
CLASS_NUM = [3, 3, 3]
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据准备
import os
from torchvision.datasets import ImageFolder
def prepare_data():
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Load datasets
dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform)
dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform)
dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform)
# Assign datasets to clients
client_datasets = [dataset_B, dataset_C]
# Server dataset (A) for public updates
public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True)
return client_datasets, public_loader
# 客户端训练函数
def client_train(client_model, server_model, dataset):
client_model.train()
server_model.eval()
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# 训练进度条
progress_bar = tqdm(total=CLIENT_EPOCHS*len(loader),
desc="Client Training",
unit="batch")
for epoch in range(CLIENT_EPOCHS):
epoch_loss = 0.0
task_loss = 0.0
distill_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# 前向传播
client_output = client_model(data)
# 获取教师模型输出
with torch.no_grad():
server_output = server_model(data)
# 计算损失
loss_task = F.cross_entropy(client_output, target)
loss_distill = F.kl_div(
F.log_softmax(client_output/TEMP, dim=1),
F.softmax(server_output/TEMP, dim=1),
reduction="batchmean"
) * (TEMP**2)
total_loss = loss_task + loss_distill
# 反向传播
total_loss.backward()
optimizer.step()
# 统计指标
epoch_loss += total_loss.item()
task_loss += loss_task.item()
distill_loss += loss_distill.item()
_, predicted = torch.max(client_output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
# 实时更新进度条
progress_bar.set_postfix({
"Epoch": f"{epoch+1}/{CLIENT_EPOCHS}",
"Batch": f"{batch_idx+1}/{len(loader)}",
"Loss": f"{total_loss.item():.4f}",
"Acc": f"{100*correct/total:.2f}%\n",
})
progress_bar.update(1)
# 每个epoch结束打印汇总信息
avg_loss = epoch_loss / len(loader)
avg_task = task_loss / len(loader)
avg_distill = distill_loss / len(loader)
epoch_acc = 100 * correct / total
print(f"\n{'='*40}")
print(f"Epoch {epoch+1} Summary:")
print(f"Average Loss: {avg_loss:.4f}")
print(f"Task Loss: {avg_task:.4f}")
print(f"Distill Loss: {avg_distill:.4f}")
print(f"Training Accuracy: {epoch_acc:.2f}%")
print(f"{'='*40}\n")
progress_bar.close()
return client_model.state_dict()
# 模型参数聚合FedAvg
def aggregate(client_params):
global_params = {}
for key in client_params[0].keys():
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
return global_params
def server_aggregate(server_model, client_models, public_loader):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
for data, _ in public_loader:
data = data.to(device)
# 获取客户端模型特征
client_features = []
with torch.no_grad():
for model in client_models:
features = model.extract_features(data) # 需要实现特征提取方法
client_features.append(features)
# 计算特征蒸馏目标
target_features = torch.stack(client_features).mean(dim=0)
# 服务器前向
server_features = server_model.extract_features(data)
# 特征对齐损失
loss = F.mse_loss(server_features, target_features)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 更新统计信息
total_loss += loss.item()
# 服务器知识更新
def server_update(server_model, client_models, public_loader):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
total_loss = 0.0
progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch")
for batch_idx, (data, _) in enumerate(progress_bar):
data = data.to(device)
optimizer.zero_grad()
# 获取客户端模型的平均输出
with torch.no_grad():
client_outputs = [model(data).mean(dim=0, keepdim=True) for model in client_models]
soft_targets = torch.stack(client_outputs).mean(dim=0)
# 蒸馏学习
server_output = server_model(data)
loss = F.kl_div(
F.log_softmax(server_output, dim=1),
F.softmax(soft_targets, dim=1),
reduction="batchmean"
)
# 反向传播
loss.backward()
optimizer.step()
# 更新统计信息
total_loss += loss.item()
progress_bar.set_postfix({
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}",
"Current Loss": f"{loss.item():.4f}"
})
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
def test_model(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = 100 * correct / total
return accuracy
# 主训练流程
def main():
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Initialize models
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
# Prepare data
client_datasets, public_loader = prepare_data()
# Test dataset (using dataset A's test set for simplicity)
test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
for round in range(NUM_ROUNDS):
print(f"\n{'#'*50}")
print(f"Federated Round {round+1}/{NUM_ROUNDS}")
print(f"{'#'*50}")
# Client selection (only 2 clients)
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
print(f"Selected Clients: {selected_clients}")
# Client local training
client_params = []
for cid in selected_clients:
print(f"\nTraining Client {cid}")
local_model = copy.deepcopy(client_models[cid])
local_model.load_state_dict(client_models[cid].state_dict())
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
client_params.append(updated_params)
# Model aggregation
global_client_params = aggregate(client_params)
for model in client_models:
model.load_state_dict(global_client_params)
# Server knowledge update
print("\nServer Updating...")
server_update(global_server_model, client_models, public_loader)
# Test model performance
server_acc = test_model(global_server_model, test_loader)
client_acc = test_model(client_models[0], test_loader)
print(f"\nRound {round+1} Performance:")
print(f"Global Model Accuracy: {server_acc:.2f}%")
print(f"Client Model Accuracy: {client_acc:.2f}%")
round_progress.update(1)
print(f"Round {round+1} completed")
print("Training completed!")
# Save trained models
torch.save(global_server_model.state_dict(), "server_model.pth")
for i in range(NUM_CLIENTS):
torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth")
print("Models saved successfully.")
# Test server model
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
server_acc = test_model(server_model, test_loader)
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
# Test client model
for i in range(NUM_CLIENTS):
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True))
client_acc = test_model(client_model, test_loader)
print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%")
if __name__ == "__main__":
main()