import torch import torch.nn as nn from transformers import ( PreTrainedModel, AutoModelForCausalLM, AutoModel, SiglipImageProcessor, ) from .configuration_doubutsu_next import DoubutsuNextConfig from .utils import slice_anyres_image class ProjectionModule(nn.Module): def __init__(self, mm_hidden_size=1152, hidden_size=1536): super(ProjectionModule, self).__init__() self.model = nn.Sequential( nn.Linear(mm_hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, hidden_size), ) def forward(self, x): return self.model(x) class DoubutsuNext(PreTrainedModel): config_class = DoubutsuNextConfig def __init__(self, config): super().__init__(config) self.vision_model = AutoModel.from_config(self.config.vision_config) self.text_model = AutoModelForCausalLM.from_config(self.config.text_config) self.processor = SiglipImageProcessor() self.mm_projector = ProjectionModule( mm_hidden_size=config.vision_config.hidden_size, hidden_size=config.text_config.hidden_size, ) @property def device(self): return self.text_model.device def encode_image(self, image): image_patches = slice_anyres_image(image) encoded_patches = [] for patch in image_patches: patch = patch.convert("RGB") processed_patch = self.processor( images=patch, return_tensors="pt", do_resize=True, size={"height": 378, "width": 378}, )["pixel_values"].to( device=self.vision_model.device, dtype=self.vision_model.dtype ) with torch.no_grad(): encoded_patch = self.vision_model( processed_patch, output_hidden_states=True ).hidden_states[-2] encoded_patches.append(encoded_patch) return torch.cat( encoded_patches, dim=1 ) # Concatenate along the sequence dimension def input_embeds(self, prompt, image_embeds, tokenizer): def _tokenize(txt): return tokenizer( txt, return_tensors="pt", add_special_tokens=False ).input_ids.to(self.device) text_emb = self.text_model.get_input_embeddings() embeds = [] tokenized_prompt = _tokenize(prompt) # Add BOS token if it exists and isn't already at the start of the prompt if tokenizer.bos_token_id is not None: if tokenized_prompt[0][0] == tokenizer.bos_token_id: tokenized_prompt = tokenized_prompt[:, 1:] # Remove existing BOS embeds.append( text_emb(torch.tensor([[tokenizer.bos_token_id]], device=self.device)) ) # Add image embeds projected_image_embeds = self.mm_projector(image_embeds.to(self.device)) embeds.append(projected_image_embeds) # Add text embeds embeds.append(text_emb(tokenized_prompt)) return torch.cat(embeds, dim=1) def get_input_embeddings(self): return self.text_model.get_input_embeddings() def generate( self, image_embeds, prompt, tokenizer, max_new_tokens=128, temperature=0.1, **kwargs, ): generate_config = { "eos_token_id": tokenizer.eos_token_id, "bos_token_id": tokenizer.bos_token_id, "pad_token_id": tokenizer.pad_token_id, "max_new_tokens": max_new_tokens, "temperature": temperature, **kwargs, } with torch.no_grad(): inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer) output_ids = self.text_model.generate( inputs_embeds=inputs_embeds, do_sample=True, **generate_config, ) return tokenizer.batch_decode(output_ids, skip_special_tokens=True) def answer_question(self, image, question, tokenizer, **kwargs): image_embeds = self.encode_image(image) chat = [ { "role": "system", "content": "You are a helpful AI assistant that can see images and answer questions about them.", }, {"role": "user", "content": question}, ] prompt = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True ) # Generate the answer with torch.no_grad(): output = self.generate( image_embeds=image_embeds, prompt=prompt, tokenizer=tokenizer, **kwargs, )[0] # Clean and return the answer cleaned_answer = output.strip() return cleaned_answer