Spaces:
Running
Running
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import torch | |
| from transformers import T5EncoderModel, T5Tokenizer, T5TokenizerFast | |
| from .base import ProcessorMixin | |
| class T5Processor(ProcessorMixin): | |
| r""" | |
| Processor for the T5 family of models. This processor is used to encode text inputs and return the embeddings | |
| and attention masks for the input text. | |
| Args: | |
| output_names (`List[str]`): | |
| The names of the outputs that the processor should return. The first output is the embeddings of the input | |
| text and the second output is the attention mask for the input text. | |
| """ | |
| def __init__( | |
| self, | |
| output_names: List[str], | |
| input_names: Optional[Dict[str, Any]] = None, | |
| *, | |
| use_attention_mask: bool = False, | |
| ): | |
| super().__init__() | |
| self.output_names = output_names | |
| self.input_names = input_names | |
| self.use_attention_mask = use_attention_mask | |
| if input_names is not None: | |
| assert len(input_names) <= 4 | |
| assert len(self.output_names) == 2 | |
| def forward( | |
| self, | |
| tokenizer: Union[T5Tokenizer, T5TokenizerFast], | |
| text_encoder: T5EncoderModel, | |
| caption: Union[str, List[str]], | |
| max_sequence_length: int, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| r""" | |
| Encode the input text and return the embeddings and attention mask for the input text. | |
| Args: | |
| tokenizer (`Union[T5Tokenizer, T5TokenizerFast]`): | |
| The tokenizer used to tokenize the input text. | |
| text_encoder (`T5EncoderModel`): | |
| The text encoder used to encode the input text. | |
| caption (`Union[str, List[str]]`): | |
| The input text to be encoded. | |
| max_sequence_length (`int`): | |
| The maximum sequence length of the input text. | |
| """ | |
| if isinstance(caption, str): | |
| caption = [caption] | |
| device = text_encoder.device | |
| dtype = text_encoder.dtype | |
| batch_size = len(caption) | |
| text_inputs = tokenizer( | |
| caption, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_attention_mask = text_inputs.attention_mask | |
| prompt_attention_mask = prompt_attention_mask.bool().to(device) | |
| te_mask = None | |
| if self.use_attention_mask: | |
| te_mask = prompt_attention_mask | |
| prompt_embeds = text_encoder(text_input_ids.to(device), te_mask)[0] | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) | |
| return { | |
| self.output_names[0]: prompt_embeds, | |
| self.output_names[1]: prompt_attention_mask, | |
| } | |