完成模型日志输出

This commit is contained in:
yoiannis 2025-03-12 14:26:35 +08:00
parent a5ca9d04d7
commit 52587d8c26
3 changed files with 89 additions and 68 deletions

104
FED.py
View File

@ -9,6 +9,7 @@ import copy
from tqdm import tqdm from tqdm import tqdm
from data_loader import get_data_loader from data_loader import get_data_loader
from logger import *
from model.repvit import repvit_m1_1 from model.repvit import repvit_m1_1
from model.mobilenetv3 import MobileNetV3 from model.mobilenetv3 import MobileNetV3
@ -34,9 +35,9 @@ def prepare_data():
]) ])
# 加载所有数据集(训练、验证、测试) # 加载所有数据集(训练、验证、测试)
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/JY_A',Cache='RAM') dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/JY_A',Cache='RAM')
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_A',Cache='RAM') dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_A',Cache='RAM')
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_B',Cache='RAM') dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_B',Cache='RAM')
# 组织客户端数据集 # 组织客户端数据集
client_datasets = [ client_datasets = [
@ -61,7 +62,7 @@ def prepare_data():
return client_datasets, public_loader, server_test_loader return client_datasets, public_loader, server_test_loader
# 客户端训练函数 # 客户端训练函数
def client_train(client_model, server_model, loader): def client_train(client_model, server_model, loader, logger):
client_model.train() client_model.train()
server_model.eval() server_model.eval()
@ -123,13 +124,12 @@ def client_train(client_model, server_model, loader):
avg_task = task_loss / len(loader) avg_task = task_loss / len(loader)
avg_distill = distill_loss / len(loader) avg_distill = distill_loss / len(loader)
epoch_acc = 100 * correct / total epoch_acc = 100 * correct / total
print(f"\n{'='*40}")
print(f"Epoch {epoch+1} Summary:") logger.info(f"\n客户端训练 Epoch {epoch+1}/{CLIENT_EPOCHS}")
print(f"Average Loss: {avg_loss:.4f}") logger.info(f"平均损失: {avg_loss:.4f}")
print(f"Task Loss: {avg_task:.4f}") logger.info(f"任务损失: {avg_task:.4f}")
print(f"Distill Loss: {avg_distill:.4f}") logger.info(f"蒸馏损失: {avg_distill:.4f}")
print(f"Training Accuracy: {epoch_acc:.2f}%") logger.info(f"训练准确率: {epoch_acc:.2f}%")
print(f"{'='*40}\n")
progress_bar.close() progress_bar.close()
return client_model.state_dict() return client_model.state_dict()
@ -173,7 +173,7 @@ def server_aggregate(server_model, client_models, public_loader):
total_loss += loss.item() total_loss += loss.item()
# 服务器知识更新 # 服务器知识更新
def server_update(server_model, client_models, public_loader): def server_update(server_model, client_models, public_loader, logger):
server_model.train() server_model.train()
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
@ -209,7 +209,7 @@ def server_update(server_model, client_models, public_loader):
"Current Loss": f"{loss.item():.4f}" "Current Loss": f"{loss.item():.4f}"
}) })
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n") logger.info(f"服务器更新完成 | 平均损失: {total_loss/len(public_loader):.4f}")
def test_model(model, test_loader): # 添加对DataLoader的支持 def test_model(model, test_loader): # 添加对DataLoader的支持
@ -230,38 +230,48 @@ def test_model(model, test_loader): # 添加对DataLoader的支持
# 主训练流程 # 主训练流程
def main(): def main():
# 初始化模型(保持不变) # 创建运行目录
run_dir = create_run_dir()
models_dir = os.path.join(run_dir, "models")
os.makedirs(models_dir, exist_ok=True)
# 初始化日志
logger = Logger(run_dir)
# 初始化模型
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)] client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
# 加载数据集 # 加载数据
client_datasets, public_loader, server_test_loader = prepare_data() client_datasets, public_loader, server_test_loader = prepare_data()
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round") # 记录关键配置
logger.info(f"\n{'#'*30} 训练配置 {'#'*30}")
logger.info(f"设备: {device}")
logger.info(f"客户端数量: {NUM_CLIENTS}")
logger.info(f"训练轮次: {NUM_ROUNDS}")
logger.info(f"本地训练epoch数: {CLIENT_EPOCHS}")
logger.info(f"蒸馏温度: {TEMP}")
logger.info(f"运行目录: {run_dir}")
logger.info(f"{'#'*70}\n")
# 主训练循环
for round in range(NUM_ROUNDS): for round in range(NUM_ROUNDS):
print(f"\n{'#'*50}") logger.info(f"\n{'#'*30} 联邦训练轮次 [{round+1}/{NUM_ROUNDS}] {'#'*30}")
print(f"Round {round+1}/{NUM_ROUNDS}")
print(f"{'#'*50}")
# 客户端选择 # 客户端选择
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False) selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
print(f"Selected clients: {selected_clients}") logger.info(f"选中客户端: {selected_clients}")
# 客户端训练 # 客户端训练
client_params = [] client_params = []
for cid in selected_clients: for cid in selected_clients:
print(f"\nTraining Client {cid}") logger.info(f"\n{'='*20} 客户端 {cid} 训练开始 {'='*20}")
local_model = copy.deepcopy(client_models[cid]) local_model = copy.deepcopy(client_models[cid])
local_model.load_state_dict(client_models[cid].state_dict()) updated_params = client_train(local_model, global_server_model,
client_datasets[cid]['train'], logger)
# 传入客户端的训练集
updated_params = client_train(
local_model,
global_server_model,
client_datasets[cid]['train'] # 使用训练集
)
client_params.append(updated_params) client_params.append(updated_params)
logger.info(f"\n{'='*20} 客户端 {cid} 训练完成 {'='*20}\n")
# 模型聚合 # 模型聚合
global_client_params = aggregate(client_params) global_client_params = aggregate(client_params)
@ -269,40 +279,40 @@ def main():
model.load_state_dict(global_client_params) model.load_state_dict(global_client_params)
# 服务器更新 # 服务器更新
print("\nServer updating...") logger.info("\n服务器知识蒸馏更新...")
server_update(global_server_model, client_models, public_loader) server_update(global_server_model, client_models, public_loader, logger)
# 测试性能 # 性能测试
server_acc = test_model(global_server_model, server_test_loader) server_acc = test_model(global_server_model, server_test_loader)
client_accuracies = [ client_accuracies = [
test_model(client_models[i], test_model(client_models[i], client_datasets[i]['test'])
client_datasets[i]['test']) # 动态创建测试loader
for i in range(NUM_CLIENTS) for i in range(NUM_CLIENTS)
] ]
print(f"\nRound {round+1} Results:") # 记录本轮结果
print(f"Server Accuracy: {server_acc:.2f}%") logger.info(f"\n本轮结果:")
logger.info(f"服务器测试准确率: {server_acc:.2f}%")
for i, acc in enumerate(client_accuracies): for i, acc in enumerate(client_accuracies):
print(f"Client {i} Accuracy: {acc:.2f}%") logger.info(f"客户端 {i} 测试准确率: {acc:.2f}%")
round_progress.update(1) # 保存最终模型
torch.save(global_server_model.state_dict(), os.path.join(models_dir, "server_model.pth"))
# 保存模型
torch.save(global_server_model.state_dict(), "server_model.pth")
for i in range(NUM_CLIENTS): for i in range(NUM_CLIENTS):
torch.save(client_models[i].state_dict(), f"client{i}_model.pth") torch.save(client_models[i].state_dict(),
os.path.join(models_dir, f"client{i}_model.pth"))
logger.info(f"\n模型已保存至: {models_dir}")
# 最终测试 # 最终测试
print("\nFinal Evaluation:") logger.info("\n最终测试结果:")
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device) server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True)) server_model.load_state_dict(torch.load(os.path.join(models_dir, "server_model.pth"),weights_only=True))
print(f"Server Accuracy: {test_model(server_model, server_test_loader):.2f}%") logger.info(f"服务器最终准确率: {test_model(server_model, server_test_loader):.2f}%")
for i in range(NUM_CLIENTS): for i in range(NUM_CLIENTS):
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
client_model.load_state_dict(torch.load(f"client{i}_model.pth",weights_only=True)) client_model.load_state_dict(torch.load(os.path.join(models_dir, f"client{i}_model.pth"),weights_only=True))
test_loader = client_datasets[i]['test'] test_loader = client_datasets[i]['test']
print(f"Client {i} Accuracy: {test_model(client_model, test_loader):.2f}%") logger.info(f"客户端 {i} 最终准确率: {test_model(client_model, test_loader):.2f}%")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,12 +1,10 @@
import os import os
from logger import logger
from PIL import Image from PIL import Image
import torch import torch
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torchvision import transforms from torchvision import transforms
from tqdm import tqdm # 导入 tqdm from tqdm import tqdm # 导入 tqdm
import logging
class ImageClassificationDataset(Dataset): class ImageClassificationDataset(Dataset):
def __init__(self, root_dir, transform=None, Cache=False): def __init__(self, root_dir, transform=None, Cache=False):
@ -19,7 +17,7 @@ class ImageClassificationDataset(Dataset):
self.labels = [] self.labels = []
self.Cache = Cache self.Cache = Cache
logger.log("info", print("info",
"init the dataloader" "init the dataloader"
) )
@ -38,7 +36,7 @@ class ImageClassificationDataset(Dataset):
self.image_paths.append(img_path) self.image_paths.append(img_path)
self.labels.append(self.class_to_idx[cls_name]) self.labels.append(self.class_to_idx[cls_name])
except Exception as e: except Exception as e:
logger.log("info", f"Read image error: {img_path} - {e}") print("info", f"Read image error: {img_path} - {e}")
def __len__(self): def __len__(self):
return len(self.labels) return len(self.labels)

View File

@ -1,27 +1,27 @@
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
from config import config
class Logger: class Logger:
def __init__(self): def __init__(self, log_dir):
os.makedirs(config.log_dir, exist_ok=True) self.log_dir = log_dir
log_file = os.path.join( os.makedirs(log_dir, exist_ok=True)
config.log_dir,
f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
)
self.logger = logging.getLogger(__name__) # 日志文件路径
self.logger.setLevel(config.log_level) log_file = os.path.join(log_dir, "training.log")
# 文件输出 # 创建日志记录器
self.logger = logging.getLogger("FedKD")
self.logger.setLevel(logging.DEBUG)
# 文件处理器(包含详细信息)
file_handler = logging.FileHandler(log_file) file_handler = logging.FileHandler(log_file)
file_formatter = logging.Formatter( file_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(message)s" "%(asctime)s - %(levelname)s - %(message)s"
) )
file_handler.setFormatter(file_formatter) file_handler.setFormatter(file_formatter)
# 控制台输出 # 控制台处理器(简洁输出
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_formatter = logging.Formatter("%(message)s") console_formatter = logging.Formatter("%(message)s")
console_handler.setFormatter(console_formatter) console_handler.setFormatter(console_formatter)
@ -29,7 +29,20 @@ class Logger:
self.logger.addHandler(file_handler) self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler) self.logger.addHandler(console_handler)
def log(self, level, message): def info(self, message):
getattr(self.logger, level)(message) self.logger.info(message)
logger = Logger() def debug(self, message):
self.logger.debug(message)
def warning(self, message):
self.logger.warning(message)
def error(self, message):
self.logger.error(message)
def create_run_dir(base_dir="runs"):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join(base_dir, timestamp)
os.makedirs(run_dir, exist_ok=True)
return run_dir