TA_EC/dataset/splitbc_compsition.py
2025-03-09 22:36:22 +08:00

184 lines
6.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import cv2
import numpy as np
import os
# 假设我们有一个函数 `remove_background` 使用某种方法去背景,返回前景掩码和前景图像
def remove_background(image_path):
# 加载图像
source_image = cv2.imread(image_path)
# 这里可以使用一个预训练的模型去背景,比如 U2-Net。为了简化假设我们得到一个二值掩码
# 掩码生成逻辑可以替换为实际的模型推理
# 转换为灰度图像
GRAY = cv2.cvtColor(source_image, cv2.COLOR_BGR2GRAY)
# 二值化处理
_, mask_threshold = cv2.threshold(GRAY, 0, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
# 定义结构元素
element = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
element1 = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
# 膨胀和腐蚀操作
mask_dilate = cv2.dilate(mask_threshold, element)
mask_erode = cv2.erode(mask_dilate, element1)
# 计算非零像素数量
count2 = cv2.countNonZero(mask_erode)
# 查找轮廓
contours, hierarchy = cv2.findContours(mask_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)
# 过滤轮廓
contours = [c for c in contours if cv2.contourArea(c) >= count2 * 0.3]
# 绘制轮廓
mask = np.zeros_like(mask_erode)
cv2.drawContours(mask, contours, -1, 1, -1)
# 将掩码转换为3通道
mask_cvtColor = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
# 应用掩码
source_image_multiply = cv2.multiply(source_image, mask_cvtColor)
# 转换为HSV颜色空间
imgHSV = cv2.cvtColor(source_image_multiply, cv2.COLOR_BGR2HSV)
# 定义HSV范围
scalarL = np.array([0, 46, 46])
scalarH = np.array([45, 255, 255])
# 根据HSV范围生成掩码
mask_inRange = cv2.inRange(imgHSV, scalarL, scalarH)
# 二值化处理
_, mask_tthreshold = cv2.threshold(mask_inRange, 0, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
# 中值滤波
mask_medianBlur = cv2.medianBlur(mask_tthreshold, 7)
# 将掩码转换为3通道
mask_scvtColor = cv2.cvtColor(mask_medianBlur, cv2.COLOR_GRAY2BGR)
# 应用掩码
source_image = cv2.multiply(source_image, mask_scvtColor)
return source_image
def synthesize_background(foreground, background_image_path):
# 验证输入路径有效性
if not os.path.exists(background_image_path):
raise FileNotFoundError(f"指定的背景路径不存在: {background_image_path}")
# 获取所有可用背景图像路径
background_paths = []
if os.path.isfile(background_image_path):
# 如果是单个图像文件
background_paths = [background_image_path]
elif os.path.isdir(background_image_path):
# 如果是目录,搜索常见图像格式
valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
for filename in os.listdir(background_image_path):
if os.path.splitext(filename)[1].lower() in valid_extensions:
background_paths.append(os.path.join(background_image_path, filename))
# 验证找到的图像文件数量
if not background_paths:
raise ValueError(f"目录中未找到支持的图像文件: {background_image_path}")
# 随机选择背景图像
selected_bg_path = random.choice(background_paths)
# 加载并验证背景图像
background = cv2.imread(selected_bg_path)
# 调整背景大小与前景一致
background = cv2.resize(background, (foreground.shape[1], foreground.shape[0]))
# 创建前景掩膜(非黑色区域)
gray = cv2.cvtColor(foreground, cv2.COLOR_BGR2GRAY)
_, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY) # 阈值设为1以保留所有非纯黑像素
mask = cv2.GaussianBlur(mask, (5,5), 0) # 高斯模糊柔化边缘
_, mask = cv2.threshold(mask, 200, 255, cv2.THRESH_BINARY) # 重新二值化
# 精准形态学处理
kernel = np.ones((2,2), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1) # 闭运算填充小孔
# 反转掩膜用于获取背景区域
mask_inv = cv2.bitwise_not(mask)
# 提取背景和前景的ROI区域
background_roi = cv2.bitwise_and(background, background, mask=mask_inv)
foreground_roi = cv2.bitwise_and(foreground, foreground, mask=mask)
# 合成图像
result = cv2.add(foreground_roi, background_roi)
# 保存结果
return result
def process_images(input_folder, background_image_path, output_base):
"""
递归处理所有子文件夹并保持目录结构
"""
# 预处理背景路径(只需执行一次)
if os.path.isfile(background_image_path):
background_paths = [background_image_path]
else:
valid_ext = ['.jpg', '.jpeg', '.png', '.bmp', '.webp']
background_paths = [
os.path.join(background_image_path, f)
for f in os.listdir(background_image_path)
if os.path.splitext(f)[1].lower() in valid_ext
]
# 递归遍历输入目录
for root, dirs, files in os.walk(input_folder):
# 计算相对路径
relative_path = os.path.relpath(root, input_folder)
# 创建对应的输出目录
output_dir = os.path.join(output_base, relative_path)
os.makedirs(output_dir, exist_ok=True)
# 处理当前目录的文件
for filename in files:
input_path = os.path.join(root, filename)
output_path = os.path.join(output_dir, filename)
# 跳过非图像文件
if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
continue
try:
# 去背景处理
foreground = remove_background(input_path)
# 随机选择背景(每次处理都重新选择)
bg_path = random.choice(background_paths)
background = cv2.imread(bg_path)
# 调整背景尺寸
h, w = foreground.shape[:2]
resized_bg = cv2.resize(background, (w, h))
# 合成背景
result = synthesize_background(foreground, bg_path)
# 保存结果
cv2.imwrite(output_path, result)
print(f"Processed: {input_path} -> {output_path}")
except Exception as e:
print(f"Error processing {input_path}: {str(e)}")
# 使用示例
input_directory = 'F:/dataset/02.TA_EC/EC27/JY_A/'
background_image_path = 'F:/dataset/02.TA_EC/rundata/BACKGROUND/ZY_B'
output_directory = 'F:/dataset/02.TA_EC/rundata/test'
process_images(input_directory, background_image_path, output_directory)