完成模型日志输出
This commit is contained in:
parent
a5ca9d04d7
commit
52587d8c26
110
FED.py
110
FED.py
@ -9,6 +9,7 @@ import copy
|
||||
from tqdm import tqdm
|
||||
|
||||
from data_loader import get_data_loader
|
||||
from logger import *
|
||||
from model.repvit import repvit_m1_1
|
||||
from model.mobilenetv3 import MobileNetV3
|
||||
|
||||
@ -34,9 +35,9 @@ def prepare_data():
|
||||
])
|
||||
|
||||
# 加载所有数据集(训练、验证、测试)
|
||||
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/JY_A',Cache='RAM')
|
||||
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_A',Cache='RAM')
|
||||
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_B',Cache='RAM')
|
||||
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/JY_A',Cache='RAM')
|
||||
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_A',Cache='RAM')
|
||||
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_B',Cache='RAM')
|
||||
|
||||
# 组织客户端数据集
|
||||
client_datasets = [
|
||||
@ -61,7 +62,7 @@ def prepare_data():
|
||||
return client_datasets, public_loader, server_test_loader
|
||||
|
||||
# 客户端训练函数
|
||||
def client_train(client_model, server_model, loader):
|
||||
def client_train(client_model, server_model, loader, logger):
|
||||
client_model.train()
|
||||
server_model.eval()
|
||||
|
||||
@ -123,13 +124,12 @@ def client_train(client_model, server_model, loader):
|
||||
avg_task = task_loss / len(loader)
|
||||
avg_distill = distill_loss / len(loader)
|
||||
epoch_acc = 100 * correct / total
|
||||
print(f"\n{'='*40}")
|
||||
print(f"Epoch {epoch+1} Summary:")
|
||||
print(f"Average Loss: {avg_loss:.4f}")
|
||||
print(f"Task Loss: {avg_task:.4f}")
|
||||
print(f"Distill Loss: {avg_distill:.4f}")
|
||||
print(f"Training Accuracy: {epoch_acc:.2f}%")
|
||||
print(f"{'='*40}\n")
|
||||
|
||||
logger.info(f"\n客户端训练 Epoch {epoch+1}/{CLIENT_EPOCHS}")
|
||||
logger.info(f"平均损失: {avg_loss:.4f}")
|
||||
logger.info(f"任务损失: {avg_task:.4f}")
|
||||
logger.info(f"蒸馏损失: {avg_distill:.4f}")
|
||||
logger.info(f"训练准确率: {epoch_acc:.2f}%")
|
||||
|
||||
progress_bar.close()
|
||||
return client_model.state_dict()
|
||||
@ -173,7 +173,7 @@ def server_aggregate(server_model, client_models, public_loader):
|
||||
total_loss += loss.item()
|
||||
|
||||
# 服务器知识更新
|
||||
def server_update(server_model, client_models, public_loader):
|
||||
def server_update(server_model, client_models, public_loader, logger):
|
||||
server_model.train()
|
||||
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
||||
|
||||
@ -209,7 +209,7 @@ def server_update(server_model, client_models, public_loader):
|
||||
"Current Loss": f"{loss.item():.4f}"
|
||||
})
|
||||
|
||||
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
|
||||
logger.info(f"服务器更新完成 | 平均损失: {total_loss/len(public_loader):.4f}")
|
||||
|
||||
|
||||
def test_model(model, test_loader): # 添加对DataLoader的支持
|
||||
@ -230,38 +230,48 @@ def test_model(model, test_loader): # 添加对DataLoader的支持
|
||||
|
||||
# 主训练流程
|
||||
def main():
|
||||
# 初始化模型(保持不变)
|
||||
# 创建运行目录
|
||||
run_dir = create_run_dir()
|
||||
models_dir = os.path.join(run_dir, "models")
|
||||
os.makedirs(models_dir, exist_ok=True)
|
||||
|
||||
# 初始化日志
|
||||
logger = Logger(run_dir)
|
||||
|
||||
# 初始化模型
|
||||
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
|
||||
|
||||
# 加载数据集
|
||||
# 加载数据
|
||||
client_datasets, public_loader, server_test_loader = prepare_data()
|
||||
|
||||
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
||||
|
||||
# 记录关键配置
|
||||
logger.info(f"\n{'#'*30} 训练配置 {'#'*30}")
|
||||
logger.info(f"设备: {device}")
|
||||
logger.info(f"客户端数量: {NUM_CLIENTS}")
|
||||
logger.info(f"训练轮次: {NUM_ROUNDS}")
|
||||
logger.info(f"本地训练epoch数: {CLIENT_EPOCHS}")
|
||||
logger.info(f"蒸馏温度: {TEMP}")
|
||||
logger.info(f"运行目录: {run_dir}")
|
||||
logger.info(f"{'#'*70}\n")
|
||||
|
||||
# 主训练循环
|
||||
for round in range(NUM_ROUNDS):
|
||||
print(f"\n{'#'*50}")
|
||||
print(f"Round {round+1}/{NUM_ROUNDS}")
|
||||
print(f"{'#'*50}")
|
||||
logger.info(f"\n{'#'*30} 联邦训练轮次 [{round+1}/{NUM_ROUNDS}] {'#'*30}")
|
||||
|
||||
# 客户端选择
|
||||
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
|
||||
print(f"Selected clients: {selected_clients}")
|
||||
logger.info(f"选中客户端: {selected_clients}")
|
||||
|
||||
# 客户端训练
|
||||
client_params = []
|
||||
for cid in selected_clients:
|
||||
print(f"\nTraining Client {cid}")
|
||||
logger.info(f"\n{'='*20} 客户端 {cid} 训练开始 {'='*20}")
|
||||
local_model = copy.deepcopy(client_models[cid])
|
||||
local_model.load_state_dict(client_models[cid].state_dict())
|
||||
|
||||
# 传入客户端的训练集
|
||||
updated_params = client_train(
|
||||
local_model,
|
||||
global_server_model,
|
||||
client_datasets[cid]['train'] # 使用训练集
|
||||
)
|
||||
updated_params = client_train(local_model, global_server_model,
|
||||
client_datasets[cid]['train'], logger)
|
||||
client_params.append(updated_params)
|
||||
logger.info(f"\n{'='*20} 客户端 {cid} 训练完成 {'='*20}\n")
|
||||
|
||||
# 模型聚合
|
||||
global_client_params = aggregate(client_params)
|
||||
@ -269,40 +279,40 @@ def main():
|
||||
model.load_state_dict(global_client_params)
|
||||
|
||||
# 服务器更新
|
||||
print("\nServer updating...")
|
||||
server_update(global_server_model, client_models, public_loader)
|
||||
logger.info("\n服务器知识蒸馏更新...")
|
||||
server_update(global_server_model, client_models, public_loader, logger)
|
||||
|
||||
# 测试性能
|
||||
# 性能测试
|
||||
server_acc = test_model(global_server_model, server_test_loader)
|
||||
client_accuracies = [
|
||||
test_model(client_models[i],
|
||||
client_datasets[i]['test']) # 动态创建测试loader
|
||||
test_model(client_models[i], client_datasets[i]['test'])
|
||||
for i in range(NUM_CLIENTS)
|
||||
]
|
||||
|
||||
print(f"\nRound {round+1} Results:")
|
||||
print(f"Server Accuracy: {server_acc:.2f}%")
|
||||
# 记录本轮结果
|
||||
logger.info(f"\n本轮结果:")
|
||||
logger.info(f"服务器测试准确率: {server_acc:.2f}%")
|
||||
for i, acc in enumerate(client_accuracies):
|
||||
print(f"Client {i} Accuracy: {acc:.2f}%")
|
||||
|
||||
round_progress.update(1)
|
||||
logger.info(f"客户端 {i} 测试准确率: {acc:.2f}%")
|
||||
|
||||
# 保存模型
|
||||
torch.save(global_server_model.state_dict(), "server_model.pth")
|
||||
# 保存最终模型
|
||||
torch.save(global_server_model.state_dict(), os.path.join(models_dir, "server_model.pth"))
|
||||
for i in range(NUM_CLIENTS):
|
||||
torch.save(client_models[i].state_dict(), f"client{i}_model.pth")
|
||||
|
||||
torch.save(client_models[i].state_dict(),
|
||||
os.path.join(models_dir, f"client{i}_model.pth"))
|
||||
logger.info(f"\n模型已保存至: {models_dir}")
|
||||
|
||||
# 最终测试
|
||||
print("\nFinal Evaluation:")
|
||||
logger.info("\n最终测试结果:")
|
||||
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
|
||||
print(f"Server Accuracy: {test_model(server_model, server_test_loader):.2f}%")
|
||||
server_model.load_state_dict(torch.load(os.path.join(models_dir, "server_model.pth"),weights_only=True))
|
||||
logger.info(f"服务器最终准确率: {test_model(server_model, server_test_loader):.2f}%")
|
||||
|
||||
for i in range(NUM_CLIENTS):
|
||||
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
|
||||
client_model.load_state_dict(torch.load(f"client{i}_model.pth",weights_only=True))
|
||||
client_model.load_state_dict(torch.load(os.path.join(models_dir, f"client{i}_model.pth"),weights_only=True))
|
||||
test_loader = client_datasets[i]['test']
|
||||
print(f"Client {i} Accuracy: {test_model(client_model, test_loader):.2f}%")
|
||||
|
||||
logger.info(f"客户端 {i} 最终准确率: {test_model(client_model, test_loader):.2f}%")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,12 +1,10 @@
|
||||
import os
|
||||
from logger import logger
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
from tqdm import tqdm # 导入 tqdm
|
||||
import logging
|
||||
|
||||
class ImageClassificationDataset(Dataset):
|
||||
def __init__(self, root_dir, transform=None, Cache=False):
|
||||
@ -19,7 +17,7 @@ class ImageClassificationDataset(Dataset):
|
||||
self.labels = []
|
||||
self.Cache = Cache
|
||||
|
||||
logger.log("info",
|
||||
print("info",
|
||||
"init the dataloader"
|
||||
)
|
||||
|
||||
@ -38,7 +36,7 @@ class ImageClassificationDataset(Dataset):
|
||||
self.image_paths.append(img_path)
|
||||
self.labels.append(self.class_to_idx[cls_name])
|
||||
except Exception as e:
|
||||
logger.log("info", f"Read image error: {img_path} - {e}")
|
||||
print("info", f"Read image error: {img_path} - {e}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
41
logger.py
41
logger.py
@ -1,27 +1,27 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from config import config
|
||||
|
||||
class Logger:
|
||||
def __init__(self):
|
||||
os.makedirs(config.log_dir, exist_ok=True)
|
||||
log_file = os.path.join(
|
||||
config.log_dir,
|
||||
f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
||||
)
|
||||
def __init__(self, log_dir):
|
||||
self.log_dir = log_dir
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.setLevel(config.log_level)
|
||||
# 日志文件路径
|
||||
log_file = os.path.join(log_dir, "training.log")
|
||||
|
||||
# 文件输出
|
||||
# 创建日志记录器
|
||||
self.logger = logging.getLogger("FedKD")
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
|
||||
# 文件处理器(包含详细信息)
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
|
||||
# 控制台输出
|
||||
# 控制台处理器(简洁输出)
|
||||
console_handler = logging.StreamHandler()
|
||||
console_formatter = logging.Formatter("%(message)s")
|
||||
console_handler.setFormatter(console_formatter)
|
||||
@ -29,7 +29,20 @@ class Logger:
|
||||
self.logger.addHandler(file_handler)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
def log(self, level, message):
|
||||
getattr(self.logger, level)(message)
|
||||
def info(self, message):
|
||||
self.logger.info(message)
|
||||
|
||||
def debug(self, message):
|
||||
self.logger.debug(message)
|
||||
|
||||
def warning(self, message):
|
||||
self.logger.warning(message)
|
||||
|
||||
def error(self, message):
|
||||
self.logger.error(message)
|
||||
|
||||
logger = Logger()
|
||||
def create_run_dir(base_dir="runs"):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
run_dir = os.path.join(base_dir, timestamp)
|
||||
os.makedirs(run_dir, exist_ok=True)
|
||||
return run_dir
|
Loading…
Reference in New Issue
Block a user