74 lines
2.9 KiB
Python
74 lines
2.9 KiB
Python
![]() |
# 输入数据集路径,划分完成训练集与测试集
|
|||
|
# 配置文件里面进行路径拼接
|
|||
|
|
|||
|
from pathlib import Path
|
|||
|
# import splitfolders
|
|||
|
import os
|
|||
|
|
|||
|
# 读取,设置标签
|
|||
|
def find_images_in_dir(directory, extensions=('.jpg', '.png', '.jpeg')):
|
|||
|
for root, dirs, files in os.walk(directory):
|
|||
|
for file in files:
|
|||
|
if any(file.lower().endswith(ext) for ext in extensions):
|
|||
|
yield os.path.join(root, file)
|
|||
|
|
|||
|
def path_subtract(full_path, base_path):
|
|||
|
# 将两个路径都转换为绝对路径,以确保它们是可比较的
|
|||
|
full_path = os.path.abspath(full_path)
|
|||
|
base_path = os.path.abspath(base_path)
|
|||
|
|
|||
|
# 检查base_path是否是full_path的前缀
|
|||
|
if full_path.startswith(base_path):
|
|||
|
# 使用relpath函数获取相对于base_path的子路径
|
|||
|
relative_path = os.path.relpath(full_path, base_path)
|
|||
|
return relative_path
|
|||
|
else:
|
|||
|
# 如果base_path不是full_path的前缀,则不能相减(或者说结果无意义)
|
|||
|
return None
|
|||
|
|
|||
|
def getlabel(root):
|
|||
|
train_set = []
|
|||
|
val_set = []
|
|||
|
test_set = []
|
|||
|
for image_path in find_images_in_dir(root):
|
|||
|
# 处理每一个文件,训练集
|
|||
|
subpath = path_subtract(image_path,root)
|
|||
|
Parts = (Path(subpath)).parts[1]
|
|||
|
|
|||
|
if 'train' in subpath:
|
|||
|
train_set.append([subpath,',',Parts.replace('_', ',') ])
|
|||
|
if 'val' in subpath:
|
|||
|
val_set.append([subpath,',',Parts.replace('_', ',') ])
|
|||
|
if 'test' in subpath:
|
|||
|
test_set.append([subpath,',',Parts.replace('_', ',') ])
|
|||
|
|
|||
|
if not os.path.exists(os.path.join(root,'label')):
|
|||
|
# 如果文件夹不存在,则创建它
|
|||
|
os.makedirs(os.path.join(root,'label'))
|
|||
|
|
|||
|
# 输出到txt
|
|||
|
if len(train_set) > 0:
|
|||
|
with open(os.path.join(root,'label','train.txt'), 'w', encoding='utf-8') as file:
|
|||
|
# 使用循环将列表中的每一项写入文件,并在每项后面加上换行符
|
|||
|
for item in train_set:
|
|||
|
for sub in item:
|
|||
|
file.write(sub)
|
|||
|
file.write('\n')
|
|||
|
if len(val_set) > 0:
|
|||
|
with open(os.path.join(root,'label','val.txt'), 'w', encoding='utf-8') as file:
|
|||
|
# 使用循环将列表中的每一项写入文件,并在每项后面加上换行符
|
|||
|
for item in val_set:
|
|||
|
for sub in item:
|
|||
|
file.write(sub)
|
|||
|
file.write('\n')
|
|||
|
if len(test_set) > 0:
|
|||
|
with open(os.path.join(root,'label','test.txt'), 'w', encoding='utf-8') as file:
|
|||
|
# 使用循环将列表中的每一项写入文件,并在每项后面加上换行符
|
|||
|
for item in test_set:
|
|||
|
for sub in item:
|
|||
|
file.write(sub)
|
|||
|
file.write('\n')
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
# train:validation:test=8:1:1
|
|||
|
getlabel('G:/dataset/test/split')
|