Merge branch 'master' of http://git.yoiannis.top/paper/TA_EC
This commit is contained in:
commit
4cb9790dee
158
FED.py
158
FED.py
@ -12,35 +12,38 @@ from model.repvit import repvit_m1_1
|
|||||||
from model.mobilenetv3 import MobileNetV3
|
from model.mobilenetv3 import MobileNetV3
|
||||||
|
|
||||||
# 配置参数
|
# 配置参数
|
||||||
NUM_CLIENTS = 4
|
NUM_CLIENTS = 2
|
||||||
NUM_ROUNDS = 3
|
NUM_ROUNDS = 10
|
||||||
CLIENT_EPOCHS = 5
|
CLIENT_EPOCHS = 2
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
TEMP = 2.0 # 蒸馏温度
|
TEMP = 2.0 # 蒸馏温度
|
||||||
|
CLASS_NUM = [3, 3, 3]
|
||||||
|
|
||||||
# 设备配置
|
# 设备配置
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# 数据准备
|
# 数据准备
|
||||||
def prepare_data(num_clients):
|
import os
|
||||||
|
from torchvision.datasets import ImageFolder
|
||||||
|
|
||||||
|
def prepare_data():
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.Resize((224, 224)), # 将图像调整为 224x224
|
transforms.Resize((224, 224)),
|
||||||
transforms.Grayscale(num_output_channels=3),
|
transforms.ToTensor()
|
||||||
transforms.ToTensor()
|
])
|
||||||
])
|
|
||||||
train_set = datasets.MNIST("./data", train=True, download=True, transform=transform)
|
|
||||||
|
|
||||||
# 非IID数据划分(每个客户端2个类别)
|
# Load datasets
|
||||||
client_data = {i: [] for i in range(num_clients)}
|
dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform)
|
||||||
labels = train_set.targets.numpy()
|
dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform)
|
||||||
for label in range(10):
|
dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform)
|
||||||
label_idx = np.where(labels == label)[0]
|
|
||||||
np.random.shuffle(label_idx)
|
|
||||||
split = np.array_split(label_idx, num_clients//2)
|
|
||||||
for i, idx in enumerate(split):
|
|
||||||
client_data[i*2 + label%2].extend(idx)
|
|
||||||
|
|
||||||
return [Subset(train_set, ids) for ids in client_data.values()]
|
# Assign datasets to clients
|
||||||
|
client_datasets = [dataset_B, dataset_C]
|
||||||
|
|
||||||
|
# Server dataset (A) for public updates
|
||||||
|
public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
|
||||||
|
return client_datasets, public_loader
|
||||||
|
|
||||||
# 客户端训练函数
|
# 客户端训练函数
|
||||||
def client_train(client_model, server_model, dataset):
|
def client_train(client_model, server_model, dataset):
|
||||||
@ -103,13 +106,6 @@ def client_train(client_model, server_model, dataset):
|
|||||||
})
|
})
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|
||||||
# 每10个batch打印详细信息
|
|
||||||
if (batch_idx + 1) % 10 == 0:
|
|
||||||
progress_bar.write(f"\nEpoch {epoch+1} | Batch {batch_idx+1}")
|
|
||||||
progress_bar.write(f"Task Loss: {loss_task:.4f}")
|
|
||||||
progress_bar.write(f"Distill Loss: {loss_distill:.4f}")
|
|
||||||
progress_bar.write(f"Total Loss: {total_loss:.4f}")
|
|
||||||
progress_bar.write(f"Batch Accuracy: {100*correct/total:.2f}%\n")
|
|
||||||
# 每个epoch结束打印汇总信息
|
# 每个epoch结束打印汇总信息
|
||||||
avg_loss = epoch_loss / len(loader)
|
avg_loss = epoch_loss / len(loader)
|
||||||
avg_task = task_loss / len(loader)
|
avg_task = task_loss / len(loader)
|
||||||
@ -133,6 +129,37 @@ def aggregate(client_params):
|
|||||||
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
|
global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
|
||||||
return global_params
|
return global_params
|
||||||
|
|
||||||
|
def server_aggregate(server_model, client_models, public_loader):
|
||||||
|
server_model.train()
|
||||||
|
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
for data, _ in public_loader:
|
||||||
|
data = data.to(device)
|
||||||
|
|
||||||
|
# 获取客户端模型特征
|
||||||
|
client_features = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for model in client_models:
|
||||||
|
features = model.extract_features(data) # 需要实现特征提取方法
|
||||||
|
client_features.append(features)
|
||||||
|
|
||||||
|
# 计算特征蒸馏目标
|
||||||
|
target_features = torch.stack(client_features).mean(dim=0)
|
||||||
|
|
||||||
|
# 服务器前向
|
||||||
|
server_features = server_model.extract_features(data)
|
||||||
|
|
||||||
|
# 特征对齐损失
|
||||||
|
loss = F.mse_loss(server_features, target_features)
|
||||||
|
|
||||||
|
# 反向传播
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# 更新统计信息
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
# 服务器知识更新
|
# 服务器知识更新
|
||||||
def server_update(server_model, client_models, public_loader):
|
def server_update(server_model, client_models, public_loader):
|
||||||
server_model.train()
|
server_model.train()
|
||||||
@ -189,63 +216,51 @@ def test_model(model, test_loader):
|
|||||||
|
|
||||||
# 主训练流程
|
# 主训练流程
|
||||||
def main():
|
def main():
|
||||||
# 初始化模型
|
transform = transforms.Compose([
|
||||||
global_server_model = repvit_m1_1(num_classes=10).to(device)
|
transforms.Resize((224, 224)),
|
||||||
client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)]
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
# Initialize models
|
||||||
|
global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||||
|
client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
|
||||||
|
|
||||||
# 准备数据
|
# Prepare data
|
||||||
client_datasets = prepare_data(NUM_CLIENTS)
|
client_datasets, public_loader = prepare_data()
|
||||||
public_loader = DataLoader(
|
|
||||||
datasets.MNIST("./data", train=False, download=True,
|
|
||||||
transform= transforms.Compose([
|
|
||||||
transforms.Resize((224, 224)), # 将图像调整为 224x224
|
|
||||||
transforms.Grayscale(num_output_channels=3),
|
|
||||||
transforms.ToTensor() # 将图像转换为张量
|
|
||||||
])),
|
|
||||||
batch_size=100, shuffle=True)
|
|
||||||
|
|
||||||
test_dataset = datasets.MNIST(
|
# Test dataset (using dataset A's test set for simplicity)
|
||||||
"./data",
|
test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform)
|
||||||
train=False,
|
|
||||||
transform= transforms.Compose([
|
|
||||||
transforms.Resize((224, 224)), # 将图像调整为 224x224
|
|
||||||
transforms.Grayscale(num_output_channels=3),
|
|
||||||
transforms.ToTensor() # 将图像转换为张量
|
|
||||||
])
|
|
||||||
)
|
|
||||||
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
||||||
|
|
||||||
|
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
||||||
|
|
||||||
for round in range(NUM_ROUNDS):
|
for round in range(NUM_ROUNDS):
|
||||||
print(f"\n{'#'*50}")
|
print(f"\n{'#'*50}")
|
||||||
print(f"Federated Round {round+1}/{NUM_ROUNDS}")
|
print(f"Federated Round {round+1}/{NUM_ROUNDS}")
|
||||||
print(f"{'#'*50}")
|
print(f"{'#'*50}")
|
||||||
|
|
||||||
# 客户端选择
|
# Client selection (only 2 clients)
|
||||||
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
|
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
|
||||||
print(f"Selected Clients: {selected_clients}")
|
print(f"Selected Clients: {selected_clients}")
|
||||||
|
|
||||||
# 客户端本地训练
|
# Client local training
|
||||||
client_params = []
|
client_params = []
|
||||||
for cid in selected_clients:
|
for cid in selected_clients:
|
||||||
print(f"\nTraining Client {cid}")
|
print(f"\nTraining Client {cid}")
|
||||||
local_model = copy.deepcopy(client_models[cid])
|
local_model = copy.deepcopy(client_models[cid])
|
||||||
local_model.load_state_dict(client_models[cid].state_dict())
|
local_model.load_state_dict(client_models[cid].state_dict())
|
||||||
|
|
||||||
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
|
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
|
||||||
client_params.append(updated_params)
|
client_params.append(updated_params)
|
||||||
|
|
||||||
# 模型聚合
|
# Model aggregation
|
||||||
global_client_params = aggregate(client_params)
|
global_client_params = aggregate(client_params)
|
||||||
for model in client_models:
|
for model in client_models:
|
||||||
model.load_state_dict(global_client_params)
|
model.load_state_dict(global_client_params)
|
||||||
|
|
||||||
# 服务器知识更新
|
# Server knowledge update
|
||||||
print("\nServer Updating...")
|
print("\nServer Updating...")
|
||||||
server_update(global_server_model, client_models, public_loader)
|
server_update(global_server_model, client_models, public_loader)
|
||||||
|
|
||||||
# 测试模型性能
|
# Test model performance
|
||||||
server_acc = test_model(global_server_model, test_loader)
|
server_acc = test_model(global_server_model, test_loader)
|
||||||
client_acc = test_model(client_models[0], test_loader)
|
client_acc = test_model(client_models[0], test_loader)
|
||||||
print(f"\nRound {round+1} Performance:")
|
print(f"\nRound {round+1} Performance:")
|
||||||
@ -253,29 +268,28 @@ def main():
|
|||||||
print(f"Client Model Accuracy: {client_acc:.2f}%")
|
print(f"Client Model Accuracy: {client_acc:.2f}%")
|
||||||
|
|
||||||
round_progress.update(1)
|
round_progress.update(1)
|
||||||
|
|
||||||
print(f"Round {round+1} completed")
|
print(f"Round {round+1} completed")
|
||||||
|
|
||||||
print("Training completed!")
|
print("Training completed!")
|
||||||
|
|
||||||
# 保存训练好的模型
|
# Save trained models
|
||||||
torch.save(global_server_model.state_dict(), "server_model.pth")
|
torch.save(global_server_model.state_dict(), "server_model.pth")
|
||||||
torch.save(client_models[0].state_dict(), "client_model.pth")
|
for i in range(NUM_CLIENTS):
|
||||||
|
torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth")
|
||||||
print("Models saved successfully.")
|
print("Models saved successfully.")
|
||||||
|
|
||||||
# 创建测试数据加载器
|
# Test server model
|
||||||
|
server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
|
||||||
# 测试服务器模型
|
server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
|
||||||
server_model = repvit_m1_1(num_classes=10).to(device)
|
|
||||||
server_model.load_state_dict(torch.load("server_model.pth"))
|
|
||||||
server_acc = test_model(server_model, test_loader)
|
server_acc = test_model(server_model, test_loader)
|
||||||
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
||||||
|
|
||||||
# 测试客户端模型
|
# Test client model
|
||||||
client_model = MobileNetV3(n_class=10).to(device)
|
for i in range(NUM_CLIENTS):
|
||||||
client_model.load_state_dict(torch.load("client_model.pth"))
|
client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
|
||||||
client_acc = test_model(client_model, test_loader)
|
client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True))
|
||||||
print(f"Client Model Test Accuracy: {client_acc:.2f}%")
|
client_acc = test_model(client_model, test_loader)
|
||||||
|
print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
@ -8,7 +8,7 @@ class Config:
|
|||||||
|
|
||||||
# 训练参数
|
# 训练参数
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
batch_size = 32
|
batch_size = 128
|
||||||
epochs = 150
|
epochs = 150
|
||||||
learning_rate = 0.001
|
learning_rate = 0.001
|
||||||
save_path = "checkpoints/best_model.pth"
|
save_path = "checkpoints/best_model.pth"
|
||||||
@ -22,4 +22,6 @@ class Config:
|
|||||||
checkpoint_path = "checkpoints/last_checkpoint.pth"
|
checkpoint_path = "checkpoints/last_checkpoint.pth"
|
||||||
output_path = "runs/"
|
output_path = "runs/"
|
||||||
|
|
||||||
|
cache = 'RAM'
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
118
data_loader.py
118
data_loader.py
@ -1,31 +1,68 @@
|
|||||||
import os
|
import os
|
||||||
|
from logger import logger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision import datasets, transforms
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
class ClassifyDataset(Dataset):
|
class ImageClassificationDataset(Dataset):
|
||||||
def __init__(self, data_dir,transforms = None):
|
def __init__(self, root_dir, transform=None,Cache=False):
|
||||||
self.data_dir = data_dir
|
self.root_dir = root_dir
|
||||||
# Assume the dataset is structured with subdirectories for each class
|
self.transform = transform
|
||||||
self.transform = transforms
|
self.classes = sorted(os.listdir(root_dir))
|
||||||
self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform)
|
self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
|
||||||
self.image_size = (3, 224, 224)
|
self.image_paths = []
|
||||||
|
self.image = []
|
||||||
|
self.labels = []
|
||||||
|
self.Cache = Cache
|
||||||
|
|
||||||
|
logger.log("info",
|
||||||
|
"init the dataloader"
|
||||||
|
)
|
||||||
|
|
||||||
|
for cls_name in self.classes:
|
||||||
|
cls_dir = os.path.join(root_dir, cls_name)
|
||||||
|
for img_name in os.listdir(cls_dir):
|
||||||
|
try:
|
||||||
|
img_path = os.path.join(cls_dir, img_name)
|
||||||
|
imgs = Image.open(img_path).convert('RGB')
|
||||||
|
if Cache == 'RAM':
|
||||||
|
if self.transform:
|
||||||
|
imgs = self.transform(imgs)
|
||||||
|
self.image.append(imgs)
|
||||||
|
else:
|
||||||
|
self.image_paths.append(img_path)
|
||||||
|
self.labels.append(self.class_to_idx[cls_name])
|
||||||
|
except:
|
||||||
|
logger.log("info",
|
||||||
|
"read image error " +
|
||||||
|
img_path
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.dataset)
|
return len(self.labels)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
try:
|
label = self.labels[idx]
|
||||||
image, label = self.dataset[idx]
|
if self.Cache == 'RAM':
|
||||||
return image, label
|
image = self.image[idx]
|
||||||
except Exception as e:
|
else:
|
||||||
black_image = np.zeros((224, 224, 3), dtype=np.uint8)
|
img_path = self.image_paths[idx]
|
||||||
return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签
|
image = Image.open(img_path).convert('RGB')
|
||||||
|
if self.transform:
|
||||||
def create_data_loaders(data_dir,batch_size=64):
|
image = self.transform(image)
|
||||||
# Define transformations for training data augmentation and normalization
|
|
||||||
|
return image, label
|
||||||
|
|
||||||
|
def get_data_loader(root_dir, batch_size=64, num_workers=4, pin_memory=True,Cache=False):
|
||||||
|
# Define the transform for the training data and for the validation data
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)), # Resize images to 224x224
|
||||||
|
transforms.ToTensor(), # Convert PIL Image to Tensor
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize the images
|
||||||
|
])
|
||||||
|
|
||||||
|
# Define transformations for training data augmentation and normalization
|
||||||
train_transforms = transforms.Compose([
|
train_transforms = transforms.Compose([
|
||||||
transforms.Resize((224, 224)),
|
transforms.Resize((224, 224)),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
@ -40,17 +77,38 @@ def create_data_loaders(data_dir,batch_size=64):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# Load the datasets with ImageFolder
|
# Load the datasets with ImageFolder
|
||||||
train_dir = data_dir + '/train'
|
train_dir = root_dir + '/train'
|
||||||
valid_dir = data_dir + '/val'
|
valid_dir = root_dir + '/val'
|
||||||
test_dir = data_dir + '/test'
|
test_dir = root_dir + '/test'
|
||||||
|
|
||||||
train_data = ClassifyDataset(train_dir, transforms=train_transforms)
|
train_data = ImageClassificationDataset(train_dir, transform=train_transforms,Cache=Cache)
|
||||||
valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms)
|
valid_data = ImageClassificationDataset(valid_dir, transform=valid_test_transforms,Cache=Cache)
|
||||||
test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms)
|
test_data = ImageClassificationDataset(test_dir, transform=valid_test_transforms,Cache=Cache)
|
||||||
|
|
||||||
# Create the DataLoaders with batch size 64
|
|
||||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
|
|
||||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
|
|
||||||
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
|
|
||||||
|
|
||||||
return train_loader, valid_loader, test_loader
|
# Create the data loader
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=pin_memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the data loader
|
||||||
|
valid_loader = DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=pin_memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the data loader
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=pin_memory
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, valid_loader, test_loader
|
||||||
|
12
dataset/test.py
Normal file
12
dataset/test.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
def debug_walk_with_links(input_folder):
|
||||||
|
for root, dirs, files in os.walk(input_folder):
|
||||||
|
print(f'Root: {root}')
|
||||||
|
print(f'Dirs: {dirs}')
|
||||||
|
print(f'Files: {files}')
|
||||||
|
print('-' * 40)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
input_folder = 'L:/Grade_datasets'
|
||||||
|
debug_walk_with_links(input_folder)
|
5
main.py
5
main.py
@ -7,6 +7,7 @@ from torchvision.datasets import MNIST
|
|||||||
from torchvision.transforms import ToTensor
|
from torchvision.transforms import ToTensor
|
||||||
|
|
||||||
from model.repvit import *
|
from model.repvit import *
|
||||||
|
from model.mobilenetv3 import *
|
||||||
from data_loader import *
|
from data_loader import *
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
@ -14,11 +15,11 @@ def main():
|
|||||||
# 初始化组件
|
# 初始化组件
|
||||||
initialize()
|
initialize()
|
||||||
|
|
||||||
model = repvit_m1_1(num_classes=10).to(config.device)
|
model = repvit_m1_0(num_classes=9).to(config.device)
|
||||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size)
|
train_loader, valid_loader, test_loader = get_data_loader('/home/yoiannis/deep_learning/dataset/02.TA_EC/datasets/EC27',batch_size=config.batch_size,Cache='RAM')
|
||||||
|
|
||||||
# 初始化训练器
|
# 初始化训练器
|
||||||
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
|
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
|
||||||
|
@ -200,6 +200,11 @@ class MobileNetV3(nn.Module):
|
|||||||
|
|
||||||
self._initialize_weights()
|
self._initialize_weights()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.features(x)
|
x = self.features(x)
|
||||||
x = x.mean(3).mean(2)
|
x = x.mean(3).mean(2)
|
||||||
|
@ -236,6 +236,10 @@ class RepViT(nn.Module):
|
|||||||
self.features = nn.ModuleList(layers)
|
self.features = nn.ModuleList(layers)
|
||||||
self.classifier = Classfier(output_channel, num_classes, distillation)
|
self.classifier = Classfier(output_channel, num_classes, distillation)
|
||||||
|
|
||||||
|
def extract_features(self, x):
|
||||||
|
for f in self.features:
|
||||||
|
x = f(x)
|
||||||
|
return x
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# x = self.features(x)
|
# x = self.features(x)
|
||||||
for f in self.features:
|
for f in self.features:
|
||||||
|
@ -4,6 +4,7 @@ from torch.utils.data import DataLoader
|
|||||||
from config import config
|
from config import config
|
||||||
from logger import logger
|
from logger import logger
|
||||||
from utils import save_checkpoint, load_checkpoint
|
from utils import save_checkpoint, load_checkpoint
|
||||||
|
import time
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(self, model, train_loader, val_loader, optimizer, criterion):
|
def __init__(self, model, train_loader, val_loader, optimizer, criterion):
|
||||||
@ -21,7 +22,7 @@ class Trainer:
|
|||||||
self.model.train()
|
self.model.train()
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
|
||||||
|
time_start = time.time()
|
||||||
for batch_idx, (data, target) in enumerate(progress_bar):
|
for batch_idx, (data, target) in enumerate(progress_bar):
|
||||||
data, target = data.to(config.device), target.to(config.device)
|
data, target = data.to(config.device), target.to(config.device)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user