|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from transformers import LlamaModel, LlamaTokenizer, LlamaTokenizerFast |
|
|
|
|
|
from .base import ProcessorMixin |
|
|
|
|
|
|
|
|
DEFAULT_PROMPT_TEMPLATE = { |
|
|
"template": ( |
|
|
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " |
|
|
"1. The main content and theme of the video." |
|
|
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." |
|
|
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." |
|
|
"4. background environment, light, style and atmosphere." |
|
|
"5. camera angles, movements, and transitions used in the video:<|eot_id|>" |
|
|
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" |
|
|
), |
|
|
"crop_start": 95, |
|
|
} |
|
|
|
|
|
|
|
|
class LlamaProcessor(ProcessorMixin): |
|
|
r""" |
|
|
Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings |
|
|
and attention masks for the input text. |
|
|
|
|
|
Args: |
|
|
output_names (`List[str]`): |
|
|
The names of the outputs that the processor should return. The first output is the embeddings of the input |
|
|
text and the second output is the attention mask for the input text. |
|
|
""" |
|
|
|
|
|
def __init__(self, output_names: List[str] = None): |
|
|
super().__init__() |
|
|
|
|
|
self.output_names = output_names |
|
|
|
|
|
assert len(output_names) == 2 |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
tokenizer: Union[LlamaTokenizer, LlamaTokenizerFast], |
|
|
text_encoder: LlamaModel, |
|
|
caption: Union[str, List[str]], |
|
|
max_sequence_length: int, |
|
|
prompt_template: Optional[Dict[str, Any]] = None, |
|
|
num_layers_to_skip: int = 2, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
r""" |
|
|
Encode the input text and return the embeddings and attention mask for the input text. |
|
|
|
|
|
Args: |
|
|
tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`): |
|
|
The tokenizer used to tokenize the input text. |
|
|
text_encoder (`LlamaModel`): |
|
|
The text encoder used to encode the input text. |
|
|
caption (`Union[str, List[str]]`): |
|
|
The input text to be encoded. |
|
|
max_sequence_length (`int`): |
|
|
The maximum sequence length of the input text. |
|
|
prompt_template (`Optional[Dict[str, Any]]`): |
|
|
The prompt template to be used to encode the input text. |
|
|
""" |
|
|
if prompt_template is None: |
|
|
prompt_template = DEFAULT_PROMPT_TEMPLATE |
|
|
if isinstance(caption, str): |
|
|
caption = [caption] |
|
|
|
|
|
device = text_encoder.device |
|
|
dtype = text_encoder.dtype |
|
|
|
|
|
batch_size = len(caption) |
|
|
caption = [prompt_template["template"].format(c) for c in caption] |
|
|
|
|
|
crop_start = prompt_template.get("crop_start", None) |
|
|
if crop_start is None: |
|
|
prompt_template_input = tokenizer( |
|
|
prompt_template["template"], |
|
|
padding="max_length", |
|
|
return_tensors="pt", |
|
|
return_length=False, |
|
|
return_overflowing_tokens=False, |
|
|
return_attention_mask=False, |
|
|
) |
|
|
crop_start = prompt_template_input["input_ids"].shape[-1] |
|
|
|
|
|
crop_start -= 2 |
|
|
|
|
|
max_sequence_length += crop_start |
|
|
text_inputs = tokenizer( |
|
|
caption, |
|
|
max_length=max_sequence_length, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
return_length=False, |
|
|
return_overflowing_tokens=False, |
|
|
return_attention_mask=True, |
|
|
) |
|
|
text_input_ids = text_inputs.input_ids.to(device) |
|
|
prompt_attention_mask = text_inputs.attention_mask.bool().to(device) |
|
|
|
|
|
prompt_embeds = text_encoder( |
|
|
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True |
|
|
).hidden_states[-(num_layers_to_skip + 1)] |
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
|
|
|
if crop_start is not None and crop_start > 0: |
|
|
prompt_embeds = prompt_embeds[:, crop_start:] |
|
|
prompt_attention_mask = prompt_attention_mask[:, crop_start:] |
|
|
|
|
|
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) |
|
|
|
|
|
return { |
|
|
self.output_names[0]: prompt_embeds, |
|
|
self.output_names[1]: prompt_attention_mask, |
|
|
} |
|
|
|