完成联邦蒸馏
This commit is contained in:
parent
617230e296
commit
382916643b
73
FED.py
73
FED.py
@ -13,10 +13,11 @@ from model.mobilenetv3 import MobileNetV3
|
||||
|
||||
# 配置参数
|
||||
NUM_CLIENTS = 2
|
||||
NUM_ROUNDS = 3
|
||||
CLIENT_EPOCHS = 5
|
||||
NUM_ROUNDS = 10
|
||||
CLIENT_EPOCHS = 2
|
||||
BATCH_SIZE = 32
|
||||
TEMP = 2.0 # 蒸馏温度
|
||||
CLASS_NUM = [3, 3, 3]
|
||||
|
||||
# 设备配置
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@ -32,9 +33,9 @@ def prepare_data():
|
||||
])
|
||||
|
||||
# Load datasets
|
||||
dataset_A = ImageFolder(root='./dataset_A/train', transform=transform)
|
||||
dataset_B = ImageFolder(root='./dataset_B/train', transform=transform)
|
||||
dataset_C = ImageFolder(root='./dataset_C/train', transform=transform)
|
||||
dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform)
|
||||
dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform)
|
||||
dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform)
|
||||
|
||||
# Assign datasets to clients
|
||||
client_datasets = [dataset_B, dataset_C]
|
||||
@ -105,13 +106,6 @@ def client_train(client_model, server_model, dataset):
|
||||
})
|
||||
progress_bar.update(1)
|
||||
|
||||
# 每10个batch打印详细信息
|
||||
if (batch_idx + 1) % 10 == 0:
|
||||
progress_bar.write(f"\nEpoch {epoch+1} | Batch {batch_idx+1}")
|
||||
progress_bar.write(f"Task Loss: {loss_task:.4f}")
|
||||
progress_bar.write(f"Distill Loss: {loss_distill:.4f}")
|
||||
progress_bar.write(f"Total Loss: {total_loss:.4f}")
|
||||
progress_bar.write(f"Batch Accuracy: {100*correct/total:.2f}%\n")
|
||||
# 每个epoch结束打印汇总信息
|
||||
avg_loss = epoch_loss / len(loader)
|
||||
avg_task = task_loss / len(loader)
|
||||
@ -135,6 +129,37 @@ def aggregate(client_params):
|
||||
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
|
||||
return global_params
|
||||
|
||||
def server_aggregate(server_model, client_models, public_loader):
|
||||
server_model.train()
|
||||
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
||||
|
||||
for data, _ in public_loader:
|
||||
data = data.to(device)
|
||||
|
||||
# 获取客户端模型特征
|
||||
client_features = []
|
||||
with torch.no_grad():
|
||||
for model in client_models:
|
||||
features = model.extract_features(data) # 需要实现特征提取方法
|
||||
client_features.append(features)
|
||||
|
||||
# 计算特征蒸馏目标
|
||||
target_features = torch.stack(client_features).mean(dim=0)
|
||||
|
||||
# 服务器前向
|
||||
server_features = server_model.extract_features(data)
|
||||
|
||||
# 特征对齐损失
|
||||
loss = F.mse_loss(server_features, target_features)
|
||||
|
||||
# 反向传播
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 更新统计信息
|
||||
total_loss += loss.item()
|
||||
|
||||
# 服务器知识更新
|
||||
def server_update(server_model, client_models, public_loader):
|
||||
server_model.train()
|
||||
@ -191,15 +216,19 @@ def test_model(model, test_loader):
|
||||
|
||||
# 主训练流程
|
||||
def main():
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
# Initialize models
|
||||
global_server_model = repvit_m1_1(num_classes=10).to(device)
|
||||
client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)]
|
||||
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
|
||||
|
||||
# Prepare data
|
||||
client_datasets, public_loader = prepare_data()
|
||||
|
||||
# Test dataset (using dataset A's test set for simplicity)
|
||||
test_dataset = ImageFolder(root='./dataset_A/test', transform=transform)
|
||||
test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform)
|
||||
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
||||
|
||||
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
||||
@ -245,20 +274,22 @@ def main():
|
||||
|
||||
# Save trained models
|
||||
torch.save(global_server_model.state_dict(), "server_model.pth")
|
||||
torch.save(client_models[0].state_dict(), "client_model.pth")
|
||||
for i in range(NUM_CLIENTS):
|
||||
torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth")
|
||||
print("Models saved successfully.")
|
||||
|
||||
# Test server model
|
||||
server_model = repvit_m1_1(num_classes=10).to(device)
|
||||
server_model.load_state_dict(torch.load("server_model.pth"))
|
||||
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
|
||||
server_acc = test_model(server_model, test_loader)
|
||||
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
||||
|
||||
# Test client model
|
||||
client_model = MobileNetV3(n_class=10).to(device)
|
||||
client_model.load_state_dict(torch.load("client_model.pth"))
|
||||
for i in range(NUM_CLIENTS):
|
||||
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
|
||||
client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True))
|
||||
client_acc = test_model(client_model, test_loader)
|
||||
print(f"Client Model Test Accuracy: {client_acc:.2f}%")
|
||||
print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -200,6 +200,11 @@ class MobileNetV3(nn.Module):
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
|
||||
def extract_features(self, x):
|
||||
x = self.features(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.mean(3).mean(2)
|
||||
|
@ -236,6 +236,10 @@ class RepViT(nn.Module):
|
||||
self.features = nn.ModuleList(layers)
|
||||
self.classifier = Classfier(output_channel, num_classes, distillation)
|
||||
|
||||
def extract_features(self, x):
|
||||
for f in self.features:
|
||||
x = f(x)
|
||||
return x
|
||||
def forward(self, x):
|
||||
# x = self.features(x)
|
||||
for f in self.features:
|
||||
|
Loading…
Reference in New Issue
Block a user