import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoader, Subset import numpy as np import copy from tqdm import tqdm from data_loader import get_data_loader from logger import * from model.repvit import repvit_m1_1 from model.mobilenetv3 import MobileNetV3 # 配置参数 NUM_CLIENTS = 2 NUM_ROUNDS = 10 CLIENT_EPOCHS = 2 BATCH_SIZE = 32 TEMP = 2.0 # 蒸馏温度 CLASS_NUM = [9, 9, 9] # 设备配置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 数据准备 import os from torchvision.datasets import ImageFolder def prepare_data(): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # 加载所有数据集(训练、验证、测试) dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/JY_A',Cache='RAM') dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_A',Cache='RAM') dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_B',Cache='RAM') # 组织客户端数据集 client_datasets = [ { # Client 0 'train': dataset_B_train, 'val': dataset_B_val, 'test': dataset_B_test }, { # Client 1 'train': dataset_C_train, 'val': dataset_C_val, 'test': dataset_C_test } ] # 公共数据集(使用A的训练集) public_loader = dataset_A_train # 服务器测试集(使用A的测试集) server_test_loader = dataset_A_test return client_datasets, public_loader, server_test_loader # 客户端训练函数 def client_train(client_model, server_model, loader, logger): client_model.train() server_model.eval() optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1) for epoch in range(CLIENT_EPOCHS): epoch_loss = 0.0 task_loss = 0.0 distill_loss = 0.0 correct = 0 total = 0 # 训练进度条 progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{CLIENT_EPOCHS}") for batch_idx, (data, target) in enumerate(progress_bar): data, target = data.to(device), target.to(device) optimizer.zero_grad() # 前向传播 client_output = client_model(data).to(device) # 获取教师模型输出 with torch.no_grad(): server_output = server_model(data).to(device) # 计算损失 loss_task = F.cross_entropy(client_output, target) loss_distill = F.kl_div( F.log_softmax(client_output/TEMP, dim=1), F.softmax(server_output/TEMP, dim=1), reduction="batchmean" ) * (TEMP**2) total_loss = loss_task + loss_distill # 反向传播 total_loss.backward() optimizer.step() # 统计指标 epoch_loss += total_loss.item() task_loss += loss_task.item() distill_loss += loss_distill.item() _, predicted = torch.max(client_output.data, 1) correct += (predicted == target).sum().item() total += target.size(0) # 实时更新进度条 progress_bar.set_postfix({ "Epoch": f"{epoch+1}/{CLIENT_EPOCHS}", "Batch": f"{batch_idx+1}/{len(loader)}", "Loss": f"{total_loss.item():.4f}", "Acc": f"{100*correct/total:.2f}%\n", }) progress_bar.update(1) # 每个epoch结束打印汇总信息 avg_loss = epoch_loss / len(loader) avg_task = task_loss / len(loader) avg_distill = distill_loss / len(loader) epoch_acc = 100 * correct / total logger.info(f"\n客户端训练 Epoch {epoch+1}/{CLIENT_EPOCHS}") logger.info(f"平均损失: {avg_loss:.4f}") logger.info(f"任务损失: {avg_task:.4f}") logger.info(f"蒸馏损失: {avg_distill:.4f}") logger.info(f"训练准确率: {epoch_acc:.2f}%") progress_bar.close() return client_model.state_dict() # 模型参数聚合(FedAvg) def aggregate(client_params): global_params = {} for key in client_params[0].keys(): global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0) return global_params def server_aggregate(server_model, client_models, public_loader): server_model.train() optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) for data, _ in public_loader: data = data.to(device) # 获取客户端模型特征 client_features = [] with torch.no_grad(): for model in client_models: features = model.extract_features(data) # 需要实现特征提取方法 client_features.append(features) # 计算特征蒸馏目标 target_features = torch.stack(client_features).mean(dim=0) # 服务器前向 server_features = server_model.extract_features(data) # 特征对齐损失 loss = F.mse_loss(server_features, target_features) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() # 更新统计信息 total_loss += loss.item() # 服务器知识更新 def server_update(server_model, client_models, public_loader, logger): server_model.train() optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) total_loss = 0.0 progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch") for batch_idx, (data, target) in enumerate(progress_bar): data = data.to(device) optimizer.zero_grad() # 获取客户端模型的平均输出 with torch.no_grad(): client_outputs = [model(data).mean(dim=0, keepdim=True) for model in client_models] soft_targets = torch.stack(client_outputs).mean(dim=0) # 蒸馏学习 server_output = server_model(data) loss = F.kl_div( F.log_softmax(server_output, dim=1), F.softmax(soft_targets, dim=1), reduction="batchmean" ) # 反向传播 loss.backward() optimizer.step() # 更新统计信息 total_loss += loss.item() progress_bar.set_postfix({ "Avg Loss": f"{total_loss/(batch_idx+1):.4f}", "Current Loss": f"{loss.item():.4f}" }) logger.info(f"服务器更新完成 | 平均损失: {total_loss/len(public_loader):.4f}") def test_model(model, test_loader): # 添加对DataLoader的支持 model.eval() correct = 0 total = 0 with torch.no_grad(): progress_bar = tqdm(test_loader, desc="Server Updating", unit="batch") for batch_idx, (data, target) in enumerate(progress_bar): data, target = data.to(device), target.to(device) output = model(data) _, predicted = torch.max(output.data, 1) total += target.size(0) correct += (predicted == target).sum().item() return 100 * correct / total # 主训练流程 def main(): # 创建运行目录 run_dir = create_run_dir() models_dir = os.path.join(run_dir, "models") os.makedirs(models_dir, exist_ok=True) # 初始化日志 logger = Logger(run_dir) # 初始化模型 global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)] # 加载数据 client_datasets, public_loader, server_test_loader = prepare_data() # 记录关键配置 logger.info(f"\n{'#'*30} 训练配置 {'#'*30}") logger.info(f"设备: {device}") logger.info(f"客户端数量: {NUM_CLIENTS}") logger.info(f"训练轮次: {NUM_ROUNDS}") logger.info(f"本地训练epoch数: {CLIENT_EPOCHS}") logger.info(f"蒸馏温度: {TEMP}") logger.info(f"运行目录: {run_dir}") logger.info(f"{'#'*70}\n") # 主训练循环 for round in range(NUM_ROUNDS): logger.info(f"\n{'#'*30} 联邦训练轮次 [{round+1}/{NUM_ROUNDS}] {'#'*30}") # 客户端选择 selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False) logger.info(f"选中客户端: {selected_clients}") # 客户端训练 client_params = [] for cid in selected_clients: logger.info(f"\n{'='*20} 客户端 {cid} 训练开始 {'='*20}") local_model = copy.deepcopy(client_models[cid]) updated_params = client_train(local_model, global_server_model, client_datasets[cid]['train'], logger) client_params.append(updated_params) logger.info(f"\n{'='*20} 客户端 {cid} 训练完成 {'='*20}\n") # 模型聚合 global_client_params = aggregate(client_params) for model in client_models: model.load_state_dict(global_client_params) # 服务器更新 logger.info("\n服务器知识蒸馏更新...") server_update(global_server_model, client_models, public_loader, logger) # 性能测试 server_acc = test_model(global_server_model, server_test_loader) client_accuracies = [ test_model(client_models[i], client_datasets[i]['test']) for i in range(NUM_CLIENTS) ] # 记录本轮结果 logger.info(f"\n本轮结果:") logger.info(f"服务器测试准确率: {server_acc:.2f}%") for i, acc in enumerate(client_accuracies): logger.info(f"客户端 {i} 测试准确率: {acc:.2f}%") # 保存最终模型 torch.save(global_server_model.state_dict(), os.path.join(models_dir, "server_model.pth")) for i in range(NUM_CLIENTS): torch.save(client_models[i].state_dict(), os.path.join(models_dir, f"client{i}_model.pth")) logger.info(f"\n模型已保存至: {models_dir}") # 最终测试 logger.info("\n最终测试结果:") server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) server_model.load_state_dict(torch.load(os.path.join(models_dir, "server_model.pth"),weights_only=True)) logger.info(f"服务器最终准确率: {test_model(server_model, server_test_loader):.2f}%") for i in range(NUM_CLIENTS): client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) client_model.load_state_dict(torch.load(os.path.join(models_dir, f"client{i}_model.pth"),weights_only=True)) test_loader = client_datasets[i]['test'] logger.info(f"客户端 {i} 最终准确率: {test_model(client_model, test_loader):.2f}%") if __name__ == "__main__": main()