|  | from typing import Any, Dict, List, Optional, Tuple, Union | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from transformers import LlamaModel, LlamaTokenizer, LlamaTokenizerFast | 
					
						
						|  |  | 
					
						
						|  | from .base import ProcessorMixin | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | DEFAULT_PROMPT_TEMPLATE = { | 
					
						
						|  | "template": ( | 
					
						
						|  | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " | 
					
						
						|  | "1. The main content and theme of the video." | 
					
						
						|  | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." | 
					
						
						|  | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." | 
					
						
						|  | "4. background environment, light, style and atmosphere." | 
					
						
						|  | "5. camera angles, movements, and transitions used in the video:<|eot_id|>" | 
					
						
						|  | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" | 
					
						
						|  | ), | 
					
						
						|  | "crop_start": 95, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LlamaProcessor(ProcessorMixin): | 
					
						
						|  | r""" | 
					
						
						|  | Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings | 
					
						
						|  | and attention masks for the input text. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | output_names (`List[str]`): | 
					
						
						|  | The names of the outputs that the processor should return. The first output is the embeddings of the input | 
					
						
						|  | text and the second output is the attention mask for the input text. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, output_names: List[str] = None): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.output_names = output_names | 
					
						
						|  |  | 
					
						
						|  | assert len(output_names) == 2 | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | tokenizer: Union[LlamaTokenizer, LlamaTokenizerFast], | 
					
						
						|  | text_encoder: LlamaModel, | 
					
						
						|  | caption: Union[str, List[str]], | 
					
						
						|  | max_sequence_length: int, | 
					
						
						|  | prompt_template: Optional[Dict[str, Any]] = None, | 
					
						
						|  | num_layers_to_skip: int = 2, | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | r""" | 
					
						
						|  | Encode the input text and return the embeddings and attention mask for the input text. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`): | 
					
						
						|  | The tokenizer used to tokenize the input text. | 
					
						
						|  | text_encoder (`LlamaModel`): | 
					
						
						|  | The text encoder used to encode the input text. | 
					
						
						|  | caption (`Union[str, List[str]]`): | 
					
						
						|  | The input text to be encoded. | 
					
						
						|  | max_sequence_length (`int`): | 
					
						
						|  | The maximum sequence length of the input text. | 
					
						
						|  | prompt_template (`Optional[Dict[str, Any]]`): | 
					
						
						|  | The prompt template to be used to encode the input text. | 
					
						
						|  | """ | 
					
						
						|  | if prompt_template is None: | 
					
						
						|  | prompt_template = DEFAULT_PROMPT_TEMPLATE | 
					
						
						|  | if isinstance(caption, str): | 
					
						
						|  | caption = [caption] | 
					
						
						|  |  | 
					
						
						|  | device = text_encoder.device | 
					
						
						|  | dtype = text_encoder.dtype | 
					
						
						|  |  | 
					
						
						|  | batch_size = len(caption) | 
					
						
						|  | caption = [prompt_template["template"].format(c) for c in caption] | 
					
						
						|  |  | 
					
						
						|  | crop_start = prompt_template.get("crop_start", None) | 
					
						
						|  | if crop_start is None: | 
					
						
						|  | prompt_template_input = tokenizer( | 
					
						
						|  | prompt_template["template"], | 
					
						
						|  | padding="max_length", | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | return_length=False, | 
					
						
						|  | return_overflowing_tokens=False, | 
					
						
						|  | return_attention_mask=False, | 
					
						
						|  | ) | 
					
						
						|  | crop_start = prompt_template_input["input_ids"].shape[-1] | 
					
						
						|  |  | 
					
						
						|  | crop_start -= 2 | 
					
						
						|  |  | 
					
						
						|  | max_sequence_length += crop_start | 
					
						
						|  | text_inputs = tokenizer( | 
					
						
						|  | caption, | 
					
						
						|  | max_length=max_sequence_length, | 
					
						
						|  | padding="max_length", | 
					
						
						|  | truncation=True, | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | return_length=False, | 
					
						
						|  | return_overflowing_tokens=False, | 
					
						
						|  | return_attention_mask=True, | 
					
						
						|  | ) | 
					
						
						|  | text_input_ids = text_inputs.input_ids.to(device) | 
					
						
						|  | prompt_attention_mask = text_inputs.attention_mask.bool().to(device) | 
					
						
						|  |  | 
					
						
						|  | prompt_embeds = text_encoder( | 
					
						
						|  | text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True | 
					
						
						|  | ).hidden_states[-(num_layers_to_skip + 1)] | 
					
						
						|  | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | 
					
						
						|  |  | 
					
						
						|  | if crop_start is not None and crop_start > 0: | 
					
						
						|  | prompt_embeds = prompt_embeds[:, crop_start:] | 
					
						
						|  | prompt_attention_mask = prompt_attention_mask[:, crop_start:] | 
					
						
						|  |  | 
					
						
						|  | prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | self.output_names[0]: prompt_embeds, | 
					
						
						|  | self.output_names[1]: prompt_attention_mask, | 
					
						
						|  | } | 
					
						
						|  |  |