添加多标签学习
This commit is contained in:
parent
21c1778a4a
commit
a7b08c786e
1
.gitignore
vendored
1
.gitignore
vendored
@ -163,7 +163,6 @@ weights/
|
|||||||
pnnx*
|
pnnx*
|
||||||
|
|
||||||
# Autogenerated files for tests
|
# Autogenerated files for tests
|
||||||
/ultralytics/assets/
|
|
||||||
|
|
||||||
# dataset cache
|
# dataset cache
|
||||||
*.cache
|
*.cache
|
||||||
|
13
train.py
13
train.py
@ -6,13 +6,14 @@ warnings.filterwarnings('ignore')
|
|||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = YOLO('ultralytics/cfg/models/v8/yolov8n-cls.yaml')
|
# model = YOLO('ultralytics/cfg/models/v8/yolov8n-cls.yaml') # 单任务学习
|
||||||
# model.load('yolov8n.pt') # loading pretrain weights
|
# model.train(data='G:/dataset/split',
|
||||||
model.train(data='G:/skin-cancer-detection',
|
model = YOLO('ultralytics/cfg/models/v8/yolov8-mtlcls.yaml',task='MTL') # 多任务学习
|
||||||
|
model.train(data='G:/dataset/test/ml.yaml',
|
||||||
cache=False,
|
cache=False,
|
||||||
imgsz=640,
|
imgsz=224,
|
||||||
epochs=300,
|
epochs=2,
|
||||||
batch=32,
|
batch=64,
|
||||||
close_mosaic=0,
|
close_mosaic=0,
|
||||||
workers=8, # Windows下出现莫名其妙卡主的情况可以尝试把workers设置为0
|
workers=8, # Windows下出现莫名其妙卡主的情况可以尝试把workers设置为0
|
||||||
optimizer='SGD', # using SGD
|
optimizer='SGD', # using SGD
|
||||||
|
BIN
ultralytics/assets/bus.jpg
Normal file
BIN
ultralytics/assets/bus.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 134 KiB |
BIN
ultralytics/assets/zidane.jpg
Normal file
BIN
ultralytics/assets/zidane.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 49 KiB |
@ -31,13 +31,14 @@ from ultralytics.utils import (
|
|||||||
|
|
||||||
# Define valid tasks and modes
|
# Define valid tasks and modes
|
||||||
MODES = {"train", "val", "predict", "export", "track", "benchmark"}
|
MODES = {"train", "val", "predict", "export", "track", "benchmark"}
|
||||||
TASKS = {"detect", "segment", "classify", "pose", "obb"}
|
TASKS = {"detect", "segment", "classify", "pose", "obb" ,"MTL"}
|
||||||
TASK2DATA = {
|
TASK2DATA = {
|
||||||
"detect": "coco8.yaml",
|
"detect": "coco8.yaml",
|
||||||
"segment": "coco8-seg.yaml",
|
"segment": "coco8-seg.yaml",
|
||||||
"classify": "imagenet10",
|
"classify": "imagenet10",
|
||||||
"pose": "coco8-pose.yaml",
|
"pose": "coco8-pose.yaml",
|
||||||
"obb": "dota8.yaml",
|
"obb": "dota8.yaml",
|
||||||
|
"MTL": "imagenet10"
|
||||||
}
|
}
|
||||||
TASK2MODEL = {
|
TASK2MODEL = {
|
||||||
"detect": "yolov8n.pt",
|
"detect": "yolov8n.pt",
|
||||||
@ -45,6 +46,7 @@ TASK2MODEL = {
|
|||||||
"classify": "yolov8n-cls.pt",
|
"classify": "yolov8n-cls.pt",
|
||||||
"pose": "yolov8n-pose.pt",
|
"pose": "yolov8n-pose.pt",
|
||||||
"obb": "yolov8n-obb.pt",
|
"obb": "yolov8n-obb.pt",
|
||||||
|
"MTL": "yolov8n-cls.pt"
|
||||||
}
|
}
|
||||||
TASK2METRIC = {
|
TASK2METRIC = {
|
||||||
"detect": "metrics/mAP50-95(B)",
|
"detect": "metrics/mAP50-95(B)",
|
||||||
|
29
ultralytics/cfg/models/v8/yolov8-mtlcls.yaml
Normal file
29
ultralytics/cfg/models/v8/yolov8-mtlcls.yaml
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
# YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
nc: 1000 # number of classes
|
||||||
|
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
|
||||||
|
# [depth, width, max_channels]
|
||||||
|
n: [0.33, 0.25, 1024]
|
||||||
|
s: [0.33, 0.50, 1024]
|
||||||
|
m: [0.67, 0.75, 1024]
|
||||||
|
l: [1.00, 1.00, 1024]
|
||||||
|
x: [1.00, 1.25, 1024]
|
||||||
|
|
||||||
|
# YOLOv8.0n backbone
|
||||||
|
backbone:
|
||||||
|
# [from, repeats, module, args]
|
||||||
|
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
||||||
|
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
||||||
|
- [-1, 3, C2f, [128, True]]
|
||||||
|
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
||||||
|
- [-1, 6, C2f, [256, True]]
|
||||||
|
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
||||||
|
- [-1, 6, C2f, [512, True]]
|
||||||
|
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
||||||
|
- [-1, 3, C2f, [1024, True]]
|
||||||
|
|
||||||
|
# YOLOv8.0n head
|
||||||
|
head:
|
||||||
|
- [-1, 1, MTLClassify, [nc]] # Classify
|
@ -6,6 +6,7 @@ from .dataset import (
|
|||||||
ClassificationDataset,
|
ClassificationDataset,
|
||||||
GroundingDataset,
|
GroundingDataset,
|
||||||
SemanticDataset,
|
SemanticDataset,
|
||||||
|
MTLDataset,
|
||||||
YOLOConcatDataset,
|
YOLOConcatDataset,
|
||||||
YOLODataset,
|
YOLODataset,
|
||||||
YOLOMultiModalDataset,
|
YOLOMultiModalDataset,
|
||||||
@ -14,6 +15,7 @@ from .dataset import (
|
|||||||
__all__ = (
|
__all__ = (
|
||||||
"BaseDataset",
|
"BaseDataset",
|
||||||
"ClassificationDataset",
|
"ClassificationDataset",
|
||||||
|
"MTLDataset",
|
||||||
"SemanticDataset",
|
"SemanticDataset",
|
||||||
"YOLODataset",
|
"YOLODataset",
|
||||||
"YOLOMultiModalDataset",
|
"YOLOMultiModalDataset",
|
||||||
|
@ -5,15 +5,17 @@ import json
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import ConcatDataset
|
from torch.utils.data import ConcatDataset
|
||||||
|
|
||||||
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, yaml_load
|
||||||
from ultralytics.utils.ops import resample_segments
|
from ultralytics.utils.ops import resample_segments
|
||||||
|
|
||||||
from .augment import (
|
from .augment import (
|
||||||
@ -504,3 +506,159 @@ class ClassificationDataset:
|
|||||||
x["msgs"] = msgs # warnings
|
x["msgs"] = msgs # warnings
|
||||||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
class MTLDataset():
|
||||||
|
"""
|
||||||
|
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
|
||||||
|
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
|
||||||
|
learning models, with optional image transformations and caching mechanisms to speed up training.
|
||||||
|
|
||||||
|
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
|
||||||
|
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
|
||||||
|
to ensure data integrity and consistency.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
cache_ram (bool): Indicates if caching in RAM is enabled.
|
||||||
|
cache_disk (bool): Indicates if caching on disk is enabled.
|
||||||
|
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
|
||||||
|
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
||||||
|
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data, args, augment=False, prefix=""):
|
||||||
|
"""
|
||||||
|
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
||||||
|
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
|
||||||
|
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
|
||||||
|
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
|
||||||
|
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
|
||||||
|
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
|
||||||
|
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
|
||||||
|
debugging. Default is an empty string.
|
||||||
|
"""
|
||||||
|
# import torchvision # scope for faster 'import ultralytics'
|
||||||
|
|
||||||
|
# # Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
|
||||||
|
# self.base = torchvision.datasets.ImageFolder(root=root)
|
||||||
|
# self.samples = self.base.samples
|
||||||
|
# self.root = self.base.root
|
||||||
|
# # 获取标签文件
|
||||||
|
self.data = data
|
||||||
|
self.root = data.rsplit('/', 1)[0]
|
||||||
|
|
||||||
|
# 加载args.data yaml
|
||||||
|
config = yaml_load(args.data, append_filename=True) # dictionary
|
||||||
|
|
||||||
|
self.root = config["path"]
|
||||||
|
self.names = []
|
||||||
|
for i in range(len(config["tasks"])):
|
||||||
|
self.names.append(config["tasks"][i])
|
||||||
|
|
||||||
|
# # Initialize attributes
|
||||||
|
if augment and args.fraction < 1.0: # reduce training fraction
|
||||||
|
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
||||||
|
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
||||||
|
# self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
||||||
|
# self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
||||||
|
self.cache_ram = False
|
||||||
|
self.cache_disk = False
|
||||||
|
self.samples = self.load_data() # filter out bad images
|
||||||
|
# self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||||||
|
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
||||||
|
self.torch_transforms = (
|
||||||
|
classify_augmentations(
|
||||||
|
size=args.imgsz,
|
||||||
|
scale=scale,
|
||||||
|
hflip=args.fliplr,
|
||||||
|
vflip=args.flipud,
|
||||||
|
erasing=args.erasing,
|
||||||
|
auto_augment=args.auto_augment,
|
||||||
|
hsv_h=args.hsv_h,
|
||||||
|
hsv_s=args.hsv_s,
|
||||||
|
hsv_v=args.hsv_v,
|
||||||
|
)
|
||||||
|
if augment
|
||||||
|
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
samples = []
|
||||||
|
with open(self.data, 'r', encoding='utf-8') as file:
|
||||||
|
for line in file:
|
||||||
|
parts = line.strip().split(',')
|
||||||
|
path = parts.pop(0)
|
||||||
|
label = []
|
||||||
|
for i in range(len(parts)):
|
||||||
|
try:
|
||||||
|
label.append([k for k,v in self.names[i].items() if v==parts[i]][0])
|
||||||
|
except:
|
||||||
|
label.append([k for k,v in self.names[i].items() if v==int(parts[i])][0])
|
||||||
|
samples.append([os.path.join(self.root,path),label,None,None])
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
"""Returns subset of data and targets corresponding to given indices."""
|
||||||
|
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||||
|
if self.cache_ram:
|
||||||
|
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
||||||
|
im = self.samples[i][3] = cv2.imread(f)
|
||||||
|
elif self.cache_disk:
|
||||||
|
if not fn.exists(): # load npy
|
||||||
|
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
||||||
|
im = np.load(fn)
|
||||||
|
else: # read image
|
||||||
|
im = cv2.imread(f) # BGR
|
||||||
|
# Convert NumPy array to PIL image
|
||||||
|
if im is not None:
|
||||||
|
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
||||||
|
sample = self.torch_transforms(im)
|
||||||
|
return {"img": sample, "cls": j}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the total number of samples in the dataset."""
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def verify_images(self):
|
||||||
|
"""Verify all images in dataset."""
|
||||||
|
|
||||||
|
desc = f"{self.prefix}Scanning {self.root}..."
|
||||||
|
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
||||||
|
|
||||||
|
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
||||||
|
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
||||||
|
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||||
|
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||||
|
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
||||||
|
if LOCAL_RANK in {-1, 0}:
|
||||||
|
d = f"{desc} {nf} images, {nc} corrupt"
|
||||||
|
TQDM(None, desc=d, total=n, initial=n)
|
||||||
|
if cache["msgs"]:
|
||||||
|
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||||
|
return samples
|
||||||
|
|
||||||
|
# Run scan if *.cache retrieval failed
|
||||||
|
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
||||||
|
with ThreadPool(NUM_THREADS) as pool:
|
||||||
|
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
||||||
|
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
||||||
|
for sample, nf_f, nc_f, msg in pbar:
|
||||||
|
if nf_f:
|
||||||
|
samples.append(sample)
|
||||||
|
if msg:
|
||||||
|
msgs.append(msg)
|
||||||
|
nf += nf_f
|
||||||
|
nc += nc_f
|
||||||
|
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
||||||
|
pbar.close()
|
||||||
|
if msgs:
|
||||||
|
LOGGER.info("\n".join(msgs))
|
||||||
|
x["hash"] = get_hash([x[0] for x in self.samples])
|
||||||
|
x["results"] = nf, nc, len(samples), samples
|
||||||
|
x["msgs"] = msgs # warnings
|
||||||
|
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||||
|
return samples
|
@ -423,6 +423,31 @@ def check_cls_dataset(dataset, split=""):
|
|||||||
return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
|
return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
|
||||||
|
|
||||||
|
|
||||||
|
def check_mtlcls_dataset(file, split=""):
|
||||||
|
# Read YAML
|
||||||
|
data = yaml_load(file, append_filename=True) # dictionary
|
||||||
|
|
||||||
|
# Checks
|
||||||
|
for k in "train", "val":
|
||||||
|
if k not in data:
|
||||||
|
if k != "val" or "validation" not in data:
|
||||||
|
raise SyntaxError(
|
||||||
|
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
|
||||||
|
)
|
||||||
|
LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
|
||||||
|
data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
|
||||||
|
if "tasks" not in data :
|
||||||
|
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
|
||||||
|
|
||||||
|
names = []
|
||||||
|
nc = []
|
||||||
|
for i in range(len(data["tasks"])):
|
||||||
|
nc.append(len(data["tasks"][i]))
|
||||||
|
names.append(check_class_names(data["tasks"][i]))
|
||||||
|
|
||||||
|
|
||||||
|
return {"train": data["train"], "val": data["val"], "test": data["test"], "nc": nc, "names": names}
|
||||||
|
|
||||||
class HUBDatasetStats:
|
class HUBDatasetStats:
|
||||||
"""
|
"""
|
||||||
A class for generating HUB dataset JSON and `-hub` dataset directory.
|
A class for generating HUB dataset JSON and `-hub` dataset directory.
|
||||||
|
@ -22,7 +22,7 @@ from torch import distributed as dist
|
|||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
|
|
||||||
from ultralytics.cfg import get_cfg, get_save_dir
|
from ultralytics.cfg import get_cfg, get_save_dir
|
||||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset, check_mtlcls_dataset
|
||||||
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||||
from ultralytics.utils import (
|
from ultralytics.utils import (
|
||||||
DEFAULT_CFG,
|
DEFAULT_CFG,
|
||||||
@ -413,10 +413,16 @@ class BaseTrainer:
|
|||||||
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
||||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||||
if RANK in {-1, 0}:
|
if RANK in {-1, 0}:
|
||||||
pbar.set_description(
|
if self.args.task == "MTL":
|
||||||
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
pbar.set_description(
|
||||||
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
||||||
)
|
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"][0].shape[0], batch["img"].shape[-1])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pbar.set_description(
|
||||||
|
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
||||||
|
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
||||||
|
)
|
||||||
self.run_callbacks("on_batch_end")
|
self.run_callbacks("on_batch_end")
|
||||||
if self.args.plots and ni in self.plot_idx:
|
if self.args.plots and ni in self.plot_idx:
|
||||||
self.plot_training_samples(batch, ni)
|
self.plot_training_samples(batch, ni)
|
||||||
@ -542,6 +548,8 @@ class BaseTrainer:
|
|||||||
try:
|
try:
|
||||||
if self.args.task == "classify":
|
if self.args.task == "classify":
|
||||||
data = check_cls_dataset(self.args.data)
|
data = check_cls_dataset(self.args.data)
|
||||||
|
elif self.args.task == "MTL":
|
||||||
|
data = check_mtlcls_dataset(self.args.data)
|
||||||
elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
|
elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
|
||||||
"detect",
|
"detect",
|
||||||
"segment",
|
"segment",
|
||||||
|
@ -28,7 +28,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.cfg import get_cfg, get_save_dir
|
from ultralytics.cfg import get_cfg, get_save_dir
|
||||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset, check_mtlcls_dataset
|
||||||
from ultralytics.nn.autobackend import AutoBackend
|
from ultralytics.nn.autobackend import AutoBackend
|
||||||
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
|
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
|
||||||
from ultralytics.utils.checks import check_imgsz
|
from ultralytics.utils.checks import check_imgsz
|
||||||
@ -101,6 +101,11 @@ class BaseValidator:
|
|||||||
|
|
||||||
self.plots = {}
|
self.plots = {}
|
||||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
|
def mean_if_list(self,lst):
|
||||||
|
if isinstance(lst, list) and lst: # 确保lst是列表且不为空
|
||||||
|
return round(sum(lst) / len(self.names), 5)
|
||||||
|
else:
|
||||||
|
return lst # 如果不是列表或列表为空,直接返回v(可能已经是数字)
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def __call__(self, trainer=None, model=None):
|
def __call__(self, trainer=None, model=None):
|
||||||
@ -140,7 +145,9 @@ class BaseValidator:
|
|||||||
self.args.batch = 1 # export.py models default to batch-size 1
|
self.args.batch = 1 # export.py models default to batch-size 1
|
||||||
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
|
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
|
||||||
|
|
||||||
if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
|
if self.args.task == "MTL":
|
||||||
|
self.data = check_mtlcls_dataset(self.args.data, split=self.args.split)
|
||||||
|
elif str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
|
||||||
self.data = check_det_dataset(self.args.data)
|
self.data = check_det_dataset(self.args.data)
|
||||||
elif self.args.task == "classify":
|
elif self.args.task == "classify":
|
||||||
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
||||||
@ -202,7 +209,7 @@ class BaseValidator:
|
|||||||
if self.training:
|
if self.training:
|
||||||
model.float()
|
model.float()
|
||||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
||||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
return {k: round(float(v) if not isinstance(v, list) else self.mean_if_list(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||||
else:
|
else:
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
"Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
|
"Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
|
||||||
|
7
ultralytics/models/yolo/MTL/__init__.py
Normal file
7
ultralytics/models/yolo/MTL/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
from ultralytics.models.yolo.MTL.predict import ClassificationPredictor
|
||||||
|
from ultralytics.models.yolo.MTL.train import ClassificationTrainer
|
||||||
|
from ultralytics.models.yolo.MTL.val import ClassificationValidator
|
||||||
|
|
||||||
|
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
|
61
ultralytics/models/yolo/MTL/predict.py
Normal file
61
ultralytics/models/yolo/MTL/predict.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ultralytics.engine.predictor import BasePredictor
|
||||||
|
from ultralytics.engine.results import Results
|
||||||
|
from ultralytics.utils import DEFAULT_CFG, ops
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationPredictor(BasePredictor):
|
||||||
|
"""
|
||||||
|
A class extending the BasePredictor class for prediction based on a classification model.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.utils import ASSETS
|
||||||
|
from ultralytics.models.yolo.classify import ClassificationPredictor
|
||||||
|
|
||||||
|
args = dict(model='yolov8n-cls.pt', source=ASSETS)
|
||||||
|
predictor = ClassificationPredictor(overrides=args)
|
||||||
|
predictor.predict_cli()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
|
"""Initializes ClassificationPredictor setting the task to 'classify'."""
|
||||||
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
|
self.args.task = "classify"
|
||||||
|
self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
|
||||||
|
|
||||||
|
def preprocess(self, img):
|
||||||
|
"""Converts input image to model-compatible data type."""
|
||||||
|
if not isinstance(img, torch.Tensor):
|
||||||
|
is_legacy_transform = any(
|
||||||
|
self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
|
||||||
|
)
|
||||||
|
if is_legacy_transform: # to handle legacy transforms
|
||||||
|
img = torch.stack([self.transforms(im) for im in img], dim=0)
|
||||||
|
else:
|
||||||
|
img = torch.stack(
|
||||||
|
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
|
||||||
|
)
|
||||||
|
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
||||||
|
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||||
|
|
||||||
|
def postprocess(self, preds, img, orig_imgs):
|
||||||
|
"""Post-processes predictions to return Results objects."""
|
||||||
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i, pred in enumerate(preds):
|
||||||
|
orig_img = orig_imgs[i]
|
||||||
|
img_path = self.batch[0][i]
|
||||||
|
results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
|
||||||
|
return results
|
168
ultralytics/models/yolo/MTL/train.py
Normal file
168
ultralytics/models/yolo/MTL/train.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
import gc
|
||||||
|
from math import dist
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
from ultralytics.data import MTLDataset, build_dataloader
|
||||||
|
from ultralytics.engine.trainer import BaseTrainer
|
||||||
|
from ultralytics.models import yolo
|
||||||
|
from ultralytics.nn.tasks import ClassificationModel,MTLClassificationModel, attempt_load_one_weight
|
||||||
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, TQDM, colorstr, yaml_load
|
||||||
|
from ultralytics.utils.plotting import plot_images, plot_results
|
||||||
|
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationTrainer(BaseTrainer):
|
||||||
|
"""
|
||||||
|
A class extending the BaseTrainer class for training based on a classification model.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.models.yolo.classify import ClassificationTrainer
|
||||||
|
|
||||||
|
args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
|
||||||
|
trainer = ClassificationTrainer(overrides=args)
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
|
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
|
||||||
|
if overrides is None:
|
||||||
|
overrides = {}
|
||||||
|
overrides["task"] = "MTL"
|
||||||
|
if overrides.get("imgsz") is None:
|
||||||
|
overrides["imgsz"] = 224
|
||||||
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
|
|
||||||
|
def set_model_attributes(self):
|
||||||
|
"""Set the YOLO model's class names from the loaded dataset."""
|
||||||
|
self.model.names = self.data["names"]
|
||||||
|
|
||||||
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
|
"""Returns a modified PyTorch model configured for training YOLO."""
|
||||||
|
model = MTLClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
|
if weights:
|
||||||
|
model.load(weights)
|
||||||
|
|
||||||
|
for m in model.modules():
|
||||||
|
if not self.args.pretrained and hasattr(m, "reset_parameters"):
|
||||||
|
m.reset_parameters()
|
||||||
|
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
||||||
|
m.p = self.args.dropout # set dropout
|
||||||
|
for p in model.parameters():
|
||||||
|
p.requires_grad = True # for training
|
||||||
|
return model
|
||||||
|
|
||||||
|
def setup_model(self):
|
||||||
|
"""Load, create or download model for any task."""
|
||||||
|
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||||
|
return
|
||||||
|
|
||||||
|
model, ckpt = str(self.model), None
|
||||||
|
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||||
|
if model.endswith(".pt"):
|
||||||
|
self.model, ckpt = attempt_load_one_weight(model, device="cpu")
|
||||||
|
for p in self.model.parameters():
|
||||||
|
p.requires_grad = True # for training
|
||||||
|
elif model.split(".")[-1] in {"yaml", "yml"}:
|
||||||
|
self.model = self.get_model(cfg=model)
|
||||||
|
elif model in torchvision.models.__dict__:
|
||||||
|
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
|
||||||
|
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
||||||
|
|
||||||
|
return ckpt
|
||||||
|
|
||||||
|
def build_dataset(self, img_path, mode="train", batch=None):
|
||||||
|
"""Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
|
||||||
|
return MTLDataset(data=img_path, args=self.args, augment=mode == "train")
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||||
|
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
||||||
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||||
|
dataset = self.build_dataset(dataset_path, mode)
|
||||||
|
|
||||||
|
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
|
||||||
|
# Attach inference transforms
|
||||||
|
if mode != "train":
|
||||||
|
if is_parallel(self.model):
|
||||||
|
self.model.module.transforms = loader.dataset.torch_transforms
|
||||||
|
else:
|
||||||
|
self.model.transforms = loader.dataset.torch_transforms
|
||||||
|
return loader
|
||||||
|
|
||||||
|
def preprocess_batch(self, batch):
|
||||||
|
"""Preprocesses a batch of images and classes."""
|
||||||
|
batch["img"] = batch["img"].to(self.device)
|
||||||
|
for i in range(len(batch["cls"])):
|
||||||
|
batch["cls"][i] = batch["cls"][i].to(self.device)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def progress_string(self):
|
||||||
|
"""Returns a formatted string showing training progress."""
|
||||||
|
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
||||||
|
"Epoch",
|
||||||
|
"GPU_mem",
|
||||||
|
*self.loss_names,
|
||||||
|
"Instances",
|
||||||
|
"Size",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_validator(self):
|
||||||
|
"""Returns an instance of ClassificationValidator for validation."""
|
||||||
|
cfg = yaml_load(self.args.data) # model dict
|
||||||
|
self.loss_names = ["loss"] + [pair.split(':')[1] for pair in cfg["tasks_name"].split()]
|
||||||
|
return yolo.MTL.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)
|
||||||
|
|
||||||
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||||
|
"""
|
||||||
|
Returns a loss dict with labelled training loss items tensor.
|
||||||
|
|
||||||
|
Not needed for classification but necessary for segmentation & detection
|
||||||
|
"""
|
||||||
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||||
|
if loss_items is None:
|
||||||
|
return keys
|
||||||
|
loss_items = [float(item) for item in loss_items]
|
||||||
|
return dict(zip(keys, loss_items))
|
||||||
|
|
||||||
|
def plot_metrics(self):
|
||||||
|
"""Plots metrics from a CSV file."""
|
||||||
|
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
|
||||||
|
|
||||||
|
def final_eval(self):
|
||||||
|
"""Evaluate trained model and save validation results."""
|
||||||
|
for f in self.last, self.best:
|
||||||
|
if f.exists():
|
||||||
|
strip_optimizer(f) # strip optimizers
|
||||||
|
if f is self.best:
|
||||||
|
LOGGER.info(f"\nValidating {f}...")
|
||||||
|
self.validator.args.data = self.args.data
|
||||||
|
self.validator.args.plots = self.args.plots
|
||||||
|
self.metrics = self.validator(model=f)
|
||||||
|
self.metrics.pop("fitness", None)
|
||||||
|
self.run_callbacks("on_fit_epoch_end")
|
||||||
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||||
|
|
||||||
|
def plot_training_samples(self, batch, ni):
|
||||||
|
"""Plots training samples with their annotations."""
|
||||||
|
plot_images(
|
||||||
|
images=batch["img"],
|
||||||
|
batch_idx=torch.arange(len(batch["img"])),
|
||||||
|
cls=batch["cls"][0].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
||||||
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||||
|
on_plot=self.on_plot,
|
||||||
|
)
|
133
ultralytics/models/yolo/MTL/val.py
Normal file
133
ultralytics/models/yolo/MTL/val.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
|
||||||
|
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||||
|
from ultralytics.engine.validator import BaseValidator
|
||||||
|
from ultralytics.utils import LOGGER
|
||||||
|
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix, MTLClassifyMetrics
|
||||||
|
from ultralytics.utils.plotting import plot_images
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationValidator(BaseValidator):
|
||||||
|
"""
|
||||||
|
A class extending the BaseValidator class for validation based on a classification model.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.models.yolo.classify import ClassificationValidator
|
||||||
|
|
||||||
|
args = dict(model='yolov8n-cls.pt', data='imagenet10')
|
||||||
|
validator = ClassificationValidator(args=args)
|
||||||
|
validator()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||||
|
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
|
||||||
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||||
|
self.targets = None
|
||||||
|
self.pred = None
|
||||||
|
self.args.task = "MTL"
|
||||||
|
self.metrics = MTLClassifyMetrics()
|
||||||
|
|
||||||
|
def get_desc(self):
|
||||||
|
"""Returns a formatted string summarizing classification metrics."""
|
||||||
|
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
|
||||||
|
|
||||||
|
def init_metrics(self, model):
|
||||||
|
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
|
||||||
|
if isinstance(model.names[0],str):
|
||||||
|
for i in range(len(model.names)):
|
||||||
|
inner_s = model.names[i][1:-1]
|
||||||
|
pairs = inner_s.split(', ')
|
||||||
|
d = {}
|
||||||
|
for pair in pairs:
|
||||||
|
key, value = pair.split(': ')
|
||||||
|
d[int(key)] = value.strip("'") # 移除值两侧的单引号
|
||||||
|
model.names[i] = d
|
||||||
|
self.names = model.names
|
||||||
|
self.nc = [ ]
|
||||||
|
for i in range(len(model.names)):
|
||||||
|
self.nc.append(len(model.names[i]))
|
||||||
|
self.confusion_matrix = [ ConfusionMatrix(nc=item, conf=self.args.conf, task="classify") for item in self.nc ]
|
||||||
|
self.pred = []
|
||||||
|
self.targets = []
|
||||||
|
|
||||||
|
def preprocess(self, batch):
|
||||||
|
"""Preprocesses input batch and returns it."""
|
||||||
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
||||||
|
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
||||||
|
for i in range(len(batch["cls"])):
|
||||||
|
batch["cls"][i] = batch["cls"][i].to(self.device)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def update_metrics(self, preds, batch):
|
||||||
|
"""Updates running metrics with model predictions and batch targets."""
|
||||||
|
lpred = []
|
||||||
|
ltarget = []
|
||||||
|
for i in range(len(self.names)):
|
||||||
|
n5 = min(len(self.names[i]), 5)
|
||||||
|
lpred.append(preds[i].argsort(1, descending=True)[:, :n5])
|
||||||
|
self.pred.append(lpred)
|
||||||
|
ltarget.append(batch["cls"][i])
|
||||||
|
self.targets.append(ltarget)
|
||||||
|
|
||||||
|
def finalize_metrics(self, *args, **kwargs):
|
||||||
|
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
||||||
|
for i in range(len(self.names)):
|
||||||
|
self.confusion_matrix[i].process_cls_preds([sublist[i] for sublist in self.pred], [sublist[i] for sublist in self.targets])
|
||||||
|
if self.args.plots:
|
||||||
|
for normalize in True, False:
|
||||||
|
self.confusion_matrix[i].plot(
|
||||||
|
save_dir=self.save_dir, names=self.names[i].values(), normalize=normalize, on_plot=self.on_plot
|
||||||
|
)
|
||||||
|
self.metrics.speed = self.speed
|
||||||
|
self.metrics.confusion_matrix = self.confusion_matrix
|
||||||
|
self.metrics.save_dir = self.save_dir
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
|
||||||
|
self.metrics.process(self.targets, self.pred)
|
||||||
|
return self.metrics.results_dict
|
||||||
|
|
||||||
|
def build_dataset(self, img_path):
|
||||||
|
"""Creates and returns a ClassificationDataset instance using given image path and preprocessing parameters."""
|
||||||
|
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
|
"""Builds and returns a data loader for classification tasks with given parameters."""
|
||||||
|
dataset = self.build_dataset(dataset_path)
|
||||||
|
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
|
||||||
|
|
||||||
|
def print_results(self):
|
||||||
|
"""Prints evaluation metrics for YOLO object detection model."""
|
||||||
|
pf = "%22s" + "%11.3g" * len(self.metrics.srckeys) # print format
|
||||||
|
for i in range(len(self.names)):
|
||||||
|
LOGGER.info(pf % ("all", self.metrics.top1[i], self.metrics.top5[i]))
|
||||||
|
|
||||||
|
def plot_val_samples(self, batch, ni):
|
||||||
|
"""Plot validation image samples."""
|
||||||
|
# plot_images(
|
||||||
|
# images=batch["img"],
|
||||||
|
# batch_idx=torch.arange(len(batch["img"])),
|
||||||
|
# cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
||||||
|
# fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||||
|
# names=self.names,
|
||||||
|
# on_plot=self.on_plot,
|
||||||
|
# )
|
||||||
|
|
||||||
|
def plot_predictions(self, batch, preds, ni):
|
||||||
|
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||||
|
# plot_images(
|
||||||
|
# batch["img"],
|
||||||
|
# batch_idx=torch.arange(len(batch["img"])),
|
||||||
|
# cls=torch.argmax(preds, dim=1),
|
||||||
|
# fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||||
|
# names=self.names,
|
||||||
|
# on_plot=self.on_plot,
|
||||||
|
# ) # pred
|
@ -1,7 +1,7 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
|
from ultralytics.models.yolo import classify, MTL, detect, obb, pose, segment, world
|
||||||
|
|
||||||
from .model import YOLO, YOLOWorld
|
from .model import YOLO, YOLOWorld
|
||||||
|
|
||||||
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
|
__all__ = "classify", "MTL" ,"segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
|
||||||
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from ultralytics.engine.model import Model
|
from ultralytics.engine.model import Model
|
||||||
from ultralytics.models import yolo
|
from ultralytics.models import yolo
|
||||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
|
from ultralytics.nn.tasks import ClassificationModel,MTLClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
|
||||||
from ultralytics.utils import ROOT, yaml_load
|
from ultralytics.utils import ROOT, yaml_load
|
||||||
|
|
||||||
|
|
||||||
@ -31,6 +31,12 @@ class YOLO(Model):
|
|||||||
"trainer": yolo.classify.ClassificationTrainer,
|
"trainer": yolo.classify.ClassificationTrainer,
|
||||||
"validator": yolo.classify.ClassificationValidator,
|
"validator": yolo.classify.ClassificationValidator,
|
||||||
"predictor": yolo.classify.ClassificationPredictor,
|
"predictor": yolo.classify.ClassificationPredictor,
|
||||||
|
},
|
||||||
|
"MTL": {
|
||||||
|
"model": MTLClassificationModel,
|
||||||
|
"trainer": yolo.MTL.ClassificationTrainer,
|
||||||
|
"validator": yolo.MTL.ClassificationValidator,
|
||||||
|
"predictor": yolo.MTL.ClassificationPredictor,
|
||||||
},
|
},
|
||||||
"detect": {
|
"detect": {
|
||||||
"model": DetectionModel,
|
"model": DetectionModel,
|
||||||
|
@ -15,7 +15,7 @@ from .conv import Conv
|
|||||||
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
||||||
from .utils import bias_init_with_prob, linear_init
|
from .utils import bias_init_with_prob, linear_init
|
||||||
|
|
||||||
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect"
|
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect", "MTLClassify"
|
||||||
|
|
||||||
|
|
||||||
class Detect(nn.Module):
|
class Detect(nn.Module):
|
||||||
@ -601,3 +601,55 @@ class v10Detect(Detect):
|
|||||||
|
|
||||||
def switch_to_deploy(self):
|
def switch_to_deploy(self):
|
||||||
del self.cv2, self.cv3
|
del self.cv2, self.cv3
|
||||||
|
|
||||||
|
|
||||||
|
class MTLClassify(nn.Module):
|
||||||
|
"""YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
||||||
|
|
||||||
|
def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
|
||||||
|
"""Initializes YOLOv8 classification head with specified input and output channels, kernel size, stride,
|
||||||
|
padding, and groups.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(c2, list):
|
||||||
|
self.n_task = len(c2)
|
||||||
|
else:
|
||||||
|
self.n_task = 1
|
||||||
|
c_ = 1280 # efficientnet_b0 size
|
||||||
|
self.conv = Conv(c1, c_, k, s, p, g)
|
||||||
|
self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
|
||||||
|
self.drop = nn.Dropout(p=0.0, inplace=True)
|
||||||
|
self.linearpool = []
|
||||||
|
if isinstance(c2, list):
|
||||||
|
for i in range(self.n_task):
|
||||||
|
self.linearpool.append(nn.Linear(c_, c2[i])) # to x(b,c2)
|
||||||
|
else:
|
||||||
|
self.linearpool.append(nn.Linear(c_, c2))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Performs a forward pass of the YOLO model on input image data."""
|
||||||
|
temp = x
|
||||||
|
if isinstance(x, list):
|
||||||
|
x = torch.cat(x, 1)
|
||||||
|
out = []
|
||||||
|
for i in range(self.n_task):
|
||||||
|
temp = self.conv(x)
|
||||||
|
temp = self.pool(temp).flatten(1)
|
||||||
|
temp = self.drop(temp)
|
||||||
|
|
||||||
|
if self.n_task != 1:
|
||||||
|
if temp.dtype == torch.float16:
|
||||||
|
temp = self.linearpool[i].to('cuda').to(torch.float16)(temp.to('cuda'))
|
||||||
|
else:
|
||||||
|
temp = self.linearpool[i].to('cuda').to(torch.float32)(temp.to('cuda'))
|
||||||
|
if self.training:
|
||||||
|
out.append(temp.softmax(1))
|
||||||
|
else:
|
||||||
|
out.append(temp)
|
||||||
|
else:
|
||||||
|
temp = self.linearpool[i](temp)
|
||||||
|
if self.training:
|
||||||
|
out = (temp.softmax(1))
|
||||||
|
else:
|
||||||
|
out = temp
|
||||||
|
return out
|
||||||
|
@ -61,11 +61,13 @@ from ultralytics.nn.modules import (
|
|||||||
v10Detect,
|
v10Detect,
|
||||||
)
|
)
|
||||||
from ultralytics.nn.extra_modules import *
|
from ultralytics.nn.extra_modules import *
|
||||||
|
from ultralytics.nn.modules.head import MTLClassify
|
||||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||||
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
||||||
from ultralytics.utils.loss import (
|
from ultralytics.utils.loss import (
|
||||||
E2EDetectLoss,
|
E2EDetectLoss,
|
||||||
v8ClassificationLoss,
|
v8ClassificationLoss,
|
||||||
|
v8MTLClassificationLoss,
|
||||||
v8DetectionLoss,
|
v8DetectionLoss,
|
||||||
v8OBBLoss,
|
v8OBBLoss,
|
||||||
v8PoseLoss,
|
v8PoseLoss,
|
||||||
@ -538,7 +540,12 @@ class ClassificationModel(BaseModel):
|
|||||||
raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
|
raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
|
||||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
||||||
self.stride = torch.Tensor([1]) # no stride constraints
|
self.stride = torch.Tensor([1]) # no stride constraints
|
||||||
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
if isinstance(self.yaml["nc"], list):
|
||||||
|
self.names = []
|
||||||
|
for i in range(len(self.yaml["nc"])):
|
||||||
|
self.names.append({i: f"{j}" for j in range(self.yaml["nc"][i])}) # default names dict)
|
||||||
|
else:
|
||||||
|
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
||||||
self.info()
|
self.info()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -566,6 +573,59 @@ class ClassificationModel(BaseModel):
|
|||||||
"""Initialize the loss criterion for the ClassificationModel."""
|
"""Initialize the loss criterion for the ClassificationModel."""
|
||||||
return v8ClassificationLoss()
|
return v8ClassificationLoss()
|
||||||
|
|
||||||
|
class MTLClassificationModel(BaseModel):
|
||||||
|
"""YOLOv8 classification model."""
|
||||||
|
|
||||||
|
def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True):
|
||||||
|
"""Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
|
||||||
|
super().__init__()
|
||||||
|
self._from_yaml(cfg, ch, nc, verbose)
|
||||||
|
|
||||||
|
def _from_yaml(self, cfg, ch, nc, verbose):
|
||||||
|
"""Set YOLOv8 model configurations and define the model architecture."""
|
||||||
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
||||||
|
|
||||||
|
# Define model
|
||||||
|
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
|
||||||
|
if nc and nc != self.yaml["nc"]:
|
||||||
|
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||||
|
self.yaml["nc"] = nc # override YAML value
|
||||||
|
elif not nc and not self.yaml.get("nc", None):
|
||||||
|
raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
|
||||||
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
||||||
|
self.stride = torch.Tensor([1]) # no stride constraints
|
||||||
|
if isinstance(self.yaml["nc"], list):
|
||||||
|
self.names = []
|
||||||
|
for i in range(len(self.yaml["nc"])):
|
||||||
|
self.names.append({i: f"{j}" for j in range(self.yaml["nc"][i])}) # default names dict)
|
||||||
|
else:
|
||||||
|
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
||||||
|
self.info()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reshape_outputs(model, nc):
|
||||||
|
"""Update a TorchVision classification model to class count 'n' if required."""
|
||||||
|
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
|
||||||
|
if isinstance(m, Classify): # YOLO Classify() head
|
||||||
|
if m.linear.out_features != nc:
|
||||||
|
m.linear = nn.Linear(m.linear.in_features, nc)
|
||||||
|
elif isinstance(m, nn.Linear): # ResNet, EfficientNet
|
||||||
|
if m.out_features != nc:
|
||||||
|
setattr(model, name, nn.Linear(m.in_features, nc))
|
||||||
|
elif isinstance(m, nn.Sequential):
|
||||||
|
types = [type(x) for x in m]
|
||||||
|
if nn.Linear in types:
|
||||||
|
i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index
|
||||||
|
if m[i].out_features != nc:
|
||||||
|
m[i] = nn.Linear(m[i].in_features, nc)
|
||||||
|
elif nn.Conv2d in types:
|
||||||
|
i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index
|
||||||
|
if m[i].out_channels != nc:
|
||||||
|
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
||||||
|
|
||||||
|
def init_criterion(self):
|
||||||
|
"""Initialize the loss criterion for the ClassificationModel."""
|
||||||
|
return v8MTLClassificationLoss(self)
|
||||||
|
|
||||||
class RTDETRDetectionModel(DetectionModel):
|
class RTDETRDetectionModel(DetectionModel):
|
||||||
"""
|
"""
|
||||||
@ -1018,7 +1078,7 @@ def parse_model(d, ch, verbose=True, warehouse_manager=None): # model_dict, inp
|
|||||||
args[j] = a
|
args[j] = a
|
||||||
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
||||||
if m in {
|
if m in {
|
||||||
Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP, C1, C2, C2f, ELAN1, AConv, SPPELAN, C2fAttn, C3, C3TR,
|
Classify, MTLClassify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP, C1, C2, C2f, ELAN1, AConv, SPPELAN, C2fAttn, C3, C3TR,
|
||||||
C3Ghost, nn.Conv2d, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3, PSA, SCDown, C2fCIB, C2f_Faster, C2f_ODConv,
|
C3Ghost, nn.Conv2d, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3, PSA, SCDown, C2fCIB, C2f_Faster, C2f_ODConv,
|
||||||
C2f_Faster_EMA, C2f_DBB, GSConv, GSConvns, VoVGSCSP, VoVGSCSPns, VoVGSCSPC, C2f_CloAtt, C3_CloAtt, SCConv, C2f_SCConv, C3_SCConv, C2f_ScConv, C3_ScConv,
|
C2f_Faster_EMA, C2f_DBB, GSConv, GSConvns, VoVGSCSP, VoVGSCSPns, VoVGSCSPC, C2f_CloAtt, C3_CloAtt, SCConv, C2f_SCConv, C3_SCConv, C2f_ScConv, C3_ScConv,
|
||||||
C3_EMSC, C3_EMSCP, C2f_EMSC, C2f_EMSCP, RCSOSA, KWConv, C2f_KW, C3_KW, DySnakeConv, C2f_DySnakeConv, C3_DySnakeConv,
|
C3_EMSC, C3_EMSCP, C2f_EMSC, C2f_EMSCP, RCSOSA, KWConv, C2f_KW, C3_KW, DySnakeConv, C2f_DySnakeConv, C3_DySnakeConv,
|
||||||
@ -1301,7 +1361,7 @@ def parse_model(d, ch, verbose=True, warehouse_manager=None): # model_dict, inp
|
|||||||
else:
|
else:
|
||||||
c2 = ch[f]
|
c2 = ch[f]
|
||||||
|
|
||||||
if isinstance(c2, list) and m not in {ChannelTransformer, PyramidContextExtraction, CrossLayerChannelAttention, CrossLayerSpatialAttention, MutilScaleEdgeInfoGenetator}:
|
if isinstance(c2, list) and m not in {MTLClassify,ChannelTransformer, PyramidContextExtraction, CrossLayerChannelAttention, CrossLayerSpatialAttention, MutilScaleEdgeInfoGenetator}:
|
||||||
is_backbone = True
|
is_backbone = True
|
||||||
m_ = m
|
m_ = m
|
||||||
m_.backbone = True
|
m_.backbone = True
|
||||||
|
@ -984,6 +984,23 @@ class v8ClassificationLoss:
|
|||||||
loss_items = loss.detach()
|
loss_items = loss.detach()
|
||||||
return loss, loss_items
|
return loss, loss_items
|
||||||
|
|
||||||
|
class v8MTLClassificationLoss:
|
||||||
|
"""Criterion class for computing training losses."""
|
||||||
|
def __init__(self,model):
|
||||||
|
self.device = next(model.parameters()).device # get model device
|
||||||
|
|
||||||
|
def __call__(self, preds, batch):
|
||||||
|
"""Compute the classification loss between predictions and true labels."""
|
||||||
|
loss = torch.zeros(len(preds) + 1, device=self.device)
|
||||||
|
if isinstance(preds, list):
|
||||||
|
for i in range(len(preds)):
|
||||||
|
loss[i + 1] = torch.nn.functional.cross_entropy(preds[i], batch["cls"][i], reduction="mean")
|
||||||
|
loss[0] = loss.sum()
|
||||||
|
return loss.sum(), loss.detach() # loss(box, cls, dfl)
|
||||||
|
else:
|
||||||
|
loss = (torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean"))
|
||||||
|
loss_items = loss.detach()
|
||||||
|
return loss, loss_items
|
||||||
|
|
||||||
class v8OBBLoss(v8DetectionLoss):
|
class v8OBBLoss(v8DetectionLoss):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
|
@ -1830,6 +1830,85 @@ class ClassifyMetrics(SimpleClass):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class MTLClassifyMetrics(SimpleClass):
|
||||||
|
"""
|
||||||
|
Class for computing classification metrics including top-1 and top-5 accuracy.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
top1 (float): The top-1 accuracy.
|
||||||
|
top5 (float): The top-5 accuracy.
|
||||||
|
speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline.
|
||||||
|
|
||||||
|
Properties:
|
||||||
|
fitness (float): The fitness of the model, which is equal to top-5 accuracy.
|
||||||
|
results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness.
|
||||||
|
keys (List[str]): A list of keys for the results_dict.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
process(targets, pred): Processes the targets and predictions to compute classification metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize a ClassifyMetrics instance."""
|
||||||
|
self.top1 = [0,0,0,0]
|
||||||
|
self.top5 = [0,0,0,0]
|
||||||
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
||||||
|
self.task = "MTL"
|
||||||
|
|
||||||
|
def process(self, targets, pred):
|
||||||
|
"""Target classes and predicted classes."""
|
||||||
|
# 多标签处理
|
||||||
|
gpred = []
|
||||||
|
gtargets = []
|
||||||
|
gcorrect = []
|
||||||
|
gacc = []
|
||||||
|
|
||||||
|
for i in range(len(targets[0])):
|
||||||
|
gpred.append( torch.cat([sublist[i] for sublist in pred]))
|
||||||
|
gtargets.append(torch.cat([sublist[i] for sublist in targets]))
|
||||||
|
gcorrect.append((gtargets[i][:, None] == gpred[i]).float())
|
||||||
|
gacc.append(torch.stack((gcorrect[i][:, 0], gcorrect[i].max(1).values), dim=1)) # (top1, top5) accuracy
|
||||||
|
self.top1[i], self.top5[i] = gacc[i].mean(0).tolist()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fitness(self):
|
||||||
|
"""Returns mean of top-1 and top-5 accuracies as fitness score."""
|
||||||
|
return sum((self.top1 + self.top5)) / 2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def results_dict(self):
|
||||||
|
"""返回一个包含模型性能指标和适应度得分的字典。"""
|
||||||
|
results = {}
|
||||||
|
# 将 top1 和 top5 的值分别填入字典
|
||||||
|
for i in range(len(self.top1)):
|
||||||
|
results[f"metrics/accuracy_top1_{i}"] = self.top1[i]
|
||||||
|
results[f"metrics/accuracy_top5_{i}"] = self.top5[i]
|
||||||
|
# 添加 fitness 值
|
||||||
|
results["fitness"] = self.fitness
|
||||||
|
return results
|
||||||
|
|
||||||
|
@property
|
||||||
|
def keys(self):
|
||||||
|
metrics = []
|
||||||
|
"""Returns a list of keys for the results_dict property."""
|
||||||
|
for i in range(len(self.top1)):
|
||||||
|
metrics.append("metrics/accuracy_top1_"+str(i))
|
||||||
|
metrics.append("metrics/accuracy_top5_"+str(i))
|
||||||
|
return metrics
|
||||||
|
@property
|
||||||
|
def srckeys(self):
|
||||||
|
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def curves(self):
|
||||||
|
"""Returns a list of curves for accessing specific metrics curves."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def curves_results(self):
|
||||||
|
"""Returns a list of curves for accessing specific metrics curves."""
|
||||||
|
return []
|
||||||
|
|
||||||
class OBBMetrics(SimpleClass):
|
class OBBMetrics(SimpleClass):
|
||||||
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
|
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
|
||||||
"""Initialize an OBBMetrics instance with directory, plotting, callback, and class names."""
|
"""Initialize an OBBMetrics instance with directory, plotting, callback, and class names."""
|
||||||
|
@ -328,12 +328,12 @@ def get_flops(model, imgsz=640):
|
|||||||
# Use stride size for input tensor
|
# Use stride size for input tensor
|
||||||
# stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
|
# stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
|
||||||
stride = 640
|
stride = 640
|
||||||
im = torch.empty((1, 3, stride, stride), device=p.device) # input image in BCHW format
|
im = torch.zeros((1, 3, stride, stride), device=p.device) # input image in BCHW format
|
||||||
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
|
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
|
||||||
return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
|
return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
|
||||||
except Exception:
|
except Exception:
|
||||||
# Use actual image size for input tensor (i.e. required for RTDETR models)
|
# Use actual image size for input tensor (i.e. required for RTDETR models)
|
||||||
im = torch.empty((1, 3, *imgsz), device=p.device) # input image in BCHW format
|
im = torch.zeros((1, 3, *imgsz), device=p.device) # input image in BCHW format
|
||||||
return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
|
return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
|
||||||
except Exception:
|
except Exception:
|
||||||
return 0.0
|
return 0.0
|
||||||
@ -350,14 +350,14 @@ def get_flops_with_torch_profiler(model, imgsz=640):
|
|||||||
try:
|
try:
|
||||||
# Use stride size for input tensor
|
# Use stride size for input tensor
|
||||||
stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
|
stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
|
||||||
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||||
with torch.profiler.profile(with_flops=True) as prof:
|
with torch.profiler.profile(with_flops=True) as prof:
|
||||||
model(im)
|
model(im)
|
||||||
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
||||||
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||||||
except Exception:
|
except Exception:
|
||||||
# Use actual image size for input tensor (i.e. required for RTDETR models)
|
# Use actual image size for input tensor (i.e. required for RTDETR models)
|
||||||
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
|
im = torch.zeros((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
|
||||||
with torch.profiler.profile(with_flops=True) as prof:
|
with torch.profiler.profile(with_flops=True) as prof:
|
||||||
model(im)
|
model(im)
|
||||||
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
||||||
|
74
utils/tobacco_label.py
Normal file
74
utils/tobacco_label.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# 输入数据集路径,划分完成训练集与测试集
|
||||||
|
# 配置文件里面进行路径拼接
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
# import splitfolders
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 读取,设置标签
|
||||||
|
def find_images_in_dir(directory, extensions=('.jpg', '.png', '.jpeg')):
|
||||||
|
for root, dirs, files in os.walk(directory):
|
||||||
|
for file in files:
|
||||||
|
if any(file.lower().endswith(ext) for ext in extensions):
|
||||||
|
yield os.path.join(root, file)
|
||||||
|
|
||||||
|
def path_subtract(full_path, base_path):
|
||||||
|
# 将两个路径都转换为绝对路径,以确保它们是可比较的
|
||||||
|
full_path = os.path.abspath(full_path)
|
||||||
|
base_path = os.path.abspath(base_path)
|
||||||
|
|
||||||
|
# 检查base_path是否是full_path的前缀
|
||||||
|
if full_path.startswith(base_path):
|
||||||
|
# 使用relpath函数获取相对于base_path的子路径
|
||||||
|
relative_path = os.path.relpath(full_path, base_path)
|
||||||
|
return relative_path
|
||||||
|
else:
|
||||||
|
# 如果base_path不是full_path的前缀,则不能相减(或者说结果无意义)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def getlabel(root):
|
||||||
|
train_set = []
|
||||||
|
val_set = []
|
||||||
|
test_set = []
|
||||||
|
for image_path in find_images_in_dir(root):
|
||||||
|
# 处理每一个文件,训练集
|
||||||
|
subpath = path_subtract(image_path,root)
|
||||||
|
Parts = (Path(subpath)).parts[1]
|
||||||
|
|
||||||
|
if 'train' in subpath:
|
||||||
|
train_set.append([subpath,',',Parts.replace('_', ',') ])
|
||||||
|
if 'val' in subpath:
|
||||||
|
val_set.append([subpath,',',Parts.replace('_', ',') ])
|
||||||
|
if 'test' in subpath:
|
||||||
|
test_set.append([subpath,',',Parts.replace('_', ',') ])
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join(root,'label')):
|
||||||
|
# 如果文件夹不存在,则创建它
|
||||||
|
os.makedirs(os.path.join(root,'label'))
|
||||||
|
|
||||||
|
# 输出到txt
|
||||||
|
if len(train_set) > 0:
|
||||||
|
with open(os.path.join(root,'label','train.txt'), 'w', encoding='utf-8') as file:
|
||||||
|
# 使用循环将列表中的每一项写入文件,并在每项后面加上换行符
|
||||||
|
for item in train_set:
|
||||||
|
for sub in item:
|
||||||
|
file.write(sub)
|
||||||
|
file.write('\n')
|
||||||
|
if len(val_set) > 0:
|
||||||
|
with open(os.path.join(root,'label','val.txt'), 'w', encoding='utf-8') as file:
|
||||||
|
# 使用循环将列表中的每一项写入文件,并在每项后面加上换行符
|
||||||
|
for item in val_set:
|
||||||
|
for sub in item:
|
||||||
|
file.write(sub)
|
||||||
|
file.write('\n')
|
||||||
|
if len(test_set) > 0:
|
||||||
|
with open(os.path.join(root,'label','test.txt'), 'w', encoding='utf-8') as file:
|
||||||
|
# 使用循环将列表中的每一项写入文件,并在每项后面加上换行符
|
||||||
|
for item in test_set:
|
||||||
|
for sub in item:
|
||||||
|
file.write(sub)
|
||||||
|
file.write('\n')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# train:validation:test=8:1:1
|
||||||
|
getlabel('G:/dataset/test/split')
|
Loading…
Reference in New Issue
Block a user