stevengrove
initial commit
186701e
raw
history blame contribute delete
No virus
3.55 kB
# Copyright (c) Tencent Inc. All rights reserved.
from typing import List, Tuple, Union
from torch import Tensor
from mmdet.structures import OptSampleList, SampleList
from mmyolo.models.detectors import YOLODetector
from mmyolo.registry import MODELS
@MODELS.register_module()
class YOLOWorldDetector(YOLODetector):
"""Implementation of YOLOW Series"""
def __init__(self,
*args,
mm_neck: bool = False,
num_train_classes=80,
num_test_classes=80,
**kwargs) -> None:
self.mm_neck = mm_neck
self.num_train_classes = num_train_classes
self.num_test_classes = num_test_classes
super().__init__(*args, **kwargs)
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, list]:
"""Calculate losses from a batch of inputs and data samples."""
self.bbox_head.num_classes = self.num_train_classes
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
losses = self.bbox_head.loss(img_feats, txt_feats, batch_data_samples)
return losses
def predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
"""
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
# self.bbox_head.num_classes = self.num_test_classes
self.bbox_head.num_classes = txt_feats[0].shape[0]
results_list = self.bbox_head.predict(img_feats,
txt_feats,
batch_data_samples,
rescale=rescale)
batch_data_samples = self.add_pred_to_datasample(
batch_data_samples, results_list)
return batch_data_samples
def reparameterize(self, texts: List[List[str]]) -> None:
self.texts = texts
def _forward(
self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
"""
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
results = self.bbox_head.forward(img_feats, txt_feats)
return results
def extract_feat(
self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]:
"""Extract features."""
if batch_data_samples is None:
texts = self.texts
elif isinstance(batch_data_samples, dict):
texts = batch_data_samples['texts']
elif isinstance(batch_data_samples, list):
texts = [data_sample.texts for data_sample in batch_data_samples]
else:
raise TypeError('batch_data_samples should be dict or list.')
img_feats, txt_feats = self.backbone(batch_inputs, texts)
if self.with_neck:
if self.mm_neck:
img_feats = self.neck(img_feats, txt_feats)
else:
img_feats = self.neck(img_feats)
return img_feats, txt_feats