Spaces:
Build error
Build error
# Copyright (c) Tencent Inc. All rights reserved. | |
from typing import List, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from mmdet.structures import OptSampleList, SampleList | |
from mmyolo.models.detectors import YOLODetector | |
from mmyolo.registry import MODELS | |
class YOLOWorldDetector(YOLODetector): | |
"""Implementation of YOLOW Series""" | |
def __init__(self, | |
*args, | |
mm_neck: bool = False, | |
num_train_classes=80, | |
num_test_classes=80, | |
**kwargs) -> None: | |
self.mm_neck = mm_neck | |
self.num_train_classes = num_train_classes | |
self.num_test_classes = num_test_classes | |
super().__init__(*args, **kwargs) | |
def loss(self, batch_inputs: Tensor, | |
batch_data_samples: SampleList) -> Union[dict, list]: | |
"""Calculate losses from a batch of inputs and data samples.""" | |
self.bbox_head.num_classes = self.num_train_classes | |
img_feats, txt_feats = self.extract_feat(batch_inputs, | |
batch_data_samples) | |
losses = self.bbox_head.loss(img_feats, txt_feats, batch_data_samples) | |
return losses | |
def predict(self, | |
batch_inputs: Tensor, | |
batch_data_samples: SampleList, | |
rescale: bool = True) -> SampleList: | |
"""Predict results from a batch of inputs and data samples with post- | |
processing. | |
""" | |
img_feats, txt_feats = self.extract_feat(batch_inputs, | |
batch_data_samples) | |
# self.bbox_head.num_classes = self.num_test_classes | |
self.bbox_head.num_classes = txt_feats[0].shape[0] | |
results_list = self.bbox_head.predict(img_feats, | |
txt_feats, | |
batch_data_samples, | |
rescale=rescale) | |
batch_data_samples = self.add_pred_to_datasample( | |
batch_data_samples, results_list) | |
return batch_data_samples | |
def reparameterize(self, texts: List[List[str]]) -> None: | |
# encode text embeddings into the detector | |
self.texts = texts | |
self.text_feats = self.backbone.forward_text(texts) | |
def _forward( | |
self, | |
batch_inputs: Tensor, | |
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: | |
"""Network forward process. Usually includes backbone, neck and head | |
forward without any post-processing. | |
""" | |
img_feats, txt_feats = self.extract_feat(batch_inputs, | |
batch_data_samples) | |
results = self.bbox_head.forward(img_feats, txt_feats) | |
return results | |
def extract_feat( | |
self, batch_inputs: Tensor, | |
batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]: | |
"""Extract features.""" | |
txt_feats = None | |
if batch_data_samples is None: | |
texts = self.texts | |
txt_feats = self.text_feats | |
elif isinstance(batch_data_samples, dict) and 'texts' in batch_data_samples: | |
texts = batch_data_samples['texts'] | |
elif isinstance(batch_data_samples, list) and hasattr(batch_data_samples[0], 'texts'): | |
texts = [data_sample.texts for data_sample in batch_data_samples] | |
elif hasattr(self, 'text_feats'): | |
texts = self.texts | |
txt_feats = self.text_feats | |
else: | |
raise TypeError('batch_data_samples should be dict or list.') | |
if txt_feats is not None: | |
# forward image only | |
img_feats = self.backbone.forward_image(batch_inputs) | |
else: | |
img_feats, txt_feats = self.backbone(batch_inputs, texts) | |
if self.with_neck: | |
if self.mm_neck: | |
img_feats = self.neck(img_feats, txt_feats) | |
else: | |
img_feats = self.neck(img_feats) | |
return img_feats, txt_feats | |
class YOLOWorldPromptDetector(YOLODetector): | |
"""Implementation of YOLO World Series""" | |
def __init__(self, | |
*args, | |
mm_neck: bool = False, | |
num_train_classes=80, | |
num_test_classes=80, | |
prompt_dim=512, | |
num_prompts=80, | |
embedding_path='', | |
freeze_prompt=False, | |
use_mlp_adapter=False, | |
**kwargs) -> None: | |
self.mm_neck = mm_neck | |
self.num_training_classes = num_train_classes | |
self.num_test_classes = num_test_classes | |
self.prompt_dim = prompt_dim | |
self.num_prompts = num_prompts | |
self.freeze_prompt = freeze_prompt | |
self.use_mlp_adapter = use_mlp_adapter | |
super().__init__(*args, **kwargs) | |
if len(embedding_path) > 0: | |
import numpy as np | |
self.embeddings = torch.nn.Parameter( | |
torch.from_numpy(np.load(embedding_path)).float()) | |
else: | |
# random init | |
embeddings = nn.functional.normalize( | |
torch.randn((num_prompts, prompt_dim)),dim=-1) | |
self.embeddings = nn.Parameter(embeddings) | |
if self.freeze_prompt: | |
self.embeddings.requires_grad = False | |
else: | |
self.embeddings.requires_grad = True | |
if use_mlp_adapter: | |
self.adapter = nn.Sequential(nn.Linear(prompt_dim, prompt_dim * 2), | |
nn.ReLU(True), | |
nn.Linear(prompt_dim * 2, prompt_dim)) | |
else: | |
self.adapter = None | |
def loss(self, batch_inputs: Tensor, | |
batch_data_samples: SampleList) -> Union[dict, list]: | |
"""Calculate losses from a batch of inputs and data samples.""" | |
self.bbox_head.num_classes = self.num_training_classes | |
img_feats, txt_feats = self.extract_feat(batch_inputs, | |
batch_data_samples) | |
losses = self.bbox_head.loss(img_feats, txt_feats, batch_data_samples) | |
return losses | |
def predict(self, | |
batch_inputs: Tensor, | |
batch_data_samples: SampleList, | |
rescale: bool = True) -> SampleList: | |
"""Predict results from a batch of inputs and data samples with post- | |
processing. | |
""" | |
img_feats, txt_feats = self.extract_feat(batch_inputs, | |
batch_data_samples) | |
self.bbox_head.num_classes = self.num_test_classes | |
results_list = self.bbox_head.predict(img_feats, | |
txt_feats, | |
batch_data_samples, | |
rescale=rescale) | |
batch_data_samples = self.add_pred_to_datasample( | |
batch_data_samples, results_list) | |
return batch_data_samples | |
def _forward( | |
self, | |
batch_inputs: Tensor, | |
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: | |
"""Network forward process. Usually includes backbone, neck and head | |
forward without any post-processing. | |
""" | |
img_feats, txt_feats = self.extract_feat(batch_inputs, | |
batch_data_samples) | |
results = self.bbox_head.forward(img_feats, txt_feats) | |
return results | |
def extract_feat( | |
self, batch_inputs: Tensor, | |
batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]: | |
"""Extract features.""" | |
# only image features | |
img_feats, _ = self.backbone(batch_inputs, None) | |
# use embeddings | |
txt_feats = self.embeddings[None] | |
if self.adapter is not None: | |
txt_feats = self.adapter(txt_feats) + txt_feats | |
txt_feats = nn.functional.normalize(txt_feats, dim=-1, p=2) | |
txt_feats = txt_feats.repeat(img_feats[0].shape[0], 1, 1) | |
if self.with_neck: | |
if self.mm_neck: | |
img_feats = self.neck(img_feats, txt_feats) | |
else: | |
img_feats = self.neck(img_feats) | |
return img_feats, txt_feats | |