# Copyright (c) Tencent Inc. All rights reserved. from typing import List, Tuple, Union import torch import torch.nn as nn from torch import Tensor from mmdet.structures import OptSampleList, SampleList from mmyolo.models.detectors import YOLODetector from mmyolo.registry import MODELS @MODELS.register_module() class YOLOWorldDetector(YOLODetector): """Implementation of YOLOW Series""" def __init__(self, *args, mm_neck: bool = False, num_train_classes=80, num_test_classes=80, **kwargs) -> None: self.mm_neck = mm_neck self.num_train_classes = num_train_classes self.num_test_classes = num_test_classes super().__init__(*args, **kwargs) def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> Union[dict, list]: """Calculate losses from a batch of inputs and data samples.""" self.bbox_head.num_classes = self.num_train_classes img_feats, txt_feats = self.extract_feat(batch_inputs, batch_data_samples) losses = self.bbox_head.loss(img_feats, txt_feats, batch_data_samples) return losses def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, rescale: bool = True) -> SampleList: """Predict results from a batch of inputs and data samples with post- processing. """ img_feats, txt_feats = self.extract_feat(batch_inputs, batch_data_samples) # self.bbox_head.num_classes = self.num_test_classes self.bbox_head.num_classes = txt_feats[0].shape[0] results_list = self.bbox_head.predict(img_feats, txt_feats, batch_data_samples, rescale=rescale) batch_data_samples = self.add_pred_to_datasample( batch_data_samples, results_list) return batch_data_samples def reparameterize(self, texts: List[List[str]]) -> None: # encode text embeddings into the detector self.texts = texts self.text_feats = self.backbone.forward_text(texts) def _forward( self, batch_inputs: Tensor, batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: """Network forward process. Usually includes backbone, neck and head forward without any post-processing. """ img_feats, txt_feats = self.extract_feat(batch_inputs, batch_data_samples) results = self.bbox_head.forward(img_feats, txt_feats) return results def extract_feat( self, batch_inputs: Tensor, batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]: """Extract features.""" txt_feats = None if batch_data_samples is None: texts = self.texts txt_feats = self.text_feats elif isinstance(batch_data_samples, dict) and 'texts' in batch_data_samples: texts = batch_data_samples['texts'] elif isinstance(batch_data_samples, list) and hasattr(batch_data_samples[0], 'texts'): texts = [data_sample.texts for data_sample in batch_data_samples] elif hasattr(self, 'text_feats'): texts = self.texts txt_feats = self.text_feats else: raise TypeError('batch_data_samples should be dict or list.') if txt_feats is not None: # forward image only img_feats = self.backbone.forward_image(batch_inputs) else: img_feats, txt_feats = self.backbone(batch_inputs, texts) if self.with_neck: if self.mm_neck: img_feats = self.neck(img_feats, txt_feats) else: img_feats = self.neck(img_feats) return img_feats, txt_feats @MODELS.register_module() class YOLOWorldPromptDetector(YOLODetector): """Implementation of YOLO World Series""" def __init__(self, *args, mm_neck: bool = False, num_train_classes=80, num_test_classes=80, prompt_dim=512, num_prompts=80, embedding_path='', freeze_prompt=False, use_mlp_adapter=False, **kwargs) -> None: self.mm_neck = mm_neck self.num_training_classes = num_train_classes self.num_test_classes = num_test_classes self.prompt_dim = prompt_dim self.num_prompts = num_prompts self.freeze_prompt = freeze_prompt self.use_mlp_adapter = use_mlp_adapter super().__init__(*args, **kwargs) if len(embedding_path) > 0: import numpy as np self.embeddings = torch.nn.Parameter( torch.from_numpy(np.load(embedding_path)).float()) else: # random init embeddings = nn.functional.normalize( torch.randn((num_prompts, prompt_dim)),dim=-1) self.embeddings = nn.Parameter(embeddings) if self.freeze_prompt: self.embeddings.requires_grad = False else: self.embeddings.requires_grad = True if use_mlp_adapter: self.adapter = nn.Sequential(nn.Linear(prompt_dim, prompt_dim * 2), nn.ReLU(True), nn.Linear(prompt_dim * 2, prompt_dim)) else: self.adapter = None def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> Union[dict, list]: """Calculate losses from a batch of inputs and data samples.""" self.bbox_head.num_classes = self.num_training_classes img_feats, txt_feats = self.extract_feat(batch_inputs, batch_data_samples) losses = self.bbox_head.loss(img_feats, txt_feats, batch_data_samples) return losses def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, rescale: bool = True) -> SampleList: """Predict results from a batch of inputs and data samples with post- processing. """ img_feats, txt_feats = self.extract_feat(batch_inputs, batch_data_samples) self.bbox_head.num_classes = self.num_test_classes results_list = self.bbox_head.predict(img_feats, txt_feats, batch_data_samples, rescale=rescale) batch_data_samples = self.add_pred_to_datasample( batch_data_samples, results_list) return batch_data_samples def _forward( self, batch_inputs: Tensor, batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: """Network forward process. Usually includes backbone, neck and head forward without any post-processing. """ img_feats, txt_feats = self.extract_feat(batch_inputs, batch_data_samples) results = self.bbox_head.forward(img_feats, txt_feats) return results def extract_feat( self, batch_inputs: Tensor, batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]: """Extract features.""" # only image features img_feats, _ = self.backbone(batch_inputs, None) # use embeddings txt_feats = self.embeddings[None] if self.adapter is not None: txt_feats = self.adapter(txt_feats) + txt_feats txt_feats = nn.functional.normalize(txt_feats, dim=-1, p=2) txt_feats = txt_feats.repeat(img_feats[0].shape[0], 1, 1) if self.with_neck: if self.mm_neck: img_feats = self.neck(img_feats, txt_feats) else: img_feats = self.neck(img_feats) return img_feats, txt_feats