ultralytics/utils/tobacco_label.py
2025-03-01 18:30:01 +08:00

74 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 输入数据集路径,划分完成训练集与测试集
# 配置文件里面进行路径拼接
from pathlib import Path
# import splitfolders
import os
# 读取,设置标签
def find_images_in_dir(directory, extensions=('.jpg', '.png', '.jpeg')):
for root, dirs, files in os.walk(directory):
for file in files:
if any(file.lower().endswith(ext) for ext in extensions):
yield os.path.join(root, file)
def path_subtract(full_path, base_path):
# 将两个路径都转换为绝对路径,以确保它们是可比较的
full_path = os.path.abspath(full_path)
base_path = os.path.abspath(base_path)
# 检查base_path是否是full_path的前缀
if full_path.startswith(base_path):
# 使用relpath函数获取相对于base_path的子路径
relative_path = os.path.relpath(full_path, base_path)
return relative_path
else:
# 如果base_path不是full_path的前缀则不能相减或者说结果无意义
return None
def getlabel(root):
train_set = []
val_set = []
test_set = []
for image_path in find_images_in_dir(root):
# 处理每一个文件,训练集
subpath = path_subtract(image_path,root)
Parts = (Path(subpath)).parts[1]
if 'train' in subpath:
train_set.append([subpath,',',Parts.replace('_', ',') ])
if 'val' in subpath:
val_set.append([subpath,',',Parts.replace('_', ',') ])
if 'test' in subpath:
test_set.append([subpath,',',Parts.replace('_', ',') ])
if not os.path.exists(os.path.join(root,'label')):
# 如果文件夹不存在,则创建它
os.makedirs(os.path.join(root,'label'))
# 输出到txt
if len(train_set) > 0:
with open(os.path.join(root,'label','train.txt'), 'w', encoding='utf-8') as file:
# 使用循环将列表中的每一项写入文件,并在每项后面加上换行符
for item in train_set:
for sub in item:
file.write(sub)
file.write('\n')
if len(val_set) > 0:
with open(os.path.join(root,'label','val.txt'), 'w', encoding='utf-8') as file:
# 使用循环将列表中的每一项写入文件,并在每项后面加上换行符
for item in val_set:
for sub in item:
file.write(sub)
file.write('\n')
if len(test_set) > 0:
with open(os.path.join(root,'label','test.txt'), 'w', encoding='utf-8') as file:
# 使用循环将列表中的每一项写入文件,并在每项后面加上换行符
for item in test_set:
for sub in item:
file.write(sub)
file.write('\n')
if __name__ == '__main__':
# train:validation:test=8:1:1
getlabel('G:/dataset/test/split')