Compare commits
No commits in common. "617230e296b4ad040a3fc615707a0413f7cdbd38" and "30eeff4b1dd8759d2b07605968e247b710d6afc5" have entirely different histories.
617230e296
...
30eeff4b1d
95
FED.py
95
FED.py
@ -12,7 +12,7 @@ from model.repvit import repvit_m1_1
|
||||
from model.mobilenetv3 import MobileNetV3
|
||||
|
||||
# 配置参数
|
||||
NUM_CLIENTS = 2
|
||||
NUM_CLIENTS = 4
|
||||
NUM_ROUNDS = 3
|
||||
CLIENT_EPOCHS = 5
|
||||
BATCH_SIZE = 32
|
||||
@ -22,27 +22,25 @@ TEMP = 2.0 # 蒸馏温度
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 数据准备
|
||||
import os
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
def prepare_data():
|
||||
def prepare_data(num_clients):
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
transforms.Resize((224, 224)), # 将图像调整为 224x224
|
||||
transforms.Grayscale(num_output_channels=3),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
train_set = datasets.MNIST("./data", train=True, download=True, transform=transform)
|
||||
|
||||
# Load datasets
|
||||
dataset_A = ImageFolder(root='./dataset_A/train', transform=transform)
|
||||
dataset_B = ImageFolder(root='./dataset_B/train', transform=transform)
|
||||
dataset_C = ImageFolder(root='./dataset_C/train', transform=transform)
|
||||
# 非IID数据划分(每个客户端2个类别)
|
||||
client_data = {i: [] for i in range(num_clients)}
|
||||
labels = train_set.targets.numpy()
|
||||
for label in range(10):
|
||||
label_idx = np.where(labels == label)[0]
|
||||
np.random.shuffle(label_idx)
|
||||
split = np.array_split(label_idx, num_clients//2)
|
||||
for i, idx in enumerate(split):
|
||||
client_data[i*2 + label%2].extend(idx)
|
||||
|
||||
# Assign datasets to clients
|
||||
client_datasets = [dataset_B, dataset_C]
|
||||
|
||||
# Server dataset (A) for public updates
|
||||
public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
return client_datasets, public_loader
|
||||
return [Subset(train_set, ids) for ids in client_data.values()]
|
||||
|
||||
# 客户端训练函数
|
||||
def client_train(client_model, server_model, dataset):
|
||||
@ -191,47 +189,63 @@ def test_model(model, test_loader):
|
||||
|
||||
# 主训练流程
|
||||
def main():
|
||||
# Initialize models
|
||||
# 初始化模型
|
||||
global_server_model = repvit_m1_1(num_classes=10).to(device)
|
||||
client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)]
|
||||
|
||||
# Prepare data
|
||||
client_datasets, public_loader = prepare_data()
|
||||
|
||||
# Test dataset (using dataset A's test set for simplicity)
|
||||
test_dataset = ImageFolder(root='./dataset_A/test', transform=transform)
|
||||
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
||||
|
||||
|
||||
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
||||
|
||||
# 准备数据
|
||||
client_datasets = prepare_data(NUM_CLIENTS)
|
||||
public_loader = DataLoader(
|
||||
datasets.MNIST("./data", train=False, download=True,
|
||||
transform= transforms.Compose([
|
||||
transforms.Resize((224, 224)), # 将图像调整为 224x224
|
||||
transforms.Grayscale(num_output_channels=3),
|
||||
transforms.ToTensor() # 将图像转换为张量
|
||||
])),
|
||||
batch_size=100, shuffle=True)
|
||||
|
||||
test_dataset = datasets.MNIST(
|
||||
"./data",
|
||||
train=False,
|
||||
transform= transforms.Compose([
|
||||
transforms.Resize((224, 224)), # 将图像调整为 224x224
|
||||
transforms.Grayscale(num_output_channels=3),
|
||||
transforms.ToTensor() # 将图像转换为张量
|
||||
])
|
||||
)
|
||||
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
||||
|
||||
for round in range(NUM_ROUNDS):
|
||||
print(f"\n{'#'*50}")
|
||||
print(f"Federated Round {round+1}/{NUM_ROUNDS}")
|
||||
print(f"{'#'*50}")
|
||||
|
||||
# Client selection (only 2 clients)
|
||||
# 客户端选择
|
||||
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
|
||||
print(f"Selected Clients: {selected_clients}")
|
||||
|
||||
# Client local training
|
||||
# 客户端本地训练
|
||||
client_params = []
|
||||
for cid in selected_clients:
|
||||
print(f"\nTraining Client {cid}")
|
||||
local_model = copy.deepcopy(client_models[cid])
|
||||
local_model.load_state_dict(client_models[cid].state_dict())
|
||||
|
||||
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
|
||||
client_params.append(updated_params)
|
||||
|
||||
# Model aggregation
|
||||
# 模型聚合
|
||||
global_client_params = aggregate(client_params)
|
||||
for model in client_models:
|
||||
model.load_state_dict(global_client_params)
|
||||
|
||||
# Server knowledge update
|
||||
# 服务器知识更新
|
||||
print("\nServer Updating...")
|
||||
server_update(global_server_model, client_models, public_loader)
|
||||
|
||||
# Test model performance
|
||||
# 测试模型性能
|
||||
server_acc = test_model(global_server_model, test_loader)
|
||||
client_acc = test_model(client_models[0], test_loader)
|
||||
print(f"\nRound {round+1} Performance:")
|
||||
@ -239,22 +253,25 @@ def main():
|
||||
print(f"Client Model Accuracy: {client_acc:.2f}%")
|
||||
|
||||
round_progress.update(1)
|
||||
|
||||
print(f"Round {round+1} completed")
|
||||
|
||||
print("Training completed!")
|
||||
|
||||
# Save trained models
|
||||
|
||||
# 保存训练好的模型
|
||||
torch.save(global_server_model.state_dict(), "server_model.pth")
|
||||
torch.save(client_models[0].state_dict(), "client_model.pth")
|
||||
print("Models saved successfully.")
|
||||
|
||||
# Test server model
|
||||
|
||||
# 创建测试数据加载器
|
||||
|
||||
# 测试服务器模型
|
||||
server_model = repvit_m1_1(num_classes=10).to(device)
|
||||
server_model.load_state_dict(torch.load("server_model.pth"))
|
||||
server_acc = test_model(server_model, test_loader)
|
||||
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
||||
|
||||
# Test client model
|
||||
|
||||
# 测试客户端模型
|
||||
client_model = MobileNetV3(n_class=10).to(device)
|
||||
client_model.load_state_dict(torch.load("client_model.pth"))
|
||||
client_acc = test_model(client_model, test_loader)
|
||||
|
@ -8,7 +8,7 @@ class Config:
|
||||
|
||||
# 训练参数
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
batch_size = 128
|
||||
batch_size = 32
|
||||
epochs = 150
|
||||
learning_rate = 0.001
|
||||
save_path = "checkpoints/best_model.pth"
|
||||
@ -22,6 +22,4 @@ class Config:
|
||||
checkpoint_path = "checkpoints/last_checkpoint.pth"
|
||||
output_path = "runs/"
|
||||
|
||||
cache = 'RAM'
|
||||
|
||||
config = Config()
|
118
data_loader.py
118
data_loader.py
@ -1,68 +1,31 @@
|
||||
import os
|
||||
from logger import logger
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import datasets, transforms
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
class ImageClassificationDataset(Dataset):
|
||||
def __init__(self, root_dir, transform=None,Cache=False):
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.classes = sorted(os.listdir(root_dir))
|
||||
self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
|
||||
self.image_paths = []
|
||||
self.image = []
|
||||
self.labels = []
|
||||
self.Cache = Cache
|
||||
|
||||
logger.log("info",
|
||||
"init the dataloader"
|
||||
)
|
||||
|
||||
for cls_name in self.classes:
|
||||
cls_dir = os.path.join(root_dir, cls_name)
|
||||
for img_name in os.listdir(cls_dir):
|
||||
try:
|
||||
img_path = os.path.join(cls_dir, img_name)
|
||||
imgs = Image.open(img_path).convert('RGB')
|
||||
if Cache == 'RAM':
|
||||
if self.transform:
|
||||
imgs = self.transform(imgs)
|
||||
self.image.append(imgs)
|
||||
else:
|
||||
self.image_paths.append(img_path)
|
||||
self.labels.append(self.class_to_idx[cls_name])
|
||||
except:
|
||||
logger.log("info",
|
||||
"read image error " +
|
||||
img_path
|
||||
)
|
||||
class ClassifyDataset(Dataset):
|
||||
def __init__(self, data_dir,transforms = None):
|
||||
self.data_dir = data_dir
|
||||
# Assume the dataset is structured with subdirectories for each class
|
||||
self.transform = transforms
|
||||
self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform)
|
||||
self.image_size = (3, 224, 224)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
label = self.labels[idx]
|
||||
if self.Cache == 'RAM':
|
||||
image = self.image[idx]
|
||||
else:
|
||||
img_path = self.image_paths[idx]
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
return image, label
|
||||
|
||||
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):
|
||||
# Define the transform for the training data and for the validation data
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)), # Resize images to 224x224
|
||||
transforms.ToTensor(), # Convert PIL Image to Tensor
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize the images
|
||||
])
|
||||
|
||||
# Define transformations for training data augmentation and normalization
|
||||
try:
|
||||
image, label = self.dataset[idx]
|
||||
return image, label
|
||||
except Exception as e:
|
||||
black_image = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||
return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签
|
||||
|
||||
def create_data_loaders(data_dir,batch_size=64):
|
||||
# Define transformations for training data augmentation and normalization
|
||||
train_transforms = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
@ -77,38 +40,17 @@ def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cach
|
||||
])
|
||||
|
||||
# Load the datasets with ImageFolder
|
||||
train_dir = root_dir + '/train'
|
||||
valid_dir = root_dir + '/val'
|
||||
test_dir = root_dir + '/test'
|
||||
train_dir = data_dir + '/train'
|
||||
valid_dir = data_dir + '/val'
|
||||
test_dir = data_dir + '/test'
|
||||
|
||||
train_data = ImageClassificationDataset(train_dir, transform=train_transforms,Cache=Cache)
|
||||
valid_data = ImageClassificationDataset(valid_dir, transform=valid_test_transforms,Cache=Cache)
|
||||
test_data = ImageClassificationDataset(test_dir, transform=valid_test_transforms,Cache=Cache)
|
||||
train_data = ClassifyDataset(train_dir, transforms=train_transforms)
|
||||
valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms)
|
||||
test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms)
|
||||
|
||||
# Create the DataLoaders with batch size 64
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
|
||||
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
|
||||
|
||||
# Create the data loader
|
||||
train_loader = DataLoader(
|
||||
train_data,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory
|
||||
)
|
||||
|
||||
# Create the data loader
|
||||
valid_loader = DataLoader(
|
||||
valid_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory
|
||||
)
|
||||
|
||||
# Create the data loader
|
||||
test_loader = DataLoader(
|
||||
test_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory
|
||||
)
|
||||
|
||||
return train_loader, valid_loader, test_loader
|
||||
return train_loader, valid_loader, test_loader
|
@ -106,15 +106,15 @@ def process_images(input_folder, background_image_path, output_base):
|
||||
递归处理所有子文件夹并保持目录结构
|
||||
"""
|
||||
# 预处理背景路径(只需执行一次)
|
||||
# if os.path.isfile(background_image_path):
|
||||
# background_paths = [background_image_path]
|
||||
# else:
|
||||
# valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
|
||||
# background_paths = [
|
||||
# os.path.join(background_image_path, f)
|
||||
# for f in os.listdir(background_image_path)
|
||||
# if os.path.splitext(f)[1].lower() in valid_ext
|
||||
# ]
|
||||
if os.path.isfile(background_image_path):
|
||||
background_paths = [background_image_path]
|
||||
else:
|
||||
valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
|
||||
background_paths = [
|
||||
os.path.join(background_image_path, f)
|
||||
for f in os.listdir(background_image_path)
|
||||
if os.path.splitext(f)[1].lower() in valid_ext
|
||||
]
|
||||
|
||||
# 递归遍历输入目录
|
||||
for root, dirs, files in os.walk(input_folder):
|
||||
@ -136,10 +136,10 @@ def process_images(input_folder, background_image_path, output_base):
|
||||
|
||||
try:
|
||||
# 去背景处理
|
||||
result = remove_background(input_path)
|
||||
foreground = remove_background(input_path)
|
||||
|
||||
|
||||
# result = edge_fill2(result)
|
||||
result = edge_fill2(foreground)
|
||||
|
||||
# 保存结果
|
||||
cv2.imwrite(output_path, result)
|
||||
@ -150,8 +150,8 @@ def process_images(input_folder, background_image_path, output_base):
|
||||
|
||||
|
||||
# 使用示例
|
||||
input_directory = 'L:/Grade_datasets/JY_A'
|
||||
input_directory = 'L:/Tobacco/2023_JY/20230821/SOURCE'
|
||||
background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B'
|
||||
output_directory = 'L:/Grade_datasets/MOVE_BACKGROUND'
|
||||
output_directory = 'L:/Test'
|
||||
|
||||
process_images(input_directory, background_image_path, output_directory)
|
@ -1,12 +0,0 @@
|
||||
import os
|
||||
|
||||
def debug_walk_with_links(input_folder):
|
||||
for root, dirs, files in os.walk(input_folder):
|
||||
print(f'Root: {root}')
|
||||
print(f'Dirs: {dirs}')
|
||||
print(f'Files: {files}')
|
||||
print('-' * 40)
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_folder = 'L:/Grade_datasets'
|
||||
debug_walk_with_links(input_folder)
|
5
main.py
5
main.py
@ -7,7 +7,6 @@ from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from model.repvit import *
|
||||
from model.mobilenetv3 import *
|
||||
from data_loader import *
|
||||
from utils import *
|
||||
|
||||
@ -15,11 +14,11 @@ def main():
|
||||
# 初始化组件
|
||||
initialize()
|
||||
|
||||
model = repvit_m1_0(num_classes=9).to(config.device)
|
||||
model = repvit_m1_1(num_classes=10).to(config.device)
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
train_loader, valid_loader, test_loader = get_data_loader('/home/yoiannis/deep_learning/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size,Cache='RAM')
|
||||
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size)
|
||||
|
||||
# 初始化训练器
|
||||
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
|
||||
|
@ -4,7 +4,6 @@ from torch.utils.data import DataLoader
|
||||
from config import config
|
||||
from logger import logger
|
||||
from utils import save_checkpoint, load_checkpoint
|
||||
import time
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, model, train_loader, val_loader, optimizer, criterion):
|
||||
@ -22,7 +21,7 @@ class Trainer:
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
|
||||
time_start = time.time()
|
||||
|
||||
for batch_idx, (data, target) in enumerate(progress_bar):
|
||||
data, target = data.to(config.device), target.to(config.device)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user