From 49b2110b5f1b84ef3c68ffeadaf8b7d0b3f80122 Mon Sep 17 00:00:00 2001 From: yoiannis Date: Sun, 9 Mar 2025 22:36:22 +0800 Subject: [PATCH] update code for kd --- .gitignore | 2 + FED.py | 179 ++++++++---- dataset/recover.py | 157 ++++++++++ dataset/split.py | 60 ++++ dataset/splitbc_compsition.py | 184 ++++++++++++ model/__init__.py | 0 model/mobilenetv3.py | 248 ++++++++++++++++ model/repvit.py | 520 ++++++++++++++++++++++++++++++++++ main.py => test.py | 0 9 files changed, 1295 insertions(+), 55 deletions(-) create mode 100644 dataset/recover.py create mode 100644 dataset/split.py create mode 100644 dataset/splitbc_compsition.py create mode 100644 model/__init__.py create mode 100644 model/mobilenetv3.py create mode 100644 model/repvit.py rename main.py => test.py (100%) diff --git a/.gitignore b/.gitignore index 1269488..a06a09e 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ data +*.pth +*.pyc diff --git a/FED.py b/FED.py index a62b3e9..df6688a 100644 --- a/FED.py +++ b/FED.py @@ -6,45 +6,28 @@ from torch.utils.data import DataLoader, Subset import numpy as np import copy +from tqdm import tqdm + +from model.repvit import repvit_m1_1 +from model.mobilenetv3 import MobileNetV3 + # 配置参数 -NUM_CLIENTS = 10 +NUM_CLIENTS = 4 NUM_ROUNDS = 3 -CLIENT_EPOCHS = 2 +CLIENT_EPOCHS = 5 BATCH_SIZE = 32 TEMP = 2.0 # 蒸馏温度 # 设备配置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# 定义中心大模型 -class ServerModel(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(784, 512) - self.fc2 = nn.Linear(512, 256) - self.fc3 = nn.Linear(256, 10) - - def forward(self, x): - x = x.view(-1, 784) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - return self.fc3(x) - -# 定义端侧小模型 -class ClientModel(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(784, 64) - self.fc2 = nn.Linear(64, 10) - - def forward(self, x): - x = x.view(-1, 784) - x = F.relu(self.fc1(x)) - return self.fc2(x) - # 数据准备 def prepare_data(num_clients): - transform = transforms.Compose([transforms.ToTensor()]) + transform = transforms.Compose([ + transforms.Resize((224, 224)), # 将图像调整为 224x224 + transforms.Grayscale(num_output_channels=3), + transforms.ToTensor() + ]) train_set = datasets.MNIST("./data", train=True, download=True, transform=transform) # 非IID数据划分(每个客户端2个类别) @@ -67,30 +50,80 @@ def client_train(client_model, server_model, dataset): optimizer = torch.optim.SGD(client_model.parameters(), lr=0.1) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - for _ in range(CLIENT_EPOCHS): - for data, target in loader: + # 训练进度条 + progress_bar = tqdm(total=CLIENT_EPOCHS*len(loader), + desc="Client Training", + unit="batch") + + for epoch in range(CLIENT_EPOCHS): + epoch_loss = 0.0 + task_loss = 0.0 + distill_loss = 0.0 + correct = 0 + total = 0 + + for batch_idx, (data, target) in enumerate(loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() - # 获取小模型输出 + # 前向传播 client_output = client_model(data) - # 获取大模型输出(知识蒸馏) + # 获取教师模型输出 with torch.no_grad(): server_output = server_model(data) - # 计算联合损失 + # 计算损失 loss_task = F.cross_entropy(client_output, target) loss_distill = F.kl_div( F.log_softmax(client_output/TEMP, dim=1), F.softmax(server_output/TEMP, dim=1), reduction="batchmean" ) * (TEMP**2) - total_loss = loss_task + loss_distill + + # 反向传播 total_loss.backward() optimizer.step() + + # 统计指标 + epoch_loss += total_loss.item() + task_loss += loss_task.item() + distill_loss += loss_distill.item() + _, predicted = torch.max(client_output.data, 1) + correct += (predicted == target).sum().item() + total += target.size(0) + + # 实时更新进度条 + progress_bar.set_postfix({ + "Epoch": f"{epoch+1}/{CLIENT_EPOCHS}", + "Batch": f"{batch_idx+1}/{len(loader)}", + "Loss": f"{total_loss.item():.4f}", + "Acc": f"{100*correct/total:.2f}%\n", + }) + progress_bar.update(1) + + # 每10个batch打印详细信息 + if (batch_idx + 1) % 10 == 0: + progress_bar.write(f"\nEpoch {epoch+1} | Batch {batch_idx+1}") + progress_bar.write(f"Task Loss: {loss_task:.4f}") + progress_bar.write(f"Distill Loss: {loss_distill:.4f}") + progress_bar.write(f"Total Loss: {total_loss:.4f}") + progress_bar.write(f"Batch Accuracy: {100*correct/total:.2f}%\n") + # 每个epoch结束打印汇总信息 + avg_loss = epoch_loss / len(loader) + avg_task = task_loss / len(loader) + avg_distill = distill_loss / len(loader) + epoch_acc = 100 * correct / total + print(f"\n{'='*40}") + print(f"Epoch {epoch+1} Summary:") + print(f"Average Loss: {avg_loss:.4f}") + print(f"Task Loss: {avg_task:.4f}") + print(f"Distill Loss: {avg_distill:.4f}") + print(f"Training Accuracy: {epoch_acc:.2f}%") + print(f"{'='*40}\n") + progress_bar.close() return client_model.state_dict() # 模型参数聚合(FedAvg) @@ -105,7 +138,10 @@ def server_update(server_model, client_models, public_loader): server_model.train() optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) - for data, _ in public_loader: + total_loss = 0.0 + progress_bar = tqdm(public_loader, desc="Server Updating", unit="batch") + + for batch_idx, (data, _) in enumerate(progress_bar): data = data.to(device) optimizer.zero_grad() @@ -122,8 +158,19 @@ def server_update(server_model, client_models, public_loader): reduction="batchmean" ) + # 反向传播 loss.backward() optimizer.step() + + # 更新统计信息 + total_loss += loss.item() + progress_bar.set_postfix({ + "Avg Loss": f"{total_loss/(batch_idx+1):.4f}", + "Current Loss": f"{loss.item():.4f}" + }) + + print(f"\nServer Update Complete | Average Loss: {total_loss/len(public_loader):.4f}\n") + def test_model(model, test_loader): model.eval() @@ -139,36 +186,54 @@ def test_model(model, test_loader): accuracy = 100 * correct / total return accuracy + # 主训练流程 def main(): # 初始化模型 - global_server_model = ServerModel().to(device) - client_models = [ClientModel().to(device) for _ in range(NUM_CLIENTS)] + global_server_model = repvit_m1_1(num_classes=10).to(device) + client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)] + + round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round") # 准备数据 client_datasets = prepare_data(NUM_CLIENTS) public_loader = DataLoader( datasets.MNIST("./data", train=False, download=True, - transform=transforms.ToTensor()), + transform= transforms.Compose([ + transforms.Resize((224, 224)), # 将图像调整为 224x224 + transforms.Grayscale(num_output_channels=3), + transforms.ToTensor() # 将图像转换为张量 + ])), batch_size=100, shuffle=True) + test_dataset = datasets.MNIST( + "./data", + train=False, + transform= transforms.Compose([ + transforms.Resize((224, 224)), # 将图像调整为 224x224 + transforms.Grayscale(num_output_channels=3), + transforms.ToTensor() # 将图像转换为张量 + ]) + ) + test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) + for round in range(NUM_ROUNDS): + print(f"\n{'#'*50}") + print(f"Federated Round {round+1}/{NUM_ROUNDS}") + print(f"{'#'*50}") + # 客户端选择 - selected_clients = np.random.choice(NUM_CLIENTS, 5, replace=False) + selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False) + print(f"Selected Clients: {selected_clients}") # 客户端本地训练 client_params = [] for cid in selected_clients: - # 下载全局模型 + print(f"\nTraining Client {cid}") local_model = copy.deepcopy(client_models[cid]) local_model.load_state_dict(client_models[cid].state_dict()) - # 本地训练 - updated_params = client_train( - local_model, - global_server_model, - client_datasets[cid] - ) + updated_params = client_train(local_model, global_server_model, client_datasets[cid]) client_params.append(updated_params) # 模型聚合 @@ -177,8 +242,18 @@ def main(): model.load_state_dict(global_client_params) # 服务器知识更新 + print("\nServer Updating...") server_update(global_server_model, client_models, public_loader) + # 测试模型性能 + server_acc = test_model(global_server_model, test_loader) + client_acc = test_model(client_models[0], test_loader) + print(f"\nRound {round+1} Performance:") + print(f"Global Model Accuracy: {server_acc:.2f}%") + print(f"Client Model Accuracy: {client_acc:.2f}%") + + round_progress.update(1) + print(f"Round {round+1} completed") print("Training completed!") @@ -189,21 +264,15 @@ def main(): print("Models saved successfully.") # 创建测试数据加载器 - test_dataset = datasets.MNIST( - "./data", - train=False, - transform=transforms.ToTensor() - ) - test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) # 测试服务器模型 - server_model = ServerModel().to(device) + server_model = repvit_m1_1(num_classes=10).to(device) server_model.load_state_dict(torch.load("server_model.pth")) server_acc = test_model(server_model, test_loader) print(f"Server Model Test Accuracy: {server_acc:.2f}%") # 测试客户端模型 - client_model = ClientModel().to(device) + client_model = MobileNetV3(n_class=10).to(device) client_model.load_state_dict(torch.load("client_model.pth")) client_acc = test_model(client_model, test_loader) print(f"Client Model Test Accuracy: {client_acc:.2f}%") diff --git a/dataset/recover.py b/dataset/recover.py new file mode 100644 index 0000000..58eb430 --- /dev/null +++ b/dataset/recover.py @@ -0,0 +1,157 @@ +import random +import cv2 +import numpy as np +import os + +# 假设我们有一个函数 `remove_background` 使用某种方法去背景,返回前景掩码和前景图像 +def remove_background(image_path): + # 加载图像 + source_image = cv2.imread(image_path) + # 这里可以使用一个预训练的模型去背景,比如 U2-Net。为了简化,假设我们得到一个二值掩码 + # 掩码生成逻辑可以替换为实际的模型推理 + # 转换为灰度图像 + GRAY = cv2.cvtColor(source_image, cv2.COLOR_BGR2GRAY) + + # 二值化处理 + _, mask_threshold = cv2.threshold(GRAY, 0, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU) + + # 定义结构元素 + element = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) + element1 = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)) + + # 膨胀和腐蚀操作 + mask_dilate = cv2.dilate(mask_threshold, element) + mask_erode = cv2.erode(mask_dilate, element1) + + # 计算非零像素数量 + count2 = cv2.countNonZero(mask_erode) + + # 查找轮廓 + contours, hierarchy = cv2.findContours(mask_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS) + + # 过滤轮廓 + contours = [c for c in contours if cv2.contourArea(c) >= count2 * 0.3] + + # 绘制轮廓 + mask = np.zeros_like(mask_erode) + cv2.drawContours(mask, contours, -1, 1, -1) + + # 将掩码转换为3通道 + mask_cvtColor = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + + # 应用掩码 + source_image_multiply = cv2.multiply(source_image, mask_cvtColor) + + # 转换为HSV颜色空间 + imgHSV = cv2.cvtColor(source_image_multiply, cv2.COLOR_BGR2HSV) + + # 定义HSV范围 + scalarL = np.array([0, 46, 46]) + scalarH = np.array([45, 255, 255]) + + # 根据HSV范围生成掩码 + mask_inRange = cv2.inRange(imgHSV, scalarL, scalarH) + + # 二值化处理 + _, mask_tthreshold = cv2.threshold(mask_inRange, 0, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU) + + # 中值滤波 + mask_medianBlur = cv2.medianBlur(mask_tthreshold, 7) + + # 将掩码转换为3通道 + mask_scvtColor = cv2.cvtColor(mask_medianBlur, cv2.COLOR_GRAY2BGR) + + # 应用掩码 + source_image = cv2.multiply(source_image, mask_scvtColor) + + return source_image + +def synthesize_background(foreground, background): + + # 创建前景掩膜(非黑色区域) + gray = cv2.cvtColor(foreground, cv2.COLOR_BGR2GRAY) + _, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY) # 阈值设为1以保留所有非纯黑像素 + + mask = cv2.GaussianBlur(mask, (5,5), 0) # 高斯模糊柔化边缘 + _, mask = cv2.threshold(mask, 200, 255, cv2.THRESH_BINARY) # 重新二值化 + + # 精准形态学处理 + kernel = np.ones((2,2), np.uint8) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1) # 闭运算填充小孔 + + # 反转掩膜用于获取背景区域 + mask_inv = cv2.bitwise_not(mask) + + # 提取背景和前景的ROI区域 + background_roi = cv2.bitwise_and(background, background, mask=mask_inv) + foreground_roi = cv2.bitwise_and(foreground, foreground, mask=mask) + + # 合成图像 + result = cv2.add(foreground_roi, background_roi) + + return result + +def edge_fill2(img): + (height, width, p) = img.shape + H = 2384 + W = 1560 + top = bottom=int((W - height) / 2) + left= right= int((H - width) / 2) + img_result = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) + return img_result + + +def process_images(input_folder, background_image_path, output_base): + """ + 递归处理所有子文件夹并保持目录结构 + """ + # 预处理背景路径(只需执行一次) + if os.path.isfile(background_image_path): + background_paths = [background_image_path] + else: + valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp'] + background_paths = [ + os.path.join(background_image_path, f) + for f in os.listdir(background_image_path) + if os.path.splitext(f)[1].lower() in valid_ext + ] + + # 递归遍历输入目录 + for root, dirs, files in os.walk(input_folder): + # 计算相对路径 + relative_path = os.path.relpath(root, input_folder) + + # 创建对应的输出目录 + output_dir = os.path.join(output_base, relative_path) + os.makedirs(output_dir, exist_ok=True) + + # 处理当前目录的文件 + for filename in files: + input_path = os.path.join(root, filename) + output_path = os.path.join(output_dir, filename) + + # 跳过非图像文件 + if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')): + continue + + try: + # 去背景处理 + foreground = remove_background(input_path) + + + result = edge_fill2(foreground) + + # 保存结果 + cv2.imwrite(output_path, result) + print(f"Processed: {input_path} -> {output_path}") + + except Exception as e: + print(f"Error processing {input_path}: {str(e)}") + + +# 使用示例 +input_directory = 'L:/Tobacco/2023_JY/20230821/SOURCE' +background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B' +output_directory = 'L:/Test' + +process_images(input_directory, background_image_path, output_directory) \ No newline at end of file diff --git a/dataset/split.py b/dataset/split.py new file mode 100644 index 0000000..f58a46f --- /dev/null +++ b/dataset/split.py @@ -0,0 +1,60 @@ +import os +import shutil +import random + +def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15): + # 确保比例总和为1 + assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1" + + # 创建输出目录 + os.makedirs(output_dir, exist_ok=True) + train_dir = os.path.join(output_dir, 'train') + val_dir = os.path.join(output_dir, 'val') + test_dir = os.path.join(output_dir, 'test') + + os.makedirs(train_dir, exist_ok=True) + os.makedirs(val_dir, exist_ok=True) + os.makedirs(test_dir, exist_ok=True) + + categories = os.listdir(base_dir) + + for category in categories: + category_path = os.path.join(base_dir, category) + + if not os.path.isdir(category_path): + continue # 跳过非目录文件 + + # 获取类别文件夹中的所有图像文件 + images = [f for f in os.listdir(category_path) if os.path.isfile(os.path.join(category_path, f))] + + # 打乱图像文件顺序 + random.shuffle(images) + + # 计算分割点 + total_images = len(images) + train_end = int(train_ratio * total_images) + val_end = train_end + int(val_ratio * total_images) + + # 创建类别文件夹于各个输出目录中 + os.makedirs(os.path.join(train_dir, category), exist_ok=True) + os.makedirs(os.path.join(val_dir, category), exist_ok=True) + os.makedirs(os.path.join(test_dir, category), exist_ok=True) + + # 分配图像到训练集 + for i in range(train_end): + shutil.copy(os.path.join(category_path, images[i]), os.path.join(train_dir, category)) + + # 分配图像到验证集 + for i in range(train_end, val_end): + shutil.copy(os.path.join(category_path, images[i]), os.path.join(val_dir, category)) + + # 分配图像到测试集 + for i in range(val_end, total_images): + shutil.copy(os.path.join(category_path, images[i]), os.path.join(test_dir, category)) + + print("Dataset successfully split into train, validation, and test sets.") + +# 使用示例 +base_directory = 'F:/dataset/02.TA_EC/EC27/JY_A' +output_directory = 'F:/dataset/02.TA_EC/datasets/EC27' +create_dataset_splits(base_directory, output_directory) \ No newline at end of file diff --git a/dataset/splitbc_compsition.py b/dataset/splitbc_compsition.py new file mode 100644 index 0000000..4a7f1e7 --- /dev/null +++ b/dataset/splitbc_compsition.py @@ -0,0 +1,184 @@ +import random +import cv2 +import numpy as np +import os + +# 假设我们有一个函数 `remove_background` 使用某种方法去背景,返回前景掩码和前景图像 +def remove_background(image_path): + # 加载图像 + source_image = cv2.imread(image_path) + # 这里可以使用一个预训练的模型去背景,比如 U2-Net。为了简化,假设我们得到一个二值掩码 + # 掩码生成逻辑可以替换为实际的模型推理 + # 转换为灰度图像 + GRAY = cv2.cvtColor(source_image, cv2.COLOR_BGR2GRAY) + + # 二值化处理 + _, mask_threshold = cv2.threshold(GRAY, 0, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU) + + # 定义结构元素 + element = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) + element1 = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)) + + # 膨胀和腐蚀操作 + mask_dilate = cv2.dilate(mask_threshold, element) + mask_erode = cv2.erode(mask_dilate, element1) + + # 计算非零像素数量 + count2 = cv2.countNonZero(mask_erode) + + # 查找轮廓 + contours, hierarchy = cv2.findContours(mask_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS) + + # 过滤轮廓 + contours = [c for c in contours if cv2.contourArea(c) >= count2 * 0.3] + + # 绘制轮廓 + mask = np.zeros_like(mask_erode) + cv2.drawContours(mask, contours, -1, 1, -1) + + # 将掩码转换为3通道 + mask_cvtColor = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) + + # 应用掩码 + source_image_multiply = cv2.multiply(source_image, mask_cvtColor) + + # 转换为HSV颜色空间 + imgHSV = cv2.cvtColor(source_image_multiply, cv2.COLOR_BGR2HSV) + + # 定义HSV范围 + scalarL = np.array([0, 46, 46]) + scalarH = np.array([45, 255, 255]) + + # 根据HSV范围生成掩码 + mask_inRange = cv2.inRange(imgHSV, scalarL, scalarH) + + # 二值化处理 + _, mask_tthreshold = cv2.threshold(mask_inRange, 0, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU) + + # 中值滤波 + mask_medianBlur = cv2.medianBlur(mask_tthreshold, 7) + + # 将掩码转换为3通道 + mask_scvtColor = cv2.cvtColor(mask_medianBlur, cv2.COLOR_GRAY2BGR) + + # 应用掩码 + source_image = cv2.multiply(source_image, mask_scvtColor) + + return source_image + +def synthesize_background(foreground, background_image_path): + # 验证输入路径有效性 + if not os.path.exists(background_image_path): + raise FileNotFoundError(f"指定的背景路径不存在: {background_image_path}") + + # 获取所有可用背景图像路径 + background_paths = [] + if os.path.isfile(background_image_path): + # 如果是单个图像文件 + background_paths = [background_image_path] + elif os.path.isdir(background_image_path): + # 如果是目录,搜索常见图像格式 + valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.webp'] + for filename in os.listdir(background_image_path): + if os.path.splitext(filename)[1].lower() in valid_extensions: + background_paths.append(os.path.join(background_image_path, filename)) + + # 验证找到的图像文件数量 + if not background_paths: + raise ValueError(f"目录中未找到支持的图像文件: {background_image_path}") + + # 随机选择背景图像 + selected_bg_path = random.choice(background_paths) + + # 加载并验证背景图像 + background = cv2.imread(selected_bg_path) + + # 调整背景大小与前景一致 + background = cv2.resize(background, (foreground.shape[1], foreground.shape[0])) + + # 创建前景掩膜(非黑色区域) + gray = cv2.cvtColor(foreground, cv2.COLOR_BGR2GRAY) + _, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY) # 阈值设为1以保留所有非纯黑像素 + + mask = cv2.GaussianBlur(mask, (5,5), 0) # 高斯模糊柔化边缘 + _, mask = cv2.threshold(mask, 200, 255, cv2.THRESH_BINARY) # 重新二值化 + + # 精准形态学处理 + kernel = np.ones((2,2), np.uint8) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1) # 闭运算填充小孔 + + # 反转掩膜用于获取背景区域 + mask_inv = cv2.bitwise_not(mask) + + # 提取背景和前景的ROI区域 + background_roi = cv2.bitwise_and(background, background, mask=mask_inv) + foreground_roi = cv2.bitwise_and(foreground, foreground, mask=mask) + + # 合成图像 + result = cv2.add(foreground_roi, background_roi) + + # 保存结果 + return result + +def process_images(input_folder, background_image_path, output_base): + """ + 递归处理所有子文件夹并保持目录结构 + """ + # 预处理背景路径(只需执行一次) + if os.path.isfile(background_image_path): + background_paths = [background_image_path] + else: + valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp'] + background_paths = [ + os.path.join(background_image_path, f) + for f in os.listdir(background_image_path) + if os.path.splitext(f)[1].lower() in valid_ext + ] + + # 递归遍历输入目录 + for root, dirs, files in os.walk(input_folder): + # 计算相对路径 + relative_path = os.path.relpath(root, input_folder) + + # 创建对应的输出目录 + output_dir = os.path.join(output_base, relative_path) + os.makedirs(output_dir, exist_ok=True) + + # 处理当前目录的文件 + for filename in files: + input_path = os.path.join(root, filename) + output_path = os.path.join(output_dir, filename) + + # 跳过非图像文件 + if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')): + continue + + try: + # 去背景处理 + foreground = remove_background(input_path) + + # 随机选择背景(每次处理都重新选择) + bg_path = random.choice(background_paths) + background = cv2.imread(bg_path) + + # 调整背景尺寸 + h, w = foreground.shape[:2] + resized_bg = cv2.resize(background, (w, h)) + + # 合成背景 + result = synthesize_background(foreground, bg_path) + + # 保存结果 + cv2.imwrite(output_path, result) + print(f"Processed: {input_path} -> {output_path}") + + except Exception as e: + print(f"Error processing {input_path}: {str(e)}") + + +# 使用示例 +input_directory = 'F:/dataset/02.TA_EC/EC27/JY_A/' +background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B' +output_directory = 'F:/dataset/02.TA_EC/rundata/test' + +process_images(input_directory, background_image_path, output_directory) \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/mobilenetv3.py b/model/mobilenetv3.py new file mode 100644 index 0000000..4692cf9 --- /dev/null +++ b/model/mobilenetv3.py @@ -0,0 +1,248 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +__all__ = ['MobileNetV3', 'mobilenetv3'] + + +def conv_bn(inp, oup, stride, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU): + return nn.Sequential( + conv_layer(inp, oup, 3, stride, 1, bias=False), + norm_layer(oup), + nlin_layer(inplace=True) + ) + + +def conv_1x1_bn(inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU): + return nn.Sequential( + conv_layer(inp, oup, 1, 1, 0, bias=False), + norm_layer(oup), + nlin_layer(inplace=True) + ) + + +class Hswish(nn.Module): + def __init__(self, inplace=True): + super(Hswish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x * F.relu6(x + 3., inplace=self.inplace) / 6. + + +class Hsigmoid(nn.Module): + def __init__(self, inplace=True): + super(Hsigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return F.relu6(x + 3., inplace=self.inplace) / 6. + + +class SEModule(nn.Module): + def __init__(self, channel, reduction=4): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + Hsigmoid() + # nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +class Identity(nn.Module): + def __init__(self, channel): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +def make_divisible(x, divisible_by=8): + import numpy as np + return int(np.ceil(x * 1. / divisible_by) * divisible_by) + + +class MobileBottleneck(nn.Module): + def __init__(self, inp, oup, kernel, stride, exp, se=False, nl='RE'): + super(MobileBottleneck, self).__init__() + assert stride in [1, 2] + assert kernel in [3, 5] + padding = (kernel - 1) // 2 + self.use_res_connect = stride == 1 and inp == oup + + conv_layer = nn.Conv2d + norm_layer = nn.BatchNorm2d + if nl == 'RE': + nlin_layer = nn.ReLU # or ReLU6 + elif nl == 'HS': + nlin_layer = Hswish + else: + raise NotImplementedError + if se: + SELayer = SEModule + else: + SELayer = Identity + + self.conv = nn.Sequential( + # pw + conv_layer(inp, exp, 1, 1, 0, bias=False), + norm_layer(exp), + nlin_layer(inplace=True), + # dw + conv_layer(exp, exp, kernel, stride, padding, groups=exp, bias=False), + norm_layer(exp), + SELayer(exp), + nlin_layer(inplace=True), + # pw-linear + conv_layer(exp, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV3(nn.Module): + def __init__(self, n_class=1000, input_size=224, dropout=0.8, mode='small', width_mult=1.0): + super(MobileNetV3, self).__init__() + input_channel = 16 + last_channel = 1280 + if mode == 'large': + # refer to Table 1 in paper + mobile_setting = [ + # k, exp, c, se, nl, s, + [3, 16, 16, False, 'RE', 1], + [3, 64, 24, False, 'RE', 2], + [3, 72, 24, False, 'RE', 1], + [5, 72, 40, True, 'RE', 2], + [5, 120, 40, True, 'RE', 1], + [5, 120, 40, True, 'RE', 1], + [3, 240, 80, False, 'HS', 2], + [3, 200, 80, False, 'HS', 1], + [3, 184, 80, False, 'HS', 1], + [3, 184, 80, False, 'HS', 1], + [3, 480, 112, True, 'HS', 1], + [3, 672, 112, True, 'HS', 1], + [5, 672, 160, True, 'HS', 2], + [5, 960, 160, True, 'HS', 1], + [5, 960, 160, True, 'HS', 1], + ] + elif mode == 'small': + # refer to Table 2 in paper + mobile_setting = [ + # k, exp, c, se, nl, s, + [3, 16, 16, True, 'RE', 2], + [3, 72, 24, False, 'RE', 2], + [3, 88, 24, False, 'RE', 1], + [5, 96, 40, True, 'HS', 2], + [5, 240, 40, True, 'HS', 1], + [5, 240, 40, True, 'HS', 1], + [5, 120, 48, True, 'HS', 1], + [5, 144, 48, True, 'HS', 1], + [5, 288, 96, True, 'HS', 2], + [5, 576, 96, True, 'HS', 1], + [5, 576, 96, True, 'HS', 1], + ] + else: + raise NotImplementedError + + # building first layer + assert input_size % 32 == 0 + last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel + self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)] + self.classifier = [] + + # building mobile blocks + for k, exp, c, se, nl, s in mobile_setting: + output_channel = make_divisible(c * width_mult) + exp_channel = make_divisible(exp * width_mult) + self.features.append(MobileBottleneck(input_channel, output_channel, k, s, exp_channel, se, nl)) + input_channel = output_channel + + # building last several layers + if mode == 'large': + last_conv = make_divisible(960 * width_mult) + self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish)) + self.features.append(nn.AdaptiveAvgPool2d(1)) + self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) + self.features.append(Hswish(inplace=True)) + elif mode == 'small': + last_conv = make_divisible(576 * width_mult) + self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish)) + # self.features.append(SEModule(last_conv)) # refer to paper Table2, but I think this is a mistake + self.features.append(nn.AdaptiveAvgPool2d(1)) + self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) + self.features.append(Hswish(inplace=True)) + else: + raise NotImplementedError + + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(p=dropout), # refer to paper section 6 + nn.Linear(last_channel, n_class), + ) + + self._initialize_weights() + + def forward(self, x): + x = self.features(x) + x = x.mean(3).mean(2) + x = self.classifier(x) + return x + + def _initialize_weights(self): + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.zeros_(m.bias) + + +def mobilenetv3(pretrained=False, **kwargs): + model = MobileNetV3(**kwargs) + if pretrained: + state_dict = torch.load('mobilenetv3_small_67.4.pth.tar') + model.load_state_dict(state_dict, strict=True) + # raise NotImplementedError + return model + + +if __name__ == '__main__': + net = mobilenetv3() + print('mobilenetv3:\n', net) + print('Total params: %.2fM' % (sum(p.numel() for p in net.parameters())/1000000.0)) + input_size=(1, 3, 224, 224) + # pip install --upgrade git+https://github.com/kuan-wang/pytorch-OpCounter.git + from thop import profile + input_tensor = torch.randn(input_size) + flops, params = profile(net, inputs=(input_tensor,)) + # print(flops) + # print(params) + print('Total params: %.2fM' % (params/1000000.0)) + print('Total flops: %.2f GMACs' % ((flops/1000000000.0) / 2.0)) + x = torch.randn(input_size) + out = net(x) \ No newline at end of file diff --git a/model/repvit.py b/model/repvit.py new file mode 100644 index 0000000..99f4bd5 --- /dev/null +++ b/model/repvit.py @@ -0,0 +1,520 @@ +import torch.nn as nn + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + +from timm.models.layers import SqueezeExcite + +import torch + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', torch.nn.BatchNorm2d(b)) + torch.nn.init.constant_(self.bn.weight, bn_weight_init) + torch.nn.init.constant_(self.bn.bias, 0) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups, + device=c.weight.device) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + +class Residual(torch.nn.Module): + def __init__(self, m, drop=0.): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1, + device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + @torch.no_grad() + def fuse(self): + if isinstance(self.m, Conv2d_BN): + m = self.m.fuse() + assert(m.groups == m.in_channels) + identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1) + identity = torch.nn.functional.pad(identity, [1,1,1,1]) + m.weight += identity.to(m.weight.device) + return m + elif isinstance(self.m, torch.nn.Conv2d): + m = self.m + assert(m.groups != m.in_channels) + identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1) + identity = torch.nn.functional.pad(identity, [1,1,1,1]) + m.weight += identity.to(m.weight.device) + return m + else: + return self + + +class RepVGGDW(torch.nn.Module): + def __init__(self, ed) -> None: + super().__init__() + self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed) + self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed) + self.dim = ed + self.bn = torch.nn.BatchNorm2d(ed) + + def forward(self, x): + return self.bn((self.conv(x) + self.conv1(x)) + x) + + @torch.no_grad() + def fuse(self): + conv = self.conv.fuse() + conv1 = self.conv1 + + conv_w = conv.weight + conv_b = conv.bias + conv1_w = conv1.weight + conv1_b = conv1.bias + + conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1]) + + identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1]) + + final_conv_w = conv_w + conv1_w + identity + final_conv_b = conv_b + conv1_b + + conv.weight.data.copy_(final_conv_w) + conv.bias.data.copy_(final_conv_b) + + bn = self.bn + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = conv.weight * w[:, None, None, None] + b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + conv.weight.data.copy_(w) + conv.bias.data.copy_(b) + return conv + + +class RepViTBlock(nn.Module): + def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs): + super(RepViTBlock, self).__init__() + assert stride in [1, 2] + + self.identity = stride == 1 and inp == oup + assert(hidden_dim == 2 * inp) + + if stride == 2: + self.token_mixer = nn.Sequential( + Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp), + SqueezeExcite(inp, 0.25) if use_se else nn.Identity(), + Conv2d_BN(inp, oup, ks=1, stride=1, pad=0) + ) + self.channel_mixer = Residual(nn.Sequential( + # pw + Conv2d_BN(oup, 2 * oup, 1, 1, 0), + nn.GELU() if use_hs else nn.GELU(), + # pw-linear + Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0), + )) + else: + assert(self.identity) + self.token_mixer = nn.Sequential( + RepVGGDW(inp), + SqueezeExcite(inp, 0.25) if use_se else nn.Identity(), + ) + self.channel_mixer = Residual(nn.Sequential( + # pw + Conv2d_BN(inp, hidden_dim, 1, 1, 0), + nn.GELU() if use_hs else nn.GELU(), + # pw-linear + Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0), + )) + + def forward(self, x): + return self.channel_mixer(self.token_mixer(x)) + +from timm.models.vision_transformer import trunc_normal_ +class BN_Linear(torch.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', torch.nn.BatchNorm1d(a)) + self.add_module('l', torch.nn.Linear(a, b, bias=bias)) + trunc_normal_(self.l.weight, std=std) + if bias: + torch.nn.init.constant_(self.l.bias, 0) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps)**0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + +class Classfier(nn.Module): + def __init__(self, dim, num_classes, distillation=True): + super().__init__() + self.classifier = BN_Linear(dim, num_classes) if num_classes > 0 else torch.nn.Identity() + self.distillation = distillation + if distillation: + self.classifier_dist = BN_Linear(dim, num_classes) if num_classes > 0 else torch.nn.Identity() + + def forward(self, x): + if self.distillation: + x = self.classifier(x), self.classifier_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.classifier(x) + return x + + @torch.no_grad() + def fuse(self): + classifier = self.classifier.fuse() + if self.distillation: + classifier_dist = self.classifier_dist.fuse() + classifier.weight += classifier_dist.weight + classifier.bias += classifier_dist.bias + classifier.weight /= 2 + classifier.bias /= 2 + return classifier + else: + return classifier + +class RepViT(nn.Module): + def __init__(self, cfgs, num_classes=1000, distillation=False): + super(RepViT, self).__init__() + # setting of inverted residual blocks + self.cfgs = cfgs + + # building first layer + input_channel = self.cfgs[0][2] + patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(), + Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1)) + layers = [patch_embed] + # building inverted residual blocks + block = RepViTBlock + for k, t, c, use_se, use_hs, s in self.cfgs: + output_channel = _make_divisible(c, 8) + exp_size = _make_divisible(input_channel * t, 8) + layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs)) + input_channel = output_channel + self.features = nn.ModuleList(layers) + self.classifier = Classfier(output_channel, num_classes, distillation) + + def forward(self, x): + # x = self.features(x) + for f in self.features: + x = f(x) + x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) + x = self.classifier(x) + return x + +from timm.models import register_model + + + +def repvit_m0_6(pretrained=False, num_classes = 1000, distillation=False): + """ + Constructs a MobileNetV3-Large model + """ + cfgs = [ + [3, 2, 40, 1, 0, 1], + [3, 2, 40, 0, 0, 1], + [3, 2, 80, 0, 0, 2], + [3, 2, 80, 1, 0, 1], + [3, 2, 80, 0, 0, 1], + [3, 2, 160, 0, 1, 2], + [3, 2, 160, 1, 1, 1], + [3, 2, 160, 0, 1, 1], + [3, 2, 160, 1, 1, 1], + [3, 2, 160, 0, 1, 1], + [3, 2, 160, 1, 1, 1], + [3, 2, 160, 0, 1, 1], + [3, 2, 160, 1, 1, 1], + [3, 2, 160, 0, 1, 1], + [3, 2, 160, 0, 1, 1], + [3, 2, 320, 0, 1, 2], + [3, 2, 320, 1, 1, 1], + ] + return RepViT(cfgs, num_classes=num_classes, distillation=distillation) + + +def repvit_m0_9(pretrained=False, num_classes = 1000, distillation=False): + """ + Constructs a MobileNetV3-Large model + """ + cfgs = [ + # k, t, c, SE, HS, s + [3, 2, 48, 1, 0, 1], + [3, 2, 48, 0, 0, 1], + [3, 2, 48, 0, 0, 1], + [3, 2, 96, 0, 0, 2], + [3, 2, 96, 1, 0, 1], + [3, 2, 96, 0, 0, 1], + [3, 2, 96, 0, 0, 1], + [3, 2, 192, 0, 1, 2], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 384, 0, 1, 2], + [3, 2, 384, 1, 1, 1], + [3, 2, 384, 0, 1, 1] + ] + return RepViT(cfgs, num_classes=num_classes, distillation=distillation) + + +def repvit_m1_0(pretrained=False, num_classes = 1000, distillation=False): + """ + Constructs a MobileNetV3-Large model + """ + cfgs = [ + # k, t, c, SE, HS, s + [3, 2, 56, 1, 0, 1], + [3, 2, 56, 0, 0, 1], + [3, 2, 56, 0, 0, 1], + [3, 2, 112, 0, 0, 2], + [3, 2, 112, 1, 0, 1], + [3, 2, 112, 0, 0, 1], + [3, 2, 112, 0, 0, 1], + [3, 2, 224, 0, 1, 2], + [3, 2, 224, 1, 1, 1], + [3, 2, 224, 0, 1, 1], + [3, 2, 224, 1, 1, 1], + [3, 2, 224, 0, 1, 1], + [3, 2, 224, 1, 1, 1], + [3, 2, 224, 0, 1, 1], + [3, 2, 224, 1, 1, 1], + [3, 2, 224, 0, 1, 1], + [3, 2, 224, 1, 1, 1], + [3, 2, 224, 0, 1, 1], + [3, 2, 224, 1, 1, 1], + [3, 2, 224, 0, 1, 1], + [3, 2, 224, 1, 1, 1], + [3, 2, 224, 0, 1, 1], + [3, 2, 224, 0, 1, 1], + [3, 2, 448, 0, 1, 2], + [3, 2, 448, 1, 1, 1], + [3, 2, 448, 0, 1, 1] + ] + return RepViT(cfgs, num_classes=num_classes, distillation=distillation) + + + +def repvit_m1_1(pretrained=False, num_classes = 1000, distillation=False): + """ + Constructs a MobileNetV3-Large model + """ + cfgs = [ + # k, t, c, SE, HS, s + [3, 2, 64, 1, 0, 1], + [3, 2, 64, 0, 0, 1], + [3, 2, 64, 0, 0, 1], + [3, 2, 128, 0, 0, 2], + [3, 2, 128, 1, 0, 1], + [3, 2, 128, 0, 0, 1], + [3, 2, 128, 0, 0, 1], + [3, 2, 256, 0, 1, 2], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 512, 0, 1, 2], + [3, 2, 512, 1, 1, 1], + [3, 2, 512, 0, 1, 1] + ] + return RepViT(cfgs, num_classes=num_classes, distillation=distillation) + + + +def repvit_m1_5(pretrained=False, num_classes = 1000, distillation=False): + """ + Constructs a MobileNetV3-Large model + """ + cfgs = [ + # k, t, c, SE, HS, s + [3, 2, 64, 1, 0, 1], + [3, 2, 64, 0, 0, 1], + [3, 2, 64, 1, 0, 1], + [3, 2, 64, 0, 0, 1], + [3, 2, 64, 0, 0, 1], + [3, 2, 128, 0, 0, 2], + [3, 2, 128, 1, 0, 1], + [3, 2, 128, 0, 0, 1], + [3, 2, 128, 1, 0, 1], + [3, 2, 128, 0, 0, 1], + [3, 2, 128, 0, 0, 1], + [3, 2, 256, 0, 1, 2], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 512, 0, 1, 2], + [3, 2, 512, 1, 1, 1], + [3, 2, 512, 0, 1, 1], + [3, 2, 512, 1, 1, 1], + [3, 2, 512, 0, 1, 1] + ] + return RepViT(cfgs, num_classes=num_classes, distillation=distillation) + + + + +def repvit_m2_3(pretrained=False, num_classes = 1000, distillation=False): + """ + Constructs a MobileNetV3-Large model + """ + cfgs = [ + # k, t, c, SE, HS, s + [3, 2, 80, 1, 0, 1], + [3, 2, 80, 0, 0, 1], + [3, 2, 80, 1, 0, 1], + [3, 2, 80, 0, 0, 1], + [3, 2, 80, 1, 0, 1], + [3, 2, 80, 0, 0, 1], + [3, 2, 80, 0, 0, 1], + [3, 2, 160, 0, 0, 2], + [3, 2, 160, 1, 0, 1], + [3, 2, 160, 0, 0, 1], + [3, 2, 160, 1, 0, 1], + [3, 2, 160, 0, 0, 1], + [3, 2, 160, 1, 0, 1], + [3, 2, 160, 0, 0, 1], + [3, 2, 160, 0, 0, 1], + [3, 2, 320, 0, 1, 2], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 320, 1, 1, 1], + [3, 2, 320, 0, 1, 1], + # [3, 2, 320, 1, 1, 1], + # [3, 2, 320, 0, 1, 1], + [3, 2, 320, 0, 1, 1], + [3, 2, 640, 0, 1, 2], + [3, 2, 640, 1, 1, 1], + [3, 2, 640, 0, 1, 1], + # [3, 2, 640, 1, 1, 1], + # [3, 2, 640, 0, 1, 1] + ] + return RepViT(cfgs, num_classes=num_classes, distillation=distillation) + +if __name__ == '__main__': + net = repvit_m1_1() + print('mobilenetv3:\n', net) + print('Total params: %.2fM' % (sum(p.numel() for p in net.parameters())/1000000.0)) + input_size=(4, 3, 224, 224) + # pip install --upgrade git+https://github.com/kuan-wang/pytorch-OpCounter.git + from thop import profile + input_tensor = torch.randn(input_size) + flops, params = profile(net, inputs=(input_tensor,)) + # print(flops) + # print(params) + print('Total params: %.2fM' % (params/1000000.0)) + print('Total flops: %.2f GMACs' % ((flops/1000000000.0) / 2.0)) + x = torch.randn(input_size) + out = net(x) \ No newline at end of file diff --git a/main.py b/test.py similarity index 100% rename from main.py rename to test.py