TA_EC/dataset/split.py

60 lines
2.2 KiB
Python
Raw Normal View History

2025-03-09 14:36:22 +00:00
import os
import shutil
import random
def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
# 确保比例总和为1
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1"
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
train_dir = os.path.join(output_dir, 'train')
val_dir = os.path.join(output_dir, 'val')
test_dir = os.path.join(output_dir, 'test')
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
categories = os.listdir(base_dir)
for category in categories:
category_path = os.path.join(base_dir, category)
if not os.path.isdir(category_path):
continue # 跳过非目录文件
# 获取类别文件夹中的所有图像文件
images = [f for f in os.listdir(category_path) if os.path.isfile(os.path.join(category_path, f))]
# 打乱图像文件顺序
random.shuffle(images)
# 计算分割点
total_images = len(images)
train_end = int(train_ratio * total_images)
val_end = train_end + int(val_ratio * total_images)
# 创建类别文件夹于各个输出目录中
os.makedirs(os.path.join(train_dir, category), exist_ok=True)
os.makedirs(os.path.join(val_dir, category), exist_ok=True)
os.makedirs(os.path.join(test_dir, category), exist_ok=True)
# 分配图像到训练集
for i in range(train_end):
shutil.copy(os.path.join(category_path, images[i]), os.path.join(train_dir, category))
# 分配图像到验证集
for i in range(train_end, val_end):
shutil.copy(os.path.join(category_path, images[i]), os.path.join(val_dir, category))
# 分配图像到测试集
for i in range(val_end, total_images):
shutil.copy(os.path.join(category_path, images[i]), os.path.join(test_dir, category))
print("Dataset successfully split into train, validation, and test sets.")
# 使用示例
base_directory = 'F:/dataset/02.TA_EC/EC27/JY_A'
output_directory = 'F:/dataset/02.TA_EC/datasets/EC27'
create_dataset_splits(base_directory, output_directory)