固定损失权重
This commit is contained in:
parent
b8b3255eeb
commit
7ee9f27471
@ -69,6 +69,7 @@ from ultralytics.utils.loss import (
|
|||||||
v8ClassificationLoss,
|
v8ClassificationLoss,
|
||||||
v8MTLClassificationLoss,
|
v8MTLClassificationLoss,
|
||||||
v8DetectionLoss,
|
v8DetectionLoss,
|
||||||
|
v8MTLUWClassificationLoss,
|
||||||
v8OBBLoss,
|
v8OBBLoss,
|
||||||
v8PoseLoss,
|
v8PoseLoss,
|
||||||
v8SegmentationLoss,
|
v8SegmentationLoss,
|
||||||
@ -95,7 +96,6 @@ from ultralytics.nn.backbone.revcol import *
|
|||||||
from ultralytics.nn.backbone.lsknet import *
|
from ultralytics.nn.backbone.lsknet import *
|
||||||
from ultralytics.nn.backbone.SwinTransformer import *
|
from ultralytics.nn.backbone.SwinTransformer import *
|
||||||
from ultralytics.nn.backbone.repvit import *
|
from ultralytics.nn.backbone.repvit import *
|
||||||
from ultralytics.nn.backbone.resnet import *
|
|
||||||
from ultralytics.nn.backbone.CSwomTramsformer import *
|
from ultralytics.nn.backbone.CSwomTramsformer import *
|
||||||
from ultralytics.nn.backbone.UniRepLKNet import *
|
from ultralytics.nn.backbone.UniRepLKNet import *
|
||||||
from ultralytics.nn.backbone.TransNext import *
|
from ultralytics.nn.backbone.TransNext import *
|
||||||
@ -625,7 +625,7 @@ class MTLClassificationModel(BaseModel):
|
|||||||
|
|
||||||
def init_criterion(self):
|
def init_criterion(self):
|
||||||
"""Initialize the loss criterion for the ClassificationModel."""
|
"""Initialize the loss criterion for the ClassificationModel."""
|
||||||
return v8MTLClassificationLoss(self)
|
return v8MTLUWClassificationLoss(self)
|
||||||
|
|
||||||
class RTDETRDetectionModel(DetectionModel):
|
class RTDETRDetectionModel(DetectionModel):
|
||||||
"""
|
"""
|
||||||
|
@ -996,12 +996,39 @@ class v8MTLClassificationLoss:
|
|||||||
for i in range(len(preds)):
|
for i in range(len(preds)):
|
||||||
loss[i + 1] = torch.nn.functional.cross_entropy(preds[i], batch["cls"][i], reduction="mean")
|
loss[i + 1] = torch.nn.functional.cross_entropy(preds[i], batch["cls"][i], reduction="mean")
|
||||||
loss[0] = loss.sum()
|
loss[0] = loss.sum()
|
||||||
return loss.sum(), loss.detach() # loss(box, cls, dfl)
|
return loss[0], loss.detach() # loss(box, cls, dfl)
|
||||||
else:
|
else:
|
||||||
loss = (torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean"))
|
loss = (torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean"))
|
||||||
loss_items = loss.detach()
|
loss_items = loss.detach()
|
||||||
return loss, loss_items
|
return loss, loss_items
|
||||||
|
|
||||||
|
class v8MTLUWClassificationLoss(nn.Module):
|
||||||
|
"""Criterion class for computing training losses with learnable uncertainty weights."""
|
||||||
|
def __init__(self, model,task_numbers = 3):
|
||||||
|
super().__init__()
|
||||||
|
self.device = next(model.parameters()).device
|
||||||
|
|
||||||
|
self.logvars = torch.ones(task_numbers, device=self.device)
|
||||||
|
|
||||||
|
# self.register_parameter('logvars', self.logvars)
|
||||||
|
# self.task_numbers = task_numbers
|
||||||
|
|
||||||
|
def forward(self, preds, batch):
|
||||||
|
"""Compute the classification loss between predictions and true labels."""
|
||||||
|
loss = torch.zeros(len(preds) + 1, device=self.device)
|
||||||
|
total_loss = torch.zeros(1, device=self.device)
|
||||||
|
if isinstance(preds, list):
|
||||||
|
for i in range(len(preds)):
|
||||||
|
loss[i + 1] = torch.nn.functional.cross_entropy(preds[i], batch["cls"][i], reduction="mean")
|
||||||
|
total_loss += (1.0 / (self.logvars[i] ** 2) * loss[i + 1] + torch.log(self.logvars[i]))
|
||||||
|
loss[0] = total_loss
|
||||||
|
return loss[0], loss.detach() # loss(box, cls, dfl)
|
||||||
|
else:
|
||||||
|
loss = (torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean"))
|
||||||
|
loss_items = loss.detach()
|
||||||
|
return loss, loss_items
|
||||||
|
|
||||||
|
|
||||||
class v8OBBLoss(v8DetectionLoss):
|
class v8OBBLoss(v8DetectionLoss):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user