固定损失权重

This commit is contained in:
yoiannis 2025-03-02 17:21:21 +08:00
parent b8b3255eeb
commit 7ee9f27471
2 changed files with 30 additions and 3 deletions

View File

@ -69,6 +69,7 @@ from ultralytics.utils.loss import (
v8ClassificationLoss, v8ClassificationLoss,
v8MTLClassificationLoss, v8MTLClassificationLoss,
v8DetectionLoss, v8DetectionLoss,
v8MTLUWClassificationLoss,
v8OBBLoss, v8OBBLoss,
v8PoseLoss, v8PoseLoss,
v8SegmentationLoss, v8SegmentationLoss,
@ -95,7 +96,6 @@ from ultralytics.nn.backbone.revcol import *
from ultralytics.nn.backbone.lsknet import * from ultralytics.nn.backbone.lsknet import *
from ultralytics.nn.backbone.SwinTransformer import * from ultralytics.nn.backbone.SwinTransformer import *
from ultralytics.nn.backbone.repvit import * from ultralytics.nn.backbone.repvit import *
from ultralytics.nn.backbone.resnet import *
from ultralytics.nn.backbone.CSwomTramsformer import * from ultralytics.nn.backbone.CSwomTramsformer import *
from ultralytics.nn.backbone.UniRepLKNet import * from ultralytics.nn.backbone.UniRepLKNet import *
from ultralytics.nn.backbone.TransNext import * from ultralytics.nn.backbone.TransNext import *
@ -625,7 +625,7 @@ class MTLClassificationModel(BaseModel):
def init_criterion(self): def init_criterion(self):
"""Initialize the loss criterion for the ClassificationModel.""" """Initialize the loss criterion for the ClassificationModel."""
return v8MTLClassificationLoss(self) return v8MTLUWClassificationLoss(self)
class RTDETRDetectionModel(DetectionModel): class RTDETRDetectionModel(DetectionModel):
""" """

View File

@ -996,12 +996,39 @@ class v8MTLClassificationLoss:
for i in range(len(preds)): for i in range(len(preds)):
loss[i + 1] = torch.nn.functional.cross_entropy(preds[i], batch["cls"][i], reduction="mean") loss[i + 1] = torch.nn.functional.cross_entropy(preds[i], batch["cls"][i], reduction="mean")
loss[0] = loss.sum() loss[0] = loss.sum()
return loss.sum(), loss.detach() # loss(box, cls, dfl) return loss[0], loss.detach() # loss(box, cls, dfl)
else: else:
loss = (torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")) loss = (torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean"))
loss_items = loss.detach() loss_items = loss.detach()
return loss, loss_items return loss, loss_items
class v8MTLUWClassificationLoss(nn.Module):
"""Criterion class for computing training losses with learnable uncertainty weights."""
def __init__(self, model,task_numbers = 3):
super().__init__()
self.device = next(model.parameters()).device
self.logvars = torch.ones(task_numbers, device=self.device)
# self.register_parameter('logvars', self.logvars)
# self.task_numbers = task_numbers
def forward(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
loss = torch.zeros(len(preds) + 1, device=self.device)
total_loss = torch.zeros(1, device=self.device)
if isinstance(preds, list):
for i in range(len(preds)):
loss[i + 1] = torch.nn.functional.cross_entropy(preds[i], batch["cls"][i], reduction="mean")
total_loss += (1.0 / (self.logvars[i] ** 2) * loss[i + 1] + torch.log(self.logvars[i]))
loss[0] = total_loss
return loss[0], loss.detach() # loss(box, cls, dfl)
else:
loss = (torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean"))
loss_items = loss.detach()
return loss, loss_items
class v8OBBLoss(v8DetectionLoss): class v8OBBLoss(v8DetectionLoss):
def __init__(self, model): def __init__(self, model):
""" """