|  | from functools import partial | 
					
						
						|  | from typing import Any, List, Optional, Mapping, Callable | 
					
						
						|  | from collections import OrderedDict | 
					
						
						|  | from argparse import Namespace | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | import torchvision.transforms as T | 
					
						
						|  | import PIL | 
					
						
						|  | import transformers | 
					
						
						|  | from transformers import PreTrainedModel, PreTrainedTokenizer | 
					
						
						|  |  | 
					
						
						|  | from .configuration_emu import EmuConfig | 
					
						
						|  | from .constants import * | 
					
						
						|  | from .modeling_llama import LlamaForCausalLM | 
					
						
						|  | from .visual import EVAVisionTransformer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class EmuPreTrainedModel(PreTrainedModel): | 
					
						
						|  | config_class = EmuConfig | 
					
						
						|  | base_model_prefix = "model" | 
					
						
						|  | supports_gradient_checkpointing = False | 
					
						
						|  | _no_split_modules = ["LlamaDecoderLayer", "Block"] | 
					
						
						|  | _skip_keys_device_placement = "past_key_values" | 
					
						
						|  |  | 
					
						
						|  | def _init_weights(self, module): | 
					
						
						|  | std = self.config.initializer_range | 
					
						
						|  | if isinstance(module, nn.Linear): | 
					
						
						|  | module.weight.data.normal_(mean=0.0, std=std) | 
					
						
						|  | if module.bias is not None: | 
					
						
						|  | module.bias.data.zero_() | 
					
						
						|  | elif isinstance(module, nn.Embedding): | 
					
						
						|  | module.weight.data.normal_(mean=0.0, std=std) | 
					
						
						|  | if module.padding_idx is not None: | 
					
						
						|  | module.weight.data[module.padding_idx].zero_() | 
					
						
						|  |  | 
					
						
						|  | class EmuForClsAndRegression(EmuPreTrainedModel): | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super(EmuForClsAndRegression, self).__init__(config) | 
					
						
						|  |  | 
					
						
						|  | self.lm = LlamaForCausalLM(config=config) | 
					
						
						|  |  | 
					
						
						|  | self.lm.model.embed_tokens.padding_idx = config.pad_token_id | 
					
						
						|  |  | 
					
						
						|  | def get_num_layers(self): | 
					
						
						|  | return len(self.lm.model.layers) | 
					
						
						|  |  | 
					
						
						|  | class EmuModel(EmuPreTrainedModel): | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super().__init__(config) | 
					
						
						|  |  | 
					
						
						|  | vision_config = Namespace(**config.vision_config) | 
					
						
						|  |  | 
					
						
						|  | self.visual = EVAVisionTransformer( | 
					
						
						|  | img_size=vision_config.image_size, | 
					
						
						|  | patch_size=vision_config.patch_size, | 
					
						
						|  | embed_dim=vision_config.width, | 
					
						
						|  | depth=vision_config.layers, | 
					
						
						|  | num_heads=vision_config.width // vision_config.head_width, | 
					
						
						|  | mlp_ratio=vision_config.mlp_ratio, | 
					
						
						|  | qkv_bias=vision_config.qkv_bias, | 
					
						
						|  | drop_path_rate=vision_config.drop_path_rate, | 
					
						
						|  | norm_layer=partial(nn.LayerNorm, eps=vision_config.layer_norm_eps), | 
					
						
						|  | xattn=vision_config.xattn, | 
					
						
						|  | postnorm=vision_config.postnorm, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.decoder = EmuForClsAndRegression(config) | 
					
						
						|  |  | 
					
						
						|  | self.gradient_checkpointing = False | 
					
						
						|  |  | 
					
						
						|  | self.n_query = vision_config.n_query | 
					
						
						|  | self.v_query = vision_config.v_query | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def device(self): | 
					
						
						|  | return next(iter(self.parameters())).device | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def dtype(self): | 
					
						
						|  | return next(iter(self.parameters())).dtype | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def encode_image(self, image: torch.Tensor, *, n_query=None): | 
					
						
						|  | n_query = n_query if n_query is not None else self.n_query | 
					
						
						|  |  | 
					
						
						|  | image_embeds = self.visual(image) | 
					
						
						|  | image_embeds = image_embeds[:, 1:, :] | 
					
						
						|  | b, n, c = image_embeds.shape | 
					
						
						|  | sqrt_n = int(n**0.5) | 
					
						
						|  | image_embeds = image_embeds.permute(0, 2, 1).view(b, c, sqrt_n, sqrt_n) | 
					
						
						|  |  | 
					
						
						|  | stride = int(sqrt_n // (n_query ** 0.5)) | 
					
						
						|  | image_embeds = F.avg_pool2d(image_embeds, kernel_size=(stride, stride), stride=stride) | 
					
						
						|  | image_embeds = image_embeds.view(b, c, -1).permute(0, 2, 1).contiguous() | 
					
						
						|  | return image_embeds | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class EmuForCausalLM(EmuPreTrainedModel): | 
					
						
						|  | _auto_class = "AutoModelForCausalLM" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super().__init__(config) | 
					
						
						|  |  | 
					
						
						|  | self.config = config | 
					
						
						|  | self.model = EmuModel(config) | 
					
						
						|  |  | 
					
						
						|  | self.project_down = nn.Linear(config.hidden_size, config.d_model, bias=False) | 
					
						
						|  |  | 
					
						
						|  | self.project_up = nn.Linear(config.d_model, config.hidden_size, bias=False) | 
					
						
						|  |  | 
					
						
						|  | self.n_query = self.model.n_query | 
					
						
						|  | self.v_query = self.model.v_query | 
					
						
						|  |  | 
					
						
						|  | self.image_placeholder = DEFAULT_IMG_TOKEN + DEFAULT_IMAGE_TOKEN * self.n_query + DEFAULT_IMG_END_TOKEN | 
					
						
						|  |  | 
					
						
						|  | self.video_placeholder = DEFAULT_IMG_TOKEN + DEFAULT_gIMG_TOKEN * self.v_query + DEFAULT_IMG_END_TOKEN | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def device(self): | 
					
						
						|  | return next(iter(self.parameters())).device | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def dtype(self): | 
					
						
						|  | return next(iter(self.parameters())).dtype | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def generate( | 
					
						
						|  | self, | 
					
						
						|  | input_ids, | 
					
						
						|  | attention_mask, | 
					
						
						|  | image: Optional[torch.Tensor] = None, | 
					
						
						|  | video: Optional[torch.Tensor] = None, | 
					
						
						|  | num_beams=5, | 
					
						
						|  | max_new_tokens=10, | 
					
						
						|  | min_len=1, | 
					
						
						|  | do_sample=False, | 
					
						
						|  | penalty_alpha=None, | 
					
						
						|  | top_p=None, | 
					
						
						|  | top_k=None, | 
					
						
						|  | temperature=None, | 
					
						
						|  | length_penalty=-1, | 
					
						
						|  | repetition_penalty=1.0, | 
					
						
						|  | **kwargs | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | text_embeds = self.model.decoder.lm.model.embed_tokens(input_ids).to("cuda") | 
					
						
						|  | if image is not None: | 
					
						
						|  | prompt_image_embeds = self.model.encode_image(image, n_query=self.n_query) | 
					
						
						|  | _, _, c = prompt_image_embeds.shape | 
					
						
						|  | prompt_image_embeds = prompt_image_embeds.view(-1, c) | 
					
						
						|  | prompt_image_embeds = self.project_up(prompt_image_embeds) | 
					
						
						|  | image_idx = (input_ids == IMAGE) | 
					
						
						|  | text_embeds[image_idx] = prompt_image_embeds.to(text_embeds.device) | 
					
						
						|  |  | 
					
						
						|  | if video is not None: | 
					
						
						|  | prompt_video_embeds = self.model.encode_image(video, n_query=self.v_query) | 
					
						
						|  | _, _, c = prompt_video_embeds.shape | 
					
						
						|  | prompt_video_embeds = prompt_video_embeds.view(-1, c) | 
					
						
						|  | prompt_video_embeds = self.project_up(prompt_video_embeds) | 
					
						
						|  | video_idx = (input_ids == VIDEO) | 
					
						
						|  | text_embeds[video_idx] = prompt_video_embeds.to(text_embeds.device) | 
					
						
						|  |  | 
					
						
						|  | outputs = self.model.decoder.lm.generate( | 
					
						
						|  | inputs_embeds=text_embeds, | 
					
						
						|  | attention_mask=attention_mask, | 
					
						
						|  | do_sample=do_sample, | 
					
						
						|  | num_beams=num_beams, | 
					
						
						|  | max_new_tokens=max_new_tokens, | 
					
						
						|  | min_length=min_len, | 
					
						
						|  | length_penalty=length_penalty, | 
					
						
						|  | repetition_penalty=repetition_penalty, | 
					
						
						|  | penalty_alpha=penalty_alpha, | 
					
						
						|  | top_k=top_k, | 
					
						
						|  | top_p=top_p, | 
					
						
						|  | temperature=temperature, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return outputs | 
					
						
						|  |  | 
					
						
						|  | def prepare_image_input(self, images): | 
					
						
						|  | image_size: int = self.config.vision_config['image_size'] | 
					
						
						|  | transform = T.Compose( | 
					
						
						|  | [ | 
					
						
						|  | T.Resize( | 
					
						
						|  | (image_size, image_size), interpolation=T.InterpolationMode.BICUBIC | 
					
						
						|  | ), | 
					
						
						|  | T.ToTensor(), | 
					
						
						|  | T.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | images = [transform(image) for image in images] | 
					
						
						|  | return torch.stack(images, 0) | 
					
						
						|  |  | 
					
						
						|  | def _prepare_chat_template(self, text, system_msg=""): | 
					
						
						|  | text = [ | 
					
						
						|  | system_msg + USER_TOKEN + ": " + t + ASSISTANT_TOKEN +":" | 
					
						
						|  | for t in text | 
					
						
						|  | ] | 
					
						
						|  | return text | 
					
						
						|  |  | 
					
						
						|  | def prepare_text_input( | 
					
						
						|  | self, | 
					
						
						|  | text: List[str], | 
					
						
						|  | tokenizer: PreTrainedTokenizer, | 
					
						
						|  | image_placeholder: str = DEFAULT_IMG_PLACEHOLDER, | 
					
						
						|  | video_placeholder: str = DEFAULT_VID_PLACEHOLDER, | 
					
						
						|  | ): | 
					
						
						|  | text = [ | 
					
						
						|  | t.replace(image_placeholder, self.image_placeholder).replace(video_placeholder, self.video_placeholder) | 
					
						
						|  | for t in text | 
					
						
						|  | ] | 
					
						
						|  | input_ids = tokenizer(text, padding="longest", return_tensors="pt") | 
					
						
						|  | return input_ids | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def build_input_ids( | 
					
						
						|  | self, | 
					
						
						|  | text: List[str], | 
					
						
						|  | tokenizer: PreTrainedTokenizer, | 
					
						
						|  | image: Optional[List["PIL.Image"]] = None, | 
					
						
						|  | video: Optional[List["PIL.Image"]] = None, | 
					
						
						|  | system_msg: str = "", | 
					
						
						|  | to_cuda: bool = True | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | if self.config.model_version == "chat": | 
					
						
						|  | text = self._prepare_chat_template(text, system_msg) | 
					
						
						|  |  | 
					
						
						|  | if image is not None: | 
					
						
						|  | image = self.prepare_image_input(image) | 
					
						
						|  | if video is not None: | 
					
						
						|  | video = self.prepare_image_input(video) | 
					
						
						|  | inputs = self.prepare_text_input(text, tokenizer) | 
					
						
						|  | input_ids = inputs.input_ids | 
					
						
						|  | attention_mask =  inputs.attention_mask | 
					
						
						|  |  | 
					
						
						|  | if to_cuda: | 
					
						
						|  | input_ids = input_ids.to("cuda") | 
					
						
						|  | attention_mask = attention_mask.to("cuda") | 
					
						
						|  | if image is not None: | 
					
						
						|  | image = image.to("cuda") | 
					
						
						|  | if video is not None: | 
					
						
						|  | video = video.to("cuda") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | 'input_ids': input_ids, | 
					
						
						|  | 'attention_mask': attention_mask, | 
					
						
						|  | 'image': image, | 
					
						
						|  | 'video': video | 
					
						
						|  | } | 
					
						
						|  |  |