This commit is contained in:
yoiannis 2025-03-12 09:38:48 +08:00
commit 4cb9790dee
9 changed files with 207 additions and 106 deletions

154
FED.py
View File

@ -12,35 +12,38 @@ from model.repvit import repvit_m1_1
from model.mobilenetv3 import MobileNetV3 from model.mobilenetv3 import MobileNetV3
# 配置参数 # 配置参数
NUM_CLIENTS = 4 NUM_CLIENTS = 2
NUM_ROUNDS = 3 NUM_ROUNDS = 10
CLIENT_EPOCHS = 5 CLIENT_EPOCHS = 2
BATCH_SIZE = 32 BATCH_SIZE = 32
TEMP = 2.0 # 蒸馏温度 TEMP = 2.0 # 蒸馏温度
CLASS_NUM = [3, 3, 3]
# 设备配置 # 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据准备 # 数据准备
def prepare_data(num_clients): import os
from torchvision.datasets import ImageFolder
def prepare_data():
transform = transforms.Compose([ transform = transforms.Compose([
transforms.Resize((224, 224)), # 将图像调整为 224x224 transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=3), transforms.ToTensor()
transforms.ToTensor() ])
])
train_set = datasets.MNIST("./data", train=True, download=True, transform=transform)
# 非IID数据划分每个客户端2个类别 # Load datasets
client_data = {i: [] for i in range(num_clients)} dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform)
labels = train_set.targets.numpy() dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform)
for label in range(10): dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform)
label_idx = np.where(labels == label)[0]
np.random.shuffle(label_idx)
split = np.array_split(label_idx, num_clients//2)
for i, idx in enumerate(split):
client_data[i*2 + label%2].extend(idx)
return [Subset(train_set, ids) for ids in client_data.values()] # Assign datasets to clients
client_datasets = [dataset_B, dataset_C]
# Server dataset (A) for public updates
public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True)
return client_datasets, public_loader
# 客户端训练函数 # 客户端训练函数
def client_train(client_model, server_model, dataset): def client_train(client_model, server_model, dataset):
@ -103,13 +106,6 @@ def client_train(client_model, server_model, dataset):
}) })
progress_bar.update(1) progress_bar.update(1)
# 每10个batch打印详细信息
if (batch_idx + 1) % 10 == 0:
progress_bar.write(f"\nEpoch {epoch+1} | Batch {batch_idx+1}")
progress_bar.write(f"Task Loss: {loss_task:.4f}")
progress_bar.write(f"Distill Loss: {loss_distill:.4f}")
progress_bar.write(f"Total Loss: {total_loss:.4f}")
progress_bar.write(f"Batch Accuracy: {100*correct/total:.2f}%\n")
# 每个epoch结束打印汇总信息 # 每个epoch结束打印汇总信息
avg_loss = epoch_loss / len(loader) avg_loss = epoch_loss / len(loader)
avg_task = task_loss / len(loader) avg_task = task_loss / len(loader)
@ -133,6 +129,37 @@ def aggregate(client_params):
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0) global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
return global_params return global_params
def server_aggregate(server_model, client_models, public_loader):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
for data, _ in public_loader:
data = data.to(device)
# 获取客户端模型特征
client_features = []
with torch.no_grad():
for model in client_models:
features = model.extract_features(data) # 需要实现特征提取方法
client_features.append(features)
# 计算特征蒸馏目标
target_features = torch.stack(client_features).mean(dim=0)
# 服务器前向
server_features = server_model.extract_features(data)
# 特征对齐损失
loss = F.mse_loss(server_features, target_features)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 更新统计信息
total_loss += loss.item()
# 服务器知识更新 # 服务器知识更新
def server_update(server_model, client_models, public_loader): def server_update(server_model, client_models, public_loader):
server_model.train() server_model.train()
@ -189,63 +216,51 @@ def test_model(model, test_loader):
# 主训练流程 # 主训练流程
def main(): def main():
# 初始化模型 transform = transforms.Compose([
global_server_model = repvit_m1_1(num_classes=10).to(device) transforms.Resize((224, 224)),
client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)] transforms.ToTensor()
])
# Initialize models
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
# Prepare data
client_datasets, public_loader = prepare_data()
# Test dataset (using dataset A's test set for simplicity)
test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round") round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
# 准备数据
client_datasets = prepare_data(NUM_CLIENTS)
public_loader = DataLoader(
datasets.MNIST("./data", train=False, download=True,
transform= transforms.Compose([
transforms.Resize((224, 224)), # 将图像调整为 224x224
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor() # 将图像转换为张量
])),
batch_size=100, shuffle=True)
test_dataset = datasets.MNIST(
"./data",
train=False,
transform= transforms.Compose([
transforms.Resize((224, 224)), # 将图像调整为 224x224
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor() # 将图像转换为张量
])
)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
for round in range(NUM_ROUNDS): for round in range(NUM_ROUNDS):
print(f"\n{'#'*50}") print(f"\n{'#'*50}")
print(f"Federated Round {round+1}/{NUM_ROUNDS}") print(f"Federated Round {round+1}/{NUM_ROUNDS}")
print(f"{'#'*50}") print(f"{'#'*50}")
# 客户端选择 # Client selection (only 2 clients)
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False) selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
print(f"Selected Clients: {selected_clients}") print(f"Selected Clients: {selected_clients}")
# 客户端本地训练 # Client local training
client_params = [] client_params = []
for cid in selected_clients: for cid in selected_clients:
print(f"\nTraining Client {cid}") print(f"\nTraining Client {cid}")
local_model = copy.deepcopy(client_models[cid]) local_model = copy.deepcopy(client_models[cid])
local_model.load_state_dict(client_models[cid].state_dict()) local_model.load_state_dict(client_models[cid].state_dict())
updated_params = client_train(local_model, global_server_model, client_datasets[cid]) updated_params = client_train(local_model, global_server_model, client_datasets[cid])
client_params.append(updated_params) client_params.append(updated_params)
# 模型聚合 # Model aggregation
global_client_params = aggregate(client_params) global_client_params = aggregate(client_params)
for model in client_models: for model in client_models:
model.load_state_dict(global_client_params) model.load_state_dict(global_client_params)
# 服务器知识更新 # Server knowledge update
print("\nServer Updating...") print("\nServer Updating...")
server_update(global_server_model, client_models, public_loader) server_update(global_server_model, client_models, public_loader)
# 测试模型性能 # Test model performance
server_acc = test_model(global_server_model, test_loader) server_acc = test_model(global_server_model, test_loader)
client_acc = test_model(client_models[0], test_loader) client_acc = test_model(client_models[0], test_loader)
print(f"\nRound {round+1} Performance:") print(f"\nRound {round+1} Performance:")
@ -253,29 +268,28 @@ def main():
print(f"Client Model Accuracy: {client_acc:.2f}%") print(f"Client Model Accuracy: {client_acc:.2f}%")
round_progress.update(1) round_progress.update(1)
print(f"Round {round+1} completed") print(f"Round {round+1} completed")
print("Training completed!") print("Training completed!")
# 保存训练好的模型 # Save trained models
torch.save(global_server_model.state_dict(), "server_model.pth") torch.save(global_server_model.state_dict(), "server_model.pth")
torch.save(client_models[0].state_dict(), "client_model.pth") for i in range(NUM_CLIENTS):
torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth")
print("Models saved successfully.") print("Models saved successfully.")
# 创建测试数据加载器 # Test server model
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
# 测试服务器模型 server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
server_model = repvit_m1_1(num_classes=10).to(device)
server_model.load_state_dict(torch.load("server_model.pth"))
server_acc = test_model(server_model, test_loader) server_acc = test_model(server_model, test_loader)
print(f"Server Model Test Accuracy: {server_acc:.2f}%") print(f"Server Model Test Accuracy: {server_acc:.2f}%")
# 测试客户端模型 # Test client model
client_model = MobileNetV3(n_class=10).to(device) for i in range(NUM_CLIENTS):
client_model.load_state_dict(torch.load("client_model.pth")) client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
client_acc = test_model(client_model, test_loader) client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True))
print(f"Client Model Test Accuracy: {client_acc:.2f}%") client_acc = test_model(client_model, test_loader)
print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

4
README.md Normal file
View File

@ -0,0 +1,4 @@
# 1.相关知识
```
https://github.com/Hao840/OFAKD
```

View File

@ -8,7 +8,7 @@ class Config:
# 训练参数 # 训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32 batch_size = 128
epochs = 150 epochs = 150
learning_rate = 0.001 learning_rate = 0.001
save_path = "checkpoints/best_model.pth" save_path = "checkpoints/best_model.pth"
@ -22,4 +22,6 @@ class Config:
checkpoint_path = "checkpoints/last_checkpoint.pth" checkpoint_path = "checkpoints/last_checkpoint.pth"
output_path = "runs/" output_path = "runs/"
cache = 'RAM'
config = Config() config = Config()

View File

@ -1,31 +1,68 @@
import os import os
from logger import logger
from PIL import Image from PIL import Image
import numpy as np
import torch import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class ClassifyDataset(Dataset): class ImageClassificationDataset(Dataset):
def __init__(self, data_dir,transforms = None): def __init__(self, root_dir, transform=None,Cache=False):
self.data_dir = data_dir self.root_dir = root_dir
# Assume the dataset is structured with subdirectories for each class self.transform = transform
self.transform = transforms self.classes = sorted(os.listdir(root_dir))
self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform) self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
self.image_size = (3, 224, 224) self.image_paths = []
self.image = []
self.labels = []
self.Cache = Cache
logger.log("info",
"init the dataloader"
)
for cls_name in self.classes:
cls_dir = os.path.join(root_dir, cls_name)
for img_name in os.listdir(cls_dir):
try:
img_path = os.path.join(cls_dir, img_name)
imgs = Image.open(img_path).convert('RGB')
if Cache == 'RAM':
if self.transform:
imgs = self.transform(imgs)
self.image.append(imgs)
else:
self.image_paths.append(img_path)
self.labels.append(self.class_to_idx[cls_name])
except:
logger.log("info",
"read image error " +
img_path
)
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.labels)
def __getitem__(self, idx): def __getitem__(self, idx):
try: label = self.labels[idx]
image, label = self.dataset[idx] if self.Cache == 'RAM':
return image, label image = self.image[idx]
except Exception as e: else:
black_image = np.zeros((224, 224, 3), dtype=np.uint8) img_path = self.image_paths[idx]
return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签 image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
def create_data_loaders(data_dir,batch_size=64): return image, label
# Define transformations for training data augmentation and normalization
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):
# Define the transform for the training data and for the validation data
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to 224x224
transforms.ToTensor(), # Convert PIL Image to Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize the images
])
# Define transformations for training data augmentation and normalization
train_transforms = transforms.Compose([ train_transforms = transforms.Compose([
transforms.Resize((224, 224)), transforms.Resize((224, 224)),
transforms.ToTensor(), transforms.ToTensor(),
@ -40,17 +77,38 @@ def create_data_loaders(data_dir,batch_size=64):
]) ])
# Load the datasets with ImageFolder # Load the datasets with ImageFolder
train_dir = data_dir + '/train' train_dir = root_dir + '/train'
valid_dir = data_dir + '/val' valid_dir = root_dir + '/val'
test_dir = data_dir + '/test' test_dir = root_dir + '/test'
train_data = ClassifyDataset(train_dir, transforms=train_transforms) train_data = ImageClassificationDataset(train_dir, transform=train_transforms,Cache=Cache)
valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms) valid_data = ImageClassificationDataset(valid_dir, transform=valid_test_transforms,Cache=Cache)
test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms) test_data = ImageClassificationDataset(test_dir, transform=valid_test_transforms,Cache=Cache)
# Create the DataLoaders with batch size 64
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) # Create the data loader
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size) train_loader = DataLoader(
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size) train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory
)
# Create the data loader
valid_loader = DataLoader(
valid_data,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory
)
# Create the data loader
test_loader = DataLoader(
test_data,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory
)
return train_loader, valid_loader, test_loader return train_loader, valid_loader, test_loader

12
dataset/test.py Normal file
View File

@ -0,0 +1,12 @@
import os
def debug_walk_with_links(input_folder):
for root, dirs, files in os.walk(input_folder):
print(f'Root: {root}')
print(f'Dirs: {dirs}')
print(f'Files: {files}')
print('-' * 40)
if __name__ == "__main__":
input_folder = 'L:/Grade_datasets'
debug_walk_with_links(input_folder)

View File

@ -7,6 +7,7 @@ from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor from torchvision.transforms import ToTensor
from model.repvit import * from model.repvit import *
from model.mobilenetv3 import *
from data_loader import * from data_loader import *
from utils import * from utils import *
@ -14,11 +15,11 @@ def main():
# 初始化组件 # 初始化组件
initialize() initialize()
model = repvit_m1_1(num_classes=10).to(config.device) model = repvit_m1_0(num_classes=9).to(config.device)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size) train_loader, valid_loader, test_loader = get_data_loader('/home/yoiannis/deep_learning/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size,Cache='RAM')
# 初始化训练器 # 初始化训练器
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion) trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)

View File

@ -200,6 +200,11 @@ class MobileNetV3(nn.Module):
self._initialize_weights() self._initialize_weights()
def extract_features(self, x):
x = self.features(x)
return x
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)
x = x.mean(3).mean(2) x = x.mean(3).mean(2)

View File

@ -236,6 +236,10 @@ class RepViT(nn.Module):
self.features = nn.ModuleList(layers) self.features = nn.ModuleList(layers)
self.classifier = Classfier(output_channel, num_classes, distillation) self.classifier = Classfier(output_channel, num_classes, distillation)
def extract_features(self, x):
for f in self.features:
x = f(x)
return x
def forward(self, x): def forward(self, x):
# x = self.features(x) # x = self.features(x)
for f in self.features: for f in self.features:

View File

@ -4,6 +4,7 @@ from torch.utils.data import DataLoader
from config import config from config import config
from logger import logger from logger import logger
from utils import save_checkpoint, load_checkpoint from utils import save_checkpoint, load_checkpoint
import time
class Trainer: class Trainer:
def __init__(self, model, train_loader, val_loader, optimizer, criterion): def __init__(self, model, train_loader, val_loader, optimizer, criterion):
@ -21,7 +22,7 @@ class Trainer:
self.model.train() self.model.train()
total_loss = 0.0 total_loss = 0.0
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}") progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
time_start = time.time()
for batch_idx, (data, target) in enumerate(progress_bar): for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(config.device), target.to(config.device) data, target = data.to(config.device), target.to(config.device)