更新联邦学习代码

This commit is contained in:
yoiannis 2025-03-12 14:00:50 +08:00
parent 4cb9790dee
commit a5ca9d04d7
3 changed files with 89 additions and 77 deletions

141
FED.py
View File

@ -8,6 +8,7 @@ import copy
from tqdm import tqdm
from data_loader import get_data_loader
from model.repvit import repvit_m1_1
from model.mobilenetv3 import MobileNetV3
@ -17,7 +18,7 @@ NUM_ROUNDS = 10
CLIENT_EPOCHS = 2
BATCH_SIZE = 32
TEMP = 2.0 # 蒸馏温度
CLASS_NUM = [3, 3, 3]
CLASS_NUM = [9, 9, 9]
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -32,49 +33,60 @@ def prepare_data():
transforms.ToTensor()
])
# Load datasets
dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform)
dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform)
dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform)
# 加载所有数据集(训练、验证、测试)
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/JY_A',Cache='RAM')
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_A',Cache='RAM')
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_B',Cache='RAM')
# Assign datasets to clients
client_datasets = [dataset_B, dataset_C]
# 组织客户端数据集
client_datasets = [
{ # Client 0
'train': dataset_B_train,
'val': dataset_B_val,
'test': dataset_B_test
},
{ # Client 1
'train': dataset_C_train,
'val': dataset_C_val,
'test': dataset_C_test
}
]
# Server dataset (A) for public updates
public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True)
# 公共数据集使用A的训练集
public_loader = dataset_A_train
return client_datasets, public_loader
# 服务器测试集使用A的测试集
server_test_loader = dataset_A_test
return client_datasets, public_loader, server_test_loader
# 客户端训练函数
def client_train(client_model, server_model, dataset):
def client_train(client_model, server_model, loader):
client_model.train()
server_model.eval()
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# 训练进度条
progress_bar = tqdm(total=CLIENT_EPOCHS*len(loader),
desc="Client Training",
unit="batch")
for epoch in range(CLIENT_EPOCHS):
epoch_loss = 0.0
task_loss = 0.0
distill_loss = 0.0
correct = 0
total = 0
# 训练进度条
progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{CLIENT_EPOCHS}")
for batch_idx, (data, target) in enumerate(loader):
for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# 前向传播
client_output = client_model(data)
client_output = client_model(data).to(device)
# 获取教师模型输出
with torch.no_grad():
server_output = server_model(data)
server_output = server_model(data).to(device)
# 计算损失
loss_task = F.cross_entropy(client_output, target)
@ -166,9 +178,10 @@ def server_update(server_model, client_models, public_loader):
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
total_loss = 0.0
progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch")
for batch_idx, (data, _) in enumerate(progress_bar):
for batch_idx, (data, target) in enumerate(progress_bar):
data = data.to(device)
optimizer.zero_grad()
@ -199,97 +212,97 @@ def server_update(server_model, client_models, public_loader):
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
def test_model(model, test_loader):
def test_model(model, test_loader): # 添加对DataLoader的支持
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
progress_bar = tqdm(test_loader, desc="Server Updating", unit="batch")
for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = 100 * correct / total
return accuracy
return 100 * correct / total
# 主训练流程
def main():
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Initialize models
# 初始化模型(保持不变)
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
# Prepare data
client_datasets, public_loader = prepare_data()
# Test dataset (using dataset A's test set for simplicity)
test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
# 加载数据集
client_datasets, public_loader, server_test_loader = prepare_data()
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
for round in range(NUM_ROUNDS):
print(f"\n{'#'*50}")
print(f"Federated Round {round+1}/{NUM_ROUNDS}")
print(f"Round {round+1}/{NUM_ROUNDS}")
print(f"{'#'*50}")
# Client selection (only 2 clients)
# 客户端选择
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
print(f"Selected Clients: {selected_clients}")
print(f"Selected clients: {selected_clients}")
# Client local training
# 客户端训练
client_params = []
for cid in selected_clients:
print(f"\nTraining Client {cid}")
local_model = copy.deepcopy(client_models[cid])
local_model.load_state_dict(client_models[cid].state_dict())
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
# 传入客户端的训练集
updated_params = client_train(
local_model,
global_server_model,
client_datasets[cid]['train'] # 使用训练集
)
client_params.append(updated_params)
# Model aggregation
# 模型聚合
global_client_params = aggregate(client_params)
for model in client_models:
model.load_state_dict(global_client_params)
# Server knowledge update
print("\nServer Updating...")
# 服务器更新
print("\nServer updating...")
server_update(global_server_model, client_models, public_loader)
# Test model performance
server_acc = test_model(global_server_model, test_loader)
client_acc = test_model(client_models[0], test_loader)
print(f"\nRound {round+1} Performance:")
print(f"Global Model Accuracy: {server_acc:.2f}%")
print(f"Client Model Accuracy: {client_acc:.2f}%")
# 测试性能
server_acc = test_model(global_server_model, server_test_loader)
client_accuracies = [
test_model(client_models[i],
client_datasets[i]['test']) # 动态创建测试loader
for i in range(NUM_CLIENTS)
]
print(f"\nRound {round+1} Results:")
print(f"Server Accuracy: {server_acc:.2f}%")
for i, acc in enumerate(client_accuracies):
print(f"Client {i} Accuracy: {acc:.2f}%")
round_progress.update(1)
print(f"Round {round+1} completed")
print("Training completed!")
# Save trained models
# 保存模型
torch.save(global_server_model.state_dict(), "server_model.pth")
for i in range(NUM_CLIENTS):
torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth")
print("Models saved successfully.")
torch.save(client_models[i].state_dict(), f"client{i}_model.pth")
# Test server model
# 最终测试
print("\nFinal Evaluation:")
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
server_acc = test_model(server_model, test_loader)
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
print(f"Server Accuracy: {test_model(server_model, server_test_loader):.2f}%")
# Test client model
for i in range(NUM_CLIENTS):
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True))
client_acc = test_model(client_model, test_loader)
print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%")
client_model.load_state_dict(torch.load(f"client{i}_model.pth",weights_only=True))
test_loader = client_datasets[i]['test']
print(f"Client {i} Accuracy: {test_model(client_model, test_loader):.2f}%")
if __name__ == "__main__":
main()

View File

@ -5,8 +5,11 @@ import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm # 导入 tqdm
import logging
class ImageClassificationDataset(Dataset):
def __init__(self, root_dir, transform=None,Cache=False):
def __init__(self, root_dir, transform=None, Cache=False):
self.root_dir = root_dir
self.transform = transform
self.classes = sorted(os.listdir(root_dir))
@ -20,7 +23,8 @@ class ImageClassificationDataset(Dataset):
"init the dataloader"
)
for cls_name in self.classes:
# 使用 tqdm 显示进度条
for cls_name in tqdm(self.classes, desc="Loading images"):
cls_dir = os.path.join(root_dir, cls_name)
for img_name in os.listdir(cls_dir):
try:
@ -33,11 +37,8 @@ class ImageClassificationDataset(Dataset):
else:
self.image_paths.append(img_path)
self.labels.append(self.class_to_idx[cls_name])
except:
logger.log("info",
"read image error " +
img_path
)
except Exception as e:
logger.log("info", f"Read image error: {img_path} - {e}")
def __len__(self):
return len(self.labels)
@ -46,12 +47,11 @@ class ImageClassificationDataset(Dataset):
label = self.labels[idx]
if self.Cache == 'RAM':
image = self.image[idx]
else:
else:
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):

View File

@ -2,9 +2,8 @@ import os
import shutil
import random
def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
# 确保比例总和为1
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1"
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
@ -55,6 +54,6 @@ def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15,
print("Dataset successfully split into train, validation, and test sets.")
# 使用示例
base_directory = 'F:/dataset/02.TA_EC/EC27/JY_A'
output_directory = 'F:/dataset/02.TA_EC/datasets/EC27'
base_directory = 'L:/Grade_datasets/SPLIT/JY_A'
output_directory = 'L:/Grade_datasets/train/JY_A'
create_dataset_splits(base_directory, output_directory)