74 lines
2.9 KiB
Python
74 lines
2.9 KiB
Python
# 输入数据集路径,划分完成训练集与测试集
|
||
# 配置文件里面进行路径拼接
|
||
|
||
from pathlib import Path
|
||
# import splitfolders
|
||
import os
|
||
|
||
# 读取,设置标签
|
||
def find_images_in_dir(directory, extensions=('.jpg', '.png', '.jpeg')):
|
||
for root, dirs, files in os.walk(directory):
|
||
for file in files:
|
||
if any(file.lower().endswith(ext) for ext in extensions):
|
||
yield os.path.join(root, file)
|
||
|
||
def path_subtract(full_path, base_path):
|
||
# 将两个路径都转换为绝对路径,以确保它们是可比较的
|
||
full_path = os.path.abspath(full_path)
|
||
base_path = os.path.abspath(base_path)
|
||
|
||
# 检查base_path是否是full_path的前缀
|
||
if full_path.startswith(base_path):
|
||
# 使用relpath函数获取相对于base_path的子路径
|
||
relative_path = os.path.relpath(full_path, base_path)
|
||
return relative_path
|
||
else:
|
||
# 如果base_path不是full_path的前缀,则不能相减(或者说结果无意义)
|
||
return None
|
||
|
||
def getlabel(root):
|
||
train_set = []
|
||
val_set = []
|
||
test_set = []
|
||
for image_path in find_images_in_dir(root):
|
||
# 处理每一个文件,训练集
|
||
subpath = path_subtract(image_path,root)
|
||
Parts = (Path(subpath)).parts[1]
|
||
|
||
if 'train' in subpath:
|
||
train_set.append([subpath,',',Parts.replace('_', ',') ])
|
||
if 'val' in subpath:
|
||
val_set.append([subpath,',',Parts.replace('_', ',') ])
|
||
if 'test' in subpath:
|
||
test_set.append([subpath,',',Parts.replace('_', ',') ])
|
||
|
||
if not os.path.exists(os.path.join(root,'label')):
|
||
# 如果文件夹不存在,则创建它
|
||
os.makedirs(os.path.join(root,'label'))
|
||
|
||
# 输出到txt
|
||
if len(train_set) > 0:
|
||
with open(os.path.join(root,'label','train.txt'), 'w', encoding='utf-8') as file:
|
||
# 使用循环将列表中的每一项写入文件,并在每项后面加上换行符
|
||
for item in train_set:
|
||
for sub in item:
|
||
file.write(sub)
|
||
file.write('\n')
|
||
if len(val_set) > 0:
|
||
with open(os.path.join(root,'label','val.txt'), 'w', encoding='utf-8') as file:
|
||
# 使用循环将列表中的每一项写入文件,并在每项后面加上换行符
|
||
for item in val_set:
|
||
for sub in item:
|
||
file.write(sub)
|
||
file.write('\n')
|
||
if len(test_set) > 0:
|
||
with open(os.path.join(root,'label','test.txt'), 'w', encoding='utf-8') as file:
|
||
# 使用循环将列表中的每一项写入文件,并在每项后面加上换行符
|
||
for item in test_set:
|
||
for sub in item:
|
||
file.write(sub)
|
||
file.write('\n')
|
||
|
||
if __name__ == '__main__':
|
||
# train:validation:test=8:1:1
|
||
getlabel('G:/dataset/test/split') |