295 lines
10 KiB
Python
295 lines
10 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torchvision import datasets, transforms
|
||
from torch.utils.data import DataLoader, Subset
|
||
import numpy as np
|
||
import copy
|
||
|
||
from tqdm import tqdm
|
||
|
||
from model.repvit import repvit_m1_1
|
||
from model.mobilenetv3 import MobileNetV3
|
||
|
||
# 配置参数
|
||
NUM_CLIENTS = 2
|
||
NUM_ROUNDS = 10
|
||
CLIENT_EPOCHS = 2
|
||
BATCH_SIZE = 32
|
||
TEMP = 2.0 # 蒸馏温度
|
||
CLASS_NUM = [3, 3, 3]
|
||
|
||
# 设备配置
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# 数据准备
|
||
import os
|
||
from torchvision.datasets import ImageFolder
|
||
|
||
def prepare_data():
|
||
transform = transforms.Compose([
|
||
transforms.Resize((224, 224)),
|
||
transforms.ToTensor()
|
||
])
|
||
|
||
# Load datasets
|
||
dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform)
|
||
dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform)
|
||
dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform)
|
||
|
||
# Assign datasets to clients
|
||
client_datasets = [dataset_B, dataset_C]
|
||
|
||
# Server dataset (A) for public updates
|
||
public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True)
|
||
|
||
return client_datasets, public_loader
|
||
|
||
# 客户端训练函数
|
||
def client_train(client_model, server_model, dataset):
|
||
client_model.train()
|
||
server_model.eval()
|
||
|
||
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
|
||
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||
|
||
# 训练进度条
|
||
progress_bar = tqdm(total=CLIENT_EPOCHS*len(loader),
|
||
desc="Client Training",
|
||
unit="batch")
|
||
|
||
for epoch in range(CLIENT_EPOCHS):
|
||
epoch_loss = 0.0
|
||
task_loss = 0.0
|
||
distill_loss = 0.0
|
||
correct = 0
|
||
total = 0
|
||
|
||
for batch_idx, (data, target) in enumerate(loader):
|
||
data, target = data.to(device), target.to(device)
|
||
optimizer.zero_grad()
|
||
|
||
# 前向传播
|
||
client_output = client_model(data)
|
||
|
||
# 获取教师模型输出
|
||
with torch.no_grad():
|
||
server_output = server_model(data)
|
||
|
||
# 计算损失
|
||
loss_task = F.cross_entropy(client_output, target)
|
||
loss_distill = F.kl_div(
|
||
F.log_softmax(client_output/TEMP, dim=1),
|
||
F.softmax(server_output/TEMP, dim=1),
|
||
reduction="batchmean"
|
||
) * (TEMP**2)
|
||
total_loss = loss_task + loss_distill
|
||
|
||
# 反向传播
|
||
total_loss.backward()
|
||
optimizer.step()
|
||
|
||
# 统计指标
|
||
epoch_loss += total_loss.item()
|
||
task_loss += loss_task.item()
|
||
distill_loss += loss_distill.item()
|
||
_, predicted = torch.max(client_output.data, 1)
|
||
correct += (predicted == target).sum().item()
|
||
total += target.size(0)
|
||
|
||
# 实时更新进度条
|
||
progress_bar.set_postfix({
|
||
"Epoch": f"{epoch+1}/{CLIENT_EPOCHS}",
|
||
"Batch": f"{batch_idx+1}/{len(loader)}",
|
||
"Loss": f"{total_loss.item():.4f}",
|
||
"Acc": f"{100*correct/total:.2f}%\n",
|
||
})
|
||
progress_bar.update(1)
|
||
|
||
# 每个epoch结束打印汇总信息
|
||
avg_loss = epoch_loss / len(loader)
|
||
avg_task = task_loss / len(loader)
|
||
avg_distill = distill_loss / len(loader)
|
||
epoch_acc = 100 * correct / total
|
||
print(f"\n{'='*40}")
|
||
print(f"Epoch {epoch+1} Summary:")
|
||
print(f"Average Loss: {avg_loss:.4f}")
|
||
print(f"Task Loss: {avg_task:.4f}")
|
||
print(f"Distill Loss: {avg_distill:.4f}")
|
||
print(f"Training Accuracy: {epoch_acc:.2f}%")
|
||
print(f"{'='*40}\n")
|
||
|
||
progress_bar.close()
|
||
return client_model.state_dict()
|
||
|
||
# 模型参数聚合(FedAvg)
|
||
def aggregate(client_params):
|
||
global_params = {}
|
||
for key in client_params[0].keys():
|
||
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
|
||
return global_params
|
||
|
||
def server_aggregate(server_model, client_models, public_loader):
|
||
server_model.train()
|
||
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
||
|
||
for data, _ in public_loader:
|
||
data = data.to(device)
|
||
|
||
# 获取客户端模型特征
|
||
client_features = []
|
||
with torch.no_grad():
|
||
for model in client_models:
|
||
features = model.extract_features(data) # 需要实现特征提取方法
|
||
client_features.append(features)
|
||
|
||
# 计算特征蒸馏目标
|
||
target_features = torch.stack(client_features).mean(dim=0)
|
||
|
||
# 服务器前向
|
||
server_features = server_model.extract_features(data)
|
||
|
||
# 特征对齐损失
|
||
loss = F.mse_loss(server_features, target_features)
|
||
|
||
# 反向传播
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
# 更新统计信息
|
||
total_loss += loss.item()
|
||
|
||
# 服务器知识更新
|
||
def server_update(server_model, client_models, public_loader):
|
||
server_model.train()
|
||
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
||
|
||
total_loss = 0.0
|
||
progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch")
|
||
|
||
for batch_idx, (data, _) in enumerate(progress_bar):
|
||
data = data.to(device)
|
||
optimizer.zero_grad()
|
||
|
||
# 获取客户端模型的平均输出
|
||
with torch.no_grad():
|
||
client_outputs = [model(data).mean(dim=0, keepdim=True) for model in client_models]
|
||
soft_targets = torch.stack(client_outputs).mean(dim=0)
|
||
|
||
# 蒸馏学习
|
||
server_output = server_model(data)
|
||
loss = F.kl_div(
|
||
F.log_softmax(server_output, dim=1),
|
||
F.softmax(soft_targets, dim=1),
|
||
reduction="batchmean"
|
||
)
|
||
|
||
# 反向传播
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
# 更新统计信息
|
||
total_loss += loss.item()
|
||
progress_bar.set_postfix({
|
||
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}",
|
||
"Current Loss": f"{loss.item():.4f}"
|
||
})
|
||
|
||
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
|
||
|
||
|
||
def test_model(model, test_loader):
|
||
model.eval()
|
||
correct = 0
|
||
total = 0
|
||
with torch.no_grad():
|
||
for data, target in test_loader:
|
||
data, target = data.to(device), target.to(device)
|
||
output = model(data)
|
||
_, predicted = torch.max(output.data, 1)
|
||
total += target.size(0)
|
||
correct += (predicted == target).sum().item()
|
||
accuracy = 100 * correct / total
|
||
return accuracy
|
||
|
||
|
||
# 主训练流程
|
||
def main():
|
||
transform = transforms.Compose([
|
||
transforms.Resize((224, 224)),
|
||
transforms.ToTensor()
|
||
])
|
||
# Initialize models
|
||
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
|
||
|
||
# Prepare data
|
||
client_datasets, public_loader = prepare_data()
|
||
|
||
# Test dataset (using dataset A's test set for simplicity)
|
||
test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform)
|
||
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
||
|
||
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
||
|
||
for round in range(NUM_ROUNDS):
|
||
print(f"\n{'#'*50}")
|
||
print(f"Federated Round {round+1}/{NUM_ROUNDS}")
|
||
print(f"{'#'*50}")
|
||
|
||
# Client selection (only 2 clients)
|
||
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
|
||
print(f"Selected Clients: {selected_clients}")
|
||
|
||
# Client local training
|
||
client_params = []
|
||
for cid in selected_clients:
|
||
print(f"\nTraining Client {cid}")
|
||
local_model = copy.deepcopy(client_models[cid])
|
||
local_model.load_state_dict(client_models[cid].state_dict())
|
||
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
|
||
client_params.append(updated_params)
|
||
|
||
# Model aggregation
|
||
global_client_params = aggregate(client_params)
|
||
for model in client_models:
|
||
model.load_state_dict(global_client_params)
|
||
|
||
# Server knowledge update
|
||
print("\nServer Updating...")
|
||
server_update(global_server_model, client_models, public_loader)
|
||
|
||
# Test model performance
|
||
server_acc = test_model(global_server_model, test_loader)
|
||
client_acc = test_model(client_models[0], test_loader)
|
||
print(f"\nRound {round+1} Performance:")
|
||
print(f"Global Model Accuracy: {server_acc:.2f}%")
|
||
print(f"Client Model Accuracy: {client_acc:.2f}%")
|
||
|
||
round_progress.update(1)
|
||
print(f"Round {round+1} completed")
|
||
|
||
print("Training completed!")
|
||
|
||
# Save trained models
|
||
torch.save(global_server_model.state_dict(), "server_model.pth")
|
||
for i in range(NUM_CLIENTS):
|
||
torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth")
|
||
print("Models saved successfully.")
|
||
|
||
# Test server model
|
||
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
|
||
server_acc = test_model(server_model, test_loader)
|
||
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
||
|
||
# Test client model
|
||
for i in range(NUM_CLIENTS):
|
||
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
|
||
client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True))
|
||
client_acc = test_model(client_model, test_loader)
|
||
print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%")
|
||
|
||
if __name__ == "__main__":
|
||
main() |