import torch import torch.nn as nn from transformers import ( PreTrainedModel, AutoModelForCausalLM, AutoModel, SiglipImageProcessor, ) from .configuration_llamavision import LlamavisionConfig class ProjectionModule(nn.Module): def __init__(self, mm_hidden_size=1152, hidden_size=4096): super(ProjectionModule, self).__init__() # Directly set up the sequential model self.model = nn.Sequential( nn.Linear(mm_hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, hidden_size), ) def forward(self, x): return self.model(x) class Llamavision(PreTrainedModel): config_class = LlamavisionConfig def __init__(self, config): super().__init__(config) self.text_model = AutoModelForCausalLM.from_config(config.text_config) self.vision_model = AutoModel.from_config(config.vision_config) self.processor = SiglipImageProcessor() self.mm_projector = ProjectionModule() @property def device(self): return self.text_model.device def tokenizer_image_token( self, prompt, tokenizer, image_token_index=-200, return_tensors=None ): prompt_chunks = [ tokenizer(chunk).input_ids for chunk in prompt.split("") ] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if ( len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id ): offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) return torch.tensor(input_ids, dtype=torch.long) def process_tensors(self, input_ids, image_features, embedding_layer): # Find the index of -200 in input_ids split_index = (input_ids == -200).nonzero(as_tuple=True)[1][0] # Split the input_ids at the index found, excluding -200 input_ids_1 = input_ids[:, :split_index] input_ids_2 = input_ids[:, split_index + 1 :] # Convert input_ids to embeddings embeddings_1 = embedding_layer(input_ids_1) embeddings_2 = embedding_layer(input_ids_2) device = image_features.device token_embeddings_part1 = embeddings_1.to(device) token_embeddings_part2 = embeddings_2.to(device) # Concatenate the token embeddings and image features concatenated_embeddings = torch.cat( [token_embeddings_part1, image_features, token_embeddings_part2], dim=1 ) # Create the corrected attention mask attention_mask = torch.ones( concatenated_embeddings.shape[:2], dtype=torch.long, device=device ) return concatenated_embeddings, attention_mask def answer_question(self, image, question, tokenizer, **kwargs): question = "" + question prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" input_ids = ( self.tokenizer_image_token(prompt, tokenizer, -200, return_tensors="pt") .unsqueeze(0) .to(self.device) ) terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>"), ] with torch.inference_mode(): image_inputs = self.processor( images=[image], return_tensors="pt", do_resize=True, size={"height": 384, "width": 384}, ) image_inputs = image_inputs["pixel_values"].to( device=self.device, dtype=self.dtype ) image_forward_outs = self.vision_model( image_inputs, output_hidden_states=True, ) image_features = image_forward_outs.hidden_states[-2] projected_embeddings = self.mm_projector(image_features).to(self.device) embedding_layer = self.text_model.get_input_embeddings() # text_embeddings = embedding_layer(input_ids) new_embeds, attn_mask = self.process_tensors( input_ids, projected_embeddings, embedding_layer ) attn_mask = attn_mask.to(self.device) new_embeds = new_embeds.to(self.device) answer = self.text_model.generate( inputs_embeds=new_embeds, attention_mask=attn_mask, eos_token_id=terminators, temperature=0.2, do_sample=True, **kwargs, )[0] return answer