TA_EC/FED.py
2025-03-12 14:26:35 +08:00

318 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
from tqdm import tqdm
from data_loader import get_data_loader
from logger import *
from model.repvit import repvit_m1_1
from model.mobilenetv3 import MobileNetV3
# 配置参数
NUM_CLIENTS = 2
NUM_ROUNDS = 10
CLIENT_EPOCHS = 2
BATCH_SIZE = 32
TEMP = 2.0 # 蒸馏温度
CLASS_NUM = [9, 9, 9]
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据准备
import os
from torchvision.datasets import ImageFolder
def prepare_data():
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# 加载所有数据集(训练、验证、测试)
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/JY_A',Cache='RAM')
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_A',Cache='RAM')
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_B',Cache='RAM')
# 组织客户端数据集
client_datasets = [
{ # Client 0
'train': dataset_B_train,
'val': dataset_B_val,
'test': dataset_B_test
},
{ # Client 1
'train': dataset_C_train,
'val': dataset_C_val,
'test': dataset_C_test
}
]
# 公共数据集使用A的训练集
public_loader = dataset_A_train
# 服务器测试集使用A的测试集
server_test_loader = dataset_A_test
return client_datasets, public_loader, server_test_loader
# 客户端训练函数
def client_train(client_model, server_model, loader, logger):
client_model.train()
server_model.eval()
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
for epoch in range(CLIENT_EPOCHS):
epoch_loss = 0.0
task_loss = 0.0
distill_loss = 0.0
correct = 0
total = 0
# 训练进度条
progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{CLIENT_EPOCHS}")
for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# 前向传播
client_output = client_model(data).to(device)
# 获取教师模型输出
with torch.no_grad():
server_output = server_model(data).to(device)
# 计算损失
loss_task = F.cross_entropy(client_output, target)
loss_distill = F.kl_div(
F.log_softmax(client_output/TEMP, dim=1),
F.softmax(server_output/TEMP, dim=1),
reduction="batchmean"
) * (TEMP**2)
total_loss = loss_task + loss_distill
# 反向传播
total_loss.backward()
optimizer.step()
# 统计指标
epoch_loss += total_loss.item()
task_loss += loss_task.item()
distill_loss += loss_distill.item()
_, predicted = torch.max(client_output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
# 实时更新进度条
progress_bar.set_postfix({
"Epoch": f"{epoch+1}/{CLIENT_EPOCHS}",
"Batch": f"{batch_idx+1}/{len(loader)}",
"Loss": f"{total_loss.item():.4f}",
"Acc": f"{100*correct/total:.2f}%\n",
})
progress_bar.update(1)
# 每个epoch结束打印汇总信息
avg_loss = epoch_loss / len(loader)
avg_task = task_loss / len(loader)
avg_distill = distill_loss / len(loader)
epoch_acc = 100 * correct / total
logger.info(f"\n客户端训练 Epoch {epoch+1}/{CLIENT_EPOCHS}")
logger.info(f"平均损失: {avg_loss:.4f}")
logger.info(f"任务损失: {avg_task:.4f}")
logger.info(f"蒸馏损失: {avg_distill:.4f}")
logger.info(f"训练准确率: {epoch_acc:.2f}%")
progress_bar.close()
return client_model.state_dict()
# 模型参数聚合FedAvg
def aggregate(client_params):
global_params = {}
for key in client_params[0].keys():
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
return global_params
def server_aggregate(server_model, client_models, public_loader):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
for data, _ in public_loader:
data = data.to(device)
# 获取客户端模型特征
client_features = []
with torch.no_grad():
for model in client_models:
features = model.extract_features(data) # 需要实现特征提取方法
client_features.append(features)
# 计算特征蒸馏目标
target_features = torch.stack(client_features).mean(dim=0)
# 服务器前向
server_features = server_model.extract_features(data)
# 特征对齐损失
loss = F.mse_loss(server_features, target_features)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 更新统计信息
total_loss += loss.item()
# 服务器知识更新
def server_update(server_model, client_models, public_loader, logger):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
total_loss = 0.0
progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch")
for batch_idx, (data, target) in enumerate(progress_bar):
data = data.to(device)
optimizer.zero_grad()
# 获取客户端模型的平均输出
with torch.no_grad():
client_outputs = [model(data).mean(dim=0, keepdim=True) for model in client_models]
soft_targets = torch.stack(client_outputs).mean(dim=0)
# 蒸馏学习
server_output = server_model(data)
loss = F.kl_div(
F.log_softmax(server_output, dim=1),
F.softmax(soft_targets, dim=1),
reduction="batchmean"
)
# 反向传播
loss.backward()
optimizer.step()
# 更新统计信息
total_loss += loss.item()
progress_bar.set_postfix({
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}",
"Current Loss": f"{loss.item():.4f}"
})
logger.info(f"服务器更新完成 | 平均损失: {total_loss/len(public_loader):.4f}")
def test_model(model, test_loader): # 添加对DataLoader的支持
model.eval()
correct = 0
total = 0
with torch.no_grad():
progress_bar = tqdm(test_loader, desc="Server Updating", unit="batch")
for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return 100 * correct / total
# 主训练流程
def main():
# 创建运行目录
run_dir = create_run_dir()
models_dir = os.path.join(run_dir, "models")
os.makedirs(models_dir, exist_ok=True)
# 初始化日志
logger = Logger(run_dir)
# 初始化模型
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
# 加载数据
client_datasets, public_loader, server_test_loader = prepare_data()
# 记录关键配置
logger.info(f"\n{'#'*30} 训练配置 {'#'*30}")
logger.info(f"设备: {device}")
logger.info(f"客户端数量: {NUM_CLIENTS}")
logger.info(f"训练轮次: {NUM_ROUNDS}")
logger.info(f"本地训练epoch数: {CLIENT_EPOCHS}")
logger.info(f"蒸馏温度: {TEMP}")
logger.info(f"运行目录: {run_dir}")
logger.info(f"{'#'*70}\n")
# 主训练循环
for round in range(NUM_ROUNDS):
logger.info(f"\n{'#'*30} 联邦训练轮次 [{round+1}/{NUM_ROUNDS}] {'#'*30}")
# 客户端选择
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
logger.info(f"选中客户端: {selected_clients}")
# 客户端训练
client_params = []
for cid in selected_clients:
logger.info(f"\n{'='*20} 客户端 {cid} 训练开始 {'='*20}")
local_model = copy.deepcopy(client_models[cid])
updated_params = client_train(local_model, global_server_model,
client_datasets[cid]['train'], logger)
client_params.append(updated_params)
logger.info(f"\n{'='*20} 客户端 {cid} 训练完成 {'='*20}\n")
# 模型聚合
global_client_params = aggregate(client_params)
for model in client_models:
model.load_state_dict(global_client_params)
# 服务器更新
logger.info("\n服务器知识蒸馏更新...")
server_update(global_server_model, client_models, public_loader, logger)
# 性能测试
server_acc = test_model(global_server_model, server_test_loader)
client_accuracies = [
test_model(client_models[i], client_datasets[i]['test'])
for i in range(NUM_CLIENTS)
]
# 记录本轮结果
logger.info(f"\n本轮结果:")
logger.info(f"服务器测试准确率: {server_acc:.2f}%")
for i, acc in enumerate(client_accuracies):
logger.info(f"客户端 {i} 测试准确率: {acc:.2f}%")
# 保存最终模型
torch.save(global_server_model.state_dict(), os.path.join(models_dir, "server_model.pth"))
for i in range(NUM_CLIENTS):
torch.save(client_models[i].state_dict(),
os.path.join(models_dir, f"client{i}_model.pth"))
logger.info(f"\n模型已保存至: {models_dir}")
# 最终测试
logger.info("\n最终测试结果:")
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
server_model.load_state_dict(torch.load(os.path.join(models_dir, "server_model.pth"),weights_only=True))
logger.info(f"服务器最终准确率: {test_model(server_model, server_test_loader):.2f}%")
for i in range(NUM_CLIENTS):
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
client_model.load_state_dict(torch.load(os.path.join(models_dir, f"client{i}_model.pth"),weights_only=True))
test_loader = client_datasets[i]['test']
logger.info(f"客户端 {i} 最终准确率: {test_model(client_model, test_loader):.2f}%")
if __name__ == "__main__":
main()