csmithxc's picture
Upload 146 files
1530901 verified
raw
history blame
No virus
8.38 kB
# Copyright (c) Tencent Inc. All rights reserved.
from typing import List, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from mmdet.structures import OptSampleList, SampleList
from mmyolo.models.detectors import YOLODetector
from mmyolo.registry import MODELS
@MODELS.register_module()
class YOLOWorldDetector(YOLODetector):
"""Implementation of YOLOW Series"""
def __init__(self,
*args,
mm_neck: bool = False,
num_train_classes=80,
num_test_classes=80,
**kwargs) -> None:
self.mm_neck = mm_neck
self.num_train_classes = num_train_classes
self.num_test_classes = num_test_classes
super().__init__(*args, **kwargs)
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, list]:
"""Calculate losses from a batch of inputs and data samples."""
self.bbox_head.num_classes = self.num_train_classes
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
losses = self.bbox_head.loss(img_feats, txt_feats, batch_data_samples)
return losses
def predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
"""
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
# self.bbox_head.num_classes = self.num_test_classes
self.bbox_head.num_classes = txt_feats[0].shape[0]
results_list = self.bbox_head.predict(img_feats,
txt_feats,
batch_data_samples,
rescale=rescale)
batch_data_samples = self.add_pred_to_datasample(
batch_data_samples, results_list)
return batch_data_samples
def reparameterize(self, texts: List[List[str]]) -> None:
# encode text embeddings into the detector
self.texts = texts
self.text_feats = self.backbone.forward_text(texts)
def _forward(
self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
"""
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
results = self.bbox_head.forward(img_feats, txt_feats)
return results
def extract_feat(
self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]:
"""Extract features."""
txt_feats = None
if batch_data_samples is None:
texts = self.texts
txt_feats = self.text_feats
elif isinstance(batch_data_samples, dict) and 'texts' in batch_data_samples:
texts = batch_data_samples['texts']
elif isinstance(batch_data_samples, list) and hasattr(batch_data_samples[0], 'texts'):
texts = [data_sample.texts for data_sample in batch_data_samples]
elif hasattr(self, 'text_feats'):
texts = self.texts
txt_feats = self.text_feats
else:
raise TypeError('batch_data_samples should be dict or list.')
if txt_feats is not None:
# forward image only
img_feats = self.backbone.forward_image(batch_inputs)
else:
img_feats, txt_feats = self.backbone(batch_inputs, texts)
if self.with_neck:
if self.mm_neck:
img_feats = self.neck(img_feats, txt_feats)
else:
img_feats = self.neck(img_feats)
return img_feats, txt_feats
@MODELS.register_module()
class YOLOWorldPromptDetector(YOLODetector):
"""Implementation of YOLO World Series"""
def __init__(self,
*args,
mm_neck: bool = False,
num_train_classes=80,
num_test_classes=80,
prompt_dim=512,
num_prompts=80,
embedding_path='',
freeze_prompt=False,
use_mlp_adapter=False,
**kwargs) -> None:
self.mm_neck = mm_neck
self.num_training_classes = num_train_classes
self.num_test_classes = num_test_classes
self.prompt_dim = prompt_dim
self.num_prompts = num_prompts
self.freeze_prompt = freeze_prompt
self.use_mlp_adapter = use_mlp_adapter
super().__init__(*args, **kwargs)
if len(embedding_path) > 0:
import numpy as np
self.embeddings = torch.nn.Parameter(
torch.from_numpy(np.load(embedding_path)).float())
else:
# random init
embeddings = nn.functional.normalize(
torch.randn((num_prompts, prompt_dim)),dim=-1)
self.embeddings = nn.Parameter(embeddings)
if self.freeze_prompt:
self.embeddings.requires_grad = False
else:
self.embeddings.requires_grad = True
if use_mlp_adapter:
self.adapter = nn.Sequential(nn.Linear(prompt_dim, prompt_dim * 2),
nn.ReLU(True),
nn.Linear(prompt_dim * 2, prompt_dim))
else:
self.adapter = None
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, list]:
"""Calculate losses from a batch of inputs and data samples."""
self.bbox_head.num_classes = self.num_training_classes
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
losses = self.bbox_head.loss(img_feats, txt_feats, batch_data_samples)
return losses
def predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
"""
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
self.bbox_head.num_classes = self.num_test_classes
results_list = self.bbox_head.predict(img_feats,
txt_feats,
batch_data_samples,
rescale=rescale)
batch_data_samples = self.add_pred_to_datasample(
batch_data_samples, results_list)
return batch_data_samples
def _forward(
self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
"""
img_feats, txt_feats = self.extract_feat(batch_inputs,
batch_data_samples)
results = self.bbox_head.forward(img_feats, txt_feats)
return results
def extract_feat(
self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]:
"""Extract features."""
# only image features
img_feats, _ = self.backbone(batch_inputs, None)
# use embeddings
txt_feats = self.embeddings[None]
if self.adapter is not None:
txt_feats = self.adapter(txt_feats) + txt_feats
txt_feats = nn.functional.normalize(txt_feats, dim=-1, p=2)
txt_feats = txt_feats.repeat(img_feats[0].shape[0], 1, 1)
if self.with_neck:
if self.mm_neck:
img_feats = self.neck(img_feats, txt_feats)
else:
img_feats = self.neck(img_feats)
return img_feats, txt_feats