完成模型日志输出

This commit is contained in:
yoiannis 2025-03-12 14:26:35 +08:00
parent a5ca9d04d7
commit 52587d8c26
3 changed files with 89 additions and 68 deletions

110
FED.py
View File

@ -9,6 +9,7 @@ import copy
from tqdm import tqdm
from data_loader import get_data_loader
from logger import *
from model.repvit import repvit_m1_1
from model.mobilenetv3 import MobileNetV3
@ -34,9 +35,9 @@ def prepare_data():
])
# 加载所有数据集(训练、验证、测试)
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/JY_A',Cache='RAM')
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_A',Cache='RAM')
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_B',Cache='RAM')
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/JY_A',Cache='RAM')
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_A',Cache='RAM')
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_B',Cache='RAM')
# 组织客户端数据集
client_datasets = [
@ -61,7 +62,7 @@ def prepare_data():
return client_datasets, public_loader, server_test_loader
# 客户端训练函数
def client_train(client_model, server_model, loader):
def client_train(client_model, server_model, loader, logger):
client_model.train()
server_model.eval()
@ -123,13 +124,12 @@ def client_train(client_model, server_model, loader):
avg_task = task_loss / len(loader)
avg_distill = distill_loss / len(loader)
epoch_acc = 100 * correct / total
print(f"\n{'='*40}")
print(f"Epoch {epoch+1} Summary:")
print(f"Average Loss: {avg_loss:.4f}")
print(f"Task Loss: {avg_task:.4f}")
print(f"Distill Loss: {avg_distill:.4f}")
print(f"Training Accuracy: {epoch_acc:.2f}%")
print(f"{'='*40}\n")
logger.info(f"\n客户端训练 Epoch {epoch+1}/{CLIENT_EPOCHS}")
logger.info(f"平均损失: {avg_loss:.4f}")
logger.info(f"任务损失: {avg_task:.4f}")
logger.info(f"蒸馏损失: {avg_distill:.4f}")
logger.info(f"训练准确率: {epoch_acc:.2f}%")
progress_bar.close()
return client_model.state_dict()
@ -173,7 +173,7 @@ def server_aggregate(server_model, client_models, public_loader):
total_loss += loss.item()
# 服务器知识更新
def server_update(server_model, client_models, public_loader):
def server_update(server_model, client_models, public_loader, logger):
server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
@ -209,7 +209,7 @@ def server_update(server_model, client_models, public_loader):
"Current Loss": f"{loss.item():.4f}"
})
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
logger.info(f"服务器更新完成 | 平均损失: {total_loss/len(public_loader):.4f}")
def test_model(model, test_loader): # 添加对DataLoader的支持
@ -230,38 +230,48 @@ def test_model(model, test_loader): # 添加对DataLoader的支持
# 主训练流程
def main():
# 初始化模型(保持不变)
# 创建运行目录
run_dir = create_run_dir()
models_dir = os.path.join(run_dir, "models")
os.makedirs(models_dir, exist_ok=True)
# 初始化日志
logger = Logger(run_dir)
# 初始化模型
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
# 加载数据集
# 加载数据
client_datasets, public_loader, server_test_loader = prepare_data()
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
# 记录关键配置
logger.info(f"\n{'#'*30} 训练配置 {'#'*30}")
logger.info(f"设备: {device}")
logger.info(f"客户端数量: {NUM_CLIENTS}")
logger.info(f"训练轮次: {NUM_ROUNDS}")
logger.info(f"本地训练epoch数: {CLIENT_EPOCHS}")
logger.info(f"蒸馏温度: {TEMP}")
logger.info(f"运行目录: {run_dir}")
logger.info(f"{'#'*70}\n")
# 主训练循环
for round in range(NUM_ROUNDS):
print(f"\n{'#'*50}")
print(f"Round {round+1}/{NUM_ROUNDS}")
print(f"{'#'*50}")
logger.info(f"\n{'#'*30} 联邦训练轮次 [{round+1}/{NUM_ROUNDS}] {'#'*30}")
# 客户端选择
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
print(f"Selected clients: {selected_clients}")
logger.info(f"选中客户端: {selected_clients}")
# 客户端训练
client_params = []
for cid in selected_clients:
print(f"\nTraining Client {cid}")
logger.info(f"\n{'='*20} 客户端 {cid} 训练开始 {'='*20}")
local_model = copy.deepcopy(client_models[cid])
local_model.load_state_dict(client_models[cid].state_dict())
# 传入客户端的训练集
updated_params = client_train(
local_model,
global_server_model,
client_datasets[cid]['train'] # 使用训练集
)
updated_params = client_train(local_model, global_server_model,
client_datasets[cid]['train'], logger)
client_params.append(updated_params)
logger.info(f"\n{'='*20} 客户端 {cid} 训练完成 {'='*20}\n")
# 模型聚合
global_client_params = aggregate(client_params)
@ -269,40 +279,40 @@ def main():
model.load_state_dict(global_client_params)
# 服务器更新
print("\nServer updating...")
server_update(global_server_model, client_models, public_loader)
logger.info("\n服务器知识蒸馏更新...")
server_update(global_server_model, client_models, public_loader, logger)
# 测试性能
# 性能测试
server_acc = test_model(global_server_model, server_test_loader)
client_accuracies = [
test_model(client_models[i],
client_datasets[i]['test']) # 动态创建测试loader
test_model(client_models[i], client_datasets[i]['test'])
for i in range(NUM_CLIENTS)
]
print(f"\nRound {round+1} Results:")
print(f"Server Accuracy: {server_acc:.2f}%")
# 记录本轮结果
logger.info(f"\n本轮结果:")
logger.info(f"服务器测试准确率: {server_acc:.2f}%")
for i, acc in enumerate(client_accuracies):
print(f"Client {i} Accuracy: {acc:.2f}%")
round_progress.update(1)
logger.info(f"客户端 {i} 测试准确率: {acc:.2f}%")
# 保存模型
torch.save(global_server_model.state_dict(), "server_model.pth")
# 保存最终模型
torch.save(global_server_model.state_dict(), os.path.join(models_dir, "server_model.pth"))
for i in range(NUM_CLIENTS):
torch.save(client_models[i].state_dict(), f"client{i}_model.pth")
torch.save(client_models[i].state_dict(),
os.path.join(models_dir, f"client{i}_model.pth"))
logger.info(f"\n模型已保存至: {models_dir}")
# 最终测试
print("\nFinal Evaluation:")
logger.info("\n最终测试结果:")
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
print(f"Server Accuracy: {test_model(server_model, server_test_loader):.2f}%")
server_model.load_state_dict(torch.load(os.path.join(models_dir, "server_model.pth"),weights_only=True))
logger.info(f"服务器最终准确率: {test_model(server_model, server_test_loader):.2f}%")
for i in range(NUM_CLIENTS):
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
client_model.load_state_dict(torch.load(f"client{i}_model.pth",weights_only=True))
client_model.load_state_dict(torch.load(os.path.join(models_dir, f"client{i}_model.pth"),weights_only=True))
test_loader = client_datasets[i]['test']
print(f"Client {i} Accuracy: {test_model(client_model, test_loader):.2f}%")
logger.info(f"客户端 {i} 最终准确率: {test_model(client_model, test_loader):.2f}%")
if __name__ == "__main__":
main()

View File

@ -1,12 +1,10 @@
import os
from logger import logger
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm # 导入 tqdm
import logging
class ImageClassificationDataset(Dataset):
def __init__(self, root_dir, transform=None, Cache=False):
@ -19,7 +17,7 @@ class ImageClassificationDataset(Dataset):
self.labels = []
self.Cache = Cache
logger.log("info",
print("info",
"init the dataloader"
)
@ -38,7 +36,7 @@ class ImageClassificationDataset(Dataset):
self.image_paths.append(img_path)
self.labels.append(self.class_to_idx[cls_name])
except Exception as e:
logger.log("info", f"Read image error: {img_path} - {e}")
print("info", f"Read image error: {img_path} - {e}")
def __len__(self):
return len(self.labels)

View File

@ -1,27 +1,27 @@
import logging
import os
from datetime import datetime
from config import config
class Logger:
def __init__(self):
os.makedirs(config.log_dir, exist_ok=True)
log_file = os.path.join(
config.log_dir,
f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
)
def __init__(self, log_dir):
self.log_dir = log_dir
os.makedirs(log_dir, exist_ok=True)
self.logger = logging.getLogger(__name__)
self.logger.setLevel(config.log_level)
# 日志文件路径
log_file = os.path.join(log_dir, "training.log")
# 文件输出
# 创建日志记录器
self.logger = logging.getLogger("FedKD")
self.logger.setLevel(logging.DEBUG)
# 文件处理器(包含详细信息)
file_handler = logging.FileHandler(log_file)
file_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(file_formatter)
# 控制台输出
# 控制台处理器(简洁输出
console_handler = logging.StreamHandler()
console_formatter = logging.Formatter("%(message)s")
console_handler.setFormatter(console_formatter)
@ -29,7 +29,20 @@ class Logger:
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
def log(self, level, message):
getattr(self.logger, level)(message)
def info(self, message):
self.logger.info(message)
def debug(self, message):
self.logger.debug(message)
def warning(self, message):
self.logger.warning(message)
def error(self, message):
self.logger.error(message)
logger = Logger()
def create_run_dir(base_dir="runs"):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join(base_dir, timestamp)
os.makedirs(run_dir, exist_ok=True)
return run_dir