from typing import Optional, Tuple, Union import torch from torch import nn from transformers import CLIPTextConfig from transformers.modeling_outputs import MaskedLMOutput from transformers.models.clip.modeling_clip import ( CLIPPreTrainedModel, CLIPTextTransformer, ) from transformers.models.roberta.modeling_roberta import RobertaLMHead class CLIPTextModelForMaskedLM(CLIPPreTrainedModel): config_class = CLIPTextConfig _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPTextConfig): super().__init__(config) self.text_model = CLIPTextTransformer(config) self.lm_head = RobertaLMHead(config) self.post_init() def get_input_embeddings(self) -> nn.Module: return self.text_model.embeddings.token_embedding def set_input_embeddings(self, value: nn.Module) -> None: self.text_model.embeddings.token_embedding = value def get_output_embeddings(self) -> nn.Module: return self.lm_head.decoder def set_output_embeddings(self, value: nn.Module) -> None: self.lm_head.decoder = value def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] prediction_scores = self.lm_head(sequence_output) mlm_loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() mlm_loss = loss_fct( prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) ) if not return_dict: output = (prediction_scores,) + outputs[2:] return ((mlm_loss,) + output) if mlm_loss is not None else output return MaskedLMOutput( loss=mlm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )