完成联邦蒸馏
This commit is contained in:
parent
617230e296
commit
382916643b
75
FED.py
75
FED.py
@ -13,10 +13,11 @@ from model.mobilenetv3 import MobileNetV3
|
|||||||
|
|
||||||
# 配置参数
|
# 配置参数
|
||||||
NUM_CLIENTS = 2
|
NUM_CLIENTS = 2
|
||||||
NUM_ROUNDS = 3
|
NUM_ROUNDS = 10
|
||||||
CLIENT_EPOCHS = 5
|
CLIENT_EPOCHS = 2
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
TEMP = 2.0 # 蒸馏温度
|
TEMP = 2.0 # 蒸馏温度
|
||||||
|
CLASS_NUM = [3, 3, 3]
|
||||||
|
|
||||||
# 设备配置
|
# 设备配置
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
@ -32,9 +33,9 @@ def prepare_data():
|
|||||||
])
|
])
|
||||||
|
|
||||||
# Load datasets
|
# Load datasets
|
||||||
dataset_A = ImageFolder(root='./dataset_A/train', transform=transform)
|
dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform)
|
||||||
dataset_B = ImageFolder(root='./dataset_B/train', transform=transform)
|
dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform)
|
||||||
dataset_C = ImageFolder(root='./dataset_C/train', transform=transform)
|
dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform)
|
||||||
|
|
||||||
# Assign datasets to clients
|
# Assign datasets to clients
|
||||||
client_datasets = [dataset_B, dataset_C]
|
client_datasets = [dataset_B, dataset_C]
|
||||||
@ -105,13 +106,6 @@ def client_train(client_model, server_model, dataset):
|
|||||||
})
|
})
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|
||||||
# 每10个batch打印详细信息
|
|
||||||
if (batch_idx + 1) % 10 == 0:
|
|
||||||
progress_bar.write(f"\nEpoch {epoch+1} | Batch {batch_idx+1}")
|
|
||||||
progress_bar.write(f"Task Loss: {loss_task:.4f}")
|
|
||||||
progress_bar.write(f"Distill Loss: {loss_distill:.4f}")
|
|
||||||
progress_bar.write(f"Total Loss: {total_loss:.4f}")
|
|
||||||
progress_bar.write(f"Batch Accuracy: {100*correct/total:.2f}%\n")
|
|
||||||
# 每个epoch结束打印汇总信息
|
# 每个epoch结束打印汇总信息
|
||||||
avg_loss = epoch_loss / len(loader)
|
avg_loss = epoch_loss / len(loader)
|
||||||
avg_task = task_loss / len(loader)
|
avg_task = task_loss / len(loader)
|
||||||
@ -135,6 +129,37 @@ def aggregate(client_params):
|
|||||||
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
|
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
|
||||||
return global_params
|
return global_params
|
||||||
|
|
||||||
|
def server_aggregate(server_model, client_models, public_loader):
|
||||||
|
server_model.train()
|
||||||
|
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
for data, _ in public_loader:
|
||||||
|
data = data.to(device)
|
||||||
|
|
||||||
|
# 获取客户端模型特征
|
||||||
|
client_features = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for model in client_models:
|
||||||
|
features = model.extract_features(data) # 需要实现特征提取方法
|
||||||
|
client_features.append(features)
|
||||||
|
|
||||||
|
# 计算特征蒸馏目标
|
||||||
|
target_features = torch.stack(client_features).mean(dim=0)
|
||||||
|
|
||||||
|
# 服务器前向
|
||||||
|
server_features = server_model.extract_features(data)
|
||||||
|
|
||||||
|
# 特征对齐损失
|
||||||
|
loss = F.mse_loss(server_features, target_features)
|
||||||
|
|
||||||
|
# 反向传播
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# 更新统计信息
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
# 服务器知识更新
|
# 服务器知识更新
|
||||||
def server_update(server_model, client_models, public_loader):
|
def server_update(server_model, client_models, public_loader):
|
||||||
server_model.train()
|
server_model.train()
|
||||||
@ -191,15 +216,19 @@ def test_model(model, test_loader):
|
|||||||
|
|
||||||
# 主训练流程
|
# 主训练流程
|
||||||
def main():
|
def main():
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
# Initialize models
|
# Initialize models
|
||||||
global_server_model = repvit_m1_1(num_classes=10).to(device)
|
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||||
client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)]
|
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
|
||||||
|
|
||||||
# Prepare data
|
# Prepare data
|
||||||
client_datasets, public_loader = prepare_data()
|
client_datasets, public_loader = prepare_data()
|
||||||
|
|
||||||
# Test dataset (using dataset A's test set for simplicity)
|
# Test dataset (using dataset A's test set for simplicity)
|
||||||
test_dataset = ImageFolder(root='./dataset_A/test', transform=transform)
|
test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform)
|
||||||
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
||||||
|
|
||||||
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
||||||
@ -245,20 +274,22 @@ def main():
|
|||||||
|
|
||||||
# Save trained models
|
# Save trained models
|
||||||
torch.save(global_server_model.state_dict(), "server_model.pth")
|
torch.save(global_server_model.state_dict(), "server_model.pth")
|
||||||
torch.save(client_models[0].state_dict(), "client_model.pth")
|
for i in range(NUM_CLIENTS):
|
||||||
|
torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth")
|
||||||
print("Models saved successfully.")
|
print("Models saved successfully.")
|
||||||
|
|
||||||
# Test server model
|
# Test server model
|
||||||
server_model = repvit_m1_1(num_classes=10).to(device)
|
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||||
server_model.load_state_dict(torch.load("server_model.pth"))
|
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
|
||||||
server_acc = test_model(server_model, test_loader)
|
server_acc = test_model(server_model, test_loader)
|
||||||
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
||||||
|
|
||||||
# Test client model
|
# Test client model
|
||||||
client_model = MobileNetV3(n_class=10).to(device)
|
for i in range(NUM_CLIENTS):
|
||||||
client_model.load_state_dict(torch.load("client_model.pth"))
|
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
|
||||||
client_acc = test_model(client_model, test_loader)
|
client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True))
|
||||||
print(f"Client Model Test Accuracy: {client_acc:.2f}%")
|
client_acc = test_model(client_model, test_loader)
|
||||||
|
print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
@ -200,6 +200,11 @@ class MobileNetV3(nn.Module):
|
|||||||
|
|
||||||
self._initialize_weights()
|
self._initialize_weights()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.features(x)
|
x = self.features(x)
|
||||||
x = x.mean(3).mean(2)
|
x = x.mean(3).mean(2)
|
||||||
|
@ -236,6 +236,10 @@ class RepViT(nn.Module):
|
|||||||
self.features = nn.ModuleList(layers)
|
self.features = nn.ModuleList(layers)
|
||||||
self.classifier = Classfier(output_channel, num_classes, distillation)
|
self.classifier = Classfier(output_channel, num_classes, distillation)
|
||||||
|
|
||||||
|
def extract_features(self, x):
|
||||||
|
for f in self.features:
|
||||||
|
x = f(x)
|
||||||
|
return x
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# x = self.features(x)
|
# x = self.features(x)
|
||||||
for f in self.features:
|
for f in self.features:
|
||||||
|
Loading…
Reference in New Issue
Block a user