更新联邦学习代码
This commit is contained in:
parent
4cb9790dee
commit
a5ca9d04d7
141
FED.py
141
FED.py
@ -8,6 +8,7 @@ import copy
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from data_loader import get_data_loader
|
||||||
from model.repvit import repvit_m1_1
|
from model.repvit import repvit_m1_1
|
||||||
from model.mobilenetv3 import MobileNetV3
|
from model.mobilenetv3 import MobileNetV3
|
||||||
|
|
||||||
@ -17,7 +18,7 @@ NUM_ROUNDS = 10
|
|||||||
CLIENT_EPOCHS = 2
|
CLIENT_EPOCHS = 2
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
TEMP = 2.0 # 蒸馏温度
|
TEMP = 2.0 # 蒸馏温度
|
||||||
CLASS_NUM = [3, 3, 3]
|
CLASS_NUM = [9, 9, 9]
|
||||||
|
|
||||||
# 设备配置
|
# 设备配置
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
@ -32,49 +33,60 @@ def prepare_data():
|
|||||||
transforms.ToTensor()
|
transforms.ToTensor()
|
||||||
])
|
])
|
||||||
|
|
||||||
# Load datasets
|
# 加载所有数据集(训练、验证、测试)
|
||||||
dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform)
|
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/JY_A',Cache='RAM')
|
||||||
dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform)
|
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_A',Cache='RAM')
|
||||||
dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform)
|
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_B',Cache='RAM')
|
||||||
|
|
||||||
# Assign datasets to clients
|
# 组织客户端数据集
|
||||||
client_datasets = [dataset_B, dataset_C]
|
client_datasets = [
|
||||||
|
{ # Client 0
|
||||||
|
'train': dataset_B_train,
|
||||||
|
'val': dataset_B_val,
|
||||||
|
'test': dataset_B_test
|
||||||
|
},
|
||||||
|
{ # Client 1
|
||||||
|
'train': dataset_C_train,
|
||||||
|
'val': dataset_C_val,
|
||||||
|
'test': dataset_C_test
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
# Server dataset (A) for public updates
|
# 公共数据集(使用A的训练集)
|
||||||
public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True)
|
public_loader = dataset_A_train
|
||||||
|
|
||||||
return client_datasets, public_loader
|
# 服务器测试集(使用A的测试集)
|
||||||
|
server_test_loader = dataset_A_test
|
||||||
|
|
||||||
|
return client_datasets, public_loader, server_test_loader
|
||||||
|
|
||||||
# 客户端训练函数
|
# 客户端训练函数
|
||||||
def client_train(client_model, server_model, dataset):
|
def client_train(client_model, server_model, loader):
|
||||||
client_model.train()
|
client_model.train()
|
||||||
server_model.eval()
|
server_model.eval()
|
||||||
|
|
||||||
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
|
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
|
||||||
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
|
|
||||||
# 训练进度条
|
|
||||||
progress_bar = tqdm(total=CLIENT_EPOCHS*len(loader),
|
|
||||||
desc="Client Training",
|
|
||||||
unit="batch")
|
|
||||||
|
|
||||||
for epoch in range(CLIENT_EPOCHS):
|
for epoch in range(CLIENT_EPOCHS):
|
||||||
epoch_loss = 0.0
|
epoch_loss = 0.0
|
||||||
task_loss = 0.0
|
task_loss = 0.0
|
||||||
distill_loss = 0.0
|
distill_loss = 0.0
|
||||||
correct = 0
|
correct = 0
|
||||||
total = 0
|
total = 0
|
||||||
|
|
||||||
|
# 训练进度条
|
||||||
|
progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{CLIENT_EPOCHS}")
|
||||||
|
|
||||||
for batch_idx, (data, target) in enumerate(loader):
|
for batch_idx, (data, target) in enumerate(progress_bar):
|
||||||
data, target = data.to(device), target.to(device)
|
data, target = data.to(device), target.to(device)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
client_output = client_model(data)
|
client_output = client_model(data).to(device)
|
||||||
|
|
||||||
# 获取教师模型输出
|
# 获取教师模型输出
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
server_output = server_model(data)
|
server_output = server_model(data).to(device)
|
||||||
|
|
||||||
# 计算损失
|
# 计算损失
|
||||||
loss_task = F.cross_entropy(client_output, target)
|
loss_task = F.cross_entropy(client_output, target)
|
||||||
@ -166,9 +178,10 @@ def server_update(server_model, client_models, public_loader):
|
|||||||
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
|
|
||||||
progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch")
|
progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch")
|
||||||
|
|
||||||
for batch_idx, (data, _) in enumerate(progress_bar):
|
for batch_idx, (data, target) in enumerate(progress_bar):
|
||||||
data = data.to(device)
|
data = data.to(device)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
@ -199,97 +212,97 @@ def server_update(server_model, client_models, public_loader):
|
|||||||
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
|
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
|
||||||
|
|
||||||
|
|
||||||
def test_model(model, test_loader):
|
def test_model(model, test_loader): # 添加对DataLoader的支持
|
||||||
model.eval()
|
model.eval()
|
||||||
correct = 0
|
correct = 0
|
||||||
total = 0
|
total = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data, target in test_loader:
|
progress_bar = tqdm(test_loader, desc="Server Updating", unit="batch")
|
||||||
|
|
||||||
|
for batch_idx, (data, target) in enumerate(progress_bar):
|
||||||
data, target = data.to(device), target.to(device)
|
data, target = data.to(device), target.to(device)
|
||||||
output = model(data)
|
output = model(data)
|
||||||
_, predicted = torch.max(output.data, 1)
|
_, predicted = torch.max(output.data, 1)
|
||||||
total += target.size(0)
|
total += target.size(0)
|
||||||
correct += (predicted == target).sum().item()
|
correct += (predicted == target).sum().item()
|
||||||
accuracy = 100 * correct / total
|
return 100 * correct / total
|
||||||
return accuracy
|
|
||||||
|
|
||||||
|
|
||||||
# 主训练流程
|
# 主训练流程
|
||||||
def main():
|
def main():
|
||||||
transform = transforms.Compose([
|
# 初始化模型(保持不变)
|
||||||
transforms.Resize((224, 224)),
|
|
||||||
transforms.ToTensor()
|
|
||||||
])
|
|
||||||
# Initialize models
|
|
||||||
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||||
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
|
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
|
||||||
|
|
||||||
# Prepare data
|
# 加载数据集
|
||||||
client_datasets, public_loader = prepare_data()
|
client_datasets, public_loader, server_test_loader = prepare_data()
|
||||||
|
|
||||||
# Test dataset (using dataset A's test set for simplicity)
|
|
||||||
test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform)
|
|
||||||
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
|
||||||
|
|
||||||
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
||||||
|
|
||||||
for round in range(NUM_ROUNDS):
|
for round in range(NUM_ROUNDS):
|
||||||
print(f"\n{'#'*50}")
|
print(f"\n{'#'*50}")
|
||||||
print(f"Federated Round {round+1}/{NUM_ROUNDS}")
|
print(f"Round {round+1}/{NUM_ROUNDS}")
|
||||||
print(f"{'#'*50}")
|
print(f"{'#'*50}")
|
||||||
|
|
||||||
# Client selection (only 2 clients)
|
# 客户端选择
|
||||||
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
|
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
|
||||||
print(f"Selected Clients: {selected_clients}")
|
print(f"Selected clients: {selected_clients}")
|
||||||
|
|
||||||
# Client local training
|
# 客户端训练
|
||||||
client_params = []
|
client_params = []
|
||||||
for cid in selected_clients:
|
for cid in selected_clients:
|
||||||
print(f"\nTraining Client {cid}")
|
print(f"\nTraining Client {cid}")
|
||||||
local_model = copy.deepcopy(client_models[cid])
|
local_model = copy.deepcopy(client_models[cid])
|
||||||
local_model.load_state_dict(client_models[cid].state_dict())
|
local_model.load_state_dict(client_models[cid].state_dict())
|
||||||
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
|
|
||||||
|
# 传入客户端的训练集
|
||||||
|
updated_params = client_train(
|
||||||
|
local_model,
|
||||||
|
global_server_model,
|
||||||
|
client_datasets[cid]['train'] # 使用训练集
|
||||||
|
)
|
||||||
client_params.append(updated_params)
|
client_params.append(updated_params)
|
||||||
|
|
||||||
# Model aggregation
|
# 模型聚合
|
||||||
global_client_params = aggregate(client_params)
|
global_client_params = aggregate(client_params)
|
||||||
for model in client_models:
|
for model in client_models:
|
||||||
model.load_state_dict(global_client_params)
|
model.load_state_dict(global_client_params)
|
||||||
|
|
||||||
# Server knowledge update
|
# 服务器更新
|
||||||
print("\nServer Updating...")
|
print("\nServer updating...")
|
||||||
server_update(global_server_model, client_models, public_loader)
|
server_update(global_server_model, client_models, public_loader)
|
||||||
|
|
||||||
# Test model performance
|
# 测试性能
|
||||||
server_acc = test_model(global_server_model, test_loader)
|
server_acc = test_model(global_server_model, server_test_loader)
|
||||||
client_acc = test_model(client_models[0], test_loader)
|
client_accuracies = [
|
||||||
print(f"\nRound {round+1} Performance:")
|
test_model(client_models[i],
|
||||||
print(f"Global Model Accuracy: {server_acc:.2f}%")
|
client_datasets[i]['test']) # 动态创建测试loader
|
||||||
print(f"Client Model Accuracy: {client_acc:.2f}%")
|
for i in range(NUM_CLIENTS)
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"\nRound {round+1} Results:")
|
||||||
|
print(f"Server Accuracy: {server_acc:.2f}%")
|
||||||
|
for i, acc in enumerate(client_accuracies):
|
||||||
|
print(f"Client {i} Accuracy: {acc:.2f}%")
|
||||||
|
|
||||||
round_progress.update(1)
|
round_progress.update(1)
|
||||||
print(f"Round {round+1} completed")
|
|
||||||
|
|
||||||
print("Training completed!")
|
# 保存模型
|
||||||
|
|
||||||
# Save trained models
|
|
||||||
torch.save(global_server_model.state_dict(), "server_model.pth")
|
torch.save(global_server_model.state_dict(), "server_model.pth")
|
||||||
for i in range(NUM_CLIENTS):
|
for i in range(NUM_CLIENTS):
|
||||||
torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth")
|
torch.save(client_models[i].state_dict(), f"client{i}_model.pth")
|
||||||
print("Models saved successfully.")
|
|
||||||
|
|
||||||
# Test server model
|
# 最终测试
|
||||||
|
print("\nFinal Evaluation:")
|
||||||
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||||
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
|
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
|
||||||
server_acc = test_model(server_model, test_loader)
|
print(f"Server Accuracy: {test_model(server_model, server_test_loader):.2f}%")
|
||||||
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
|
||||||
|
|
||||||
# Test client model
|
|
||||||
for i in range(NUM_CLIENTS):
|
for i in range(NUM_CLIENTS):
|
||||||
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
|
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
|
||||||
client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True))
|
client_model.load_state_dict(torch.load(f"client{i}_model.pth",weights_only=True))
|
||||||
client_acc = test_model(client_model, test_loader)
|
test_loader = client_datasets[i]['test']
|
||||||
print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%")
|
print(f"Client {i} Accuracy: {test_model(client_model, test_loader):.2f}%")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
@ -5,8 +5,11 @@ import torch
|
|||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from tqdm import tqdm # 导入 tqdm
|
||||||
|
import logging
|
||||||
|
|
||||||
class ImageClassificationDataset(Dataset):
|
class ImageClassificationDataset(Dataset):
|
||||||
def __init__(self, root_dir, transform=None,Cache=False):
|
def __init__(self, root_dir, transform=None, Cache=False):
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.classes = sorted(os.listdir(root_dir))
|
self.classes = sorted(os.listdir(root_dir))
|
||||||
@ -20,7 +23,8 @@ class ImageClassificationDataset(Dataset):
|
|||||||
"init the dataloader"
|
"init the dataloader"
|
||||||
)
|
)
|
||||||
|
|
||||||
for cls_name in self.classes:
|
# 使用 tqdm 显示进度条
|
||||||
|
for cls_name in tqdm(self.classes, desc="Loading images"):
|
||||||
cls_dir = os.path.join(root_dir, cls_name)
|
cls_dir = os.path.join(root_dir, cls_name)
|
||||||
for img_name in os.listdir(cls_dir):
|
for img_name in os.listdir(cls_dir):
|
||||||
try:
|
try:
|
||||||
@ -33,11 +37,8 @@ class ImageClassificationDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
self.image_paths.append(img_path)
|
self.image_paths.append(img_path)
|
||||||
self.labels.append(self.class_to_idx[cls_name])
|
self.labels.append(self.class_to_idx[cls_name])
|
||||||
except:
|
except Exception as e:
|
||||||
logger.log("info",
|
logger.log("info", f"Read image error: {img_path} - {e}")
|
||||||
"read image error " +
|
|
||||||
img_path
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.labels)
|
return len(self.labels)
|
||||||
@ -46,12 +47,11 @@ class ImageClassificationDataset(Dataset):
|
|||||||
label = self.labels[idx]
|
label = self.labels[idx]
|
||||||
if self.Cache == 'RAM':
|
if self.Cache == 'RAM':
|
||||||
image = self.image[idx]
|
image = self.image[idx]
|
||||||
else:
|
else:
|
||||||
img_path = self.image_paths[idx]
|
img_path = self.image_paths[idx]
|
||||||
image = Image.open(img_path).convert('RGB')
|
image = Image.open(img_path).convert('RGB')
|
||||||
if self.transform:
|
if self.transform:
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
|
|
||||||
return image, label
|
return image, label
|
||||||
|
|
||||||
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):
|
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):
|
||||||
|
@ -2,9 +2,8 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import random
|
import random
|
||||||
|
|
||||||
def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
|
def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
|
||||||
# 确保比例总和为1
|
# 确保比例总和为1
|
||||||
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1"
|
|
||||||
|
|
||||||
# 创建输出目录
|
# 创建输出目录
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
@ -55,6 +54,6 @@ def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15,
|
|||||||
print("Dataset successfully split into train, validation, and test sets.")
|
print("Dataset successfully split into train, validation, and test sets.")
|
||||||
|
|
||||||
# 使用示例
|
# 使用示例
|
||||||
base_directory = 'F:/dataset/02.TA_EC/EC27/JY_A'
|
base_directory = 'L:/Grade_datasets/SPLIT/JY_A'
|
||||||
output_directory = 'F:/dataset/02.TA_EC/datasets/EC27'
|
output_directory = 'L:/Grade_datasets/train/JY_A'
|
||||||
create_dataset_splits(base_directory, output_directory)
|
create_dataset_splits(base_directory, output_directory)
|
Loading…
Reference in New Issue
Block a user