更新联邦学习代码

This commit is contained in:
yoiannis 2025-03-12 14:00:50 +08:00
parent 4cb9790dee
commit a5ca9d04d7
3 changed files with 89 additions and 77 deletions

141
FED.py
View File

@ -8,6 +8,7 @@ import copy
from tqdm import tqdm from tqdm import tqdm
from data_loader import get_data_loader
from model.repvit import repvit_m1_1 from model.repvit import repvit_m1_1
from model.mobilenetv3 import MobileNetV3 from model.mobilenetv3 import MobileNetV3
@ -17,7 +18,7 @@ NUM_ROUNDS = 10
CLIENT_EPOCHS = 2 CLIENT_EPOCHS = 2
BATCH_SIZE = 32 BATCH_SIZE = 32
TEMP = 2.0 # 蒸馏温度 TEMP = 2.0 # 蒸馏温度
CLASS_NUM = [3, 3, 3] CLASS_NUM = [9, 9, 9]
# 设备配置 # 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -32,49 +33,60 @@ def prepare_data():
transforms.ToTensor() transforms.ToTensor()
]) ])
# Load datasets # 加载所有数据集(训练、验证、测试)
dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform) dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/JY_A',Cache='RAM')
dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform) dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_A',Cache='RAM')
dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform) dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_B',Cache='RAM')
# Assign datasets to clients # 组织客户端数据集
client_datasets = [dataset_B, dataset_C] client_datasets = [
{ # Client 0
'train': dataset_B_train,
'val': dataset_B_val,
'test': dataset_B_test
},
{ # Client 1
'train': dataset_C_train,
'val': dataset_C_val,
'test': dataset_C_test
}
]
# Server dataset (A) for public updates # 公共数据集使用A的训练集
public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True) public_loader = dataset_A_train
return client_datasets, public_loader # 服务器测试集使用A的测试集
server_test_loader = dataset_A_test
return client_datasets, public_loader, server_test_loader
# 客户端训练函数 # 客户端训练函数
def client_train(client_model, server_model, dataset): def client_train(client_model, server_model, loader):
client_model.train() client_model.train()
server_model.eval() server_model.eval()
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1) optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# 训练进度条
progress_bar = tqdm(total=CLIENT_EPOCHS*len(loader),
desc="Client Training",
unit="batch")
for epoch in range(CLIENT_EPOCHS): for epoch in range(CLIENT_EPOCHS):
epoch_loss = 0.0 epoch_loss = 0.0
task_loss = 0.0 task_loss = 0.0
distill_loss = 0.0 distill_loss = 0.0
correct = 0 correct = 0
total = 0 total = 0
# 训练进度条
progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{CLIENT_EPOCHS}")
for batch_idx, (data, target) in enumerate(loader): for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
optimizer.zero_grad() optimizer.zero_grad()
# 前向传播 # 前向传播
client_output = client_model(data) client_output = client_model(data).to(device)
# 获取教师模型输出 # 获取教师模型输出
with torch.no_grad(): with torch.no_grad():
server_output = server_model(data) server_output = server_model(data).to(device)
# 计算损失 # 计算损失
loss_task = F.cross_entropy(client_output, target) loss_task = F.cross_entropy(client_output, target)
@ -166,9 +178,10 @@ def server_update(server_model, client_models, public_loader):
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
total_loss = 0.0 total_loss = 0.0
progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch") progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch")
for batch_idx, (data, _) in enumerate(progress_bar): for batch_idx, (data, target) in enumerate(progress_bar):
data = data.to(device) data = data.to(device)
optimizer.zero_grad() optimizer.zero_grad()
@ -199,97 +212,97 @@ def server_update(server_model, client_models, public_loader):
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n") print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
def test_model(model, test_loader): def test_model(model, test_loader): # 添加对DataLoader的支持
model.eval() model.eval()
correct = 0 correct = 0
total = 0 total = 0
with torch.no_grad(): with torch.no_grad():
for data, target in test_loader: progress_bar = tqdm(test_loader, desc="Server Updating", unit="batch")
for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
output = model(data) output = model(data)
_, predicted = torch.max(output.data, 1) _, predicted = torch.max(output.data, 1)
total += target.size(0) total += target.size(0)
correct += (predicted == target).sum().item() correct += (predicted == target).sum().item()
accuracy = 100 * correct / total return 100 * correct / total
return accuracy
# 主训练流程 # 主训练流程
def main(): def main():
transform = transforms.Compose([ # 初始化模型(保持不变)
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Initialize models
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)] client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
# Prepare data # 加载数据集
client_datasets, public_loader = prepare_data() client_datasets, public_loader, server_test_loader = prepare_data()
# Test dataset (using dataset A's test set for simplicity)
test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round") round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
for round in range(NUM_ROUNDS): for round in range(NUM_ROUNDS):
print(f"\n{'#'*50}") print(f"\n{'#'*50}")
print(f"Federated Round {round+1}/{NUM_ROUNDS}") print(f"Round {round+1}/{NUM_ROUNDS}")
print(f"{'#'*50}") print(f"{'#'*50}")
# Client selection (only 2 clients) # 客户端选择
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False) selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
print(f"Selected Clients: {selected_clients}") print(f"Selected clients: {selected_clients}")
# Client local training # 客户端训练
client_params = [] client_params = []
for cid in selected_clients: for cid in selected_clients:
print(f"\nTraining Client {cid}") print(f"\nTraining Client {cid}")
local_model = copy.deepcopy(client_models[cid]) local_model = copy.deepcopy(client_models[cid])
local_model.load_state_dict(client_models[cid].state_dict()) local_model.load_state_dict(client_models[cid].state_dict())
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
# 传入客户端的训练集
updated_params = client_train(
local_model,
global_server_model,
client_datasets[cid]['train'] # 使用训练集
)
client_params.append(updated_params) client_params.append(updated_params)
# Model aggregation # 模型聚合
global_client_params = aggregate(client_params) global_client_params = aggregate(client_params)
for model in client_models: for model in client_models:
model.load_state_dict(global_client_params) model.load_state_dict(global_client_params)
# Server knowledge update # 服务器更新
print("\nServer Updating...") print("\nServer updating...")
server_update(global_server_model, client_models, public_loader) server_update(global_server_model, client_models, public_loader)
# Test model performance # 测试性能
server_acc = test_model(global_server_model, test_loader) server_acc = test_model(global_server_model, server_test_loader)
client_acc = test_model(client_models[0], test_loader) client_accuracies = [
print(f"\nRound {round+1} Performance:") test_model(client_models[i],
print(f"Global Model Accuracy: {server_acc:.2f}%") client_datasets[i]['test']) # 动态创建测试loader
print(f"Client Model Accuracy: {client_acc:.2f}%") for i in range(NUM_CLIENTS)
]
print(f"\nRound {round+1} Results:")
print(f"Server Accuracy: {server_acc:.2f}%")
for i, acc in enumerate(client_accuracies):
print(f"Client {i} Accuracy: {acc:.2f}%")
round_progress.update(1) round_progress.update(1)
print(f"Round {round+1} completed")
print("Training completed!") # 保存模型
# Save trained models
torch.save(global_server_model.state_dict(), "server_model.pth") torch.save(global_server_model.state_dict(), "server_model.pth")
for i in range(NUM_CLIENTS): for i in range(NUM_CLIENTS):
torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth") torch.save(client_models[i].state_dict(), f"client{i}_model.pth")
print("Models saved successfully.")
# Test server model # 最终测试
print("\nFinal Evaluation:")
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True)) server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
server_acc = test_model(server_model, test_loader) print(f"Server Accuracy: {test_model(server_model, server_test_loader):.2f}%")
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
# Test client model
for i in range(NUM_CLIENTS): for i in range(NUM_CLIENTS):
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True)) client_model.load_state_dict(torch.load(f"client{i}_model.pth",weights_only=True))
client_acc = test_model(client_model, test_loader) test_loader = client_datasets[i]['test']
print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%") print(f"Client {i} Accuracy: {test_model(client_model, test_loader):.2f}%")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -5,8 +5,11 @@ import torch
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torchvision import transforms from torchvision import transforms
from tqdm import tqdm # 导入 tqdm
import logging
class ImageClassificationDataset(Dataset): class ImageClassificationDataset(Dataset):
def __init__(self, root_dir, transform=None,Cache=False): def __init__(self, root_dir, transform=None, Cache=False):
self.root_dir = root_dir self.root_dir = root_dir
self.transform = transform self.transform = transform
self.classes = sorted(os.listdir(root_dir)) self.classes = sorted(os.listdir(root_dir))
@ -20,7 +23,8 @@ class ImageClassificationDataset(Dataset):
"init the dataloader" "init the dataloader"
) )
for cls_name in self.classes: # 使用 tqdm 显示进度条
for cls_name in tqdm(self.classes, desc="Loading images"):
cls_dir = os.path.join(root_dir, cls_name) cls_dir = os.path.join(root_dir, cls_name)
for img_name in os.listdir(cls_dir): for img_name in os.listdir(cls_dir):
try: try:
@ -33,11 +37,8 @@ class ImageClassificationDataset(Dataset):
else: else:
self.image_paths.append(img_path) self.image_paths.append(img_path)
self.labels.append(self.class_to_idx[cls_name]) self.labels.append(self.class_to_idx[cls_name])
except: except Exception as e:
logger.log("info", logger.log("info", f"Read image error: {img_path} - {e}")
"read image error " +
img_path
)
def __len__(self): def __len__(self):
return len(self.labels) return len(self.labels)
@ -46,12 +47,11 @@ class ImageClassificationDataset(Dataset):
label = self.labels[idx] label = self.labels[idx]
if self.Cache == 'RAM': if self.Cache == 'RAM':
image = self.image[idx] image = self.image[idx]
else: else:
img_path = self.image_paths[idx] img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB') image = Image.open(img_path).convert('RGB')
if self.transform: if self.transform:
image = self.transform(image) image = self.transform(image)
return image, label return image, label
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False): def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):

View File

@ -2,9 +2,8 @@ import os
import shutil import shutil
import random import random
def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15): def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
# 确保比例总和为1 # 确保比例总和为1
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1"
# 创建输出目录 # 创建输出目录
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@ -55,6 +54,6 @@ def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15,
print("Dataset successfully split into train, validation, and test sets.") print("Dataset successfully split into train, validation, and test sets.")
# 使用示例 # 使用示例
base_directory = 'F:/dataset/02.TA_EC/EC27/JY_A' base_directory = 'L:/Grade_datasets/SPLIT/JY_A'
output_directory = 'F:/dataset/02.TA_EC/datasets/EC27' output_directory = 'L:/Grade_datasets/train/JY_A'
create_dataset_splits(base_directory, output_directory) create_dataset_splits(base_directory, output_directory)