# Copyright (c) Tencent Inc. All rights reserved. import itertools from typing import List, Sequence, Tuple import torch from torch import Tensor from torch.nn.modules.batchnorm import _BatchNorm from mmengine.model import BaseModule from mmyolo.registry import MODELS from mmdet.utils import OptMultiConfig, ConfigType from transformers import ( AutoTokenizer, AutoModel, CLIPTextConfig) from transformers import CLIPTextModelWithProjection as CLIPTP @MODELS.register_module() class HuggingVisionBackbone(BaseModule): def __init__(self, model_name: str, out_indices: Sequence[int] = (0, 1, 2, 3), norm_eval: bool = True, frozen_modules: Sequence[str] = (), init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) self.norm_eval = norm_eval self.frozen_modules = frozen_modules self.model = AutoModel.from_pretrained(model_name) self._freeze_modules() def forward(self, image: Tensor) -> Tuple[Tensor]: encoded_dict = self.image_model(pixel_values=image, output_hidden_states=True) hidden_states = encoded_dict.hidden_states img_feats = encoded_dict.get('reshaped_hidden_states', hidden_states) img_feats = [img_feats[i] for i in self.image_out_indices] return tuple(img_feats) def _freeze_modules(self): for name, module in self.model.named_modules(): for frozen_name in self.frozen_modules: if name.startswith(frozen_name): module.eval() for param in module.parameters(): param.requires_grad = False break def train(self, mode=True): super().train(mode) self._freeze_modules() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() @MODELS.register_module() class HuggingCLIPLanguageBackbone(BaseModule): def __init__(self, model_name: str, frozen_modules: Sequence[str] = (), dropout: float = 0.0, training_use_cache: bool = False, init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) self.frozen_modules = frozen_modules self.training_use_cache = training_use_cache self.tokenizer = AutoTokenizer.from_pretrained(model_name) clip_config = CLIPTextConfig.from_pretrained(model_name, attention_dropout=dropout) self.model = CLIPTP.from_pretrained(model_name, config=clip_config) self._freeze_modules() def forward_cache(self, text: List[List[str]]) -> Tensor: if not hasattr(self, "cache"): self.cache = self.forward_text(text) return self.cache def forward(self, text: List[List[str]]) -> Tensor: if self.training: return self.forward_text(text) else: return self.forward_text(text) # return self.forward_cache(text) def forward_tokenizer(self, texts): if not hasattr(self, 'text'): text = list(itertools.chain(*texts)) # print(text) # # text = ['a photo of {}'.format(x) for x in text] text = self.tokenizer(text=text, return_tensors='pt', padding=True) # print(text) self.text = text.to(device=self.model.device) return self.text def forward_text(self, text: List[List[str]]) -> Tensor: num_per_batch = [len(t) for t in text] assert max(num_per_batch) == min(num_per_batch), ( 'number of sequences not equal in batch') # print(max([[len(t.split(' ')) for t in tt] for tt in text])) # print(num_per_batch, max(num_per_batch)) text = list(itertools.chain(*text)) # print(text) # text = ['a photo of {}'.format(x) for x in text] # text = self.forward_tokenizer(text) text = self.tokenizer(text=text, return_tensors='pt', padding=True) text = text.to(device=self.model.device) txt_outputs = self.model(**text) # txt_feats = txt_outputs.last_hidden_state[:, 0, :] txt_feats = txt_outputs.text_embeds txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) txt_feats = txt_feats.reshape(-1, num_per_batch[0], txt_feats.shape[-1]) return txt_feats def _freeze_modules(self): if len(self.frozen_modules) == 0: # not freeze return if self.frozen_modules[0] == "all": self.model.eval() for _, module in self.model.named_modules(): module.eval() for param in module.parameters(): param.requires_grad = False return for name, module in self.model.named_modules(): for frozen_name in self.frozen_modules: if name.startswith(frozen_name): module.eval() for param in module.parameters(): param.requires_grad = False break def train(self, mode=True): super().train(mode) self._freeze_modules() @MODELS.register_module() class PseudoLanguageBackbone(BaseModule): """Pseudo Language Backbone Args: text_embed_path (str): path to the text embedding file """ def __init__(self, text_embed_path: str = "", test_embed_path: str = None, init_cfg: OptMultiConfig = None): super().__init__(init_cfg) # {text:embed} self.text_embed = torch.load(text_embed_path, map_location='cpu') if test_embed_path is None: self.test_embed = self.text_embed else: self.test_embed = torch.load(test_embed_path) self.register_buffer("buff", torch.zeros([ 1, ])) def forward_cache(self, text: List[List[str]]) -> Tensor: if not hasattr(self, "cache"): self.cache = self.forward_text(text) return self.cache def forward(self, text: List[List[str]]) -> Tensor: if self.training: return self.forward_text(text) else: return self.forward_cache(text) def forward_text(self, text: List[List[str]]) -> Tensor: num_per_batch = [len(t) for t in text] assert max(num_per_batch) == min(num_per_batch), ( 'number of sequences not equal in batch') text = list(itertools.chain(*text)) if self.training: text_embed_dict = self.text_embed else: text_embed_dict = self.test_embed text_embeds = torch.stack( [text_embed_dict[x.split("/")[0]] for x in text]) # requires no grad and force to float text_embeds = text_embeds.to( self.buff.device).requires_grad_(False).float() text_embeds = text_embeds.reshape(-1, num_per_batch[0], text_embeds.shape[-1]) return text_embeds @MODELS.register_module() class MultiModalYOLOBackbone(BaseModule): def __init__(self, image_model: ConfigType, text_model: ConfigType, init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg) self.image_model = MODELS.build(image_model) self.text_model = MODELS.build(text_model) def forward(self, image: Tensor, text: List[List[str]]) -> Tuple[Tuple[Tensor], Tensor]: img_feats = self.image_model(image) txt_feats = self.text_model(text) return img_feats, txt_feats