60 lines
2.2 KiB
Python
60 lines
2.2 KiB
Python
import os
|
|
import shutil
|
|
import random
|
|
|
|
def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
|
|
# 确保比例总和为1
|
|
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1"
|
|
|
|
# 创建输出目录
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
train_dir = os.path.join(output_dir, 'train')
|
|
val_dir = os.path.join(output_dir, 'val')
|
|
test_dir = os.path.join(output_dir, 'test')
|
|
|
|
os.makedirs(train_dir, exist_ok=True)
|
|
os.makedirs(val_dir, exist_ok=True)
|
|
os.makedirs(test_dir, exist_ok=True)
|
|
|
|
categories = os.listdir(base_dir)
|
|
|
|
for category in categories:
|
|
category_path = os.path.join(base_dir, category)
|
|
|
|
if not os.path.isdir(category_path):
|
|
continue # 跳过非目录文件
|
|
|
|
# 获取类别文件夹中的所有图像文件
|
|
images = [f for f in os.listdir(category_path) if os.path.isfile(os.path.join(category_path, f))]
|
|
|
|
# 打乱图像文件顺序
|
|
random.shuffle(images)
|
|
|
|
# 计算分割点
|
|
total_images = len(images)
|
|
train_end = int(train_ratio * total_images)
|
|
val_end = train_end + int(val_ratio * total_images)
|
|
|
|
# 创建类别文件夹于各个输出目录中
|
|
os.makedirs(os.path.join(train_dir, category), exist_ok=True)
|
|
os.makedirs(os.path.join(val_dir, category), exist_ok=True)
|
|
os.makedirs(os.path.join(test_dir, category), exist_ok=True)
|
|
|
|
# 分配图像到训练集
|
|
for i in range(train_end):
|
|
shutil.copy(os.path.join(category_path, images[i]), os.path.join(train_dir, category))
|
|
|
|
# 分配图像到验证集
|
|
for i in range(train_end, val_end):
|
|
shutil.copy(os.path.join(category_path, images[i]), os.path.join(val_dir, category))
|
|
|
|
# 分配图像到测试集
|
|
for i in range(val_end, total_images):
|
|
shutil.copy(os.path.join(category_path, images[i]), os.path.join(test_dir, category))
|
|
|
|
print("Dataset successfully split into train, validation, and test sets.")
|
|
|
|
# 使用示例
|
|
base_directory = 'F:/dataset/02.TA_EC/EC27/JY_A'
|
|
output_directory = 'F:/dataset/02.TA_EC/datasets/EC27'
|
|
create_dataset_splits(base_directory, output_directory) |