Compare commits

..

No commits in common. "617230e296b4ad040a3fc615707a0413f7cdbd38" and "30eeff4b1dd8759d2b07605968e247b710d6afc5" have entirely different histories.

8 changed files with 103 additions and 164 deletions

95
FED.py
View File

@ -12,7 +12,7 @@ from model.repvit import repvit_m1_1
from model.mobilenetv3 import MobileNetV3
# 配置参数
NUM_CLIENTS = 2
NUM_CLIENTS = 4
NUM_ROUNDS = 3
CLIENT_EPOCHS = 5
BATCH_SIZE = 32
@ -22,27 +22,25 @@ TEMP = 2.0 # 蒸馏温度
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据准备
import os
from torchvision.datasets import ImageFolder
def prepare_data():
def prepare_data(num_clients):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
transforms.Resize((224, 224)), # 将图像调整为 224x224
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor()
])
train_set = datasets.MNIST("./data", train=True, download=True, transform=transform)
# Load datasets
dataset_A = ImageFolder(root='./dataset_A/train', transform=transform)
dataset_B = ImageFolder(root='./dataset_B/train', transform=transform)
dataset_C = ImageFolder(root='./dataset_C/train', transform=transform)
# 非IID数据划分每个客户端2个类别
client_data = {i: [] for i in range(num_clients)}
labels = train_set.targets.numpy()
for label in range(10):
label_idx = np.where(labels == label)[0]
np.random.shuffle(label_idx)
split = np.array_split(label_idx, num_clients//2)
for i, idx in enumerate(split):
client_data[i*2 + label%2].extend(idx)
# Assign datasets to clients
client_datasets = [dataset_B, dataset_C]
# Server dataset (A) for public updates
public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True)
return client_datasets, public_loader
return [Subset(train_set, ids) for ids in client_data.values()]
# 客户端训练函数
def client_train(client_model, server_model, dataset):
@ -191,47 +189,63 @@ def test_model(model, test_loader):
# 主训练流程
def main():
# Initialize models
# 初始化模型
global_server_model = repvit_m1_1(num_classes=10).to(device)
client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)]
# Prepare data
client_datasets, public_loader = prepare_data()
# Test dataset (using dataset A's test set for simplicity)
test_dataset = ImageFolder(root='./dataset_A/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
# 准备数据
client_datasets = prepare_data(NUM_CLIENTS)
public_loader = DataLoader(
datasets.MNIST("./data", train=False, download=True,
transform= transforms.Compose([
transforms.Resize((224, 224)), # 将图像调整为 224x224
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor() # 将图像转换为张量
])),
batch_size=100, shuffle=True)
test_dataset = datasets.MNIST(
"./data",
train=False,
transform= transforms.Compose([
transforms.Resize((224, 224)), # 将图像调整为 224x224
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor() # 将图像转换为张量
])
)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
for round in range(NUM_ROUNDS):
print(f"\n{'#'*50}")
print(f"Federated Round {round+1}/{NUM_ROUNDS}")
print(f"{'#'*50}")
# Client selection (only 2 clients)
# 客户端选择
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
print(f"Selected Clients: {selected_clients}")
# Client local training
# 客户端本地训练
client_params = []
for cid in selected_clients:
print(f"\nTraining Client {cid}")
local_model = copy.deepcopy(client_models[cid])
local_model.load_state_dict(client_models[cid].state_dict())
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
client_params.append(updated_params)
# Model aggregation
# 模型聚合
global_client_params = aggregate(client_params)
for model in client_models:
model.load_state_dict(global_client_params)
# Server knowledge update
# 服务器知识更新
print("\nServer Updating...")
server_update(global_server_model, client_models, public_loader)
# Test model performance
# 测试模型性能
server_acc = test_model(global_server_model, test_loader)
client_acc = test_model(client_models[0], test_loader)
print(f"\nRound {round+1} Performance:")
@ -239,22 +253,25 @@ def main():
print(f"Client Model Accuracy: {client_acc:.2f}%")
round_progress.update(1)
print(f"Round {round+1} completed")
print("Training completed!")
# Save trained models
# 保存训练好的模型
torch.save(global_server_model.state_dict(), "server_model.pth")
torch.save(client_models[0].state_dict(), "client_model.pth")
print("Models saved successfully.")
# Test server model
# 创建测试数据加载器
# 测试服务器模型
server_model = repvit_m1_1(num_classes=10).to(device)
server_model.load_state_dict(torch.load("server_model.pth"))
server_acc = test_model(server_model, test_loader)
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
# Test client model
# 测试客户端模型
client_model = MobileNetV3(n_class=10).to(device)
client_model.load_state_dict(torch.load("client_model.pth"))
client_acc = test_model(client_model, test_loader)

View File

@ -1,4 +0,0 @@
# 1.相关知识
```
https://github.com/Hao840/OFAKD
```

View File

@ -8,7 +8,7 @@ class Config:
# 训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
batch_size = 32
epochs = 150
learning_rate = 0.001
save_path = "checkpoints/best_model.pth"
@ -22,6 +22,4 @@ class Config:
checkpoint_path = "checkpoints/last_checkpoint.pth"
output_path = "runs/"
cache = 'RAM'
config = Config()

View File

@ -1,68 +1,31 @@
import os
from logger import logger
from PIL import Image
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class ImageClassificationDataset(Dataset):
def __init__(self, root_dir, transform=None,Cache=False):
self.root_dir = root_dir
self.transform = transform
self.classes = sorted(os.listdir(root_dir))
self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
self.image_paths = []
self.image = []
self.labels = []
self.Cache = Cache
logger.log("info",
"init the dataloader"
)
for cls_name in self.classes:
cls_dir = os.path.join(root_dir, cls_name)
for img_name in os.listdir(cls_dir):
try:
img_path = os.path.join(cls_dir, img_name)
imgs = Image.open(img_path).convert('RGB')
if Cache == 'RAM':
if self.transform:
imgs = self.transform(imgs)
self.image.append(imgs)
else:
self.image_paths.append(img_path)
self.labels.append(self.class_to_idx[cls_name])
except:
logger.log("info",
"read image error " +
img_path
)
class ClassifyDataset(Dataset):
def __init__(self, data_dir,transforms = None):
self.data_dir = data_dir
# Assume the dataset is structured with subdirectories for each class
self.transform = transforms
self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform)
self.image_size = (3, 224, 224)
def __len__(self):
return len(self.labels)
return len(self.dataset)
def __getitem__(self, idx):
label = self.labels[idx]
if self.Cache == 'RAM':
image = self.image[idx]
else:
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):
# Define the transform for the training data and for the validation data
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to 224x224
transforms.ToTensor(), # Convert PIL Image to Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize the images
])
# Define transformations for training data augmentation and normalization
try:
image, label = self.dataset[idx]
return image, label
except Exception as e:
black_image = np.zeros((224, 224, 3), dtype=np.uint8)
return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签
def create_data_loaders(data_dir,batch_size=64):
# Define transformations for training data augmentation and normalization
train_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
@ -77,38 +40,17 @@ def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cach
])
# Load the datasets with ImageFolder
train_dir = root_dir + '/train'
valid_dir = root_dir + '/val'
test_dir = root_dir + '/test'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/val'
test_dir = data_dir + '/test'
train_data = ImageClassificationDataset(train_dir, transform=train_transforms,Cache=Cache)
valid_data = ImageClassificationDataset(valid_dir, transform=valid_test_transforms,Cache=Cache)
test_data = ImageClassificationDataset(test_dir, transform=valid_test_transforms,Cache=Cache)
train_data = ClassifyDataset(train_dir, transforms=train_transforms)
valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms)
test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms)
# Create the DataLoaders with batch size 64
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
# Create the data loader
train_loader = DataLoader(
train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory
)
# Create the data loader
valid_loader = DataLoader(
valid_data,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory
)
# Create the data loader
test_loader = DataLoader(
test_data,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory
)
return train_loader, valid_loader, test_loader
return train_loader, valid_loader, test_loader

View File

@ -106,15 +106,15 @@ def process_images(input_folder, background_image_path, output_base):
递归处理所有子文件夹并保持目录结构
"""
# 预处理背景路径(只需执行一次)
# if os.path.isfile(background_image_path):
# background_paths = [background_image_path]
# else:
# valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
# background_paths = [
# os.path.join(background_image_path, f)
# for f in os.listdir(background_image_path)
# if os.path.splitext(f)[1].lower() in valid_ext
# ]
if os.path.isfile(background_image_path):
background_paths = [background_image_path]
else:
valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
background_paths = [
os.path.join(background_image_path, f)
for f in os.listdir(background_image_path)
if os.path.splitext(f)[1].lower() in valid_ext
]
# 递归遍历输入目录
for root, dirs, files in os.walk(input_folder):
@ -136,10 +136,10 @@ def process_images(input_folder, background_image_path, output_base):
try:
# 去背景处理
result = remove_background(input_path)
foreground = remove_background(input_path)
# result = edge_fill2(result)
result = edge_fill2(foreground)
# 保存结果
cv2.imwrite(output_path, result)
@ -150,8 +150,8 @@ def process_images(input_folder, background_image_path, output_base):
# 使用示例
input_directory = 'L:/Grade_datasets/JY_A'
input_directory = 'L:/Tobacco/2023_JY/20230821/SOURCE'
background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B'
output_directory = 'L:/Grade_datasets/MOVE_BACKGROUND'
output_directory = 'L:/Test'
process_images(input_directory, background_image_path, output_directory)

View File

@ -1,12 +0,0 @@
import os
def debug_walk_with_links(input_folder):
for root, dirs, files in os.walk(input_folder):
print(f'Root: {root}')
print(f'Dirs: {dirs}')
print(f'Files: {files}')
print('-' * 40)
if __name__ == "__main__":
input_folder = 'L:/Grade_datasets'
debug_walk_with_links(input_folder)

View File

@ -7,7 +7,6 @@ from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from model.repvit import *
from model.mobilenetv3 import *
from data_loader import *
from utils import *
@ -15,11 +14,11 @@ def main():
# 初始化组件
initialize()
model = repvit_m1_0(num_classes=9).to(config.device)
model = repvit_m1_1(num_classes=10).to(config.device)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss()
train_loader, valid_loader, test_loader = get_data_loader('/home/yoiannis/deep_learning/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size,Cache='RAM')
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size)
# 初始化训练器
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)

View File

@ -4,7 +4,6 @@ from torch.utils.data import DataLoader
from config import config
from logger import logger
from utils import save_checkpoint, load_checkpoint
import time
class Trainer:
def __init__(self, model, train_loader, val_loader, optimizer, criterion):
@ -22,7 +21,7 @@ class Trainer:
self.model.train()
total_loss = 0.0
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
time_start = time.time()
for batch_idx, (data, target) in enumerate(progress_bar):
data, target = data.to(config.device), target.to(config.device)