完成联邦蒸馏

This commit is contained in:
yoiannis 2025-03-12 00:21:31 +08:00
parent 617230e296
commit 382916643b
3 changed files with 62 additions and 22 deletions

73
FED.py
View File

@ -13,10 +13,11 @@ from model.mobilenetv3 import MobileNetV3
# 配置参数 # 配置参数
NUM_CLIENTS = 2 NUM_CLIENTS = 2
NUM_ROUNDS = 3 NUM_ROUNDS = 10
CLIENT_EPOCHS = 5 CLIENT_EPOCHS = 2
BATCH_SIZE = 32 BATCH_SIZE = 32
TEMP = 2.0 # 蒸馏温度 TEMP = 2.0 # 蒸馏温度
CLASS_NUM = [3, 3, 3]
# 设备配置 # 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -32,9 +33,9 @@ def prepare_data():
]) ])
# Load datasets # Load datasets
dataset_A = ImageFolder(root='./dataset_A/train', transform=transform) dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform)
dataset_B = ImageFolder(root='./dataset_B/train', transform=transform) dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform)
dataset_C = ImageFolder(root='./dataset_C/train', transform=transform) dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform)
# Assign datasets to clients # Assign datasets to clients
client_datasets = [dataset_B, dataset_C] client_datasets = [dataset_B, dataset_C]
@ -105,13 +106,6 @@ def client_train(client_model, server_model, dataset):
}) })
progress_bar.update(1) progress_bar.update(1)
# 每10个batch打印详细信息
if (batch_idx + 1) % 10 == 0:
progress_bar.write(f"\nEpoch {epoch+1} | Batch {batch_idx+1}")
progress_bar.write(f"Task Loss: {loss_task:.4f}")
progress_bar.write(f"Distill Loss: {loss_distill:.4f}")
progress_bar.write(f"Total Loss: {total_loss:.4f}")
progress_bar.write(f"Batch Accuracy: {100*correct/total:.2f}%\n")
# 每个epoch结束打印汇总信息 # 每个epoch结束打印汇总信息
avg_loss = epoch_loss / len(loader) avg_loss = epoch_loss / len(loader)
avg_task = task_loss / len(loader) avg_task = task_loss / len(loader)
@ -135,6 +129,37 @@ def aggregate(client_params):
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0) global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
return global_params return global_params
def server_aggregate(server_model, client_models, public_loader):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
for data, _ in public_loader:
data = data.to(device)
# 获取客户端模型特征
client_features = []
with torch.no_grad():
for model in client_models:
features = model.extract_features(data) # 需要实现特征提取方法
client_features.append(features)
# 计算特征蒸馏目标
target_features = torch.stack(client_features).mean(dim=0)
# 服务器前向
server_features = server_model.extract_features(data)
# 特征对齐损失
loss = F.mse_loss(server_features, target_features)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 更新统计信息
total_loss += loss.item()
# 服务器知识更新 # 服务器知识更新
def server_update(server_model, client_models, public_loader): def server_update(server_model, client_models, public_loader):
server_model.train() server_model.train()
@ -191,15 +216,19 @@ def test_model(model, test_loader):
# 主训练流程 # 主训练流程
def main(): def main():
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Initialize models # Initialize models
global_server_model = repvit_m1_1(num_classes=10).to(device) global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)] client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
# Prepare data # Prepare data
client_datasets, public_loader = prepare_data() client_datasets, public_loader = prepare_data()
# Test dataset (using dataset A's test set for simplicity) # Test dataset (using dataset A's test set for simplicity)
test_dataset = ImageFolder(root='./dataset_A/test', transform=transform) test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round") round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
@ -245,20 +274,22 @@ def main():
# Save trained models # Save trained models
torch.save(global_server_model.state_dict(), "server_model.pth") torch.save(global_server_model.state_dict(), "server_model.pth")
torch.save(client_models[0].state_dict(), "client_model.pth") for i in range(NUM_CLIENTS):
torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth")
print("Models saved successfully.") print("Models saved successfully.")
# Test server model # Test server model
server_model = repvit_m1_1(num_classes=10).to(device) server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
server_model.load_state_dict(torch.load("server_model.pth")) server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
server_acc = test_model(server_model, test_loader) server_acc = test_model(server_model, test_loader)
print(f"Server Model Test Accuracy: {server_acc:.2f}%") print(f"Server Model Test Accuracy: {server_acc:.2f}%")
# Test client model # Test client model
client_model = MobileNetV3(n_class=10).to(device) for i in range(NUM_CLIENTS):
client_model.load_state_dict(torch.load("client_model.pth")) client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True))
client_acc = test_model(client_model, test_loader) client_acc = test_model(client_model, test_loader)
print(f"Client Model Test Accuracy: {client_acc:.2f}%") print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -200,6 +200,11 @@ class MobileNetV3(nn.Module):
self._initialize_weights() self._initialize_weights()
def extract_features(self, x):
x = self.features(x)
return x
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)
x = x.mean(3).mean(2) x = x.mean(3).mean(2)

View File

@ -236,6 +236,10 @@ class RepViT(nn.Module):
self.features = nn.ModuleList(layers) self.features = nn.ModuleList(layers)
self.classifier = Classfier(output_channel, num_classes, distillation) self.classifier = Classfier(output_channel, num_classes, distillation)
def extract_features(self, x):
for f in self.features:
x = f(x)
return x
def forward(self, x): def forward(self, x):
# x = self.features(x) # x = self.features(x)
for f in self.features: for f in self.features: