From 52587d8c2651a2eabfbe6914b5d43918b32eaa04 Mon Sep 17 00:00:00 2001 From: yoiannis <13330431063> Date: Wed, 12 Mar 2025 14:26:35 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E6=A8=A1=E5=9E=8B=E6=97=A5?= =?UTF-8?q?=E5=BF=97=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- FED.py | 110 +++++++++++++++++++++++++++---------------------- data_loader.py | 6 +-- logger.py | 41 +++++++++++------- 3 files changed, 89 insertions(+), 68 deletions(-) diff --git a/FED.py b/FED.py index 8ec2835..b08ebe7 100644 --- a/FED.py +++ b/FED.py @@ -9,6 +9,7 @@ import copy from tqdm import tqdm from data_loader import get_data_loader +from logger import * from model.repvit import repvit_m1_1 from model.mobilenetv3 import MobileNetV3 @@ -34,9 +35,9 @@ def prepare_data(): ]) # 加载所有数据集(训练、验证、测试) - dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/JY_A',Cache='RAM') - dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_A',Cache='RAM') - dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_B',Cache='RAM') + dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/JY_A',Cache='RAM') + dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_A',Cache='RAM') + dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_B',Cache='RAM') # 组织客户端数据集 client_datasets = [ @@ -61,7 +62,7 @@ def prepare_data(): return client_datasets, public_loader, server_test_loader # 客户端训练函数 -def client_train(client_model, server_model, loader): +def client_train(client_model, server_model, loader, logger): client_model.train() server_model.eval() @@ -123,13 +124,12 @@ def client_train(client_model, server_model, loader): avg_task = task_loss / len(loader) avg_distill = distill_loss / len(loader) epoch_acc = 100 * correct / total - print(f"\n{'='*40}") - print(f"Epoch {epoch+1} Summary:") - print(f"Average Loss: {avg_loss:.4f}") - print(f"Task Loss: {avg_task:.4f}") - print(f"Distill Loss: {avg_distill:.4f}") - print(f"Training Accuracy: {epoch_acc:.2f}%") - print(f"{'='*40}\n") + + logger.info(f"\n客户端训练 Epoch {epoch+1}/{CLIENT_EPOCHS}") + logger.info(f"平均损失: {avg_loss:.4f}") + logger.info(f"任务损失: {avg_task:.4f}") + logger.info(f"蒸馏损失: {avg_distill:.4f}") + logger.info(f"训练准确率: {epoch_acc:.2f}%") progress_bar.close() return client_model.state_dict() @@ -173,7 +173,7 @@ def server_aggregate(server_model, client_models, public_loader): total_loss += loss.item() # 服务器知识更新 -def server_update(server_model, client_models, public_loader): +def server_update(server_model, client_models, public_loader, logger): server_model.train() optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) @@ -209,7 +209,7 @@ def server_update(server_model, client_models, public_loader): "Current Loss": f"{loss.item():.4f}" }) - print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n") + logger.info(f"服务器更新完成 | 平均损失: {total_loss/len(public_loader):.4f}") def test_model(model, test_loader): # 添加对DataLoader的支持 @@ -230,38 +230,48 @@ def test_model(model, test_loader): # 添加对DataLoader的支持 # 主训练流程 def main(): - # 初始化模型(保持不变) + # 创建运行目录 + run_dir = create_run_dir() + models_dir = os.path.join(run_dir, "models") + os.makedirs(models_dir, exist_ok=True) + + # 初始化日志 + logger = Logger(run_dir) + + # 初始化模型 global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)] - # 加载数据集 + # 加载数据 client_datasets, public_loader, server_test_loader = prepare_data() - round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round") - + # 记录关键配置 + logger.info(f"\n{'#'*30} 训练配置 {'#'*30}") + logger.info(f"设备: {device}") + logger.info(f"客户端数量: {NUM_CLIENTS}") + logger.info(f"训练轮次: {NUM_ROUNDS}") + logger.info(f"本地训练epoch数: {CLIENT_EPOCHS}") + logger.info(f"蒸馏温度: {TEMP}") + logger.info(f"运行目录: {run_dir}") + logger.info(f"{'#'*70}\n") + + # 主训练循环 for round in range(NUM_ROUNDS): - print(f"\n{'#'*50}") - print(f"Round {round+1}/{NUM_ROUNDS}") - print(f"{'#'*50}") + logger.info(f"\n{'#'*30} 联邦训练轮次 [{round+1}/{NUM_ROUNDS}] {'#'*30}") # 客户端选择 selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False) - print(f"Selected clients: {selected_clients}") + logger.info(f"选中客户端: {selected_clients}") # 客户端训练 client_params = [] for cid in selected_clients: - print(f"\nTraining Client {cid}") + logger.info(f"\n{'='*20} 客户端 {cid} 训练开始 {'='*20}") local_model = copy.deepcopy(client_models[cid]) - local_model.load_state_dict(client_models[cid].state_dict()) - - # 传入客户端的训练集 - updated_params = client_train( - local_model, - global_server_model, - client_datasets[cid]['train'] # 使用训练集 - ) + updated_params = client_train(local_model, global_server_model, + client_datasets[cid]['train'], logger) client_params.append(updated_params) + logger.info(f"\n{'='*20} 客户端 {cid} 训练完成 {'='*20}\n") # 模型聚合 global_client_params = aggregate(client_params) @@ -269,40 +279,40 @@ def main(): model.load_state_dict(global_client_params) # 服务器更新 - print("\nServer updating...") - server_update(global_server_model, client_models, public_loader) + logger.info("\n服务器知识蒸馏更新...") + server_update(global_server_model, client_models, public_loader, logger) - # 测试性能 + # 性能测试 server_acc = test_model(global_server_model, server_test_loader) client_accuracies = [ - test_model(client_models[i], - client_datasets[i]['test']) # 动态创建测试loader + test_model(client_models[i], client_datasets[i]['test']) for i in range(NUM_CLIENTS) ] - print(f"\nRound {round+1} Results:") - print(f"Server Accuracy: {server_acc:.2f}%") + # 记录本轮结果 + logger.info(f"\n本轮结果:") + logger.info(f"服务器测试准确率: {server_acc:.2f}%") for i, acc in enumerate(client_accuracies): - print(f"Client {i} Accuracy: {acc:.2f}%") - - round_progress.update(1) + logger.info(f"客户端 {i} 测试准确率: {acc:.2f}%") - # 保存模型 - torch.save(global_server_model.state_dict(), "server_model.pth") + # 保存最终模型 + torch.save(global_server_model.state_dict(), os.path.join(models_dir, "server_model.pth")) for i in range(NUM_CLIENTS): - torch.save(client_models[i].state_dict(), f"client{i}_model.pth") - + torch.save(client_models[i].state_dict(), + os.path.join(models_dir, f"client{i}_model.pth")) + logger.info(f"\n模型已保存至: {models_dir}") + # 最终测试 - print("\nFinal Evaluation:") + logger.info("\n最终测试结果:") server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) - server_model.load_state_dict(torch.load("server_model.pth",weights_only=True)) - print(f"Server Accuracy: {test_model(server_model, server_test_loader):.2f}%") + server_model.load_state_dict(torch.load(os.path.join(models_dir, "server_model.pth"),weights_only=True)) + logger.info(f"服务器最终准确率: {test_model(server_model, server_test_loader):.2f}%") for i in range(NUM_CLIENTS): client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) - client_model.load_state_dict(torch.load(f"client{i}_model.pth",weights_only=True)) + client_model.load_state_dict(torch.load(os.path.join(models_dir, f"client{i}_model.pth"),weights_only=True)) test_loader = client_datasets[i]['test'] - print(f"Client {i} Accuracy: {test_model(client_model, test_loader):.2f}%") - + logger.info(f"客户端 {i} 最终准确率: {test_model(client_model, test_loader):.2f}%") + if __name__ == "__main__": main() \ No newline at end of file diff --git a/data_loader.py b/data_loader.py index b11b718..7283d13 100644 --- a/data_loader.py +++ b/data_loader.py @@ -1,12 +1,10 @@ import os -from logger import logger from PIL import Image import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from tqdm import tqdm # 导入 tqdm -import logging class ImageClassificationDataset(Dataset): def __init__(self, root_dir, transform=None, Cache=False): @@ -19,7 +17,7 @@ class ImageClassificationDataset(Dataset): self.labels = [] self.Cache = Cache - logger.log("info", + print("info", "init the dataloader" ) @@ -38,7 +36,7 @@ class ImageClassificationDataset(Dataset): self.image_paths.append(img_path) self.labels.append(self.class_to_idx[cls_name]) except Exception as e: - logger.log("info", f"Read image error: {img_path} - {e}") + print("info", f"Read image error: {img_path} - {e}") def __len__(self): return len(self.labels) diff --git a/logger.py b/logger.py index 21257f4..d56245f 100644 --- a/logger.py +++ b/logger.py @@ -1,27 +1,27 @@ import logging import os from datetime import datetime -from config import config class Logger: - def __init__(self): - os.makedirs(config.log_dir, exist_ok=True) - log_file = os.path.join( - config.log_dir, - f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" - ) + def __init__(self, log_dir): + self.log_dir = log_dir + os.makedirs(log_dir, exist_ok=True) - self.logger = logging.getLogger(__name__) - self.logger.setLevel(config.log_level) + # 日志文件路径 + log_file = os.path.join(log_dir, "training.log") - # 文件输出 + # 创建日志记录器 + self.logger = logging.getLogger("FedKD") + self.logger.setLevel(logging.DEBUG) + + # 文件处理器(包含详细信息) file_handler = logging.FileHandler(log_file) file_formatter = logging.Formatter( "%(asctime)s - %(levelname)s - %(message)s" ) file_handler.setFormatter(file_formatter) - # 控制台输出 + # 控制台处理器(简洁输出) console_handler = logging.StreamHandler() console_formatter = logging.Formatter("%(message)s") console_handler.setFormatter(console_formatter) @@ -29,7 +29,20 @@ class Logger: self.logger.addHandler(file_handler) self.logger.addHandler(console_handler) - def log(self, level, message): - getattr(self.logger, level)(message) + def info(self, message): + self.logger.info(message) + + def debug(self, message): + self.logger.debug(message) + + def warning(self, message): + self.logger.warning(message) + + def error(self, message): + self.logger.error(message) -logger = Logger() \ No newline at end of file +def create_run_dir(base_dir="runs"): + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + run_dir = os.path.join(base_dir, timestamp) + os.makedirs(run_dir, exist_ok=True) + return run_dir \ No newline at end of file