|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
from transformers import CLIPTextConfig |
|
from transformers.modeling_outputs import MaskedLMOutput |
|
from transformers.models.clip.modeling_clip import ( |
|
CLIPPreTrainedModel, |
|
CLIPTextTransformer, |
|
) |
|
from transformers.models.roberta.modeling_roberta import RobertaLMHead |
|
|
|
|
|
class CLIPTextModelForMaskedLM(CLIPPreTrainedModel): |
|
config_class = CLIPTextConfig |
|
|
|
_no_split_modules = ["CLIPEncoderLayer"] |
|
|
|
def __init__(self, config: CLIPTextConfig): |
|
super().__init__(config) |
|
self.text_model = CLIPTextTransformer(config) |
|
self.lm_head = RobertaLMHead(config) |
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
return self.text_model.embeddings.token_embedding |
|
|
|
def set_input_embeddings(self, value: nn.Module) -> None: |
|
self.text_model.embeddings.token_embedding = value |
|
|
|
def get_output_embeddings(self) -> nn.Module: |
|
return self.lm_head.decoder |
|
|
|
def set_output_embeddings(self, value: nn.Module) -> None: |
|
self.lm_head.decoder = value |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
outputs = self.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
prediction_scores = self.lm_head(sequence_output) |
|
|
|
mlm_loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
mlm_loss = loss_fct( |
|
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) |
|
) |
|
|
|
if not return_dict: |
|
output = (prediction_scores,) + outputs[2:] |
|
return ((mlm_loss,) + output) if mlm_loss is not None else output |
|
|
|
return MaskedLMOutput( |
|
loss=mlm_loss, |
|
logits=prediction_scores, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|