|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" PyTorch ZH-CLIP model.""" |
|
|
|
|
|
from typing import Optional, Tuple, Union |
|
from torch import TensorType |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import logging, ModelOutput |
|
from transformers.models.auto.modeling_auto import AutoModel |
|
|
|
from transformers.models.clip.modeling_clip import CLIPVisionConfig, CLIPVisionModel |
|
from .configuration_zhclip import ZhCLIPConfig |
|
from dataclasses import dataclass |
|
|
|
logger = logging.get_logger(__name__) |
|
_CONFIG_FOR_DOC = "ZhCLIPConfig" |
|
|
|
@dataclass |
|
class ZhCLIPModelOutput(ModelOutput): |
|
|
|
text_features: torch.FloatTensor = None |
|
image_features: torch.FloatTensor = None |
|
|
|
|
|
class MeanPooler(nn.Module): |
|
"""Mean pooling""" |
|
|
|
def forward(self, last_hidden_state: TensorType, attention_mask: TensorType): |
|
masked_output = last_hidden_state * attention_mask.unsqueeze(-1) |
|
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True) |
|
|
|
|
|
class ZhCLIPPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization. |
|
""" |
|
|
|
config_class = ZhCLIPConfig |
|
base_model_prefix = "zhclip" |
|
supports_gradient_checkpointing = False |
|
_keys_to_ignore_on_load_missing = [r"position_ids"] |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
class ZhCLIPModel(ZhCLIPPreTrainedModel): |
|
def __init__( |
|
self, |
|
config: Optional[ZhCLIPConfig] = None, |
|
vision_model: Optional[PreTrainedModel] = None, |
|
text_model: Optional[PreTrainedModel] = None, |
|
): |
|
|
|
if config is None and (vision_model is None or text_model is None): |
|
raise ValueError("Either a configuration or an vision and a text model has to be provided") |
|
|
|
if config is None: |
|
config = ZhCLIPConfig(vision_model.config, text_model.config) |
|
else: |
|
if not isinstance(config, self.config_class): |
|
raise ValueError(f"config: {config} has to be of type {self.config_class}") |
|
|
|
|
|
super().__init__(config) |
|
|
|
if vision_model is None: |
|
if isinstance(config.vision_config, CLIPVisionConfig): |
|
vision_model = CLIPVisionModel(config.vision_config).vision_model |
|
else: |
|
vision_model = AutoModel.from_config(config.vision_config) |
|
|
|
if text_model is None: |
|
text_model = AutoModel.from_config(config.text_config) |
|
|
|
self.vision_model = vision_model |
|
self.text_model = text_model |
|
|
|
|
|
|
|
self.vision_model.config = self.config.vision_config |
|
self.text_model.config = self.config.text_config |
|
|
|
self.vision_embed_dim = config.vision_config.hidden_size |
|
self.text_embed_dim = config.text_config.hidden_size |
|
self.coattention_dim = config.hidden_size |
|
|
|
|
|
mlp_hidden_size = (self.text_embed_dim + self.coattention_dim) // 2 |
|
self.text_projection = nn.Sequential( |
|
nn.Linear(self.text_embed_dim, mlp_hidden_size, bias=False), |
|
nn.GELU(), |
|
nn.Linear(mlp_hidden_size, self.coattention_dim, bias=False), |
|
) |
|
self.text_pooler = MeanPooler() |
|
self.visual_projection = nn.Linear(self.vision_embed_dim, self.coattention_dim) |
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
patch_ids = None, |
|
extend_token_type_ids = None, |
|
return_loss: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], ZhCLIPModelOutput]: |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
image_features = self.get_image_features( |
|
pixel_values=pixel_values, |
|
return_dict=return_dict, |
|
) |
|
text_features = self.get_text_features( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
return_dict=return_dict, |
|
) |
|
return ZhCLIPModelOutput( |
|
image_features = image_features, |
|
text_features = text_features, |
|
) |
|
|
|
|
|
@classmethod |
|
def from_pretrained(cls, *args, **kwargs): |
|
|
|
|
|
kwargs["_fast_init"] = False |
|
return super().from_pretrained(*args, **kwargs) |
|
|
|
|
|
def get_text_features( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
token_type_ids=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
text_outputs = self.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
token_type_ids=token_type_ids, |
|
|
|
|
|
return_dict=return_dict, |
|
) |
|
if attention_mask is None: |
|
attention_mask = (input_ids != self.config.pad_token_id).long() |
|
text_pool = self.text_pooler(text_outputs[0], attention_mask) |
|
text_feat = self.text_projection(text_pool) |
|
return text_feat |
|
|
|
|
|
def get_image_features( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> torch.FloatTensor: |
|
r""" |
|
Returns: |
|
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by |
|
applying the projection layer to the pooled output of [`CLIPVisionModel`]. |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, CLIPModel |
|
|
|
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> inputs = processor(images=image, return_tensors="pt") |
|
|
|
>>> image_features = model.get_image_features(**inputs) |
|
```""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
pooled_output = vision_outputs[1] |
|
image_features = self.visual_projection(pooled_output) |
|
|
|
return image_features |
|
|