diff --git a/.gitignore b/.gitignore index 7d419a7..085500a 100644 --- a/.gitignore +++ b/.gitignore @@ -163,7 +163,6 @@ weights/ pnnx* # Autogenerated files for tests -/ultralytics/assets/ # dataset cache *.cache diff --git a/train.py b/train.py index cb55960..5910681 100644 --- a/train.py +++ b/train.py @@ -6,13 +6,14 @@ warnings.filterwarnings('ignore') from ultralytics import YOLO if __name__ == '__main__': - model = YOLO('ultralytics/cfg/models/v8/yolov8n-cls.yaml') - # model.load('yolov8n.pt') # loading pretrain weights - model.train(data='G:/skin-cancer-detection', + # model = YOLO('ultralytics/cfg/models/v8/yolov8n-cls.yaml') # 单任务学习 + # model.train(data='G:/dataset/split', + model = YOLO('ultralytics/cfg/models/v8/yolov8-mtlcls.yaml',task='MTL') # 多任务学习 + model.train(data='G:/dataset/test/ml.yaml', cache=False, - imgsz=640, - epochs=300, - batch=32, + imgsz=224, + epochs=2, + batch=64, close_mosaic=0, workers=8, # Windows下出现莫名其妙卡主的情况可以尝试把workers设置为0 optimizer='SGD', # using SGD diff --git a/ultralytics/assets/bus.jpg b/ultralytics/assets/bus.jpg new file mode 100644 index 0000000..40eaaf5 Binary files /dev/null and b/ultralytics/assets/bus.jpg differ diff --git a/ultralytics/assets/zidane.jpg b/ultralytics/assets/zidane.jpg new file mode 100644 index 0000000..eeab1cd Binary files /dev/null and b/ultralytics/assets/zidane.jpg differ diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 34886eb..6e8d46c 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -31,13 +31,14 @@ from ultralytics.utils import ( # Define valid tasks and modes MODES = {"train", "val", "predict", "export", "track", "benchmark"} -TASKS = {"detect", "segment", "classify", "pose", "obb"} +TASKS = {"detect", "segment", "classify", "pose", "obb" ,"MTL"} TASK2DATA = { "detect": "coco8.yaml", "segment": "coco8-seg.yaml", "classify": "imagenet10", "pose": "coco8-pose.yaml", "obb": "dota8.yaml", + "MTL": "imagenet10" } TASK2MODEL = { "detect": "yolov8n.pt", @@ -45,6 +46,7 @@ TASK2MODEL = { "classify": "yolov8n-cls.pt", "pose": "yolov8n-pose.pt", "obb": "yolov8n-obb.pt", + "MTL": "yolov8n-cls.pt" } TASK2METRIC = { "detect": "metrics/mAP50-95(B)", diff --git a/ultralytics/cfg/models/v8/yolov8-mtlcls.yaml b/ultralytics/cfg/models/v8/yolov8-mtlcls.yaml new file mode 100644 index 0000000..9a96994 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-mtlcls.yaml @@ -0,0 +1,29 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 1000 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.00, 1.25, 1024] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + +# YOLOv8.0n head +head: + - [-1, 1, MTLClassify, [nc]] # Classify diff --git a/ultralytics/data/__init__.py b/ultralytics/data/__init__.py index 834432e..5ceae5f 100644 --- a/ultralytics/data/__init__.py +++ b/ultralytics/data/__init__.py @@ -6,6 +6,7 @@ from .dataset import ( ClassificationDataset, GroundingDataset, SemanticDataset, + MTLDataset, YOLOConcatDataset, YOLODataset, YOLOMultiModalDataset, @@ -14,6 +15,7 @@ from .dataset import ( __all__ = ( "BaseDataset", "ClassificationDataset", + "MTLDataset", "SemanticDataset", "YOLODataset", "YOLOMultiModalDataset", diff --git a/ultralytics/data/dataset.py b/ultralytics/data/dataset.py index 8bbb08f..1848be9 100644 --- a/ultralytics/data/dataset.py +++ b/ultralytics/data/dataset.py @@ -5,15 +5,17 @@ import json from collections import defaultdict from itertools import repeat from multiprocessing.pool import ThreadPool +import os from pathlib import Path import cv2 import numpy as np import torch +import torchvision from PIL import Image from torch.utils.data import ConcatDataset -from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr +from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, yaml_load from ultralytics.utils.ops import resample_segments from .augment import ( @@ -504,3 +506,159 @@ class ClassificationDataset: x["msgs"] = msgs # warnings save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION) return samples + +class MTLDataset(): + """ + Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image + augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep + learning models, with optional image transformations and caching mechanisms to speed up training. + + This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images + in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process + to ensure data integrity and consistency. + + Attributes: + cache_ram (bool): Indicates if caching in RAM is enabled. + cache_disk (bool): Indicates if caching on disk is enabled. + samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache + file (if caching on disk), and optionally the loaded image array (if caching in RAM). + torch_transforms (callable): PyTorch transforms to be applied to the images. + """ + + def __init__(self, data, args, augment=False, prefix=""): + """ + Initialize YOLO object with root, image size, augmentations, and cache settings. + + Args: + root (str): Path to the dataset directory where images are stored in a class-specific folder structure. + args (Namespace): Configuration containing dataset-related settings such as image size, augmentation + parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction + of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training), + `auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`. + augment (bool, optional): Whether to apply augmentations to the dataset. Default is False. + prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and + debugging. Default is an empty string. + """ + # import torchvision # scope for faster 'import ultralytics' + + # # Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import + # self.base = torchvision.datasets.ImageFolder(root=root) + # self.samples = self.base.samples + # self.root = self.base.root + # # 获取标签文件 + self.data = data + self.root = data.rsplit('/', 1)[0] + + # 加载args.data yaml + config = yaml_load(args.data, append_filename=True) # dictionary + + self.root = config["path"] + self.names = [] + for i in range(len(config["tasks"])): + self.names.append(config["tasks"][i]) + + # # Initialize attributes + if augment and args.fraction < 1.0: # reduce training fraction + self.samples = self.samples[: round(len(self.samples) * args.fraction)] + self.prefix = colorstr(f"{prefix}: ") if prefix else "" + # self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM + # self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files + self.cache_ram = False + self.cache_disk = False + self.samples = self.load_data() # filter out bad images + # self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im + scale = (1.0 - args.scale, 1.0) # (0.08, 1.0) + self.torch_transforms = ( + classify_augmentations( + size=args.imgsz, + scale=scale, + hflip=args.fliplr, + vflip=args.flipud, + erasing=args.erasing, + auto_augment=args.auto_augment, + hsv_h=args.hsv_h, + hsv_s=args.hsv_s, + hsv_v=args.hsv_v, + ) + if augment + else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction) + ) + + def load_data(self): + samples = [] + with open(self.data, 'r', encoding='utf-8') as file: + for line in file: + parts = line.strip().split(',') + path = parts.pop(0) + label = [] + for i in range(len(parts)): + try: + label.append([k for k,v in self.names[i].items() if v==parts[i]][0]) + except: + label.append([k for k,v in self.names[i].items() if v==int(parts[i])][0]) + samples.append([os.path.join(self.root,path),label,None,None]) + + return samples + + def __getitem__(self, i): + """Returns subset of data and targets corresponding to given indices.""" + f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image + if self.cache_ram: + if im is None: # Warning: two separate if statements required here, do not combine this with previous line + im = self.samples[i][3] = cv2.imread(f) + elif self.cache_disk: + if not fn.exists(): # load npy + np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False) + im = np.load(fn) + else: # read image + im = cv2.imread(f) # BGR + # Convert NumPy array to PIL image + if im is not None: + im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) + sample = self.torch_transforms(im) + return {"img": sample, "cls": j} + return None + + def __len__(self) -> int: + """Return the total number of samples in the dataset.""" + return len(self.samples) + + def verify_images(self): + """Verify all images in dataset.""" + + desc = f"{self.prefix}Scanning {self.root}..." + path = Path(self.root).with_suffix(".cache") # *.cache file path + + with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError): + cache = load_dataset_cache_file(path) # attempt to load a *.cache file + assert cache["version"] == DATASET_CACHE_VERSION # matches current version + assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash + nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total + if LOCAL_RANK in {-1, 0}: + d = f"{desc} {nf} images, {nc} corrupt" + TQDM(None, desc=d, total=n, initial=n) + if cache["msgs"]: + LOGGER.info("\n".join(cache["msgs"])) # display warnings + return samples + + # Run scan if *.cache retrieval failed + nf, nc, msgs, samples, x = 0, 0, [], [], {} + with ThreadPool(NUM_THREADS) as pool: + results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix))) + pbar = TQDM(results, desc=desc, total=len(self.samples)) + for sample, nf_f, nc_f, msg in pbar: + if nf_f: + samples.append(sample) + if msg: + msgs.append(msg) + nf += nf_f + nc += nc_f + pbar.desc = f"{desc} {nf} images, {nc} corrupt" + pbar.close() + if msgs: + LOGGER.info("\n".join(msgs)) + x["hash"] = get_hash([x[0] for x in self.samples]) + x["results"] = nf, nc, len(samples), samples + x["msgs"] = msgs # warnings + save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION) + return samples \ No newline at end of file diff --git a/ultralytics/data/utils.py b/ultralytics/data/utils.py index fa9bfbb..823bf49 100644 --- a/ultralytics/data/utils.py +++ b/ultralytics/data/utils.py @@ -423,6 +423,31 @@ def check_cls_dataset(dataset, split=""): return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names} +def check_mtlcls_dataset(file, split=""): + # Read YAML + data = yaml_load(file, append_filename=True) # dictionary + + # Checks + for k in "train", "val": + if k not in data: + if k != "val" or "validation" not in data: + raise SyntaxError( + emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.") + ) + LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.") + data["val"] = data.pop("validation") # replace 'validation' key with 'val' key + if "tasks" not in data : + raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs.")) + + names = [] + nc = [] + for i in range(len(data["tasks"])): + nc.append(len(data["tasks"][i])) + names.append(check_class_names(data["tasks"][i])) + + + return {"train": data["train"], "val": data["val"], "test": data["test"], "nc": nc, "names": names} + class HUBDatasetStats: """ A class for generating HUB dataset JSON and `-hub` dataset directory. diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 85dbc4c..a557013 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -22,7 +22,7 @@ from torch import distributed as dist from torch import nn, optim from ultralytics.cfg import get_cfg, get_save_dir -from ultralytics.data.utils import check_cls_dataset, check_det_dataset +from ultralytics.data.utils import check_cls_dataset, check_det_dataset, check_mtlcls_dataset from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights from ultralytics.utils import ( DEFAULT_CFG, @@ -413,10 +413,16 @@ class BaseTrainer: loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1 losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) if RANK in {-1, 0}: - pbar.set_description( - ("%11s" * 2 + "%11.4g" * (2 + loss_len)) - % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]) - ) + if self.args.task == "MTL": + pbar.set_description( + ("%11s" * 2 + "%11.4g" * (2 + loss_len)) + % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"][0].shape[0], batch["img"].shape[-1]) + ) + else: + pbar.set_description( + ("%11s" * 2 + "%11.4g" * (2 + loss_len)) + % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]) + ) self.run_callbacks("on_batch_end") if self.args.plots and ni in self.plot_idx: self.plot_training_samples(batch, ni) @@ -542,6 +548,8 @@ class BaseTrainer: try: if self.args.task == "classify": data = check_cls_dataset(self.args.data) + elif self.args.task == "MTL": + data = check_mtlcls_dataset(self.args.data) elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in { "detect", "segment", diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index e2fbfff..63f844e 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -28,7 +28,7 @@ import numpy as np import torch from ultralytics.cfg import get_cfg, get_save_dir -from ultralytics.data.utils import check_cls_dataset, check_det_dataset +from ultralytics.data.utils import check_cls_dataset, check_det_dataset, check_mtlcls_dataset from ultralytics.nn.autobackend import AutoBackend from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis from ultralytics.utils.checks import check_imgsz @@ -101,7 +101,12 @@ class BaseValidator: self.plots = {} self.callbacks = _callbacks or callbacks.get_default_callbacks() - + def mean_if_list(self,lst): + if isinstance(lst, list) and lst: # 确保lst是列表且不为空 + return round(sum(lst) / len(self.names), 5) + else: + return lst # 如果不是列表或列表为空,直接返回v(可能已经是数字) + @smart_inference_mode() def __call__(self, trainer=None, model=None): """Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer @@ -140,7 +145,9 @@ class BaseValidator: self.args.batch = 1 # export.py models default to batch-size 1 LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models") - if str(self.args.data).split(".")[-1] in {"yaml", "yml"}: + if self.args.task == "MTL": + self.data = check_mtlcls_dataset(self.args.data, split=self.args.split) + elif str(self.args.data).split(".")[-1] in {"yaml", "yml"}: self.data = check_det_dataset(self.args.data) elif self.args.task == "classify": self.data = check_cls_dataset(self.args.data, split=self.args.split) @@ -202,7 +209,7 @@ class BaseValidator: if self.training: model.float() results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} - return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats + return {k: round(float(v) if not isinstance(v, list) else self.mean_if_list(v), 5) for k, v in results.items()} # return results as 5 decimal place floats else: LOGGER.info( "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image" diff --git a/ultralytics/models/yolo/MTL/__init__.py b/ultralytics/models/yolo/MTL/__init__.py new file mode 100644 index 0000000..29cd0ae --- /dev/null +++ b/ultralytics/models/yolo/MTL/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from ultralytics.models.yolo.MTL.predict import ClassificationPredictor +from ultralytics.models.yolo.MTL.train import ClassificationTrainer +from ultralytics.models.yolo.MTL.val import ClassificationValidator + +__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator" \ No newline at end of file diff --git a/ultralytics/models/yolo/MTL/predict.py b/ultralytics/models/yolo/MTL/predict.py new file mode 100644 index 0000000..998196e --- /dev/null +++ b/ultralytics/models/yolo/MTL/predict.py @@ -0,0 +1,61 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import cv2 +import torch +from PIL import Image + +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import DEFAULT_CFG, ops + + +class ClassificationPredictor(BasePredictor): + """ + A class extending the BasePredictor class for prediction based on a classification model. + + Notes: + - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'. + + Example: + ```python + from ultralytics.utils import ASSETS + from ultralytics.models.yolo.classify import ClassificationPredictor + + args = dict(model='yolov8n-cls.pt', source=ASSETS) + predictor = ClassificationPredictor(overrides=args) + predictor.predict_cli() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initializes ClassificationPredictor setting the task to 'classify'.""" + super().__init__(cfg, overrides, _callbacks) + self.args.task = "classify" + self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor" + + def preprocess(self, img): + """Converts input image to model-compatible data type.""" + if not isinstance(img, torch.Tensor): + is_legacy_transform = any( + self._legacy_transform_name in str(transform) for transform in self.transforms.transforms + ) + if is_legacy_transform: # to handle legacy transforms + img = torch.stack([self.transforms(im) for im in img], dim=0) + else: + img = torch.stack( + [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0 + ) + img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) + return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 + + def postprocess(self, preds, img, orig_imgs): + """Post-processes predictions to return Results objects.""" + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for i, pred in enumerate(preds): + orig_img = orig_imgs[i] + img_path = self.batch[0][i] + results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred)) + return results \ No newline at end of file diff --git a/ultralytics/models/yolo/MTL/train.py b/ultralytics/models/yolo/MTL/train.py new file mode 100644 index 0000000..75e2de3 --- /dev/null +++ b/ultralytics/models/yolo/MTL/train.py @@ -0,0 +1,168 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import gc +from math import dist +import math +import time +import warnings +import numpy as np + +import torch +import torchvision + +from ultralytics.data import MTLDataset, build_dataloader +from ultralytics.engine.trainer import BaseTrainer +from ultralytics.models import yolo +from ultralytics.nn.tasks import ClassificationModel,MTLClassificationModel, attempt_load_one_weight +from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, TQDM, colorstr, yaml_load +from ultralytics.utils.plotting import plot_images, plot_results +from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first + + +class ClassificationTrainer(BaseTrainer): + """ + A class extending the BaseTrainer class for training based on a classification model. + + Notes: + - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'. + + Example: + ```python + from ultralytics.models.yolo.classify import ClassificationTrainer + + args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3) + trainer = ClassificationTrainer(overrides=args) + trainer.train() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks.""" + if overrides is None: + overrides = {} + overrides["task"] = "MTL" + if overrides.get("imgsz") is None: + overrides["imgsz"] = 224 + super().__init__(cfg, overrides, _callbacks) + + def set_model_attributes(self): + """Set the YOLO model's class names from the loaded dataset.""" + self.model.names = self.data["names"] + + def get_model(self, cfg=None, weights=None, verbose=True): + """Returns a modified PyTorch model configured for training YOLO.""" + model = MTLClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + + for m in model.modules(): + if not self.args.pretrained and hasattr(m, "reset_parameters"): + m.reset_parameters() + if isinstance(m, torch.nn.Dropout) and self.args.dropout: + m.p = self.args.dropout # set dropout + for p in model.parameters(): + p.requires_grad = True # for training + return model + + def setup_model(self): + """Load, create or download model for any task.""" + if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed + return + + model, ckpt = str(self.model), None + # Load a YOLO model locally, from torchvision, or from Ultralytics assets + if model.endswith(".pt"): + self.model, ckpt = attempt_load_one_weight(model, device="cpu") + for p in self.model.parameters(): + p.requires_grad = True # for training + elif model.split(".")[-1] in {"yaml", "yml"}: + self.model = self.get_model(cfg=model) + elif model in torchvision.models.__dict__: + self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None) + else: + raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.") + ClassificationModel.reshape_outputs(self.model, self.data["nc"]) + + return ckpt + + def build_dataset(self, img_path, mode="train", batch=None): + """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.).""" + return MTLDataset(data=img_path, args=self.args, augment=mode == "train") + + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): + """Returns PyTorch DataLoader with transforms to preprocess images for inference.""" + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP + dataset = self.build_dataset(dataset_path, mode) + + loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank) + # Attach inference transforms + if mode != "train": + if is_parallel(self.model): + self.model.module.transforms = loader.dataset.torch_transforms + else: + self.model.transforms = loader.dataset.torch_transforms + return loader + + def preprocess_batch(self, batch): + """Preprocesses a batch of images and classes.""" + batch["img"] = batch["img"].to(self.device) + for i in range(len(batch["cls"])): + batch["cls"][i] = batch["cls"][i].to(self.device) + return batch + + + def progress_string(self): + """Returns a formatted string showing training progress.""" + return ("\n" + "%11s" * (4 + len(self.loss_names))) % ( + "Epoch", + "GPU_mem", + *self.loss_names, + "Instances", + "Size", + ) + + def get_validator(self): + """Returns an instance of ClassificationValidator for validation.""" + cfg = yaml_load(self.args.data) # model dict + self.loss_names = ["loss"] + [pair.split(':')[1] for pair in cfg["tasks_name"].split()] + return yolo.MTL.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks) + + def label_loss_items(self, loss_items=None, prefix="train"): + """ + Returns a loss dict with labelled training loss items tensor. + + Not needed for classification but necessary for segmentation & detection + """ + keys = [f"{prefix}/{x}" for x in self.loss_names] + if loss_items is None: + return keys + loss_items = [float(item) for item in loss_items] + return dict(zip(keys, loss_items)) + + def plot_metrics(self): + """Plots metrics from a CSV file.""" + plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png + + def final_eval(self): + """Evaluate trained model and save validation results.""" + for f in self.last, self.best: + if f.exists(): + strip_optimizer(f) # strip optimizers + if f is self.best: + LOGGER.info(f"\nValidating {f}...") + self.validator.args.data = self.args.data + self.validator.args.plots = self.args.plots + self.metrics = self.validator(model=f) + self.metrics.pop("fitness", None) + self.run_callbacks("on_fit_epoch_end") + LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") + + def plot_training_samples(self, batch, ni): + """Plots training samples with their annotations.""" + plot_images( + images=batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=batch["cls"][0].view(-1), # warning: use .view(), not .squeeze() for Classify models + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) \ No newline at end of file diff --git a/ultralytics/models/yolo/MTL/val.py b/ultralytics/models/yolo/MTL/val.py new file mode 100644 index 0000000..738fbc1 --- /dev/null +++ b/ultralytics/models/yolo/MTL/val.py @@ -0,0 +1,133 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import torch +import json + +from ultralytics.data import ClassificationDataset, build_dataloader +from ultralytics.engine.validator import BaseValidator +from ultralytics.utils import LOGGER +from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix, MTLClassifyMetrics +from ultralytics.utils.plotting import plot_images + + +class ClassificationValidator(BaseValidator): + """ + A class extending the BaseValidator class for validation based on a classification model. + + Notes: + - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'. + + Example: + ```python + from ultralytics.models.yolo.classify import ClassificationValidator + + args = dict(model='yolov8n-cls.pt', data='imagenet10') + validator = ClassificationValidator(args=args) + validator() + ``` + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.targets = None + self.pred = None + self.args.task = "MTL" + self.metrics = MTLClassifyMetrics() + + def get_desc(self): + """Returns a formatted string summarizing classification metrics.""" + return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc") + + def init_metrics(self, model): + """Initialize confusion matrix, class names, and top-1 and top-5 accuracy.""" + if isinstance(model.names[0],str): + for i in range(len(model.names)): + inner_s = model.names[i][1:-1] + pairs = inner_s.split(', ') + d = {} + for pair in pairs: + key, value = pair.split(': ') + d[int(key)] = value.strip("'") # 移除值两侧的单引号 + model.names[i] = d + self.names = model.names + self.nc = [ ] + for i in range(len(model.names)): + self.nc.append(len(model.names[i])) + self.confusion_matrix = [ ConfusionMatrix(nc=item, conf=self.args.conf, task="classify") for item in self.nc ] + self.pred = [] + self.targets = [] + + def preprocess(self, batch): + """Preprocesses input batch and returns it.""" + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = batch["img"].half() if self.args.half else batch["img"].float() + for i in range(len(batch["cls"])): + batch["cls"][i] = batch["cls"][i].to(self.device) + return batch + + def update_metrics(self, preds, batch): + """Updates running metrics with model predictions and batch targets.""" + lpred = [] + ltarget = [] + for i in range(len(self.names)): + n5 = min(len(self.names[i]), 5) + lpred.append(preds[i].argsort(1, descending=True)[:, :n5]) + self.pred.append(lpred) + ltarget.append(batch["cls"][i]) + self.targets.append(ltarget) + + def finalize_metrics(self, *args, **kwargs): + """Finalizes metrics of the model such as confusion_matrix and speed.""" + for i in range(len(self.names)): + self.confusion_matrix[i].process_cls_preds([sublist[i] for sublist in self.pred], [sublist[i] for sublist in self.targets]) + if self.args.plots: + for normalize in True, False: + self.confusion_matrix[i].plot( + save_dir=self.save_dir, names=self.names[i].values(), normalize=normalize, on_plot=self.on_plot + ) + self.metrics.speed = self.speed + self.metrics.confusion_matrix = self.confusion_matrix + self.metrics.save_dir = self.save_dir + + def get_stats(self): + """Returns a dictionary of metrics obtained by processing targets and predictions.""" + self.metrics.process(self.targets, self.pred) + return self.metrics.results_dict + + def build_dataset(self, img_path): + """Creates and returns a ClassificationDataset instance using given image path and preprocessing parameters.""" + return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split) + + def get_dataloader(self, dataset_path, batch_size): + """Builds and returns a data loader for classification tasks with given parameters.""" + dataset = self.build_dataset(dataset_path) + return build_dataloader(dataset, batch_size, self.args.workers, rank=-1) + + def print_results(self): + """Prints evaluation metrics for YOLO object detection model.""" + pf = "%22s" + "%11.3g" * len(self.metrics.srckeys) # print format + for i in range(len(self.names)): + LOGGER.info(pf % ("all", self.metrics.top1[i], self.metrics.top5[i])) + + def plot_val_samples(self, batch, ni): + """Plot validation image samples.""" + # plot_images( + # images=batch["img"], + # batch_idx=torch.arange(len(batch["img"])), + # cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models + # fname=self.save_dir / f"val_batch{ni}_labels.jpg", + # names=self.names, + # on_plot=self.on_plot, + # ) + + def plot_predictions(self, batch, preds, ni): + """Plots predicted bounding boxes on input images and saves the result.""" + # plot_images( + # batch["img"], + # batch_idx=torch.arange(len(batch["img"])), + # cls=torch.argmax(preds, dim=1), + # fname=self.save_dir / f"val_batch{ni}_pred.jpg", + # names=self.names, + # on_plot=self.on_plot, + # ) # pred diff --git a/ultralytics/models/yolo/__init__.py b/ultralytics/models/yolo/__init__.py index 8d9aedf..61265ea 100644 --- a/ultralytics/models/yolo/__init__.py +++ b/ultralytics/models/yolo/__init__.py @@ -1,7 +1,7 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from ultralytics.models.yolo import classify, detect, obb, pose, segment, world +from ultralytics.models.yolo import classify, MTL, detect, obb, pose, segment, world from .model import YOLO, YOLOWorld -__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld" +__all__ = "classify", "MTL" ,"segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld" diff --git a/ultralytics/models/yolo/model.py b/ultralytics/models/yolo/model.py index d540322..5b008ef 100644 --- a/ultralytics/models/yolo/model.py +++ b/ultralytics/models/yolo/model.py @@ -4,7 +4,7 @@ from pathlib import Path from ultralytics.engine.model import Model from ultralytics.models import yolo -from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel +from ultralytics.nn.tasks import ClassificationModel,MTLClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel from ultralytics.utils import ROOT, yaml_load @@ -31,6 +31,12 @@ class YOLO(Model): "trainer": yolo.classify.ClassificationTrainer, "validator": yolo.classify.ClassificationValidator, "predictor": yolo.classify.ClassificationPredictor, + }, + "MTL": { + "model": MTLClassificationModel, + "trainer": yolo.MTL.ClassificationTrainer, + "validator": yolo.MTL.ClassificationValidator, + "predictor": yolo.MTL.ClassificationPredictor, }, "detect": { "model": DetectionModel, diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 95fa825..c38d2ae 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -15,7 +15,7 @@ from .conv import Conv from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer from .utils import bias_init_with_prob, linear_init -__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect" +__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect", "MTLClassify" class Detect(nn.Module): @@ -600,4 +600,56 @@ class v10Detect(Detect): self.one2one_cv3 = copy.deepcopy(self.cv3) def switch_to_deploy(self): - del self.cv2, self.cv3 \ No newline at end of file + del self.cv2, self.cv3 + + +class MTLClassify(nn.Module): + """YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2).""" + + def __init__(self, c1, c2, k=1, s=1, p=None, g=1): + """Initializes YOLOv8 classification head with specified input and output channels, kernel size, stride, + padding, and groups. + """ + super().__init__() + if isinstance(c2, list): + self.n_task = len(c2) + else: + self.n_task = 1 + c_ = 1280 # efficientnet_b0 size + self.conv = Conv(c1, c_, k, s, p, g) + self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1) + self.drop = nn.Dropout(p=0.0, inplace=True) + self.linearpool = [] + if isinstance(c2, list): + for i in range(self.n_task): + self.linearpool.append(nn.Linear(c_, c2[i])) # to x(b,c2) + else: + self.linearpool.append(nn.Linear(c_, c2)) + + def forward(self, x): + """Performs a forward pass of the YOLO model on input image data.""" + temp = x + if isinstance(x, list): + x = torch.cat(x, 1) + out = [] + for i in range(self.n_task): + temp = self.conv(x) + temp = self.pool(temp).flatten(1) + temp = self.drop(temp) + + if self.n_task != 1: + if temp.dtype == torch.float16: + temp = self.linearpool[i].to('cuda').to(torch.float16)(temp.to('cuda')) + else: + temp = self.linearpool[i].to('cuda').to(torch.float32)(temp.to('cuda')) + if self.training: + out.append(temp.softmax(1)) + else: + out.append(temp) + else: + temp = self.linearpool[i](temp) + if self.training: + out = (temp.softmax(1)) + else: + out = temp + return out diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 258ab74..75f4576 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -61,11 +61,13 @@ from ultralytics.nn.modules import ( v10Detect, ) from ultralytics.nn.extra_modules import * +from ultralytics.nn.modules.head import MTLClassify from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml from ultralytics.utils.loss import ( E2EDetectLoss, v8ClassificationLoss, + v8MTLClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, @@ -538,7 +540,12 @@ class ClassificationModel(BaseModel): raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.") self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist self.stride = torch.Tensor([1]) # no stride constraints - self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict + if isinstance(self.yaml["nc"], list): + self.names = [] + for i in range(len(self.yaml["nc"])): + self.names.append({i: f"{j}" for j in range(self.yaml["nc"][i])}) # default names dict) + else: + self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict self.info() @staticmethod @@ -566,7 +573,60 @@ class ClassificationModel(BaseModel): """Initialize the loss criterion for the ClassificationModel.""" return v8ClassificationLoss() +class MTLClassificationModel(BaseModel): + """YOLOv8 classification model.""" + def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True): + """Init ClassificationModel with YAML, channels, number of classes, verbose flag.""" + super().__init__() + self._from_yaml(cfg, ch, nc, verbose) + + def _from_yaml(self, cfg, ch, nc, verbose): + """Set YOLOv8 model configurations and define the model architecture.""" + self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict + + # Define model + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: + LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") + self.yaml["nc"] = nc # override YAML value + elif not nc and not self.yaml.get("nc", None): + raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.") + self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist + self.stride = torch.Tensor([1]) # no stride constraints + if isinstance(self.yaml["nc"], list): + self.names = [] + for i in range(len(self.yaml["nc"])): + self.names.append({i: f"{j}" for j in range(self.yaml["nc"][i])}) # default names dict) + else: + self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict + self.info() + + @staticmethod + def reshape_outputs(model, nc): + """Update a TorchVision classification model to class count 'n' if required.""" + name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module + if isinstance(m, Classify): # YOLO Classify() head + if m.linear.out_features != nc: + m.linear = nn.Linear(m.linear.in_features, nc) + elif isinstance(m, nn.Linear): # ResNet, EfficientNet + if m.out_features != nc: + setattr(model, name, nn.Linear(m.in_features, nc)) + elif isinstance(m, nn.Sequential): + types = [type(x) for x in m] + if nn.Linear in types: + i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index + if m[i].out_features != nc: + m[i] = nn.Linear(m[i].in_features, nc) + elif nn.Conv2d in types: + i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index + if m[i].out_channels != nc: + m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None) + + def init_criterion(self): + """Initialize the loss criterion for the ClassificationModel.""" + return v8MTLClassificationLoss(self) + class RTDETRDetectionModel(DetectionModel): """ RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class. @@ -1018,7 +1078,7 @@ def parse_model(d, ch, verbose=True, warehouse_manager=None): # model_dict, inp args[j] = a n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain if m in { - Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP, C1, C2, C2f, ELAN1, AConv, SPPELAN, C2fAttn, C3, C3TR, + Classify, MTLClassify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP, C1, C2, C2f, ELAN1, AConv, SPPELAN, C2fAttn, C3, C3TR, C3Ghost, nn.Conv2d, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3, PSA, SCDown, C2fCIB, C2f_Faster, C2f_ODConv, C2f_Faster_EMA, C2f_DBB, GSConv, GSConvns, VoVGSCSP, VoVGSCSPns, VoVGSCSPC, C2f_CloAtt, C3_CloAtt, SCConv, C2f_SCConv, C3_SCConv, C2f_ScConv, C3_ScConv, C3_EMSC, C3_EMSCP, C2f_EMSC, C2f_EMSCP, RCSOSA, KWConv, C2f_KW, C3_KW, DySnakeConv, C2f_DySnakeConv, C3_DySnakeConv, @@ -1301,7 +1361,7 @@ def parse_model(d, ch, verbose=True, warehouse_manager=None): # model_dict, inp else: c2 = ch[f] - if isinstance(c2, list) and m not in {ChannelTransformer, PyramidContextExtraction, CrossLayerChannelAttention, CrossLayerSpatialAttention, MutilScaleEdgeInfoGenetator}: + if isinstance(c2, list) and m not in {MTLClassify,ChannelTransformer, PyramidContextExtraction, CrossLayerChannelAttention, CrossLayerSpatialAttention, MutilScaleEdgeInfoGenetator}: is_backbone = True m_ = m m_.backbone = True diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index b1af2e6..9251049 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -984,6 +984,23 @@ class v8ClassificationLoss: loss_items = loss.detach() return loss, loss_items +class v8MTLClassificationLoss: + """Criterion class for computing training losses.""" + def __init__(self,model): + self.device = next(model.parameters()).device # get model device + + def __call__(self, preds, batch): + """Compute the classification loss between predictions and true labels.""" + loss = torch.zeros(len(preds) + 1, device=self.device) + if isinstance(preds, list): + for i in range(len(preds)): + loss[i + 1] = torch.nn.functional.cross_entropy(preds[i], batch["cls"][i], reduction="mean") + loss[0] = loss.sum() + return loss.sum(), loss.detach() # loss(box, cls, dfl) + else: + loss = (torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")) + loss_items = loss.detach() + return loss, loss_items class v8OBBLoss(v8DetectionLoss): def __init__(self, model): diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py index 97991e9..626c88f 100644 --- a/ultralytics/utils/metrics.py +++ b/ultralytics/utils/metrics.py @@ -1830,6 +1830,85 @@ class ClassifyMetrics(SimpleClass): return [] +class MTLClassifyMetrics(SimpleClass): + """ + Class for computing classification metrics including top-1 and top-5 accuracy. + + Attributes: + top1 (float): The top-1 accuracy. + top5 (float): The top-5 accuracy. + speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline. + + Properties: + fitness (float): The fitness of the model, which is equal to top-5 accuracy. + results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness. + keys (List[str]): A list of keys for the results_dict. + + Methods: + process(targets, pred): Processes the targets and predictions to compute classification metrics. + """ + + def __init__(self) -> None: + """Initialize a ClassifyMetrics instance.""" + self.top1 = [0,0,0,0] + self.top5 = [0,0,0,0] + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "MTL" + + def process(self, targets, pred): + """Target classes and predicted classes.""" + # 多标签处理 + gpred = [] + gtargets = [] + gcorrect = [] + gacc = [] + + for i in range(len(targets[0])): + gpred.append( torch.cat([sublist[i] for sublist in pred])) + gtargets.append(torch.cat([sublist[i] for sublist in targets])) + gcorrect.append((gtargets[i][:, None] == gpred[i]).float()) + gacc.append(torch.stack((gcorrect[i][:, 0], gcorrect[i].max(1).values), dim=1)) # (top1, top5) accuracy + self.top1[i], self.top5[i] = gacc[i].mean(0).tolist() + + @property + def fitness(self): + """Returns mean of top-1 and top-5 accuracies as fitness score.""" + return sum((self.top1 + self.top5)) / 2 + + @property + def results_dict(self): + """返回一个包含模型性能指标和适应度得分的字典。""" + results = {} + # 将 top1 和 top5 的值分别填入字典 + for i in range(len(self.top1)): + results[f"metrics/accuracy_top1_{i}"] = self.top1[i] + results[f"metrics/accuracy_top5_{i}"] = self.top5[i] + # 添加 fitness 值 + results["fitness"] = self.fitness + return results + + @property + def keys(self): + metrics = [] + """Returns a list of keys for the results_dict property.""" + for i in range(len(self.top1)): + metrics.append("metrics/accuracy_top1_"+str(i)) + metrics.append("metrics/accuracy_top5_"+str(i)) + return metrics + @property + def srckeys(self): + return ["metrics/accuracy_top1", "metrics/accuracy_top5"] + + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [] + + @property + def curves_results(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [] + class OBBMetrics(SimpleClass): def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: """Initialize an OBBMetrics instance with directory, plotting, callback, and class names.""" diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index e6b1b4f..67c5d90 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -328,12 +328,12 @@ def get_flops(model, imgsz=640): # Use stride size for input tensor # stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride stride = 640 - im = torch.empty((1, 3, stride, stride), device=p.device) # input image in BCHW format + im = torch.zeros((1, 3, stride, stride), device=p.device) # input image in BCHW format flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs except Exception: # Use actual image size for input tensor (i.e. required for RTDETR models) - im = torch.empty((1, 3, *imgsz), device=p.device) # input image in BCHW format + im = torch.zeros((1, 3, *imgsz), device=p.device) # input image in BCHW format return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs except Exception: return 0.0 @@ -350,14 +350,14 @@ def get_flops_with_torch_profiler(model, imgsz=640): try: # Use stride size for input tensor stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride - im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format + im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format with torch.profiler.profile(with_flops=True) as prof: model(im) flops = sum(x.flops for x in prof.key_averages()) / 1e9 flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs except Exception: # Use actual image size for input tensor (i.e. required for RTDETR models) - im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format + im = torch.zeros((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format with torch.profiler.profile(with_flops=True) as prof: model(im) flops = sum(x.flops for x in prof.key_averages()) / 1e9 diff --git a/utils/tobacco_label.py b/utils/tobacco_label.py new file mode 100644 index 0000000..9946053 --- /dev/null +++ b/utils/tobacco_label.py @@ -0,0 +1,74 @@ +# 输入数据集路径,划分完成训练集与测试集 +# 配置文件里面进行路径拼接 + +from pathlib import Path +# import splitfolders +import os + +# 读取,设置标签 +def find_images_in_dir(directory, extensions=('.jpg', '.png', '.jpeg')): + for root, dirs, files in os.walk(directory): + for file in files: + if any(file.lower().endswith(ext) for ext in extensions): + yield os.path.join(root, file) + +def path_subtract(full_path, base_path): + # 将两个路径都转换为绝对路径,以确保它们是可比较的 + full_path = os.path.abspath(full_path) + base_path = os.path.abspath(base_path) + + # 检查base_path是否是full_path的前缀 + if full_path.startswith(base_path): + # 使用relpath函数获取相对于base_path的子路径 + relative_path = os.path.relpath(full_path, base_path) + return relative_path + else: + # 如果base_path不是full_path的前缀,则不能相减(或者说结果无意义) + return None + +def getlabel(root): + train_set = [] + val_set = [] + test_set = [] + for image_path in find_images_in_dir(root): + # 处理每一个文件,训练集 + subpath = path_subtract(image_path,root) + Parts = (Path(subpath)).parts[1] + + if 'train' in subpath: + train_set.append([subpath,',',Parts.replace('_', ',') ]) + if 'val' in subpath: + val_set.append([subpath,',',Parts.replace('_', ',') ]) + if 'test' in subpath: + test_set.append([subpath,',',Parts.replace('_', ',') ]) + + if not os.path.exists(os.path.join(root,'label')): + # 如果文件夹不存在,则创建它 + os.makedirs(os.path.join(root,'label')) + + # 输出到txt + if len(train_set) > 0: + with open(os.path.join(root,'label','train.txt'), 'w', encoding='utf-8') as file: + # 使用循环将列表中的每一项写入文件,并在每项后面加上换行符 + for item in train_set: + for sub in item: + file.write(sub) + file.write('\n') + if len(val_set) > 0: + with open(os.path.join(root,'label','val.txt'), 'w', encoding='utf-8') as file: + # 使用循环将列表中的每一项写入文件,并在每项后面加上换行符 + for item in val_set: + for sub in item: + file.write(sub) + file.write('\n') + if len(test_set) > 0: + with open(os.path.join(root,'label','test.txt'), 'w', encoding='utf-8') as file: + # 使用循环将列表中的每一项写入文件,并在每项后面加上换行符 + for item in test_set: + for sub in item: + file.write(sub) + file.write('\n') + +if __name__ == '__main__': + # train:validation:test=8:1:1 + getlabel('G:/dataset/test/split') \ No newline at end of file