xgen-mm-phi3-mini-instruct-r-v1 / modeling_xgenmm.py
UncleFish's picture
update inference code to support transformers==4.41.1
854508f
from transformers import PreTrainedModel, AutoModelForCausalLM
import torch
import open_clip
from typing import List, Optional, Tuple, Union
from .utils import check_embedding_fns
from .vlm import InstructPerceiverResampler, KosmosInstruct
from .configuration_xgenmm import XGenMMVisionEncoderConfig, XGenMMVisionTokenizerConfig, XGenMMConfig
class XGenMMVisionEncoder(PreTrainedModel):
main_input_name = "pixel_values"
config_class = XGenMMVisionEncoderConfig
def __init__(self, config: XGenMMVisionEncoderConfig):
super().__init__(config)
if config.model_name != 'ViT-H-14-378-quickgelu':
raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.")
self.model, _, _ = open_clip.create_model_and_transforms(
model_name = config.model_name,
force_image_size=config.force_image_size
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
# assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
return self.model.encode_image(pixel_values)
# vision tokenizer
class XGenMMVisionTokenizer(PreTrainedModel):
config_class = XGenMMVisionTokenizerConfig
def __init__(self, config: XGenMMVisionTokenizerConfig):
super().__init__(config)
self.model = InstructPerceiverResampler(
dim_llm=config.lang_embedding_dim,
dim=config.vis_feature_dim,
dim_inner=config.lang_embedding_dim,
num_latents=config.num_vis_tokens,
repeat_latents=config.repeat_latents
)
def forward(self,
vision_features: torch.Tensor,
vision_attn_masks: torch.Tensor):
return self.model(vision_features, vision_attn_masks)
# XGenMM model
class XGenMMModelForConditionalGeneration(PreTrainedModel):
config_class = XGenMMConfig
def __init__(self, config: XGenMMConfig):
super().__init__(config)
# vision encoder initialization
vision_encoder = XGenMMVisionEncoder(config.vision_encoder_config).model
vision_encoder.visual.output_tokens = True
vision_encoder = vision_encoder.visual
# language model initialization
language_model = AutoModelForCausalLM.from_config(config.text_config)
check_embedding_fns(language_model)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
# vision tokenizer initialization
if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]:
overwrite = language_model.get_input_embeddings().weight.shape[1]
config.vision_tokenizer_config.lang_embedding_dim = overwrite
print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.")
vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
self.vlm = KosmosInstruct(
vision_encoder=vision_encoder,
vision_tokenizer=vision_tokenizer,
lang_model=language_model,
initial_tokenizer_len = config.text_config.initial_tokenizer_len,
pad_token_id = config.text_config.pad_token_id,
image_aspect_ratio = config.vision_encoder_config.image_aspect_ratio,
anyres_patch_sampling = config.vision_encoder_config.anyres_patch_sampling
)
# Initialize weights and apply final processing
self.post_init()
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**generate_kwargs,
) -> torch.LongTensor:
self.vlm = self.vlm.eval()
return self.vlm.generate(
vision_x = pixel_values,
lang_x = input_ids,
attention_mask = attention_mask,
**generate_kwargs)
def update_special_tokens(self, tokenizer):
tokenizer.add_special_tokens(
{"additional_special_tokens": list(self.vlm.special_tokens.values())}
)
self.vlm.lang_model.config.vocab_size = len(tokenizer)
self.vlm.set_special_token_ids(
{
v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values()
}
)
return tokenizer