完成模型日志输出
This commit is contained in:
parent
a5ca9d04d7
commit
52587d8c26
104
FED.py
104
FED.py
@ -9,6 +9,7 @@ import copy
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from data_loader import get_data_loader
|
from data_loader import get_data_loader
|
||||||
|
from logger import *
|
||||||
from model.repvit import repvit_m1_1
|
from model.repvit import repvit_m1_1
|
||||||
from model.mobilenetv3 import MobileNetV3
|
from model.mobilenetv3 import MobileNetV3
|
||||||
|
|
||||||
@ -34,9 +35,9 @@ def prepare_data():
|
|||||||
])
|
])
|
||||||
|
|
||||||
# 加载所有数据集(训练、验证、测试)
|
# 加载所有数据集(训练、验证、测试)
|
||||||
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/JY_A',Cache='RAM')
|
dataset_A_train,dataset_A_val,dataset_A_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/JY_A',Cache='RAM')
|
||||||
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_A',Cache='RAM')
|
dataset_B_train,dataset_B_val,dataset_B_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_A',Cache='RAM')
|
||||||
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/03.TA_EC_FD3/ZY_B',Cache='RAM')
|
dataset_C_train,dataset_C_val,dataset_C_test = get_data_loader(root_dir='/home/yoiannis/deep_learning/dataset/debug/ZY_B',Cache='RAM')
|
||||||
|
|
||||||
# 组织客户端数据集
|
# 组织客户端数据集
|
||||||
client_datasets = [
|
client_datasets = [
|
||||||
@ -61,7 +62,7 @@ def prepare_data():
|
|||||||
return client_datasets, public_loader, server_test_loader
|
return client_datasets, public_loader, server_test_loader
|
||||||
|
|
||||||
# 客户端训练函数
|
# 客户端训练函数
|
||||||
def client_train(client_model, server_model, loader):
|
def client_train(client_model, server_model, loader, logger):
|
||||||
client_model.train()
|
client_model.train()
|
||||||
server_model.eval()
|
server_model.eval()
|
||||||
|
|
||||||
@ -123,13 +124,12 @@ def client_train(client_model, server_model, loader):
|
|||||||
avg_task = task_loss / len(loader)
|
avg_task = task_loss / len(loader)
|
||||||
avg_distill = distill_loss / len(loader)
|
avg_distill = distill_loss / len(loader)
|
||||||
epoch_acc = 100 * correct / total
|
epoch_acc = 100 * correct / total
|
||||||
print(f"\n{'='*40}")
|
|
||||||
print(f"Epoch {epoch+1} Summary:")
|
logger.info(f"\n客户端训练 Epoch {epoch+1}/{CLIENT_EPOCHS}")
|
||||||
print(f"Average Loss: {avg_loss:.4f}")
|
logger.info(f"平均损失: {avg_loss:.4f}")
|
||||||
print(f"Task Loss: {avg_task:.4f}")
|
logger.info(f"任务损失: {avg_task:.4f}")
|
||||||
print(f"Distill Loss: {avg_distill:.4f}")
|
logger.info(f"蒸馏损失: {avg_distill:.4f}")
|
||||||
print(f"Training Accuracy: {epoch_acc:.2f}%")
|
logger.info(f"训练准确率: {epoch_acc:.2f}%")
|
||||||
print(f"{'='*40}\n")
|
|
||||||
|
|
||||||
progress_bar.close()
|
progress_bar.close()
|
||||||
return client_model.state_dict()
|
return client_model.state_dict()
|
||||||
@ -173,7 +173,7 @@ def server_aggregate(server_model, client_models, public_loader):
|
|||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
# 服务器知识更新
|
# 服务器知识更新
|
||||||
def server_update(server_model, client_models, public_loader):
|
def server_update(server_model, client_models, public_loader, logger):
|
||||||
server_model.train()
|
server_model.train()
|
||||||
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
@ -209,7 +209,7 @@ def server_update(server_model, client_models, public_loader):
|
|||||||
"Current Loss": f"{loss.item():.4f}"
|
"Current Loss": f"{loss.item():.4f}"
|
||||||
})
|
})
|
||||||
|
|
||||||
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
|
logger.info(f"服务器更新完成 | 平均损失: {total_loss/len(public_loader):.4f}")
|
||||||
|
|
||||||
|
|
||||||
def test_model(model, test_loader): # 添加对DataLoader的支持
|
def test_model(model, test_loader): # 添加对DataLoader的支持
|
||||||
@ -230,38 +230,48 @@ def test_model(model, test_loader): # 添加对DataLoader的支持
|
|||||||
|
|
||||||
# 主训练流程
|
# 主训练流程
|
||||||
def main():
|
def main():
|
||||||
# 初始化模型(保持不变)
|
# 创建运行目录
|
||||||
|
run_dir = create_run_dir()
|
||||||
|
models_dir = os.path.join(run_dir, "models")
|
||||||
|
os.makedirs(models_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 初始化日志
|
||||||
|
logger = Logger(run_dir)
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||||
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
|
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
|
||||||
|
|
||||||
# 加载数据集
|
# 加载数据
|
||||||
client_datasets, public_loader, server_test_loader = prepare_data()
|
client_datasets, public_loader, server_test_loader = prepare_data()
|
||||||
|
|
||||||
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
# 记录关键配置
|
||||||
|
logger.info(f"\n{'#'*30} 训练配置 {'#'*30}")
|
||||||
|
logger.info(f"设备: {device}")
|
||||||
|
logger.info(f"客户端数量: {NUM_CLIENTS}")
|
||||||
|
logger.info(f"训练轮次: {NUM_ROUNDS}")
|
||||||
|
logger.info(f"本地训练epoch数: {CLIENT_EPOCHS}")
|
||||||
|
logger.info(f"蒸馏温度: {TEMP}")
|
||||||
|
logger.info(f"运行目录: {run_dir}")
|
||||||
|
logger.info(f"{'#'*70}\n")
|
||||||
|
|
||||||
|
# 主训练循环
|
||||||
for round in range(NUM_ROUNDS):
|
for round in range(NUM_ROUNDS):
|
||||||
print(f"\n{'#'*50}")
|
logger.info(f"\n{'#'*30} 联邦训练轮次 [{round+1}/{NUM_ROUNDS}] {'#'*30}")
|
||||||
print(f"Round {round+1}/{NUM_ROUNDS}")
|
|
||||||
print(f"{'#'*50}")
|
|
||||||
|
|
||||||
# 客户端选择
|
# 客户端选择
|
||||||
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
|
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
|
||||||
print(f"Selected clients: {selected_clients}")
|
logger.info(f"选中客户端: {selected_clients}")
|
||||||
|
|
||||||
# 客户端训练
|
# 客户端训练
|
||||||
client_params = []
|
client_params = []
|
||||||
for cid in selected_clients:
|
for cid in selected_clients:
|
||||||
print(f"\nTraining Client {cid}")
|
logger.info(f"\n{'='*20} 客户端 {cid} 训练开始 {'='*20}")
|
||||||
local_model = copy.deepcopy(client_models[cid])
|
local_model = copy.deepcopy(client_models[cid])
|
||||||
local_model.load_state_dict(client_models[cid].state_dict())
|
updated_params = client_train(local_model, global_server_model,
|
||||||
|
client_datasets[cid]['train'], logger)
|
||||||
# 传入客户端的训练集
|
|
||||||
updated_params = client_train(
|
|
||||||
local_model,
|
|
||||||
global_server_model,
|
|
||||||
client_datasets[cid]['train'] # 使用训练集
|
|
||||||
)
|
|
||||||
client_params.append(updated_params)
|
client_params.append(updated_params)
|
||||||
|
logger.info(f"\n{'='*20} 客户端 {cid} 训练完成 {'='*20}\n")
|
||||||
|
|
||||||
# 模型聚合
|
# 模型聚合
|
||||||
global_client_params = aggregate(client_params)
|
global_client_params = aggregate(client_params)
|
||||||
@ -269,40 +279,40 @@ def main():
|
|||||||
model.load_state_dict(global_client_params)
|
model.load_state_dict(global_client_params)
|
||||||
|
|
||||||
# 服务器更新
|
# 服务器更新
|
||||||
print("\nServer updating...")
|
logger.info("\n服务器知识蒸馏更新...")
|
||||||
server_update(global_server_model, client_models, public_loader)
|
server_update(global_server_model, client_models, public_loader, logger)
|
||||||
|
|
||||||
# 测试性能
|
# 性能测试
|
||||||
server_acc = test_model(global_server_model, server_test_loader)
|
server_acc = test_model(global_server_model, server_test_loader)
|
||||||
client_accuracies = [
|
client_accuracies = [
|
||||||
test_model(client_models[i],
|
test_model(client_models[i], client_datasets[i]['test'])
|
||||||
client_datasets[i]['test']) # 动态创建测试loader
|
|
||||||
for i in range(NUM_CLIENTS)
|
for i in range(NUM_CLIENTS)
|
||||||
]
|
]
|
||||||
|
|
||||||
print(f"\nRound {round+1} Results:")
|
# 记录本轮结果
|
||||||
print(f"Server Accuracy: {server_acc:.2f}%")
|
logger.info(f"\n本轮结果:")
|
||||||
|
logger.info(f"服务器测试准确率: {server_acc:.2f}%")
|
||||||
for i, acc in enumerate(client_accuracies):
|
for i, acc in enumerate(client_accuracies):
|
||||||
print(f"Client {i} Accuracy: {acc:.2f}%")
|
logger.info(f"客户端 {i} 测试准确率: {acc:.2f}%")
|
||||||
|
|
||||||
round_progress.update(1)
|
# 保存最终模型
|
||||||
|
torch.save(global_server_model.state_dict(), os.path.join(models_dir, "server_model.pth"))
|
||||||
# 保存模型
|
|
||||||
torch.save(global_server_model.state_dict(), "server_model.pth")
|
|
||||||
for i in range(NUM_CLIENTS):
|
for i in range(NUM_CLIENTS):
|
||||||
torch.save(client_models[i].state_dict(), f"client{i}_model.pth")
|
torch.save(client_models[i].state_dict(),
|
||||||
|
os.path.join(models_dir, f"client{i}_model.pth"))
|
||||||
|
logger.info(f"\n模型已保存至: {models_dir}")
|
||||||
|
|
||||||
# 最终测试
|
# 最终测试
|
||||||
print("\nFinal Evaluation:")
|
logger.info("\n最终测试结果:")
|
||||||
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||||
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
|
server_model.load_state_dict(torch.load(os.path.join(models_dir, "server_model.pth"),weights_only=True))
|
||||||
print(f"Server Accuracy: {test_model(server_model, server_test_loader):.2f}%")
|
logger.info(f"服务器最终准确率: {test_model(server_model, server_test_loader):.2f}%")
|
||||||
|
|
||||||
for i in range(NUM_CLIENTS):
|
for i in range(NUM_CLIENTS):
|
||||||
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
|
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
|
||||||
client_model.load_state_dict(torch.load(f"client{i}_model.pth",weights_only=True))
|
client_model.load_state_dict(torch.load(os.path.join(models_dir, f"client{i}_model.pth"),weights_only=True))
|
||||||
test_loader = client_datasets[i]['test']
|
test_loader = client_datasets[i]['test']
|
||||||
print(f"Client {i} Accuracy: {test_model(client_model, test_loader):.2f}%")
|
logger.info(f"客户端 {i} 最终准确率: {test_model(client_model, test_loader):.2f}%")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
@ -1,12 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
from logger import logger
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
from tqdm import tqdm # 导入 tqdm
|
from tqdm import tqdm # 导入 tqdm
|
||||||
import logging
|
|
||||||
|
|
||||||
class ImageClassificationDataset(Dataset):
|
class ImageClassificationDataset(Dataset):
|
||||||
def __init__(self, root_dir, transform=None, Cache=False):
|
def __init__(self, root_dir, transform=None, Cache=False):
|
||||||
@ -19,7 +17,7 @@ class ImageClassificationDataset(Dataset):
|
|||||||
self.labels = []
|
self.labels = []
|
||||||
self.Cache = Cache
|
self.Cache = Cache
|
||||||
|
|
||||||
logger.log("info",
|
print("info",
|
||||||
"init the dataloader"
|
"init the dataloader"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -38,7 +36,7 @@ class ImageClassificationDataset(Dataset):
|
|||||||
self.image_paths.append(img_path)
|
self.image_paths.append(img_path)
|
||||||
self.labels.append(self.class_to_idx[cls_name])
|
self.labels.append(self.class_to_idx[cls_name])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.log("info", f"Read image error: {img_path} - {e}")
|
print("info", f"Read image error: {img_path} - {e}")
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.labels)
|
return len(self.labels)
|
||||||
|
41
logger.py
41
logger.py
@ -1,27 +1,27 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from config import config
|
|
||||||
|
|
||||||
class Logger:
|
class Logger:
|
||||||
def __init__(self):
|
def __init__(self, log_dir):
|
||||||
os.makedirs(config.log_dir, exist_ok=True)
|
self.log_dir = log_dir
|
||||||
log_file = os.path.join(
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
config.log_dir,
|
|
||||||
f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger = logging.getLogger(__name__)
|
# 日志文件路径
|
||||||
self.logger.setLevel(config.log_level)
|
log_file = os.path.join(log_dir, "training.log")
|
||||||
|
|
||||||
# 文件输出
|
# 创建日志记录器
|
||||||
|
self.logger = logging.getLogger("FedKD")
|
||||||
|
self.logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
# 文件处理器(包含详细信息)
|
||||||
file_handler = logging.FileHandler(log_file)
|
file_handler = logging.FileHandler(log_file)
|
||||||
file_formatter = logging.Formatter(
|
file_formatter = logging.Formatter(
|
||||||
"%(asctime)s - %(levelname)s - %(message)s"
|
"%(asctime)s - %(levelname)s - %(message)s"
|
||||||
)
|
)
|
||||||
file_handler.setFormatter(file_formatter)
|
file_handler.setFormatter(file_formatter)
|
||||||
|
|
||||||
# 控制台输出
|
# 控制台处理器(简洁输出)
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_formatter = logging.Formatter("%(message)s")
|
console_formatter = logging.Formatter("%(message)s")
|
||||||
console_handler.setFormatter(console_formatter)
|
console_handler.setFormatter(console_formatter)
|
||||||
@ -29,7 +29,20 @@ class Logger:
|
|||||||
self.logger.addHandler(file_handler)
|
self.logger.addHandler(file_handler)
|
||||||
self.logger.addHandler(console_handler)
|
self.logger.addHandler(console_handler)
|
||||||
|
|
||||||
def log(self, level, message):
|
def info(self, message):
|
||||||
getattr(self.logger, level)(message)
|
self.logger.info(message)
|
||||||
|
|
||||||
logger = Logger()
|
def debug(self, message):
|
||||||
|
self.logger.debug(message)
|
||||||
|
|
||||||
|
def warning(self, message):
|
||||||
|
self.logger.warning(message)
|
||||||
|
|
||||||
|
def error(self, message):
|
||||||
|
self.logger.error(message)
|
||||||
|
|
||||||
|
def create_run_dir(base_dir="runs"):
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
run_dir = os.path.join(base_dir, timestamp)
|
||||||
|
os.makedirs(run_dir, exist_ok=True)
|
||||||
|
return run_dir
|
Loading…
Reference in New Issue
Block a user