TA_EC/FED.py

308 lines
10 KiB
Python
Raw Normal View History

2025-03-04 04:55:40 +00:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
2025-03-09 14:36:22 +00:00
from tqdm import tqdm
2025-03-12 06:00:50 +00:00
from data_loader import get_data_loader
2025-03-09 14:36:22 +00:00
from model.repvit import repvit_m1_1
from model.mobilenetv3 import MobileNetV3
2025-03-04 04:55:40 +00:00
# 配置参数
2025-03-11 15:12:34 +00:00
NUM_CLIENTS = 2
2025-03-11 16:21:31 +00:00
NUM_ROUNDS = 10
CLIENT_EPOCHS = 2
2025-03-04 04:55:40 +00:00
BATCH_SIZE = 32
TEMP = 2.0 # 蒸馏温度
2025-03-12 06:00:50 +00:00
CLASS_NUM = [9, 9, 9]
2025-03-04 04:55:40 +00:00
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据准备
2025-03-11 15:12:34 +00:00
import os
from torchvision.datasets import ImageFolder
def prepare_data():
2025-03-09 14:36:22 +00:00
transform = transforms.Compose([
2025-03-11 15:12:34 +00:00
transforms.Resize((224, 224)),
transforms.ToTensor()
])
2025-03-12 06:00:50 +00:00
# 加载所有数据集(训练、验证、测试)
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/JY_A',Cache='RAM')
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_A',Cache='RAM')
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_B',Cache='RAM')
2025-03-04 04:55:40 +00:00
2025-03-12 06:00:50 +00:00
# 组织客户端数据集
client_datasets = [
{ # Client 0
'train': dataset_B_train,
'val': dataset_B_val,
'test': dataset_B_test
},
{ # Client 1
'train': dataset_C_train,
'val': dataset_C_val,
'test': dataset_C_test
}
]
2025-03-04 04:55:40 +00:00
2025-03-12 06:00:50 +00:00
# 公共数据集使用A的训练集
public_loader = dataset_A_train
2025-03-11 15:12:34 +00:00
2025-03-12 06:00:50 +00:00
# 服务器测试集使用A的测试集
server_test_loader = dataset_A_test
return client_datasets, public_loader, server_test_loader
2025-03-04 04:55:40 +00:00
# 客户端训练函数
2025-03-12 06:00:50 +00:00
def client_train(client_model, server_model, loader):
2025-03-04 04:55:40 +00:00
client_model.train()
server_model.eval()
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
2025-03-12 06:00:50 +00:00
2025-03-09 14:36:22 +00:00
for epoch in range(CLIENT_EPOCHS):
epoch_loss = 0.0
task_loss = 0.0
distill_loss = 0.0
correct = 0
total = 0
2025-03-12 06:00:50 +00:00
# 训练进度条
progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{CLIENT_EPOCHS}")
2025-03-09 14:36:22 +00:00
2025-03-12 06:00:50 +00:00
for batch_idx, (data, target) in enumerate(progress_bar):
2025-03-04 04:55:40 +00:00
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
2025-03-09 14:36:22 +00:00
# 前向传播
2025-03-12 06:00:50 +00:00
client_output = client_model(data).to(device)
2025-03-04 04:55:40 +00:00
2025-03-09 14:36:22 +00:00
# 获取教师模型输出
2025-03-04 04:55:40 +00:00
with torch.no_grad():
2025-03-12 06:00:50 +00:00
server_output = server_model(data).to(device)
2025-03-04 04:55:40 +00:00
2025-03-09 14:36:22 +00:00
# 计算损失
2025-03-04 04:55:40 +00:00
loss_task = F.cross_entropy(client_output, target)
loss_distill = F.kl_div(
F.log_softmax(client_output/TEMP, dim=1),
F.softmax(server_output/TEMP, dim=1),
reduction="batchmean"
) * (TEMP**2)
total_loss = loss_task + loss_distill
2025-03-09 14:36:22 +00:00
# 反向传播
2025-03-04 04:55:40 +00:00
total_loss.backward()
optimizer.step()
2025-03-09 14:36:22 +00:00
# 统计指标
epoch_loss += total_loss.item()
task_loss += loss_task.item()
distill_loss += loss_distill.item()
_, predicted = torch.max(client_output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
# 实时更新进度条
progress_bar.set_postfix({
"Epoch": f"{epoch+1}/{CLIENT_EPOCHS}",
"Batch": f"{batch_idx+1}/{len(loader)}",
"Loss": f"{total_loss.item():.4f}",
"Acc": f"{100*correct/total:.2f}%\n",
})
progress_bar.update(1)
# 每个epoch结束打印汇总信息
avg_loss = epoch_loss / len(loader)
avg_task = task_loss / len(loader)
avg_distill = distill_loss / len(loader)
epoch_acc = 100 * correct / total
print(f"\n{'='*40}")
print(f"Epoch {epoch+1} Summary:")
print(f"Average Loss: {avg_loss:.4f}")
print(f"Task Loss: {avg_task:.4f}")
print(f"Distill Loss: {avg_distill:.4f}")
print(f"Training Accuracy: {epoch_acc:.2f}%")
print(f"{'='*40}\n")
2025-03-04 04:55:40 +00:00
2025-03-09 14:36:22 +00:00
progress_bar.close()
2025-03-04 04:55:40 +00:00
return client_model.state_dict()
# 模型参数聚合FedAvg
def aggregate(client_params):
global_params = {}
for key in client_params[0].keys():
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
return global_params
2025-03-11 16:21:31 +00:00
def server_aggregate(server_model, client_models, public_loader):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
for data, _ in public_loader:
data = data.to(device)
# 获取客户端模型特征
client_features = []
with torch.no_grad():
for model in client_models:
features = model.extract_features(data) # 需要实现特征提取方法
client_features.append(features)
# 计算特征蒸馏目标
target_features = torch.stack(client_features).mean(dim=0)
# 服务器前向
server_features = server_model.extract_features(data)
# 特征对齐损失
loss = F.mse_loss(server_features, target_features)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 更新统计信息
total_loss += loss.item()
2025-03-04 04:55:40 +00:00
# 服务器知识更新
def server_update(server_model, client_models, public_loader):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
2025-03-09 14:36:22 +00:00
total_loss = 0.0
2025-03-12 06:00:50 +00:00
2025-03-09 14:36:22 +00:00
progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch")
2025-03-12 06:00:50 +00:00
for batch_idx, (data, target) in enumerate(progress_bar):
2025-03-04 04:55:40 +00:00
data = data.to(device)
optimizer.zero_grad()
# 获取客户端模型的平均输出
with torch.no_grad():
client_outputs = [model(data).mean(dim=0, keepdim=True) for model in client_models]
soft_targets = torch.stack(client_outputs).mean(dim=0)
# 蒸馏学习
server_output = server_model(data)
loss = F.kl_div(
F.log_softmax(server_output, dim=1),
F.softmax(soft_targets, dim=1),
reduction="batchmean"
)
2025-03-09 14:36:22 +00:00
# 反向传播
2025-03-04 04:55:40 +00:00
loss.backward()
optimizer.step()
2025-03-09 14:36:22 +00:00
# 更新统计信息
total_loss += loss.item()
progress_bar.set_postfix({
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}",
"Current Loss": f"{loss.item():.4f}"
})
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
2025-03-04 04:55:40 +00:00
2025-03-12 06:00:50 +00:00
def test_model(model, test_loader): # 添加对DataLoader的支持
2025-03-04 04:55:40 +00:00
model.eval()
correct = 0
total = 0
with torch.no_grad():
2025-03-12 06:00:50 +00:00
progress_bar = tqdm(test_loader, desc="Server Updating", unit="batch")
for batch_idx, (data, target) in enumerate(progress_bar):
2025-03-04 04:55:40 +00:00
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
2025-03-12 06:00:50 +00:00
return 100 * correct / total
2025-03-04 04:55:40 +00:00
2025-03-09 14:36:22 +00:00
2025-03-04 04:55:40 +00:00
# 主训练流程
def main():
2025-03-12 06:00:50 +00:00
# 初始化模型(保持不变)
2025-03-11 16:21:31 +00:00
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
2025-03-04 04:55:40 +00:00
2025-03-12 06:00:50 +00:00
# 加载数据集
client_datasets, public_loader, server_test_loader = prepare_data()
2025-03-09 14:36:22 +00:00
2025-03-11 15:12:34 +00:00
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
2025-03-04 04:55:40 +00:00
for round in range(NUM_ROUNDS):
2025-03-09 14:36:22 +00:00
print(f"\n{'#'*50}")
2025-03-12 06:00:50 +00:00
print(f"Round {round+1}/{NUM_ROUNDS}")
2025-03-09 14:36:22 +00:00
print(f"{'#'*50}")
2025-03-12 06:00:50 +00:00
# 客户端选择
2025-03-09 14:36:22 +00:00
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
2025-03-12 06:00:50 +00:00
print(f"Selected clients: {selected_clients}")
2025-03-04 04:55:40 +00:00
2025-03-12 06:00:50 +00:00
# 客户端训练
2025-03-04 04:55:40 +00:00
client_params = []
for cid in selected_clients:
2025-03-09 14:36:22 +00:00
print(f"\nTraining Client {cid}")
2025-03-04 04:55:40 +00:00
local_model = copy.deepcopy(client_models[cid])
local_model.load_state_dict(client_models[cid].state_dict())
2025-03-12 06:00:50 +00:00
# 传入客户端的训练集
updated_params = client_train(
local_model,
global_server_model,
client_datasets[cid]['train'] # 使用训练集
)
2025-03-04 04:55:40 +00:00
client_params.append(updated_params)
2025-03-12 06:00:50 +00:00
# 模型聚合
2025-03-04 04:55:40 +00:00
global_client_params = aggregate(client_params)
for model in client_models:
model.load_state_dict(global_client_params)
2025-03-12 06:00:50 +00:00
# 服务器更新
print("\nServer updating...")
2025-03-04 04:55:40 +00:00
server_update(global_server_model, client_models, public_loader)
2025-03-12 06:00:50 +00:00
# 测试性能
server_acc = test_model(global_server_model, server_test_loader)
client_accuracies = [
test_model(client_models[i],
client_datasets[i]['test']) # 动态创建测试loader
for i in range(NUM_CLIENTS)
]
print(f"\nRound {round+1} Results:")
print(f"Server Accuracy: {server_acc:.2f}%")
for i, acc in enumerate(client_accuracies):
print(f"Client {i} Accuracy: {acc:.2f}%")
2025-03-09 14:36:22 +00:00
round_progress.update(1)
2025-03-11 15:12:34 +00:00
2025-03-12 06:00:50 +00:00
# 保存模型
2025-03-04 04:55:40 +00:00
torch.save(global_server_model.state_dict(), "server_model.pth")
2025-03-11 16:21:31 +00:00
for i in range(NUM_CLIENTS):
2025-03-12 06:00:50 +00:00
torch.save(client_models[i].state_dict(), f"client{i}_model.pth")
2025-03-11 15:12:34 +00:00
2025-03-12 06:00:50 +00:00
# 最终测试
print("\nFinal Evaluation:")
2025-03-11 16:21:31 +00:00
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
2025-03-12 06:00:50 +00:00
print(f"Server Accuracy: {test_model(server_model, server_test_loader):.2f}%")
2025-03-11 15:12:34 +00:00
2025-03-11 16:21:31 +00:00
for i in range(NUM_CLIENTS):
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
2025-03-12 06:00:50 +00:00
client_model.load_state_dict(torch.load(f"client{i}_model.pth",weights_only=True))
test_loader = client_datasets[i]['test']
print(f"Client {i} Accuracy: {test_model(client_model, test_loader):.2f}%")
2025-03-04 04:55:40 +00:00
if __name__ == "__main__":
main()