clip_text_model_for_masked_lm / modeling_clip_masked_lm.py
Bingsu's picture
Update modeling_clip_masked_lm.py
bfd9473
raw
history blame
2.66 kB
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import CLIPTextConfig
from transformers.modeling_outputs import MaskedLMOutput
from transformers.models.clip.modeling_clip import (
CLIPPreTrainedModel,
CLIPTextTransformer,
)
from transformers.models.roberta.modeling_roberta import RobertaLMHead
class CLIPTextModelForMaskedLM(CLIPPreTrainedModel):
config_class = CLIPTextConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.text_model = CLIPTextTransformer(config)
self.lm_head = RobertaLMHead(config)
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, value: nn.Module) -> None:
self.text_model.embeddings.token_embedding = value
def get_output_embeddings(self) -> nn.Module:
return self.lm_head.decoder
def set_output_embeddings(self, value: nn.Module) -> None:
self.lm_head.decoder = value
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
mlm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
mlm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((mlm_loss,) + output) if mlm_loss is not None else output
return MaskedLMOutput(
loss=mlm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)