更新联邦学习代码
This commit is contained in:
parent
4cb9790dee
commit
a5ca9d04d7
139
FED.py
139
FED.py
@ -8,6 +8,7 @@ import copy
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from data_loader import get_data_loader
|
||||
from model.repvit import repvit_m1_1
|
||||
from model.mobilenetv3 import MobileNetV3
|
||||
|
||||
@ -17,7 +18,7 @@ NUM_ROUNDS = 10
|
||||
CLIENT_EPOCHS = 2
|
||||
BATCH_SIZE = 32
|
||||
TEMP = 2.0 # 蒸馏温度
|
||||
CLASS_NUM = [3, 3, 3]
|
||||
CLASS_NUM = [9, 9, 9]
|
||||
|
||||
# 设备配置
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@ -32,31 +33,39 @@ def prepare_data():
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
# Load datasets
|
||||
dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform)
|
||||
dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform)
|
||||
dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform)
|
||||
# 加载所有数据集(训练、验证、测试)
|
||||
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/JY_A',Cache='RAM')
|
||||
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_A',Cache='RAM')
|
||||
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_B',Cache='RAM')
|
||||
|
||||
# Assign datasets to clients
|
||||
client_datasets = [dataset_B, dataset_C]
|
||||
# 组织客户端数据集
|
||||
client_datasets = [
|
||||
{ # Client 0
|
||||
'train': dataset_B_train,
|
||||
'val': dataset_B_val,
|
||||
'test': dataset_B_test
|
||||
},
|
||||
{ # Client 1
|
||||
'train': dataset_C_train,
|
||||
'val': dataset_C_val,
|
||||
'test': dataset_C_test
|
||||
}
|
||||
]
|
||||
|
||||
# Server dataset (A) for public updates
|
||||
public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True)
|
||||
# 公共数据集(使用A的训练集)
|
||||
public_loader = dataset_A_train
|
||||
|
||||
return client_datasets, public_loader
|
||||
# 服务器测试集(使用A的测试集)
|
||||
server_test_loader = dataset_A_test
|
||||
|
||||
return client_datasets, public_loader, server_test_loader
|
||||
|
||||
# 客户端训练函数
|
||||
def client_train(client_model, server_model, dataset):
|
||||
def client_train(client_model, server_model, loader):
|
||||
client_model.train()
|
||||
server_model.eval()
|
||||
|
||||
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
|
||||
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
# 训练进度条
|
||||
progress_bar = tqdm(total=CLIENT_EPOCHS*len(loader),
|
||||
desc="Client Training",
|
||||
unit="batch")
|
||||
|
||||
for epoch in range(CLIENT_EPOCHS):
|
||||
epoch_loss = 0.0
|
||||
@ -65,16 +74,19 @@ def client_train(client_model, server_model, dataset):
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for batch_idx, (data, target) in enumerate(loader):
|
||||
# 训练进度条
|
||||
progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{CLIENT_EPOCHS}")
|
||||
|
||||
for batch_idx, (data, target) in enumerate(progress_bar):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 前向传播
|
||||
client_output = client_model(data)
|
||||
client_output = client_model(data).to(device)
|
||||
|
||||
# 获取教师模型输出
|
||||
with torch.no_grad():
|
||||
server_output = server_model(data)
|
||||
server_output = server_model(data).to(device)
|
||||
|
||||
# 计算损失
|
||||
loss_task = F.cross_entropy(client_output, target)
|
||||
@ -166,9 +178,10 @@ def server_update(server_model, client_models, public_loader):
|
||||
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
||||
|
||||
total_loss = 0.0
|
||||
|
||||
progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch")
|
||||
|
||||
for batch_idx, (data, _) in enumerate(progress_bar):
|
||||
for batch_idx, (data, target) in enumerate(progress_bar):
|
||||
data = data.to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
@ -199,97 +212,97 @@ def server_update(server_model, client_models, public_loader):
|
||||
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
|
||||
|
||||
|
||||
def test_model(model, test_loader):
|
||||
def test_model(model, test_loader): # 添加对DataLoader的支持
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
progress_bar = tqdm(test_loader, desc="Server Updating", unit="batch")
|
||||
|
||||
for batch_idx, (data, target) in enumerate(progress_bar):
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
total += target.size(0)
|
||||
correct += (predicted == target).sum().item()
|
||||
accuracy = 100 * correct / total
|
||||
return accuracy
|
||||
return 100 * correct / total
|
||||
|
||||
|
||||
# 主训练流程
|
||||
def main():
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
# Initialize models
|
||||
# 初始化模型(保持不变)
|
||||
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
|
||||
|
||||
# Prepare data
|
||||
client_datasets, public_loader = prepare_data()
|
||||
|
||||
# Test dataset (using dataset A's test set for simplicity)
|
||||
test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform)
|
||||
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
||||
# 加载数据集
|
||||
client_datasets, public_loader, server_test_loader = prepare_data()
|
||||
|
||||
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
||||
|
||||
for round in range(NUM_ROUNDS):
|
||||
print(f"\n{'#'*50}")
|
||||
print(f"Federated Round {round+1}/{NUM_ROUNDS}")
|
||||
print(f"Round {round+1}/{NUM_ROUNDS}")
|
||||
print(f"{'#'*50}")
|
||||
|
||||
# Client selection (only 2 clients)
|
||||
# 客户端选择
|
||||
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
|
||||
print(f"Selected Clients: {selected_clients}")
|
||||
print(f"Selected clients: {selected_clients}")
|
||||
|
||||
# Client local training
|
||||
# 客户端训练
|
||||
client_params = []
|
||||
for cid in selected_clients:
|
||||
print(f"\nTraining Client {cid}")
|
||||
local_model = copy.deepcopy(client_models[cid])
|
||||
local_model.load_state_dict(client_models[cid].state_dict())
|
||||
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
|
||||
|
||||
# 传入客户端的训练集
|
||||
updated_params = client_train(
|
||||
local_model,
|
||||
global_server_model,
|
||||
client_datasets[cid]['train'] # 使用训练集
|
||||
)
|
||||
client_params.append(updated_params)
|
||||
|
||||
# Model aggregation
|
||||
# 模型聚合
|
||||
global_client_params = aggregate(client_params)
|
||||
for model in client_models:
|
||||
model.load_state_dict(global_client_params)
|
||||
|
||||
# Server knowledge update
|
||||
print("\nServer Updating...")
|
||||
# 服务器更新
|
||||
print("\nServer updating...")
|
||||
server_update(global_server_model, client_models, public_loader)
|
||||
|
||||
# Test model performance
|
||||
server_acc = test_model(global_server_model, test_loader)
|
||||
client_acc = test_model(client_models[0], test_loader)
|
||||
print(f"\nRound {round+1} Performance:")
|
||||
print(f"Global Model Accuracy: {server_acc:.2f}%")
|
||||
print(f"Client Model Accuracy: {client_acc:.2f}%")
|
||||
# 测试性能
|
||||
server_acc = test_model(global_server_model, server_test_loader)
|
||||
client_accuracies = [
|
||||
test_model(client_models[i],
|
||||
client_datasets[i]['test']) # 动态创建测试loader
|
||||
for i in range(NUM_CLIENTS)
|
||||
]
|
||||
|
||||
print(f"\nRound {round+1} Results:")
|
||||
print(f"Server Accuracy: {server_acc:.2f}%")
|
||||
for i, acc in enumerate(client_accuracies):
|
||||
print(f"Client {i} Accuracy: {acc:.2f}%")
|
||||
|
||||
round_progress.update(1)
|
||||
print(f"Round {round+1} completed")
|
||||
|
||||
print("Training completed!")
|
||||
|
||||
# Save trained models
|
||||
# 保存模型
|
||||
torch.save(global_server_model.state_dict(), "server_model.pth")
|
||||
for i in range(NUM_CLIENTS):
|
||||
torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth")
|
||||
print("Models saved successfully.")
|
||||
torch.save(client_models[i].state_dict(), f"client{i}_model.pth")
|
||||
|
||||
# Test server model
|
||||
# 最终测试
|
||||
print("\nFinal Evaluation:")
|
||||
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
|
||||
server_acc = test_model(server_model, test_loader)
|
||||
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
||||
print(f"Server Accuracy: {test_model(server_model, server_test_loader):.2f}%")
|
||||
|
||||
# Test client model
|
||||
for i in range(NUM_CLIENTS):
|
||||
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
|
||||
client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True))
|
||||
client_acc = test_model(client_model, test_loader)
|
||||
print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%")
|
||||
client_model.load_state_dict(torch.load(f"client{i}_model.pth",weights_only=True))
|
||||
test_loader = client_datasets[i]['test']
|
||||
print(f"Client {i} Accuracy: {test_model(client_model, test_loader):.2f}%")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -5,8 +5,11 @@ import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
from tqdm import tqdm # 导入 tqdm
|
||||
import logging
|
||||
|
||||
class ImageClassificationDataset(Dataset):
|
||||
def __init__(self, root_dir, transform=None,Cache=False):
|
||||
def __init__(self, root_dir, transform=None, Cache=False):
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.classes = sorted(os.listdir(root_dir))
|
||||
@ -20,7 +23,8 @@ class ImageClassificationDataset(Dataset):
|
||||
"init the dataloader"
|
||||
)
|
||||
|
||||
for cls_name in self.classes:
|
||||
# 使用 tqdm 显示进度条
|
||||
for cls_name in tqdm(self.classes, desc="Loading images"):
|
||||
cls_dir = os.path.join(root_dir, cls_name)
|
||||
for img_name in os.listdir(cls_dir):
|
||||
try:
|
||||
@ -33,11 +37,8 @@ class ImageClassificationDataset(Dataset):
|
||||
else:
|
||||
self.image_paths.append(img_path)
|
||||
self.labels.append(self.class_to_idx[cls_name])
|
||||
except:
|
||||
logger.log("info",
|
||||
"read image error " +
|
||||
img_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.log("info", f"Read image error: {img_path} - {e}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
@ -51,7 +52,6 @@ class ImageClassificationDataset(Dataset):
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
return image, label
|
||||
|
||||
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):
|
||||
|
@ -2,9 +2,8 @@ import os
|
||||
import shutil
|
||||
import random
|
||||
|
||||
def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
|
||||
def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
|
||||
# 确保比例总和为1
|
||||
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1"
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@ -55,6 +54,6 @@ def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15,
|
||||
print("Dataset successfully split into train, validation, and test sets.")
|
||||
|
||||
# 使用示例
|
||||
base_directory = 'F:/dataset/02.TA_EC/EC27/JY_A'
|
||||
output_directory = 'F:/dataset/02.TA_EC/datasets/EC27'
|
||||
base_directory = 'L:/Grade_datasets/SPLIT/JY_A'
|
||||
output_directory = 'L:/Grade_datasets/train/JY_A'
|
||||
create_dataset_splits(base_directory, output_directory)
|
Loading…
Reference in New Issue
Block a user