diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 75f4576..c9fbd5c 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -69,6 +69,7 @@ from ultralytics.utils.loss import ( v8ClassificationLoss, v8MTLClassificationLoss, v8DetectionLoss, + v8MTLUWClassificationLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, @@ -95,7 +96,6 @@ from ultralytics.nn.backbone.revcol import * from ultralytics.nn.backbone.lsknet import * from ultralytics.nn.backbone.SwinTransformer import * from ultralytics.nn.backbone.repvit import * -from ultralytics.nn.backbone.resnet import * from ultralytics.nn.backbone.CSwomTramsformer import * from ultralytics.nn.backbone.UniRepLKNet import * from ultralytics.nn.backbone.TransNext import * @@ -625,7 +625,7 @@ class MTLClassificationModel(BaseModel): def init_criterion(self): """Initialize the loss criterion for the ClassificationModel.""" - return v8MTLClassificationLoss(self) + return v8MTLUWClassificationLoss(self) class RTDETRDetectionModel(DetectionModel): """ diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 9251049..76b0c3b 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -996,12 +996,39 @@ class v8MTLClassificationLoss: for i in range(len(preds)): loss[i + 1] = torch.nn.functional.cross_entropy(preds[i], batch["cls"][i], reduction="mean") loss[0] = loss.sum() - return loss.sum(), loss.detach() # loss(box, cls, dfl) + return loss[0], loss.detach() # loss(box, cls, dfl) else: loss = (torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")) loss_items = loss.detach() return loss, loss_items +class v8MTLUWClassificationLoss(nn.Module): + """Criterion class for computing training losses with learnable uncertainty weights.""" + def __init__(self, model,task_numbers = 3): + super().__init__() + self.device = next(model.parameters()).device + + self.logvars = torch.ones(task_numbers, device=self.device) + + # self.register_parameter('logvars', self.logvars) + # self.task_numbers = task_numbers + + def forward(self, preds, batch): + """Compute the classification loss between predictions and true labels.""" + loss = torch.zeros(len(preds) + 1, device=self.device) + total_loss = torch.zeros(1, device=self.device) + if isinstance(preds, list): + for i in range(len(preds)): + loss[i + 1] = torch.nn.functional.cross_entropy(preds[i], batch["cls"][i], reduction="mean") + total_loss += (1.0 / (self.logvars[i] ** 2) * loss[i + 1] + torch.log(self.logvars[i])) + loss[0] = total_loss + return loss[0], loss.detach() # loss(box, cls, dfl) + else: + loss = (torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")) + loss_items = loss.detach() + return loss, loss_items + + class v8OBBLoss(v8DetectionLoss): def __init__(self, model): """