TA_EC/dataset/splitdataset.py
2025-03-12 09:38:47 +08:00

42 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import random
import shutil
def split_dataset(input_folder, output_folder):
# 创建输出子目录
subdatasets = [os.path.join(output_folder, f'subdataset_{i+1}') for i in range(3)]
for subdataset in subdatasets:
os.makedirs(subdataset, exist_ok=True)
# 遍历每个类别文件夹
for root, dirs, _ in os.walk(input_folder):
for category in dirs:
category_folder = os.path.join(root, category)
images = [os.path.join(category_folder, f) for f in os.listdir(category_folder)
if f.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
# 如果图像数量少于 1200进行补充
if len(images) < 1200:
images.extend(random.choices(images, k=1200 - len(images)))
# 打乱图像顺序
random.shuffle(images)
# 为每个子数据集创建类别子文件夹
for subdataset_path in subdatasets:
category_subfolder = os.path.join(subdataset_path, category)
os.makedirs(category_subfolder, exist_ok=True)
# 将图像分配到子数据集中,每个子数据集 400 张
for i, image_path in enumerate(images):
subdataset_index = i // 400
subdataset_path = subdatasets[subdataset_index]
category_subfolder = os.path.join(subdataset_path, category)
shutil.copy(image_path, category_subfolder)
print(f'Dataset split into 3 subdatasets with 400 images per category at {output_folder}')
if __name__ == "__main__":
input_folder = 'L:/Grade_datasets/MOVE_BACKGROUND'
output_folder = 'L:/Grade_datasets/SPLIT'
split_dataset(input_folder, output_folder)