42 lines
1.7 KiB
Python
42 lines
1.7 KiB
Python
import os
|
||
import random
|
||
import shutil
|
||
|
||
def split_dataset(input_folder, output_folder):
|
||
# 创建输出子目录
|
||
subdatasets = [os.path.join(output_folder, f'subdataset_{i+1}') for i in range(3)]
|
||
for subdataset in subdatasets:
|
||
os.makedirs(subdataset, exist_ok=True)
|
||
|
||
# 遍历每个类别文件夹
|
||
for root, dirs, _ in os.walk(input_folder):
|
||
for category in dirs:
|
||
category_folder = os.path.join(root, category)
|
||
images = [os.path.join(category_folder, f) for f in os.listdir(category_folder)
|
||
if f.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
|
||
|
||
# 如果图像数量少于 1200,进行补充
|
||
if len(images) < 1200:
|
||
images.extend(random.choices(images, k=1200 - len(images)))
|
||
|
||
# 打乱图像顺序
|
||
random.shuffle(images)
|
||
|
||
# 为每个子数据集创建类别子文件夹
|
||
for subdataset_path in subdatasets:
|
||
category_subfolder = os.path.join(subdataset_path, category)
|
||
os.makedirs(category_subfolder, exist_ok=True)
|
||
|
||
# 将图像分配到子数据集中,每个子数据集 400 张
|
||
for i, image_path in enumerate(images):
|
||
subdataset_index = i // 400
|
||
subdataset_path = subdatasets[subdataset_index]
|
||
category_subfolder = os.path.join(subdataset_path, category)
|
||
shutil.copy(image_path, category_subfolder)
|
||
|
||
print(f'Dataset split into 3 subdatasets with 400 images per category at {output_folder}')
|
||
|
||
if __name__ == "__main__":
|
||
input_folder = 'L:/Grade_datasets/MOVE_BACKGROUND'
|
||
output_folder = 'L:/Grade_datasets/SPLIT'
|
||
split_dataset(input_folder, output_folder) |