TA_EC/dataset/split.py
2025-03-12 14:00:50 +08:00

59 lines
2.1 KiB
Python

import os
import shutil
import random
def create_dataset_splits(base_dir, output_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
# 确保比例总和为1
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
train_dir = os.path.join(output_dir, 'train')
val_dir = os.path.join(output_dir, 'val')
test_dir = os.path.join(output_dir, 'test')
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
categories = os.listdir(base_dir)
for category in categories:
category_path = os.path.join(base_dir, category)
if not os.path.isdir(category_path):
continue # 跳过非目录文件
# 获取类别文件夹中的所有图像文件
images = [f for f in os.listdir(category_path) if os.path.isfile(os.path.join(category_path, f))]
# 打乱图像文件顺序
random.shuffle(images)
# 计算分割点
total_images = len(images)
train_end = int(train_ratio * total_images)
val_end = train_end + int(val_ratio * total_images)
# 创建类别文件夹于各个输出目录中
os.makedirs(os.path.join(train_dir, category), exist_ok=True)
os.makedirs(os.path.join(val_dir, category), exist_ok=True)
os.makedirs(os.path.join(test_dir, category), exist_ok=True)
# 分配图像到训练集
for i in range(train_end):
shutil.copy(os.path.join(category_path, images[i]), os.path.join(train_dir, category))
# 分配图像到验证集
for i in range(train_end, val_end):
shutil.copy(os.path.join(category_path, images[i]), os.path.join(val_dir, category))
# 分配图像到测试集
for i in range(val_end, total_images):
shutil.copy(os.path.join(category_path, images[i]), os.path.join(test_dir, category))
print("Dataset successfully split into train, validation, and test sets.")
# 使用示例
base_directory = 'L:/Grade_datasets/SPLIT/JY_A'
output_directory = 'L:/Grade_datasets/train/JY_A'
create_dataset_splits(base_directory, output_directory)