Spaces:
Running
on
Zero
Running
on
Zero
| from dataclasses import dataclass | |
| from typing import List, Optional, Tuple, Union, Dict, Any | |
| import math | |
| import torch.utils.checkpoint | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel, AutoConfig, AutoModel | |
| from transformers.activations import ACT2FN | |
| from transformers.cache_utils import Cache | |
| from transformers.modeling_outputs import ModelOutput | |
| from transformers.utils import logging | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.dynamic_module_utils import get_class_from_dynamic_module | |
| from transformers.models.auto import AutoModel, AutoModelForCausalLM, CONFIG_MAPPING | |
| from transformers.generation import GenerationMixin | |
| from transformers import LlamaForCausalLM, Qwen2ForCausalLM | |
| # from models.modeling_qwen2 import Qwen2ForCausalLM | |
| from models.modeling_qwen2_vl_fast import Qwen2VLForCausalLM | |
| from models.utils import _pad_input, _unpad_input | |
| logger = logging.get_logger(__name__) | |
| class LlavaConfig(PretrainedConfig): | |
| model_type = "llava" | |
| is_composition = False | |
| def __init__( | |
| self, | |
| vision_config=None, | |
| text_config=None, | |
| ignore_index=-100, | |
| image_token_index=32000, | |
| projector_hidden_act="gelu", | |
| vision_feature_select_strategy="default", | |
| vision_feature_layer=-2, | |
| image_newline_idx=32002, | |
| image_new_idx=32003, | |
| projection_head="MLP", | |
| **kwargs, | |
| ): | |
| self.ignore_index = ignore_index | |
| self.image_token_index = image_token_index | |
| self.projector_hidden_act = projector_hidden_act | |
| self.vision_feature_select_strategy = vision_feature_select_strategy | |
| self.vision_feature_layer = vision_feature_layer | |
| self.image_newline_idx = image_newline_idx | |
| self.image_new_idx = image_new_idx | |
| self.projection_head = projection_head | |
| self.vision_config = vision_config | |
| if isinstance(self.vision_config, dict): | |
| vision_config["model_type"] = ( | |
| vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" | |
| ) | |
| if 'auto_map' in vision_config: | |
| repo_id, class_ref = vision_config['auto_map']['AutoConfig'].split("--") | |
| config_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) | |
| self.vision_config = config_class(**vision_config) | |
| elif vision_config["model_type"] in CONFIG_MAPPING: | |
| self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) | |
| else: | |
| raise ValueError(f'vision_config["model_type"] = {vision_config["model_type"]} not supported!') | |
| self.text_config = text_config | |
| if isinstance(self.text_config, dict): | |
| text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" | |
| if 'auto_map' in text_config: | |
| repo_id, class_ref = text_config['auto_map']['AutoConfig'].split("--") | |
| config_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) | |
| self.text_config = config_class(**text_config) | |
| elif text_config["model_type"] in CONFIG_MAPPING: | |
| self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) | |
| else: | |
| raise ValueError(f'text_config["model_type"] = {text_config["model_type"]} not supported!') | |
| super().__init__(**kwargs) | |
| # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava | |
| class LlavaCausalLMOutputWithPast(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| logits: torch.FloatTensor = None | |
| past_key_values: Optional[List[torch.FloatTensor]] = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| position_ids: Optional[torch.LongTensor] = None | |
| def add_split_tokens(image_features, image_newline_embed, image_new_embed): | |
| num_images, num_image_patches, embed_dim = image_features.shape | |
| num_height_patches, num_width_patches = int(math.sqrt(num_image_patches)), int(math.sqrt(num_image_patches)) | |
| # add image_newline | |
| image_features = image_features.view(num_images, num_height_patches, num_width_patches, embed_dim) | |
| image_features = torch.cat([ | |
| image_features, | |
| image_newline_embed.expand((num_images, num_height_patches, 1, embed_dim)) | |
| ], dim=2) | |
| num_image_patches += num_height_patches | |
| image_features = image_features.view(num_images, num_image_patches, embed_dim) | |
| # add image_new | |
| image_features = torch.cat([ | |
| image_features, | |
| image_new_embed.expand((num_images, 1, embed_dim)) | |
| ], dim = 1) | |
| return image_features | |
| class LlavaMultiModalProjector(nn.Module): | |
| def __init__(self, config: LlavaConfig): | |
| super().__init__() | |
| self.config = config | |
| self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) | |
| self.act = ACT2FN[config.projector_hidden_act] | |
| self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) | |
| image_newline_idx = torch.tensor([config.image_newline_idx], dtype=torch.long) | |
| image_new_idx = torch.tensor([config.image_new_idx], dtype=torch.long) | |
| self.register_buffer('image_newline_idx', image_newline_idx, persistent=False) | |
| self.register_buffer('image_new_idx', image_new_idx, persistent=False) | |
| def forward(self, image_features, input_embeddings): | |
| selected_image_feature = image_features[self.config.vision_feature_layer] | |
| if self.config.vision_feature_select_strategy == "default": | |
| selected_image_feature = selected_image_feature[:, 1:] | |
| elif self.config.vision_feature_select_strategy == "full": | |
| selected_image_feature = selected_image_feature | |
| else: | |
| raise ValueError( | |
| f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" | |
| ) | |
| hidden_states = self.linear_1(selected_image_feature) | |
| hidden_states = self.act(hidden_states) | |
| hidden_states = self.linear_2(hidden_states) | |
| image_newline_embed = input_embeddings(self.image_newline_idx).squeeze() | |
| image_new_embed = input_embeddings(self.image_new_idx).squeeze() | |
| hidden_states = add_split_tokens(hidden_states, image_newline_embed, image_new_embed) | |
| return hidden_states | |
| class PixelShuffleMultiModalProjector(nn.Module): | |
| def __init__(self, config: LlavaConfig): | |
| super().__init__() | |
| self.config = config | |
| self.downsample_ratio = 0.5 | |
| vit_hidden_size = config.vision_config.hidden_size | |
| llm_hidden_size = config.text_config.hidden_size | |
| self.mlp = nn.Sequential( | |
| nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), | |
| nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), | |
| nn.GELU(), | |
| nn.Linear(llm_hidden_size, llm_hidden_size) | |
| ) | |
| image_newline_idx = torch.tensor([config.image_newline_idx], dtype=torch.long) | |
| image_new_idx = torch.tensor([config.image_new_idx], dtype=torch.long) | |
| self.register_buffer('image_newline_idx', image_newline_idx, persistent=False) | |
| self.register_buffer('image_new_idx', image_new_idx, persistent=False) | |
| def forward(self, image_features, input_embeddings): | |
| selected_image_feature = image_features[self.config.vision_feature_layer] | |
| if self.config.vision_feature_select_strategy == "default": | |
| selected_image_feature = selected_image_feature[:, 1:] | |
| elif self.config.vision_feature_select_strategy == "full": | |
| selected_image_feature = selected_image_feature | |
| else: | |
| raise ValueError( | |
| f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" | |
| ) | |
| image_features = self.pixel_shuffle(selected_image_feature) | |
| hidden_states = self.mlp(image_features) | |
| image_newline_embed = input_embeddings(self.image_newline_idx).squeeze() | |
| image_new_embed = input_embeddings(self.image_new_idx).squeeze() | |
| hidden_states = add_split_tokens(hidden_states, image_newline_embed, image_new_embed) | |
| return hidden_states | |
| def pixel_shuffle(self, x, scale_factor=0.5): | |
| if scale_factor == 1: | |
| return x | |
| n, wh, c = x.shape | |
| h, w = int(math.sqrt(wh)), int(math.sqrt(wh)) | |
| x = x.view(n, h, w, c) | |
| n, w, h, c = x.size() | |
| # N, W, H, C --> N, W, H * scale, C // scale | |
| x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) | |
| # N, W, H * scale, C // scale --> N, H * scale, W, C // scale | |
| x = x.permute(0, 2, 1, 3).contiguous() | |
| # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) | |
| x = x.view(n, int(h * scale_factor), int(w * scale_factor), | |
| int(c / (scale_factor * scale_factor))) | |
| x = x.permute(0, 2, 1, 3).contiguous() | |
| x = x.view(x.shape[0], -1, x.shape[-1]) | |
| return x | |
| LLAVA_START_DOCSTRING = r""" | |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the | |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads | |
| etc.) | |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. | |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage | |
| and behavior. | |
| Parameters: | |
| config ([`LlavaConfig`] or [`LlavaVisionConfig`]): | |
| Model configuration class with all the parameters of the model. Initializing with a config file does not | |
| load the weights associated with the model, only the configuration. Check out the | |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
| """ | |
| class TarsierPreTrainedModel(PreTrainedModel): | |
| config_class = LlavaConfig | |
| base_model_prefix = "llm" | |
| supports_gradient_checkpointing = True # TODO: support latest gc | |
| _skip_keys_device_placement = "past_key_values" | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = False | |
| _supports_cache_class = True # TODO: support different cache | |
| _supports_static_cache = True | |
| def _init_weights(self, module): | |
| std = ( | |
| self.config.initializer_range | |
| if hasattr(self.config, "initializer_range") | |
| else self.config.text_config.initializer_range | |
| ) | |
| if hasattr(module, "class_embedding"): | |
| module.class_embedding.data.normal_(mean=0.0, std=std) | |
| if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.weight.data.fill_(1.0) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| def _no_split_modules(self): | |
| return self.language_model._no_split_modules + self.vision_tower._no_split_modules | |
| class TarsierForConditionalGeneration(TarsierPreTrainedModel, GenerationMixin): | |
| def __init__(self, config: LlavaConfig): | |
| super().__init__(config) | |
| self.vision_tower = AutoModel.from_config(config.vision_config, trust_remote_code=True) | |
| if config.text_config.model_type == 'qwen2': | |
| self.language_model = Qwen2ForCausalLM(config.text_config) | |
| elif config.text_config.model_type == 'qwen2_vl': | |
| self.language_model = Qwen2VLForCausalLM(config.text_config) | |
| elif config.text_config.model_type == 'llama': | |
| self.language_model = LlamaForCausalLM(config.text_config) | |
| else: | |
| raise ValueError(f'{config.text_config.model_type} not supported!') | |
| if config.projection_head == 'Pixel_Shuffle': | |
| self.multi_modal_projector = PixelShuffleMultiModalProjector(config) | |
| elif config.projection_head == 'MLP': | |
| self.multi_modal_projector = LlavaMultiModalProjector(config) | |
| elif config.projection_head == 'auto_map': | |
| repo_id, class_ref = config.auto_map['ProjectionLayer'].split("--") | |
| model_class = get_class_from_dynamic_module(class_ref, repo_id) | |
| self.multi_modal_projector = model_class(config) | |
| elif config.projection_head is None: | |
| self.multi_modal_projector = lambda x, *args, **kwargs: x | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.language_model.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.language_model.set_input_embeddings(value) | |
| def get_output_embeddings(self): | |
| return self.language_model.get_output_embeddings() | |
| def set_output_embeddings(self, new_embeddings): | |
| self.language_model.set_output_embeddings(new_embeddings) | |
| def set_decoder(self, decoder): | |
| self.language_model.set_decoder(decoder) | |
| def get_decoder(self): | |
| return self.language_model.get_decoder() | |
| def tie_weights(self): | |
| return self.language_model.tie_weights() | |
| def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: | |
| model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) | |
| # update vocab size | |
| self.config.text_config.vocab_size = model_embeds.num_embeddings | |
| return model_embeds | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| pixel_values: torch.FloatTensor = None, | |
| image_grid_thw: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| num_images: Optional[torch.Tensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| use_rmpad: Optional[bool] = False, | |
| **kwargs, | |
| ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if input_ids is None: | |
| raise ValueError("You must specify input_ids") | |
| bsz, max_seq_len = input_ids.shape[0], input_ids.shape[1] | |
| if max_seq_len > 1: | |
| special_image_mask = input_ids == self.config.image_token_index | |
| print(f'[{input_ids.device}] num_images: {num_images.tolist()} num_image_tokens: {special_image_mask.sum(-1).tolist()}', flush=True) | |
| if position_ids is None: | |
| if 'Qwen2VLForCausalLM' in self.language_model.__class__.__name__: | |
| position_ids = self.language_model.get_rope_index(input_ids, image_grid_thw, attention_mask) # [bsz, seqlen, 3] | |
| else: | |
| position_ids = attention_mask.long().cumsum(-1) - 1 # # [bsz, seqlen] | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| if use_rmpad: | |
| input_ids, input_ids_indices, cu_seqlens, _ = _unpad_input(input_ids, attention_mask) # [bsz, seqlen] -> [1, seqlen] | |
| position_ids, _, _, _ = _unpad_input(position_ids, attention_mask) | |
| input_ids, position_ids = input_ids.unsqueeze(0), position_ids.unsqueeze(0) | |
| else: | |
| input_ids_indices, cu_seqlens = None, None | |
| inputs_embeds = self.get_input_embeddings()(input_ids) # [1, seqlen, dim] | |
| image_features = None | |
| if pixel_values is not None: # training / first step in generation | |
| if 'Qwen2VLForCausalLM' in self.language_model.__class__.__name__: | |
| pixel_values = pixel_values.type(self.vision_tower.get_dtype()) | |
| image_features = self.vision_tower(pixel_values, image_grid_thw) | |
| else: | |
| image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) | |
| image_features = self.multi_modal_projector( | |
| image_outputs.hidden_states, | |
| self.get_input_embeddings(), | |
| ) | |
| special_image_mask = input_ids == self.config.image_token_index | |
| if special_image_mask.sum() > 0: | |
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) | |
| inputs_embeds = inputs_embeds.masked_scatter( | |
| special_image_mask.unsqueeze(-1).expand_as(inputs_embeds), | |
| image_features | |
| ) | |
| else: | |
| inputs_embeds = image_features.sum(dim=(0,1)) * 0. + inputs_embeds | |
| outputs = self.language_model( | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| use_rmpad=use_rmpad, | |
| cu_seqlens=cu_seqlens, | |
| ) | |
| logits = outputs[0] | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| if use_rmpad: | |
| labels = labels.view(-1)[input_ids_indices.long()] | |
| shift_labels = torch.cat((labels[1:], labels.new_ones((1))*-100)) | |
| shift_labels.requires_grad = False | |
| lbl_seq_lens = (cu_seqlens[1:]-1).long() | |
| shift_labels[lbl_seq_lens] = -100 | |
| loss = loss_fct(logits.squeeze(0), shift_labels) | |
| else: | |
| # Shift so that tokens < n predict n | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| # Flatten the tokens | |
| shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size) | |
| shift_labels = shift_labels.view(-1) | |
| # Enable model parallelism | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| loss = loss_fct(shift_logits, shift_labels) | |
| elif use_rmpad: # 训练的时候,就不 unpad logits 了,节省显存。 | |
| logits = _pad_input(logits.squeeze(0), input_ids_indices, bsz, max_seq_len) | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return LlavaCausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| position_ids=position_ids, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids, | |
| attention_mask=None, | |
| position_ids=None, | |
| past_key_values=None, | |
| cache_position=None, | |
| use_cache=True, | |
| pixel_values=None, | |
| image_grid_thw=None, | |
| **kwargs, | |
| ): | |
| if past_key_values is not None: | |
| past_length = past_key_values.get_seq_length() | |
| input_ids = input_ids[:, past_length:] | |
| model_inputs = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "past_key_values": past_key_values, | |
| "use_cache": use_cache, | |
| } | |
| if kwargs.get('num_images') is not None: | |
| model_inputs['num_images'] = kwargs['num_images'] | |
| if cache_position[0] == 0: | |
| # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore | |
| # Otherwise we need pixel values to be passed to model | |
| model_inputs["pixel_values"] = pixel_values | |
| model_inputs["image_grid_thw"] = image_grid_thw | |
| else: | |
| model_inputs['position_ids'] = position_ids[:, -1, ...].unsqueeze(1).to(device=input_ids.device) + 1 | |
| return model_inputs | |
| def _update_model_kwargs_for_generation( | |
| self, | |
| outputs: ModelOutput, | |
| model_kwargs: Dict[str, Any], | |
| is_encoder_decoder: bool = False, | |
| num_new_tokens: int = 1, | |
| ) -> Dict[str, Any]: | |
| model_kwargs = super()._update_model_kwargs_for_generation( | |
| outputs=outputs, | |
| model_kwargs=model_kwargs, | |
| is_encoder_decoder=is_encoder_decoder, | |
| num_new_tokens=num_new_tokens, | |
| ) | |
| if getattr(outputs, "position_ids", None) is not None: | |
| model_kwargs["position_ids"] = outputs.position_ids | |
| return model_kwargs | |