TA_EC/FED.py
2025-03-12 14:00:50 +08:00

308 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
from tqdm import tqdm
from data_loader import get_data_loader
from model.repvit import repvit_m1_1
from model.mobilenetv3 import MobileNetV3
# 配置参数
NUM_CLIENTS = 2
NUM_ROUNDS = 10
CLIENT_EPOCHS = 2
BATCH_SIZE = 32
TEMP = 2.0 # 蒸馏温度
CLASS_NUM = [9, 9, 9]
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据准备
import os
from torchvision.datasets import ImageFolder
def prepare_data():
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# 加载所有数据集(训练、验证、测试)
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/JY_A',Cache='RAM')
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_A',Cache='RAM')
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_B',Cache='RAM')
# 组织客户端数据集
client_datasets = [
{ # Client 0
'train': dataset_B_train,
'val': dataset_B_val,
'test': dataset_B_test
},
{ # Client 1
'train': dataset_C_train,
'val': dataset_C_val,
'test': dataset_C_test
}
]
# 公共数据集使用A的训练集
public_loader = dataset_A_train
# 服务器测试集使用A的测试集
server_test_loader = dataset_A_test
return client_datasets, public_loader, server_test_loader
# 客户端训练函数
def client_train(client_model, server_model, loader):
client_model.train()
server_model.eval()
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
for epoch in range(CLIENT_EPOCHS):
epoch_loss = 0.0
task_loss = 0.0
distill_loss = 0.0
correct = 0
total = 0
# 训练进度条
progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{CLIENT_EPOCHS}")
for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# 前向传播
client_output = client_model(data).to(device)
# 获取教师模型输出
with torch.no_grad():
server_output = server_model(data).to(device)
# 计算损失
loss_task = F.cross_entropy(client_output, target)
loss_distill = F.kl_div(
F.log_softmax(client_output/TEMP, dim=1),
F.softmax(server_output/TEMP, dim=1),
reduction="batchmean"
) * (TEMP**2)
total_loss = loss_task + loss_distill
# 反向传播
total_loss.backward()
optimizer.step()
# 统计指标
epoch_loss += total_loss.item()
task_loss += loss_task.item()
distill_loss += loss_distill.item()
_, predicted = torch.max(client_output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
# 实时更新进度条
progress_bar.set_postfix({
"Epoch": f"{epoch+1}/{CLIENT_EPOCHS}",
"Batch": f"{batch_idx+1}/{len(loader)}",
"Loss": f"{total_loss.item():.4f}",
"Acc": f"{100*correct/total:.2f}%\n",
})
progress_bar.update(1)
# 每个epoch结束打印汇总信息
avg_loss = epoch_loss / len(loader)
avg_task = task_loss / len(loader)
avg_distill = distill_loss / len(loader)
epoch_acc = 100 * correct / total
print(f"\n{'='*40}")
print(f"Epoch {epoch+1} Summary:")
print(f"Average Loss: {avg_loss:.4f}")
print(f"Task Loss: {avg_task:.4f}")
print(f"Distill Loss: {avg_distill:.4f}")
print(f"Training Accuracy: {epoch_acc:.2f}%")
print(f"{'='*40}\n")
progress_bar.close()
return client_model.state_dict()
# 模型参数聚合FedAvg
def aggregate(client_params):
global_params = {}
for key in client_params[0].keys():
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
return global_params
def server_aggregate(server_model, client_models, public_loader):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
for data, _ in public_loader:
data = data.to(device)
# 获取客户端模型特征
client_features = []
with torch.no_grad():
for model in client_models:
features = model.extract_features(data) # 需要实现特征提取方法
client_features.append(features)
# 计算特征蒸馏目标
target_features = torch.stack(client_features).mean(dim=0)
# 服务器前向
server_features = server_model.extract_features(data)
# 特征对齐损失
loss = F.mse_loss(server_features, target_features)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 更新统计信息
total_loss += loss.item()
# 服务器知识更新
def server_update(server_model, client_models, public_loader):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
total_loss = 0.0
progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch")
for batch_idx, (data, target) in enumerate(progress_bar):
data = data.to(device)
optimizer.zero_grad()
# 获取客户端模型的平均输出
with torch.no_grad():
client_outputs = [model(data).mean(dim=0, keepdim=True) for model in client_models]
soft_targets = torch.stack(client_outputs).mean(dim=0)
# 蒸馏学习
server_output = server_model(data)
loss = F.kl_div(
F.log_softmax(server_output, dim=1),
F.softmax(soft_targets, dim=1),
reduction="batchmean"
)
# 反向传播
loss.backward()
optimizer.step()
# 更新统计信息
total_loss += loss.item()
progress_bar.set_postfix({
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}",
"Current Loss": f"{loss.item():.4f}"
})
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
def test_model(model, test_loader): # 添加对DataLoader的支持
model.eval()
correct = 0
total = 0
with torch.no_grad():
progress_bar = tqdm(test_loader, desc="Server Updating", unit="batch")
for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return 100 * correct / total
# 主训练流程
def main():
# 初始化模型(保持不变)
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
# 加载数据集
client_datasets, public_loader, server_test_loader = prepare_data()
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
for round in range(NUM_ROUNDS):
print(f"\n{'#'*50}")
print(f"Round {round+1}/{NUM_ROUNDS}")
print(f"{'#'*50}")
# 客户端选择
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
print(f"Selected clients: {selected_clients}")
# 客户端训练
client_params = []
for cid in selected_clients:
print(f"\nTraining Client {cid}")
local_model = copy.deepcopy(client_models[cid])
local_model.load_state_dict(client_models[cid].state_dict())
# 传入客户端的训练集
updated_params = client_train(
local_model,
global_server_model,
client_datasets[cid]['train'] # 使用训练集
)
client_params.append(updated_params)
# 模型聚合
global_client_params = aggregate(client_params)
for model in client_models:
model.load_state_dict(global_client_params)
# 服务器更新
print("\nServer updating...")
server_update(global_server_model, client_models, public_loader)
# 测试性能
server_acc = test_model(global_server_model, server_test_loader)
client_accuracies = [
test_model(client_models[i],
client_datasets[i]['test']) # 动态创建测试loader
for i in range(NUM_CLIENTS)
]
print(f"\nRound {round+1} Results:")
print(f"Server Accuracy: {server_acc:.2f}%")
for i, acc in enumerate(client_accuracies):
print(f"Client {i} Accuracy: {acc:.2f}%")
round_progress.update(1)
# 保存模型
torch.save(global_server_model.state_dict(), "server_model.pth")
for i in range(NUM_CLIENTS):
torch.save(client_models[i].state_dict(), f"client{i}_model.pth")
# 最终测试
print("\nFinal Evaluation:")
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
print(f"Server Accuracy: {test_model(server_model, server_test_loader):.2f}%")
for i in range(NUM_CLIENTS):
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
client_model.load_state_dict(torch.load(f"client{i}_model.pth",weights_only=True))
test_loader = client_datasets[i]['test']
print(f"Client {i} Accuracy: {test_model(client_model, test_loader):.2f}%")
if __name__ == "__main__":
main()