from collections import Counter, defaultdict import numpy as np import torch import torch.nn.functional as F from torch import Tensor from torch import nn from torch.nn import CrossEntropyLoss import copy import math from transformers.activations import gelu from typing import List, Optional, Tuple, Union from transformers.modeling_utils import PreTrainedModel, PretrainedConfig from transformers import CONFIG_MAPPING from transformers.modeling_outputs import BaseModelOutput from transformers import GenerationConfig from transformers import CLIPConfig, CLIPProcessor, CLIPModel, AutoModel from transformers import WhisperConfig, WhisperPreTrainedModel, WhisperModel from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig def most_frequent_element(tensor): flattened_list = tensor.flatten().tolist() counter = Counter(flattened_list) most_common_element = counter.most_common(1)[0][1] return most_common_element class MM_LLMs_Config(PretrainedConfig): model_type = 'mm_llms' is_composition = True def __init__( self, image_config=None, llm_config=None, vision_select_layer=None, **kwargs ): self.image_config = image_config self.llm_config = llm_config self.vision_select_layer = vision_select_layer if isinstance(self.image_config, dict): image_config["model_type"] = ( image_config["model_type"] if "model_type" in image_config else "clip" ) self.image_config = CONFIG_MAPPING[image_config["model_type"]](**image_config) if isinstance(self.llm_config, dict): llm_config["model_type"] = llm_config["model_type"] if "model_type" in llm_config else "llama" self.llm_config = CONFIG_MAPPING[llm_config["model_type"]](**llm_config) super().__init__(**kwargs) class LlavaMultiModalProjector(nn.Module): def __init__(self, in_hidden_size, out_hidden_size, conv_kernel=None, conv_stride=3): super().__init__() self.conv_kernel = conv_kernel if conv_kernel: self.linear_1 = nn.Conv1d( in_hidden_size, out_hidden_size, kernel_size=conv_kernel, stride=conv_stride) else: self.linear_1 = nn.Linear( in_hidden_size, out_hidden_size, bias=True, ) self.act = gelu self.linear_2 = nn.Linear( out_hidden_size, out_hidden_size, bias=True ) def forward(self, image_features): hidden_states = self.linear_1(image_features) if self.conv_kernel: hidden_states = hidden_states.transpose(1, 2).contiguous() hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states class MM_LLMs(PreTrainedModel): config_class = MM_LLMs_Config supports_gradient_checkpointing = True _supports_flash_attn_2 = True def __init__(self, config, flash_attention=False, dtype=torch.float32): super().__init__(config) self.config = config self.image_encoder = AutoModel.from_config(config.image_config) self.llm = AutoModelForCausalLM.from_config( config.llm_config, use_flash_attention_2=flash_attention, torch_dtype=dtype, ) self.image_projector = LlavaMultiModalProjector( config.image_config.vision_config.hidden_size, config.llm_config.hidden_size ) def forward(self, input_ids: torch.LongTensor = None, image_index: torch.LongTensor = None, audio_index: torch.LongTensor = None, image_starts: torch.int = None, image_ends: torch.int = None, audio_starts: torch.int = None, audio_ends: torch.int = None, images: torch.FloatTensor = None, audios: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs): return_dict = return_dict if return_dict is not None else self.config.use_return_dict images = images.type(self.image_encoder.dtype) if images is not None else None audios = audios.type(self.audio_encoder.dtype) if audios is not None else None model_inputs = self.prepare_inputs_for_generation( input_ids=input_ids, image_index=image_index, audio_index=audio_index, image_starts=image_starts, image_ends=image_ends, audio_starts=audio_starts, audio_ends=audio_ends, images=images, audios=audios, attention_mask=attention_mask, labels=labels) outputs = self.llm( inputs_embeds=model_inputs['inputs_embeds'], attention_mask=model_inputs['attention_mask'], labels=model_inputs['labels'], return_dict=return_dict) return outputs def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, images=None, audios=None, audio_starts=None, audio_ends=None, image_starts=None, image_ends=None, attention_mask=None, labels=None, audio_index=None, image_index=None, **kwargs): image_features = self.encode_image( images) if images is not None else None embed_tokens = self.llm.model.embed_tokens text_embeddings = embed_tokens(input_ids) batch_size = text_embeddings.shape[0] seq_len = text_embeddings.shape[1] embed_dim = text_embeddings.shape[2] max_count_image = most_frequent_element(image_index) seq_image = image_features.shape[1] new_len = text_embeddings.shape[1] + seq_image * max_count_image final_embedding = torch.zeros( batch_size, new_len, embed_dim, device=text_embeddings.device, dtype=text_embeddings.dtype ) final_embedding[:, :seq_len] = text_embeddings final_attention_mask = torch.zeros( batch_size, new_len, device=attention_mask.device, dtype=attention_mask.dtype ) final_attention_mask[:, :seq_len] = attention_mask if labels is not None: final_labels = torch.full( (batch_size, new_len), -100, device=labels.device, dtype=labels.dtype ) final_labels[:, :seq_len] = labels else: final_labels = None image_id = int(image_starts[0]) where_is = torch.where(input_ids == image_id) positions = defaultdict(int) b_image = 0 for i in range(len(where_is[0])): b, k = where_is[0][i], where_is[1][i] int_b = int(b) int_k = int(k) l = int(input_ids[b, k]) f = image_features[b_image] b_image += 1 c = torch.cat([final_embedding[b, :int_k + 1 + positions[int_b]], f, text_embeddings[b, k + 1:]]) final_embedding[b, :len(c)] = c final_attention_mask[b, :len(c)] = 1.0 if labels is not None: ignore = torch.tensor([-100] * len(f), device=labels.device) c_label = torch.cat( [final_labels[b, :int_k + 1 + positions[int_b]], ignore, labels[b, k + 1:]]) final_labels[b, :len(c)] = c_label positions[int_b] += len(f) model_inputs = { "input_ids": input_ids, "inputs_embeds": final_embedding, "use_cache": kwargs.get("use_cache"), "attention_mask": final_attention_mask, "labels": final_labels, } return model_inputs def encode_image(self, images): if self.config.vision_select_layer is not None: encoded = self.image_encoder.vision_model(images, output_hidden_states=True) encoded = encoded.hidden_states[self.config.vision_select_layer] else: encoded = self.image_encoder.vision_model(images)[0] image_features = self.image_projector(encoded) return image_features