update code for kd
This commit is contained in:
parent
4051332732
commit
49b2110b5f
2
.gitignore
vendored
2
.gitignore
vendored
@ -1 +1,3 @@
|
|||||||
data
|
data
|
||||||
|
*.pth
|
||||||
|
*.pyc
|
||||||
|
179
FED.py
179
FED.py
@ -6,45 +6,28 @@ from torch.utils.data import DataLoader, Subset
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from model.repvit import repvit_m1_1
|
||||||
|
from model.mobilenetv3 import MobileNetV3
|
||||||
|
|
||||||
# 配置参数
|
# 配置参数
|
||||||
NUM_CLIENTS = 10
|
NUM_CLIENTS = 4
|
||||||
NUM_ROUNDS = 3
|
NUM_ROUNDS = 3
|
||||||
CLIENT_EPOCHS = 2
|
CLIENT_EPOCHS = 5
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
TEMP = 2.0 # 蒸馏温度
|
TEMP = 2.0 # 蒸馏温度
|
||||||
|
|
||||||
# 设备配置
|
# 设备配置
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# 定义中心大模型
|
|
||||||
class ServerModel(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.fc1 = nn.Linear(784, 512)
|
|
||||||
self.fc2 = nn.Linear(512, 256)
|
|
||||||
self.fc3 = nn.Linear(256, 10)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.view(-1, 784)
|
|
||||||
x = F.relu(self.fc1(x))
|
|
||||||
x = F.relu(self.fc2(x))
|
|
||||||
return self.fc3(x)
|
|
||||||
|
|
||||||
# 定义端侧小模型
|
|
||||||
class ClientModel(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.fc1 = nn.Linear(784, 64)
|
|
||||||
self.fc2 = nn.Linear(64, 10)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.view(-1, 784)
|
|
||||||
x = F.relu(self.fc1(x))
|
|
||||||
return self.fc2(x)
|
|
||||||
|
|
||||||
# 数据准备
|
# 数据准备
|
||||||
def prepare_data(num_clients):
|
def prepare_data(num_clients):
|
||||||
transform = transforms.Compose([transforms.ToTensor()])
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)), # 将图像调整为 224x224
|
||||||
|
transforms.Grayscale(num_output_channels=3),
|
||||||
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
train_set = datasets.MNIST("./data", train=True, download=True, transform=transform)
|
train_set = datasets.MNIST("./data", train=True, download=True, transform=transform)
|
||||||
|
|
||||||
# 非IID数据划分(每个客户端2个类别)
|
# 非IID数据划分(每个客户端2个类别)
|
||||||
@ -67,30 +50,80 @@ def client_train(client_model, server_model, dataset):
|
|||||||
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
|
optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1)
|
||||||
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
|
||||||
for _ in range(CLIENT_EPOCHS):
|
# 训练进度条
|
||||||
for data, target in loader:
|
progress_bar = tqdm(total=CLIENT_EPOCHS*len(loader),
|
||||||
|
desc="Client Training",
|
||||||
|
unit="batch")
|
||||||
|
|
||||||
|
for epoch in range(CLIENT_EPOCHS):
|
||||||
|
epoch_loss = 0.0
|
||||||
|
task_loss = 0.0
|
||||||
|
distill_loss = 0.0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
for batch_idx, (data, target) in enumerate(loader):
|
||||||
data, target = data.to(device), target.to(device)
|
data, target = data.to(device), target.to(device)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# 获取小模型输出
|
# 前向传播
|
||||||
client_output = client_model(data)
|
client_output = client_model(data)
|
||||||
|
|
||||||
# 获取大模型输出(知识蒸馏)
|
# 获取教师模型输出
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
server_output = server_model(data)
|
server_output = server_model(data)
|
||||||
|
|
||||||
# 计算联合损失
|
# 计算损失
|
||||||
loss_task = F.cross_entropy(client_output, target)
|
loss_task = F.cross_entropy(client_output, target)
|
||||||
loss_distill = F.kl_div(
|
loss_distill = F.kl_div(
|
||||||
F.log_softmax(client_output/TEMP, dim=1),
|
F.log_softmax(client_output/TEMP, dim=1),
|
||||||
F.softmax(server_output/TEMP, dim=1),
|
F.softmax(server_output/TEMP, dim=1),
|
||||||
reduction="batchmean"
|
reduction="batchmean"
|
||||||
) * (TEMP**2)
|
) * (TEMP**2)
|
||||||
|
|
||||||
total_loss = loss_task + loss_distill
|
total_loss = loss_task + loss_distill
|
||||||
|
|
||||||
|
# 反向传播
|
||||||
total_loss.backward()
|
total_loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
# 统计指标
|
||||||
|
epoch_loss += total_loss.item()
|
||||||
|
task_loss += loss_task.item()
|
||||||
|
distill_loss += loss_distill.item()
|
||||||
|
_, predicted = torch.max(client_output.data, 1)
|
||||||
|
correct += (predicted == target).sum().item()
|
||||||
|
total += target.size(0)
|
||||||
|
|
||||||
|
# 实时更新进度条
|
||||||
|
progress_bar.set_postfix({
|
||||||
|
"Epoch": f"{epoch+1}/{CLIENT_EPOCHS}",
|
||||||
|
"Batch": f"{batch_idx+1}/{len(loader)}",
|
||||||
|
"Loss": f"{total_loss.item():.4f}",
|
||||||
|
"Acc": f"{100*correct/total:.2f}%\n",
|
||||||
|
})
|
||||||
|
progress_bar.update(1)
|
||||||
|
|
||||||
|
# 每10个batch打印详细信息
|
||||||
|
if (batch_idx + 1) % 10 == 0:
|
||||||
|
progress_bar.write(f"\nEpoch {epoch+1} | Batch {batch_idx+1}")
|
||||||
|
progress_bar.write(f"Task Loss: {loss_task:.4f}")
|
||||||
|
progress_bar.write(f"Distill Loss: {loss_distill:.4f}")
|
||||||
|
progress_bar.write(f"Total Loss: {total_loss:.4f}")
|
||||||
|
progress_bar.write(f"Batch Accuracy: {100*correct/total:.2f}%\n")
|
||||||
|
# 每个epoch结束打印汇总信息
|
||||||
|
avg_loss = epoch_loss / len(loader)
|
||||||
|
avg_task = task_loss / len(loader)
|
||||||
|
avg_distill = distill_loss / len(loader)
|
||||||
|
epoch_acc = 100 * correct / total
|
||||||
|
print(f"\n{'='*40}")
|
||||||
|
print(f"Epoch {epoch+1} Summary:")
|
||||||
|
print(f"Average Loss: {avg_loss:.4f}")
|
||||||
|
print(f"Task Loss: {avg_task:.4f}")
|
||||||
|
print(f"Distill Loss: {avg_distill:.4f}")
|
||||||
|
print(f"Training Accuracy: {epoch_acc:.2f}%")
|
||||||
|
print(f"{'='*40}\n")
|
||||||
|
|
||||||
|
progress_bar.close()
|
||||||
return client_model.state_dict()
|
return client_model.state_dict()
|
||||||
|
|
||||||
# 模型参数聚合(FedAvg)
|
# 模型参数聚合(FedAvg)
|
||||||
@ -105,7 +138,10 @@ def server_update(server_model, client_models, public_loader):
|
|||||||
server_model.train()
|
server_model.train()
|
||||||
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
|
||||||
|
|
||||||
for data, _ in public_loader:
|
total_loss = 0.0
|
||||||
|
progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch")
|
||||||
|
|
||||||
|
for batch_idx, (data, _) in enumerate(progress_bar):
|
||||||
data = data.to(device)
|
data = data.to(device)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
@ -122,9 +158,20 @@ def server_update(server_model, client_models, public_loader):
|
|||||||
reduction="batchmean"
|
reduction="batchmean"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 反向传播
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
# 更新统计信息
|
||||||
|
total_loss += loss.item()
|
||||||
|
progress_bar.set_postfix({
|
||||||
|
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}",
|
||||||
|
"Current Loss": f"{loss.item():.4f}"
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n")
|
||||||
|
|
||||||
|
|
||||||
def test_model(model, test_loader):
|
def test_model(model, test_loader):
|
||||||
model.eval()
|
model.eval()
|
||||||
correct = 0
|
correct = 0
|
||||||
@ -139,36 +186,54 @@ def test_model(model, test_loader):
|
|||||||
accuracy = 100 * correct / total
|
accuracy = 100 * correct / total
|
||||||
return accuracy
|
return accuracy
|
||||||
|
|
||||||
|
|
||||||
# 主训练流程
|
# 主训练流程
|
||||||
def main():
|
def main():
|
||||||
# 初始化模型
|
# 初始化模型
|
||||||
global_server_model = ServerModel().to(device)
|
global_server_model = repvit_m1_1(num_classes=10).to(device)
|
||||||
client_models = [ClientModel().to(device) for _ in range(NUM_CLIENTS)]
|
client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)]
|
||||||
|
|
||||||
|
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
||||||
|
|
||||||
# 准备数据
|
# 准备数据
|
||||||
client_datasets = prepare_data(NUM_CLIENTS)
|
client_datasets = prepare_data(NUM_CLIENTS)
|
||||||
public_loader = DataLoader(
|
public_loader = DataLoader(
|
||||||
datasets.MNIST("./data", train=False, download=True,
|
datasets.MNIST("./data", train=False, download=True,
|
||||||
transform=transforms.ToTensor()),
|
transform= transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)), # 将图像调整为 224x224
|
||||||
|
transforms.Grayscale(num_output_channels=3),
|
||||||
|
transforms.ToTensor() # 将图像转换为张量
|
||||||
|
])),
|
||||||
batch_size=100, shuffle=True)
|
batch_size=100, shuffle=True)
|
||||||
|
|
||||||
|
test_dataset = datasets.MNIST(
|
||||||
|
"./data",
|
||||||
|
train=False,
|
||||||
|
transform= transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)), # 将图像调整为 224x224
|
||||||
|
transforms.Grayscale(num_output_channels=3),
|
||||||
|
transforms.ToTensor() # 将图像转换为张量
|
||||||
|
])
|
||||||
|
)
|
||||||
|
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
||||||
|
|
||||||
for round in range(NUM_ROUNDS):
|
for round in range(NUM_ROUNDS):
|
||||||
|
print(f"\n{'#'*50}")
|
||||||
|
print(f"Federated Round {round+1}/{NUM_ROUNDS}")
|
||||||
|
print(f"{'#'*50}")
|
||||||
|
|
||||||
# 客户端选择
|
# 客户端选择
|
||||||
selected_clients = np.random.choice(NUM_CLIENTS, 5, replace=False)
|
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
|
||||||
|
print(f"Selected Clients: {selected_clients}")
|
||||||
|
|
||||||
# 客户端本地训练
|
# 客户端本地训练
|
||||||
client_params = []
|
client_params = []
|
||||||
for cid in selected_clients:
|
for cid in selected_clients:
|
||||||
# 下载全局模型
|
print(f"\nTraining Client {cid}")
|
||||||
local_model = copy.deepcopy(client_models[cid])
|
local_model = copy.deepcopy(client_models[cid])
|
||||||
local_model.load_state_dict(client_models[cid].state_dict())
|
local_model.load_state_dict(client_models[cid].state_dict())
|
||||||
|
|
||||||
# 本地训练
|
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
|
||||||
updated_params = client_train(
|
|
||||||
local_model,
|
|
||||||
global_server_model,
|
|
||||||
client_datasets[cid]
|
|
||||||
)
|
|
||||||
client_params.append(updated_params)
|
client_params.append(updated_params)
|
||||||
|
|
||||||
# 模型聚合
|
# 模型聚合
|
||||||
@ -177,8 +242,18 @@ def main():
|
|||||||
model.load_state_dict(global_client_params)
|
model.load_state_dict(global_client_params)
|
||||||
|
|
||||||
# 服务器知识更新
|
# 服务器知识更新
|
||||||
|
print("\nServer Updating...")
|
||||||
server_update(global_server_model, client_models, public_loader)
|
server_update(global_server_model, client_models, public_loader)
|
||||||
|
|
||||||
|
# 测试模型性能
|
||||||
|
server_acc = test_model(global_server_model, test_loader)
|
||||||
|
client_acc = test_model(client_models[0], test_loader)
|
||||||
|
print(f"\nRound {round+1} Performance:")
|
||||||
|
print(f"Global Model Accuracy: {server_acc:.2f}%")
|
||||||
|
print(f"Client Model Accuracy: {client_acc:.2f}%")
|
||||||
|
|
||||||
|
round_progress.update(1)
|
||||||
|
|
||||||
print(f"Round {round+1} completed")
|
print(f"Round {round+1} completed")
|
||||||
|
|
||||||
print("Training completed!")
|
print("Training completed!")
|
||||||
@ -189,21 +264,15 @@ def main():
|
|||||||
print("Models saved successfully.")
|
print("Models saved successfully.")
|
||||||
|
|
||||||
# 创建测试数据加载器
|
# 创建测试数据加载器
|
||||||
test_dataset = datasets.MNIST(
|
|
||||||
"./data",
|
|
||||||
train=False,
|
|
||||||
transform=transforms.ToTensor()
|
|
||||||
)
|
|
||||||
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
|
||||||
|
|
||||||
# 测试服务器模型
|
# 测试服务器模型
|
||||||
server_model = ServerModel().to(device)
|
server_model = repvit_m1_1(num_classes=10).to(device)
|
||||||
server_model.load_state_dict(torch.load("server_model.pth"))
|
server_model.load_state_dict(torch.load("server_model.pth"))
|
||||||
server_acc = test_model(server_model, test_loader)
|
server_acc = test_model(server_model, test_loader)
|
||||||
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
||||||
|
|
||||||
# 测试客户端模型
|
# 测试客户端模型
|
||||||
client_model = ClientModel().to(device)
|
client_model = MobileNetV3(n_class=10).to(device)
|
||||||
client_model.load_state_dict(torch.load("client_model.pth"))
|
client_model.load_state_dict(torch.load("client_model.pth"))
|
||||||
client_acc = test_model(client_model, test_loader)
|
client_acc = test_model(client_model, test_loader)
|
||||||
print(f"Client Model Test Accuracy: {client_acc:.2f}%")
|
print(f"Client Model Test Accuracy: {client_acc:.2f}%")
|
||||||
|
157
dataset/recover.py
Normal file
157
dataset/recover.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import random
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 假设我们有一个函数 `remove_background` 使用某种方法去背景,返回前景掩码和前景图像
|
||||||
|
def remove_background(image_path):
|
||||||
|
# 加载图像
|
||||||
|
source_image = cv2.imread(image_path)
|
||||||
|
# 这里可以使用一个预训练的模型去背景,比如 U2-Net。为了简化,假设我们得到一个二值掩码
|
||||||
|
# 掩码生成逻辑可以替换为实际的模型推理
|
||||||
|
# 转换为灰度图像
|
||||||
|
GRAY = cv2.cvtColor(source_image, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# 二值化处理
|
||||||
|
_, mask_threshold = cv2.threshold(GRAY, 0, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
|
||||||
|
|
||||||
|
# 定义结构元素
|
||||||
|
element = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
|
||||||
|
element1 = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
|
||||||
|
|
||||||
|
# 膨胀和腐蚀操作
|
||||||
|
mask_dilate = cv2.dilate(mask_threshold, element)
|
||||||
|
mask_erode = cv2.erode(mask_dilate, element1)
|
||||||
|
|
||||||
|
# 计算非零像素数量
|
||||||
|
count2 = cv2.countNonZero(mask_erode)
|
||||||
|
|
||||||
|
# 查找轮廓
|
||||||
|
contours, hierarchy = cv2.findContours(mask_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)
|
||||||
|
|
||||||
|
# 过滤轮廓
|
||||||
|
contours = [c for c in contours if cv2.contourArea(c) >= count2 * 0.3]
|
||||||
|
|
||||||
|
# 绘制轮廓
|
||||||
|
mask = np.zeros_like(mask_erode)
|
||||||
|
cv2.drawContours(mask, contours, -1, 1, -1)
|
||||||
|
|
||||||
|
# 将掩码转换为3通道
|
||||||
|
mask_cvtColor = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
# 应用掩码
|
||||||
|
source_image_multiply = cv2.multiply(source_image, mask_cvtColor)
|
||||||
|
|
||||||
|
# 转换为HSV颜色空间
|
||||||
|
imgHSV = cv2.cvtColor(source_image_multiply, cv2.COLOR_BGR2HSV)
|
||||||
|
|
||||||
|
# 定义HSV范围
|
||||||
|
scalarL = np.array([0, 46, 46])
|
||||||
|
scalarH = np.array([45, 255, 255])
|
||||||
|
|
||||||
|
# 根据HSV范围生成掩码
|
||||||
|
mask_inRange = cv2.inRange(imgHSV, scalarL, scalarH)
|
||||||
|
|
||||||
|
# 二值化处理
|
||||||
|
_, mask_tthreshold = cv2.threshold(mask_inRange, 0, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
|
||||||
|
|
||||||
|
# 中值滤波
|
||||||
|
mask_medianBlur = cv2.medianBlur(mask_tthreshold, 7)
|
||||||
|
|
||||||
|
# 将掩码转换为3通道
|
||||||
|
mask_scvtColor = cv2.cvtColor(mask_medianBlur, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
# 应用掩码
|
||||||
|
source_image = cv2.multiply(source_image, mask_scvtColor)
|
||||||
|
|
||||||
|
return source_image
|
||||||
|
|
||||||
|
def synthesize_background(foreground, background):
|
||||||
|
|
||||||
|
# 创建前景掩膜(非黑色区域)
|
||||||
|
gray = cv2.cvtColor(foreground, cv2.COLOR_BGR2GRAY)
|
||||||
|
_, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY) # 阈值设为1以保留所有非纯黑像素
|
||||||
|
|
||||||
|
mask = cv2.GaussianBlur(mask, (5,5), 0) # 高斯模糊柔化边缘
|
||||||
|
_, mask = cv2.threshold(mask, 200, 255, cv2.THRESH_BINARY) # 重新二值化
|
||||||
|
|
||||||
|
# 精准形态学处理
|
||||||
|
kernel = np.ones((2,2), np.uint8)
|
||||||
|
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1) # 闭运算填充小孔
|
||||||
|
|
||||||
|
# 反转掩膜用于获取背景区域
|
||||||
|
mask_inv = cv2.bitwise_not(mask)
|
||||||
|
|
||||||
|
# 提取背景和前景的ROI区域
|
||||||
|
background_roi = cv2.bitwise_and(background, background, mask=mask_inv)
|
||||||
|
foreground_roi = cv2.bitwise_and(foreground, foreground, mask=mask)
|
||||||
|
|
||||||
|
# 合成图像
|
||||||
|
result = cv2.add(foreground_roi, background_roi)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def edge_fill2(img):
|
||||||
|
(height, width, p) = img.shape
|
||||||
|
H = 2384
|
||||||
|
W = 1560
|
||||||
|
top = bottom=int((W - height) / 2)
|
||||||
|
left= right= int((H - width) / 2)
|
||||||
|
img_result = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
|
||||||
|
return img_result
|
||||||
|
|
||||||
|
|
||||||
|
def process_images(input_folder, background_image_path, output_base):
|
||||||
|
"""
|
||||||
|
递归处理所有子文件夹并保持目录结构
|
||||||
|
"""
|
||||||
|
# 预处理背景路径(只需执行一次)
|
||||||
|
if os.path.isfile(background_image_path):
|
||||||
|
background_paths = [background_image_path]
|
||||||
|
else:
|
||||||
|
valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
|
||||||
|
background_paths = [
|
||||||
|
os.path.join(background_image_path, f)
|
||||||
|
for f in os.listdir(background_image_path)
|
||||||
|
if os.path.splitext(f)[1].lower() in valid_ext
|
||||||
|
]
|
||||||
|
|
||||||
|
# 递归遍历输入目录
|
||||||
|
for root, dirs, files in os.walk(input_folder):
|
||||||
|
# 计算相对路径
|
||||||
|
relative_path = os.path.relpath(root, input_folder)
|
||||||
|
|
||||||
|
# 创建对应的输出目录
|
||||||
|
output_dir = os.path.join(output_base, relative_path)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 处理当前目录的文件
|
||||||
|
for filename in files:
|
||||||
|
input_path = os.path.join(root, filename)
|
||||||
|
output_path = os.path.join(output_dir, filename)
|
||||||
|
|
||||||
|
# 跳过非图像文件
|
||||||
|
if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 去背景处理
|
||||||
|
foreground = remove_background(input_path)
|
||||||
|
|
||||||
|
|
||||||
|
result = edge_fill2(foreground)
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
cv2.imwrite(output_path, result)
|
||||||
|
print(f"Processed: {input_path} -> {output_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {input_path}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
input_directory = 'L:/Tobacco/2023_JY/20230821/SOURCE'
|
||||||
|
background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B'
|
||||||
|
output_directory = 'L:/Test'
|
||||||
|
|
||||||
|
process_images(input_directory, background_image_path, output_directory)
|
60
dataset/split.py
Normal file
60
dataset/split.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import random
|
||||||
|
|
||||||
|
def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
|
||||||
|
# 确保比例总和为1
|
||||||
|
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1"
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
train_dir = os.path.join(output_dir, 'train')
|
||||||
|
val_dir = os.path.join(output_dir, 'val')
|
||||||
|
test_dir = os.path.join(output_dir, 'test')
|
||||||
|
|
||||||
|
os.makedirs(train_dir, exist_ok=True)
|
||||||
|
os.makedirs(val_dir, exist_ok=True)
|
||||||
|
os.makedirs(test_dir, exist_ok=True)
|
||||||
|
|
||||||
|
categories = os.listdir(base_dir)
|
||||||
|
|
||||||
|
for category in categories:
|
||||||
|
category_path = os.path.join(base_dir, category)
|
||||||
|
|
||||||
|
if not os.path.isdir(category_path):
|
||||||
|
continue # 跳过非目录文件
|
||||||
|
|
||||||
|
# 获取类别文件夹中的所有图像文件
|
||||||
|
images = [f for f in os.listdir(category_path) if os.path.isfile(os.path.join(category_path, f))]
|
||||||
|
|
||||||
|
# 打乱图像文件顺序
|
||||||
|
random.shuffle(images)
|
||||||
|
|
||||||
|
# 计算分割点
|
||||||
|
total_images = len(images)
|
||||||
|
train_end = int(train_ratio * total_images)
|
||||||
|
val_end = train_end + int(val_ratio * total_images)
|
||||||
|
|
||||||
|
# 创建类别文件夹于各个输出目录中
|
||||||
|
os.makedirs(os.path.join(train_dir, category), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(val_dir, category), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(test_dir, category), exist_ok=True)
|
||||||
|
|
||||||
|
# 分配图像到训练集
|
||||||
|
for i in range(train_end):
|
||||||
|
shutil.copy(os.path.join(category_path, images[i]), os.path.join(train_dir, category))
|
||||||
|
|
||||||
|
# 分配图像到验证集
|
||||||
|
for i in range(train_end, val_end):
|
||||||
|
shutil.copy(os.path.join(category_path, images[i]), os.path.join(val_dir, category))
|
||||||
|
|
||||||
|
# 分配图像到测试集
|
||||||
|
for i in range(val_end, total_images):
|
||||||
|
shutil.copy(os.path.join(category_path, images[i]), os.path.join(test_dir, category))
|
||||||
|
|
||||||
|
print("Dataset successfully split into train, validation, and test sets.")
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
base_directory = 'F:/dataset/02.TA_EC/EC27/JY_A'
|
||||||
|
output_directory = 'F:/dataset/02.TA_EC/datasets/EC27'
|
||||||
|
create_dataset_splits(base_directory, output_directory)
|
184
dataset/splitbc_compsition.py
Normal file
184
dataset/splitbc_compsition.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
import random
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 假设我们有一个函数 `remove_background` 使用某种方法去背景,返回前景掩码和前景图像
|
||||||
|
def remove_background(image_path):
|
||||||
|
# 加载图像
|
||||||
|
source_image = cv2.imread(image_path)
|
||||||
|
# 这里可以使用一个预训练的模型去背景,比如 U2-Net。为了简化,假设我们得到一个二值掩码
|
||||||
|
# 掩码生成逻辑可以替换为实际的模型推理
|
||||||
|
# 转换为灰度图像
|
||||||
|
GRAY = cv2.cvtColor(source_image, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# 二值化处理
|
||||||
|
_, mask_threshold = cv2.threshold(GRAY, 0, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
|
||||||
|
|
||||||
|
# 定义结构元素
|
||||||
|
element = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
|
||||||
|
element1 = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
|
||||||
|
|
||||||
|
# 膨胀和腐蚀操作
|
||||||
|
mask_dilate = cv2.dilate(mask_threshold, element)
|
||||||
|
mask_erode = cv2.erode(mask_dilate, element1)
|
||||||
|
|
||||||
|
# 计算非零像素数量
|
||||||
|
count2 = cv2.countNonZero(mask_erode)
|
||||||
|
|
||||||
|
# 查找轮廓
|
||||||
|
contours, hierarchy = cv2.findContours(mask_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)
|
||||||
|
|
||||||
|
# 过滤轮廓
|
||||||
|
contours = [c for c in contours if cv2.contourArea(c) >= count2 * 0.3]
|
||||||
|
|
||||||
|
# 绘制轮廓
|
||||||
|
mask = np.zeros_like(mask_erode)
|
||||||
|
cv2.drawContours(mask, contours, -1, 1, -1)
|
||||||
|
|
||||||
|
# 将掩码转换为3通道
|
||||||
|
mask_cvtColor = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
# 应用掩码
|
||||||
|
source_image_multiply = cv2.multiply(source_image, mask_cvtColor)
|
||||||
|
|
||||||
|
# 转换为HSV颜色空间
|
||||||
|
imgHSV = cv2.cvtColor(source_image_multiply, cv2.COLOR_BGR2HSV)
|
||||||
|
|
||||||
|
# 定义HSV范围
|
||||||
|
scalarL = np.array([0, 46, 46])
|
||||||
|
scalarH = np.array([45, 255, 255])
|
||||||
|
|
||||||
|
# 根据HSV范围生成掩码
|
||||||
|
mask_inRange = cv2.inRange(imgHSV, scalarL, scalarH)
|
||||||
|
|
||||||
|
# 二值化处理
|
||||||
|
_, mask_tthreshold = cv2.threshold(mask_inRange, 0, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
|
||||||
|
|
||||||
|
# 中值滤波
|
||||||
|
mask_medianBlur = cv2.medianBlur(mask_tthreshold, 7)
|
||||||
|
|
||||||
|
# 将掩码转换为3通道
|
||||||
|
mask_scvtColor = cv2.cvtColor(mask_medianBlur, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
# 应用掩码
|
||||||
|
source_image = cv2.multiply(source_image, mask_scvtColor)
|
||||||
|
|
||||||
|
return source_image
|
||||||
|
|
||||||
|
def synthesize_background(foreground, background_image_path):
|
||||||
|
# 验证输入路径有效性
|
||||||
|
if not os.path.exists(background_image_path):
|
||||||
|
raise FileNotFoundError(f"指定的背景路径不存在: {background_image_path}")
|
||||||
|
|
||||||
|
# 获取所有可用背景图像路径
|
||||||
|
background_paths = []
|
||||||
|
if os.path.isfile(background_image_path):
|
||||||
|
# 如果是单个图像文件
|
||||||
|
background_paths = [background_image_path]
|
||||||
|
elif os.path.isdir(background_image_path):
|
||||||
|
# 如果是目录,搜索常见图像格式
|
||||||
|
valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
|
||||||
|
for filename in os.listdir(background_image_path):
|
||||||
|
if os.path.splitext(filename)[1].lower() in valid_extensions:
|
||||||
|
background_paths.append(os.path.join(background_image_path, filename))
|
||||||
|
|
||||||
|
# 验证找到的图像文件数量
|
||||||
|
if not background_paths:
|
||||||
|
raise ValueError(f"目录中未找到支持的图像文件: {background_image_path}")
|
||||||
|
|
||||||
|
# 随机选择背景图像
|
||||||
|
selected_bg_path = random.choice(background_paths)
|
||||||
|
|
||||||
|
# 加载并验证背景图像
|
||||||
|
background = cv2.imread(selected_bg_path)
|
||||||
|
|
||||||
|
# 调整背景大小与前景一致
|
||||||
|
background = cv2.resize(background, (foreground.shape[1], foreground.shape[0]))
|
||||||
|
|
||||||
|
# 创建前景掩膜(非黑色区域)
|
||||||
|
gray = cv2.cvtColor(foreground, cv2.COLOR_BGR2GRAY)
|
||||||
|
_, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY) # 阈值设为1以保留所有非纯黑像素
|
||||||
|
|
||||||
|
mask = cv2.GaussianBlur(mask, (5,5), 0) # 高斯模糊柔化边缘
|
||||||
|
_, mask = cv2.threshold(mask, 200, 255, cv2.THRESH_BINARY) # 重新二值化
|
||||||
|
|
||||||
|
# 精准形态学处理
|
||||||
|
kernel = np.ones((2,2), np.uint8)
|
||||||
|
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1) # 闭运算填充小孔
|
||||||
|
|
||||||
|
# 反转掩膜用于获取背景区域
|
||||||
|
mask_inv = cv2.bitwise_not(mask)
|
||||||
|
|
||||||
|
# 提取背景和前景的ROI区域
|
||||||
|
background_roi = cv2.bitwise_and(background, background, mask=mask_inv)
|
||||||
|
foreground_roi = cv2.bitwise_and(foreground, foreground, mask=mask)
|
||||||
|
|
||||||
|
# 合成图像
|
||||||
|
result = cv2.add(foreground_roi, background_roi)
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
return result
|
||||||
|
|
||||||
|
def process_images(input_folder, background_image_path, output_base):
|
||||||
|
"""
|
||||||
|
递归处理所有子文件夹并保持目录结构
|
||||||
|
"""
|
||||||
|
# 预处理背景路径(只需执行一次)
|
||||||
|
if os.path.isfile(background_image_path):
|
||||||
|
background_paths = [background_image_path]
|
||||||
|
else:
|
||||||
|
valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
|
||||||
|
background_paths = [
|
||||||
|
os.path.join(background_image_path, f)
|
||||||
|
for f in os.listdir(background_image_path)
|
||||||
|
if os.path.splitext(f)[1].lower() in valid_ext
|
||||||
|
]
|
||||||
|
|
||||||
|
# 递归遍历输入目录
|
||||||
|
for root, dirs, files in os.walk(input_folder):
|
||||||
|
# 计算相对路径
|
||||||
|
relative_path = os.path.relpath(root, input_folder)
|
||||||
|
|
||||||
|
# 创建对应的输出目录
|
||||||
|
output_dir = os.path.join(output_base, relative_path)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 处理当前目录的文件
|
||||||
|
for filename in files:
|
||||||
|
input_path = os.path.join(root, filename)
|
||||||
|
output_path = os.path.join(output_dir, filename)
|
||||||
|
|
||||||
|
# 跳过非图像文件
|
||||||
|
if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 去背景处理
|
||||||
|
foreground = remove_background(input_path)
|
||||||
|
|
||||||
|
# 随机选择背景(每次处理都重新选择)
|
||||||
|
bg_path = random.choice(background_paths)
|
||||||
|
background = cv2.imread(bg_path)
|
||||||
|
|
||||||
|
# 调整背景尺寸
|
||||||
|
h, w = foreground.shape[:2]
|
||||||
|
resized_bg = cv2.resize(background, (w, h))
|
||||||
|
|
||||||
|
# 合成背景
|
||||||
|
result = synthesize_background(foreground, bg_path)
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
cv2.imwrite(output_path, result)
|
||||||
|
print(f"Processed: {input_path} -> {output_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {input_path}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
input_directory = 'F:/dataset/02.TA_EC/EC27/JY_A/'
|
||||||
|
background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B'
|
||||||
|
output_directory = 'F:/dataset/02.TA_EC/rundata/test'
|
||||||
|
|
||||||
|
process_images(input_directory, background_image_path, output_directory)
|
0
model/__init__.py
Normal file
0
model/__init__.py
Normal file
248
model/mobilenetv3.py
Normal file
248
model/mobilenetv3.py
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['MobileNetV3', 'mobilenetv3']
|
||||||
|
|
||||||
|
|
||||||
|
def conv_bn(inp, oup, stride, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU):
|
||||||
|
return nn.Sequential(
|
||||||
|
conv_layer(inp, oup, 3, stride, 1, bias=False),
|
||||||
|
norm_layer(oup),
|
||||||
|
nlin_layer(inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_1x1_bn(inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU):
|
||||||
|
return nn.Sequential(
|
||||||
|
conv_layer(inp, oup, 1, 1, 0, bias=False),
|
||||||
|
norm_layer(oup),
|
||||||
|
nlin_layer(inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Hswish(nn.Module):
|
||||||
|
def __init__(self, inplace=True):
|
||||||
|
super(Hswish, self).__init__()
|
||||||
|
self.inplace = inplace
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * F.relu6(x + 3., inplace=self.inplace) / 6.
|
||||||
|
|
||||||
|
|
||||||
|
class Hsigmoid(nn.Module):
|
||||||
|
def __init__(self, inplace=True):
|
||||||
|
super(Hsigmoid, self).__init__()
|
||||||
|
self.inplace = inplace
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.relu6(x + 3., inplace=self.inplace) / 6.
|
||||||
|
|
||||||
|
|
||||||
|
class SEModule(nn.Module):
|
||||||
|
def __init__(self, channel, reduction=4):
|
||||||
|
super(SEModule, self).__init__()
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(channel, channel // reduction, bias=False),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(channel // reduction, channel, bias=False),
|
||||||
|
Hsigmoid()
|
||||||
|
# nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, _, _ = x.size()
|
||||||
|
y = self.avg_pool(x).view(b, c)
|
||||||
|
y = self.fc(y).view(b, c, 1, 1)
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(nn.Module):
|
||||||
|
def __init__(self, channel):
|
||||||
|
super(Identity, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(x, divisible_by=8):
|
||||||
|
import numpy as np
|
||||||
|
return int(np.ceil(x * 1. / divisible_by) * divisible_by)
|
||||||
|
|
||||||
|
|
||||||
|
class MobileBottleneck(nn.Module):
|
||||||
|
def __init__(self, inp, oup, kernel, stride, exp, se=False, nl='RE'):
|
||||||
|
super(MobileBottleneck, self).__init__()
|
||||||
|
assert stride in [1, 2]
|
||||||
|
assert kernel in [3, 5]
|
||||||
|
padding = (kernel - 1) // 2
|
||||||
|
self.use_res_connect = stride == 1 and inp == oup
|
||||||
|
|
||||||
|
conv_layer = nn.Conv2d
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
if nl == 'RE':
|
||||||
|
nlin_layer = nn.ReLU # or ReLU6
|
||||||
|
elif nl == 'HS':
|
||||||
|
nlin_layer = Hswish
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
if se:
|
||||||
|
SELayer = SEModule
|
||||||
|
else:
|
||||||
|
SELayer = Identity
|
||||||
|
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
# pw
|
||||||
|
conv_layer(inp, exp, 1, 1, 0, bias=False),
|
||||||
|
norm_layer(exp),
|
||||||
|
nlin_layer(inplace=True),
|
||||||
|
# dw
|
||||||
|
conv_layer(exp, exp, kernel, stride, padding, groups=exp, bias=False),
|
||||||
|
norm_layer(exp),
|
||||||
|
SELayer(exp),
|
||||||
|
nlin_layer(inplace=True),
|
||||||
|
# pw-linear
|
||||||
|
conv_layer(exp, oup, 1, 1, 0, bias=False),
|
||||||
|
norm_layer(oup),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_res_connect:
|
||||||
|
return x + self.conv(x)
|
||||||
|
else:
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV3(nn.Module):
|
||||||
|
def __init__(self, n_class=1000, input_size=224, dropout=0.8, mode='small', width_mult=1.0):
|
||||||
|
super(MobileNetV3, self).__init__()
|
||||||
|
input_channel = 16
|
||||||
|
last_channel = 1280
|
||||||
|
if mode == 'large':
|
||||||
|
# refer to Table 1 in paper
|
||||||
|
mobile_setting = [
|
||||||
|
# k, exp, c, se, nl, s,
|
||||||
|
[3, 16, 16, False, 'RE', 1],
|
||||||
|
[3, 64, 24, False, 'RE', 2],
|
||||||
|
[3, 72, 24, False, 'RE', 1],
|
||||||
|
[5, 72, 40, True, 'RE', 2],
|
||||||
|
[5, 120, 40, True, 'RE', 1],
|
||||||
|
[5, 120, 40, True, 'RE', 1],
|
||||||
|
[3, 240, 80, False, 'HS', 2],
|
||||||
|
[3, 200, 80, False, 'HS', 1],
|
||||||
|
[3, 184, 80, False, 'HS', 1],
|
||||||
|
[3, 184, 80, False, 'HS', 1],
|
||||||
|
[3, 480, 112, True, 'HS', 1],
|
||||||
|
[3, 672, 112, True, 'HS', 1],
|
||||||
|
[5, 672, 160, True, 'HS', 2],
|
||||||
|
[5, 960, 160, True, 'HS', 1],
|
||||||
|
[5, 960, 160, True, 'HS', 1],
|
||||||
|
]
|
||||||
|
elif mode == 'small':
|
||||||
|
# refer to Table 2 in paper
|
||||||
|
mobile_setting = [
|
||||||
|
# k, exp, c, se, nl, s,
|
||||||
|
[3, 16, 16, True, 'RE', 2],
|
||||||
|
[3, 72, 24, False, 'RE', 2],
|
||||||
|
[3, 88, 24, False, 'RE', 1],
|
||||||
|
[5, 96, 40, True, 'HS', 2],
|
||||||
|
[5, 240, 40, True, 'HS', 1],
|
||||||
|
[5, 240, 40, True, 'HS', 1],
|
||||||
|
[5, 120, 48, True, 'HS', 1],
|
||||||
|
[5, 144, 48, True, 'HS', 1],
|
||||||
|
[5, 288, 96, True, 'HS', 2],
|
||||||
|
[5, 576, 96, True, 'HS', 1],
|
||||||
|
[5, 576, 96, True, 'HS', 1],
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# building first layer
|
||||||
|
assert input_size % 32 == 0
|
||||||
|
last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel
|
||||||
|
self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)]
|
||||||
|
self.classifier = []
|
||||||
|
|
||||||
|
# building mobile blocks
|
||||||
|
for k, exp, c, se, nl, s in mobile_setting:
|
||||||
|
output_channel = make_divisible(c * width_mult)
|
||||||
|
exp_channel = make_divisible(exp * width_mult)
|
||||||
|
self.features.append(MobileBottleneck(input_channel, output_channel, k, s, exp_channel, se, nl))
|
||||||
|
input_channel = output_channel
|
||||||
|
|
||||||
|
# building last several layers
|
||||||
|
if mode == 'large':
|
||||||
|
last_conv = make_divisible(960 * width_mult)
|
||||||
|
self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish))
|
||||||
|
self.features.append(nn.AdaptiveAvgPool2d(1))
|
||||||
|
self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0))
|
||||||
|
self.features.append(Hswish(inplace=True))
|
||||||
|
elif mode == 'small':
|
||||||
|
last_conv = make_divisible(576 * width_mult)
|
||||||
|
self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish))
|
||||||
|
# self.features.append(SEModule(last_conv)) # refer to paper Table2, but I think this is a mistake
|
||||||
|
self.features.append(nn.AdaptiveAvgPool2d(1))
|
||||||
|
self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0))
|
||||||
|
self.features.append(Hswish(inplace=True))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# make it nn.Sequential
|
||||||
|
self.features = nn.Sequential(*self.features)
|
||||||
|
|
||||||
|
# building classifier
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Dropout(p=dropout), # refer to paper section 6
|
||||||
|
nn.Linear(last_channel, n_class),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._initialize_weights()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = x.mean(3).mean(2)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _initialize_weights(self):
|
||||||
|
# weight initialization
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, 0, 0.01)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def mobilenetv3(pretrained=False, **kwargs):
|
||||||
|
model = MobileNetV3(**kwargs)
|
||||||
|
if pretrained:
|
||||||
|
state_dict = torch.load('mobilenetv3_small_67.4.pth.tar')
|
||||||
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
# raise NotImplementedError
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
net = mobilenetv3()
|
||||||
|
print('mobilenetv3:\n', net)
|
||||||
|
print('Total params: %.2fM' % (sum(p.numel() for p in net.parameters())/1000000.0))
|
||||||
|
input_size=(1, 3, 224, 224)
|
||||||
|
# pip install --upgrade git+https://github.com/kuan-wang/pytorch-OpCounter.git
|
||||||
|
from thop import profile
|
||||||
|
input_tensor = torch.randn(input_size)
|
||||||
|
flops, params = profile(net, inputs=(input_tensor,))
|
||||||
|
# print(flops)
|
||||||
|
# print(params)
|
||||||
|
print('Total params: %.2fM' % (params/1000000.0))
|
||||||
|
print('Total flops: %.2f GMACs' % ((flops/1000000000.0) / 2.0))
|
||||||
|
x = torch.randn(input_size)
|
||||||
|
out = net(x)
|
520
model/repvit.py
Normal file
520
model/repvit.py
Normal file
@ -0,0 +1,520 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
def _make_divisible(v, divisor, min_value=None):
|
||||||
|
"""
|
||||||
|
This function is taken from the original tf repo.
|
||||||
|
It ensures that all layers have a channel number that is divisible by 8
|
||||||
|
It can be seen here:
|
||||||
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||||
|
:param v:
|
||||||
|
:param divisor:
|
||||||
|
:param min_value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if min_value is None:
|
||||||
|
min_value = divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
# Make sure that round down does not go down by more than 10%.
|
||||||
|
if new_v < 0.9 * v:
|
||||||
|
new_v += divisor
|
||||||
|
return new_v
|
||||||
|
|
||||||
|
from timm.models.layers import SqueezeExcite
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class Conv2d_BN(torch.nn.Sequential):
|
||||||
|
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
|
||||||
|
groups=1, bn_weight_init=1, resolution=-10000):
|
||||||
|
super().__init__()
|
||||||
|
self.add_module('c', torch.nn.Conv2d(
|
||||||
|
a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||||
|
self.add_module('bn', torch.nn.BatchNorm2d(b))
|
||||||
|
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
|
||||||
|
torch.nn.init.constant_(self.bn.bias, 0)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def fuse(self):
|
||||||
|
c, bn = self._modules.values()
|
||||||
|
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
||||||
|
w = c.weight * w[:, None, None, None]
|
||||||
|
b = bn.bias - bn.running_mean * bn.weight / \
|
||||||
|
(bn.running_var + bn.eps)**0.5
|
||||||
|
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
|
||||||
|
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
|
||||||
|
device=c.weight.device)
|
||||||
|
m.weight.data.copy_(w)
|
||||||
|
m.bias.data.copy_(b)
|
||||||
|
return m
|
||||||
|
|
||||||
|
class Residual(torch.nn.Module):
|
||||||
|
def __init__(self, m, drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
self.m = m
|
||||||
|
self.drop = drop
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.training and self.drop > 0:
|
||||||
|
return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
|
||||||
|
device=x.device).ge_(self.drop).div(1 - self.drop).detach()
|
||||||
|
else:
|
||||||
|
return x + self.m(x)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def fuse(self):
|
||||||
|
if isinstance(self.m, Conv2d_BN):
|
||||||
|
m = self.m.fuse()
|
||||||
|
assert(m.groups == m.in_channels)
|
||||||
|
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
|
||||||
|
identity = torch.nn.functional.pad(identity, [1,1,1,1])
|
||||||
|
m.weight += identity.to(m.weight.device)
|
||||||
|
return m
|
||||||
|
elif isinstance(self.m, torch.nn.Conv2d):
|
||||||
|
m = self.m
|
||||||
|
assert(m.groups != m.in_channels)
|
||||||
|
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
|
||||||
|
identity = torch.nn.functional.pad(identity, [1,1,1,1])
|
||||||
|
m.weight += identity.to(m.weight.device)
|
||||||
|
return m
|
||||||
|
else:
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class RepVGGDW(torch.nn.Module):
|
||||||
|
def __init__(self, ed) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
|
||||||
|
self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
|
||||||
|
self.dim = ed
|
||||||
|
self.bn = torch.nn.BatchNorm2d(ed)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.bn((self.conv(x) + self.conv1(x)) + x)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def fuse(self):
|
||||||
|
conv = self.conv.fuse()
|
||||||
|
conv1 = self.conv1
|
||||||
|
|
||||||
|
conv_w = conv.weight
|
||||||
|
conv_b = conv.bias
|
||||||
|
conv1_w = conv1.weight
|
||||||
|
conv1_b = conv1.bias
|
||||||
|
|
||||||
|
conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])
|
||||||
|
|
||||||
|
identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])
|
||||||
|
|
||||||
|
final_conv_w = conv_w + conv1_w + identity
|
||||||
|
final_conv_b = conv_b + conv1_b
|
||||||
|
|
||||||
|
conv.weight.data.copy_(final_conv_w)
|
||||||
|
conv.bias.data.copy_(final_conv_b)
|
||||||
|
|
||||||
|
bn = self.bn
|
||||||
|
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
||||||
|
w = conv.weight * w[:, None, None, None]
|
||||||
|
b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
|
||||||
|
(bn.running_var + bn.eps)**0.5
|
||||||
|
conv.weight.data.copy_(w)
|
||||||
|
conv.bias.data.copy_(b)
|
||||||
|
return conv
|
||||||
|
|
||||||
|
|
||||||
|
class RepViTBlock(nn.Module):
|
||||||
|
def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
|
||||||
|
super(RepViTBlock, self).__init__()
|
||||||
|
assert stride in [1, 2]
|
||||||
|
|
||||||
|
self.identity = stride == 1 and inp == oup
|
||||||
|
assert(hidden_dim == 2 * inp)
|
||||||
|
|
||||||
|
if stride == 2:
|
||||||
|
self.token_mixer = nn.Sequential(
|
||||||
|
Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
|
||||||
|
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
|
||||||
|
Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
|
||||||
|
)
|
||||||
|
self.channel_mixer = Residual(nn.Sequential(
|
||||||
|
# pw
|
||||||
|
Conv2d_BN(oup, 2 * oup, 1, 1, 0),
|
||||||
|
nn.GELU() if use_hs else nn.GELU(),
|
||||||
|
# pw-linear
|
||||||
|
Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
assert(self.identity)
|
||||||
|
self.token_mixer = nn.Sequential(
|
||||||
|
RepVGGDW(inp),
|
||||||
|
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
|
||||||
|
)
|
||||||
|
self.channel_mixer = Residual(nn.Sequential(
|
||||||
|
# pw
|
||||||
|
Conv2d_BN(inp, hidden_dim, 1, 1, 0),
|
||||||
|
nn.GELU() if use_hs else nn.GELU(),
|
||||||
|
# pw-linear
|
||||||
|
Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
|
||||||
|
))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.channel_mixer(self.token_mixer(x))
|
||||||
|
|
||||||
|
from timm.models.vision_transformer import trunc_normal_
|
||||||
|
class BN_Linear(torch.nn.Sequential):
|
||||||
|
def __init__(self, a, b, bias=True, std=0.02):
|
||||||
|
super().__init__()
|
||||||
|
self.add_module('bn', torch.nn.BatchNorm1d(a))
|
||||||
|
self.add_module('l', torch.nn.Linear(a, b, bias=bias))
|
||||||
|
trunc_normal_(self.l.weight, std=std)
|
||||||
|
if bias:
|
||||||
|
torch.nn.init.constant_(self.l.bias, 0)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def fuse(self):
|
||||||
|
bn, l = self._modules.values()
|
||||||
|
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
||||||
|
b = bn.bias - self.bn.running_mean * \
|
||||||
|
self.bn.weight / (bn.running_var + bn.eps)**0.5
|
||||||
|
w = l.weight * w[None, :]
|
||||||
|
if l.bias is None:
|
||||||
|
b = b @ self.l.weight.T
|
||||||
|
else:
|
||||||
|
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
|
||||||
|
m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device)
|
||||||
|
m.weight.data.copy_(w)
|
||||||
|
m.bias.data.copy_(b)
|
||||||
|
return m
|
||||||
|
|
||||||
|
class Classfier(nn.Module):
|
||||||
|
def __init__(self, dim, num_classes, distillation=True):
|
||||||
|
super().__init__()
|
||||||
|
self.classifier = BN_Linear(dim, num_classes) if num_classes > 0 else torch.nn.Identity()
|
||||||
|
self.distillation = distillation
|
||||||
|
if distillation:
|
||||||
|
self.classifier_dist = BN_Linear(dim, num_classes) if num_classes > 0 else torch.nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.distillation:
|
||||||
|
x = self.classifier(x), self.classifier_dist(x)
|
||||||
|
if not self.training:
|
||||||
|
x = (x[0] + x[1]) / 2
|
||||||
|
else:
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def fuse(self):
|
||||||
|
classifier = self.classifier.fuse()
|
||||||
|
if self.distillation:
|
||||||
|
classifier_dist = self.classifier_dist.fuse()
|
||||||
|
classifier.weight += classifier_dist.weight
|
||||||
|
classifier.bias += classifier_dist.bias
|
||||||
|
classifier.weight /= 2
|
||||||
|
classifier.bias /= 2
|
||||||
|
return classifier
|
||||||
|
else:
|
||||||
|
return classifier
|
||||||
|
|
||||||
|
class RepViT(nn.Module):
|
||||||
|
def __init__(self, cfgs, num_classes=1000, distillation=False):
|
||||||
|
super(RepViT, self).__init__()
|
||||||
|
# setting of inverted residual blocks
|
||||||
|
self.cfgs = cfgs
|
||||||
|
|
||||||
|
# building first layer
|
||||||
|
input_channel = self.cfgs[0][2]
|
||||||
|
patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(),
|
||||||
|
Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1))
|
||||||
|
layers = [patch_embed]
|
||||||
|
# building inverted residual blocks
|
||||||
|
block = RepViTBlock
|
||||||
|
for k, t, c, use_se, use_hs, s in self.cfgs:
|
||||||
|
output_channel = _make_divisible(c, 8)
|
||||||
|
exp_size = _make_divisible(input_channel * t, 8)
|
||||||
|
layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
|
||||||
|
input_channel = output_channel
|
||||||
|
self.features = nn.ModuleList(layers)
|
||||||
|
self.classifier = Classfier(output_channel, num_classes, distillation)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x = self.features(x)
|
||||||
|
for f in self.features:
|
||||||
|
x = f(x)
|
||||||
|
x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
|
||||||
|
x = self.classifier(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
from timm.models import register_model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def repvit_m0_6(pretrained=False, num_classes = 1000, distillation=False):
|
||||||
|
"""
|
||||||
|
Constructs a MobileNetV3-Large model
|
||||||
|
"""
|
||||||
|
cfgs = [
|
||||||
|
[3, 2, 40, 1, 0, 1],
|
||||||
|
[3, 2, 40, 0, 0, 1],
|
||||||
|
[3, 2, 80, 0, 0, 2],
|
||||||
|
[3, 2, 80, 1, 0, 1],
|
||||||
|
[3, 2, 80, 0, 0, 1],
|
||||||
|
[3, 2, 160, 0, 1, 2],
|
||||||
|
[3, 2, 160, 1, 1, 1],
|
||||||
|
[3, 2, 160, 0, 1, 1],
|
||||||
|
[3, 2, 160, 1, 1, 1],
|
||||||
|
[3, 2, 160, 0, 1, 1],
|
||||||
|
[3, 2, 160, 1, 1, 1],
|
||||||
|
[3, 2, 160, 0, 1, 1],
|
||||||
|
[3, 2, 160, 1, 1, 1],
|
||||||
|
[3, 2, 160, 0, 1, 1],
|
||||||
|
[3, 2, 160, 0, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 2],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
]
|
||||||
|
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
|
||||||
|
|
||||||
|
|
||||||
|
def repvit_m0_9(pretrained=False, num_classes = 1000, distillation=False):
|
||||||
|
"""
|
||||||
|
Constructs a MobileNetV3-Large model
|
||||||
|
"""
|
||||||
|
cfgs = [
|
||||||
|
# k, t, c, SE, HS, s
|
||||||
|
[3, 2, 48, 1, 0, 1],
|
||||||
|
[3, 2, 48, 0, 0, 1],
|
||||||
|
[3, 2, 48, 0, 0, 1],
|
||||||
|
[3, 2, 96, 0, 0, 2],
|
||||||
|
[3, 2, 96, 1, 0, 1],
|
||||||
|
[3, 2, 96, 0, 0, 1],
|
||||||
|
[3, 2, 96, 0, 0, 1],
|
||||||
|
[3, 2, 192, 0, 1, 2],
|
||||||
|
[3, 2, 192, 1, 1, 1],
|
||||||
|
[3, 2, 192, 0, 1, 1],
|
||||||
|
[3, 2, 192, 1, 1, 1],
|
||||||
|
[3, 2, 192, 0, 1, 1],
|
||||||
|
[3, 2, 192, 1, 1, 1],
|
||||||
|
[3, 2, 192, 0, 1, 1],
|
||||||
|
[3, 2, 192, 1, 1, 1],
|
||||||
|
[3, 2, 192, 0, 1, 1],
|
||||||
|
[3, 2, 192, 1, 1, 1],
|
||||||
|
[3, 2, 192, 0, 1, 1],
|
||||||
|
[3, 2, 192, 1, 1, 1],
|
||||||
|
[3, 2, 192, 0, 1, 1],
|
||||||
|
[3, 2, 192, 1, 1, 1],
|
||||||
|
[3, 2, 192, 0, 1, 1],
|
||||||
|
[3, 2, 192, 0, 1, 1],
|
||||||
|
[3, 2, 384, 0, 1, 2],
|
||||||
|
[3, 2, 384, 1, 1, 1],
|
||||||
|
[3, 2, 384, 0, 1, 1]
|
||||||
|
]
|
||||||
|
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
|
||||||
|
|
||||||
|
|
||||||
|
def repvit_m1_0(pretrained=False, num_classes = 1000, distillation=False):
|
||||||
|
"""
|
||||||
|
Constructs a MobileNetV3-Large model
|
||||||
|
"""
|
||||||
|
cfgs = [
|
||||||
|
# k, t, c, SE, HS, s
|
||||||
|
[3, 2, 56, 1, 0, 1],
|
||||||
|
[3, 2, 56, 0, 0, 1],
|
||||||
|
[3, 2, 56, 0, 0, 1],
|
||||||
|
[3, 2, 112, 0, 0, 2],
|
||||||
|
[3, 2, 112, 1, 0, 1],
|
||||||
|
[3, 2, 112, 0, 0, 1],
|
||||||
|
[3, 2, 112, 0, 0, 1],
|
||||||
|
[3, 2, 224, 0, 1, 2],
|
||||||
|
[3, 2, 224, 1, 1, 1],
|
||||||
|
[3, 2, 224, 0, 1, 1],
|
||||||
|
[3, 2, 224, 1, 1, 1],
|
||||||
|
[3, 2, 224, 0, 1, 1],
|
||||||
|
[3, 2, 224, 1, 1, 1],
|
||||||
|
[3, 2, 224, 0, 1, 1],
|
||||||
|
[3, 2, 224, 1, 1, 1],
|
||||||
|
[3, 2, 224, 0, 1, 1],
|
||||||
|
[3, 2, 224, 1, 1, 1],
|
||||||
|
[3, 2, 224, 0, 1, 1],
|
||||||
|
[3, 2, 224, 1, 1, 1],
|
||||||
|
[3, 2, 224, 0, 1, 1],
|
||||||
|
[3, 2, 224, 1, 1, 1],
|
||||||
|
[3, 2, 224, 0, 1, 1],
|
||||||
|
[3, 2, 224, 0, 1, 1],
|
||||||
|
[3, 2, 448, 0, 1, 2],
|
||||||
|
[3, 2, 448, 1, 1, 1],
|
||||||
|
[3, 2, 448, 0, 1, 1]
|
||||||
|
]
|
||||||
|
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def repvit_m1_1(pretrained=False, num_classes = 1000, distillation=False):
|
||||||
|
"""
|
||||||
|
Constructs a MobileNetV3-Large model
|
||||||
|
"""
|
||||||
|
cfgs = [
|
||||||
|
# k, t, c, SE, HS, s
|
||||||
|
[3, 2, 64, 1, 0, 1],
|
||||||
|
[3, 2, 64, 0, 0, 1],
|
||||||
|
[3, 2, 64, 0, 0, 1],
|
||||||
|
[3, 2, 128, 0, 0, 2],
|
||||||
|
[3, 2, 128, 1, 0, 1],
|
||||||
|
[3, 2, 128, 0, 0, 1],
|
||||||
|
[3, 2, 128, 0, 0, 1],
|
||||||
|
[3, 2, 256, 0, 1, 2],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 512, 0, 1, 2],
|
||||||
|
[3, 2, 512, 1, 1, 1],
|
||||||
|
[3, 2, 512, 0, 1, 1]
|
||||||
|
]
|
||||||
|
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def repvit_m1_5(pretrained=False, num_classes = 1000, distillation=False):
|
||||||
|
"""
|
||||||
|
Constructs a MobileNetV3-Large model
|
||||||
|
"""
|
||||||
|
cfgs = [
|
||||||
|
# k, t, c, SE, HS, s
|
||||||
|
[3, 2, 64, 1, 0, 1],
|
||||||
|
[3, 2, 64, 0, 0, 1],
|
||||||
|
[3, 2, 64, 1, 0, 1],
|
||||||
|
[3, 2, 64, 0, 0, 1],
|
||||||
|
[3, 2, 64, 0, 0, 1],
|
||||||
|
[3, 2, 128, 0, 0, 2],
|
||||||
|
[3, 2, 128, 1, 0, 1],
|
||||||
|
[3, 2, 128, 0, 0, 1],
|
||||||
|
[3, 2, 128, 1, 0, 1],
|
||||||
|
[3, 2, 128, 0, 0, 1],
|
||||||
|
[3, 2, 128, 0, 0, 1],
|
||||||
|
[3, 2, 256, 0, 1, 2],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 1, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 256, 0, 1, 1],
|
||||||
|
[3, 2, 512, 0, 1, 2],
|
||||||
|
[3, 2, 512, 1, 1, 1],
|
||||||
|
[3, 2, 512, 0, 1, 1],
|
||||||
|
[3, 2, 512, 1, 1, 1],
|
||||||
|
[3, 2, 512, 0, 1, 1]
|
||||||
|
]
|
||||||
|
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def repvit_m2_3(pretrained=False, num_classes = 1000, distillation=False):
|
||||||
|
"""
|
||||||
|
Constructs a MobileNetV3-Large model
|
||||||
|
"""
|
||||||
|
cfgs = [
|
||||||
|
# k, t, c, SE, HS, s
|
||||||
|
[3, 2, 80, 1, 0, 1],
|
||||||
|
[3, 2, 80, 0, 0, 1],
|
||||||
|
[3, 2, 80, 1, 0, 1],
|
||||||
|
[3, 2, 80, 0, 0, 1],
|
||||||
|
[3, 2, 80, 1, 0, 1],
|
||||||
|
[3, 2, 80, 0, 0, 1],
|
||||||
|
[3, 2, 80, 0, 0, 1],
|
||||||
|
[3, 2, 160, 0, 0, 2],
|
||||||
|
[3, 2, 160, 1, 0, 1],
|
||||||
|
[3, 2, 160, 0, 0, 1],
|
||||||
|
[3, 2, 160, 1, 0, 1],
|
||||||
|
[3, 2, 160, 0, 0, 1],
|
||||||
|
[3, 2, 160, 1, 0, 1],
|
||||||
|
[3, 2, 160, 0, 0, 1],
|
||||||
|
[3, 2, 160, 0, 0, 1],
|
||||||
|
[3, 2, 320, 0, 1, 2],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 1, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
# [3, 2, 320, 1, 1, 1],
|
||||||
|
# [3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 320, 0, 1, 1],
|
||||||
|
[3, 2, 640, 0, 1, 2],
|
||||||
|
[3, 2, 640, 1, 1, 1],
|
||||||
|
[3, 2, 640, 0, 1, 1],
|
||||||
|
# [3, 2, 640, 1, 1, 1],
|
||||||
|
# [3, 2, 640, 0, 1, 1]
|
||||||
|
]
|
||||||
|
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
net = repvit_m1_1()
|
||||||
|
print('mobilenetv3:\n', net)
|
||||||
|
print('Total params: %.2fM' % (sum(p.numel() for p in net.parameters())/1000000.0))
|
||||||
|
input_size=(4, 3, 224, 224)
|
||||||
|
# pip install --upgrade git+https://github.com/kuan-wang/pytorch-OpCounter.git
|
||||||
|
from thop import profile
|
||||||
|
input_tensor = torch.randn(input_size)
|
||||||
|
flops, params = profile(net, inputs=(input_tensor,))
|
||||||
|
# print(flops)
|
||||||
|
# print(params)
|
||||||
|
print('Total params: %.2fM' % (params/1000000.0))
|
||||||
|
print('Total flops: %.2f GMACs' % ((flops/1000000000.0) / 2.0))
|
||||||
|
x = torch.randn(input_size)
|
||||||
|
out = net(x)
|
Loading…
Reference in New Issue
Block a user