|
from transformers import PreTrainedModel, AutoModelForCausalLM |
|
import torch |
|
import open_clip |
|
from typing import List, Optional, Tuple, Union |
|
from .utils import check_embedding_fns |
|
from .vlm import InstructPerceiverResampler, KosmosInstruct |
|
from .configuration_xgenmm import XGenMMVisionEncoderConfig, XGenMMVisionTokenizerConfig, XGenMMConfig |
|
|
|
class XGenMMVisionEncoder(PreTrainedModel): |
|
main_input_name = "pixel_values" |
|
config_class = XGenMMVisionEncoderConfig |
|
|
|
def __init__(self, config: XGenMMVisionEncoderConfig): |
|
super().__init__(config) |
|
if config.model_name != 'ViT-H-14-378-quickgelu': |
|
raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.") |
|
self.model, _, _ = open_clip.create_model_and_transforms( |
|
model_name = config.model_name, |
|
force_image_size=config.force_image_size |
|
) |
|
|
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
|
|
|
return self.model.encode_image(pixel_values) |
|
|
|
|
|
|
|
class XGenMMVisionTokenizer(PreTrainedModel): |
|
config_class = XGenMMVisionTokenizerConfig |
|
def __init__(self, config: XGenMMVisionTokenizerConfig): |
|
super().__init__(config) |
|
self.model = InstructPerceiverResampler( |
|
dim_llm=config.lang_embedding_dim, |
|
dim=config.vis_feature_dim, |
|
dim_inner=config.lang_embedding_dim, |
|
num_latents=config.num_vis_tokens, |
|
repeat_latents=config.repeat_latents |
|
) |
|
|
|
def forward(self, |
|
vision_features: torch.Tensor, |
|
vision_attn_masks: torch.Tensor): |
|
return self.model(vision_features, vision_attn_masks) |
|
|
|
|
|
class XGenMMModelForConditionalGeneration(PreTrainedModel): |
|
config_class = XGenMMConfig |
|
|
|
def __init__(self, config: XGenMMConfig): |
|
super().__init__(config) |
|
|
|
|
|
vision_encoder = XGenMMVisionEncoder(config.vision_encoder_config).model |
|
vision_encoder.visual.output_tokens = True |
|
vision_encoder = vision_encoder.visual |
|
|
|
|
|
language_model = AutoModelForCausalLM.from_config(config.text_config) |
|
check_embedding_fns(language_model) |
|
|
|
if language_model._tied_weights_keys is not None: |
|
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] |
|
|
|
|
|
if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]: |
|
overwrite = language_model.get_input_embeddings().weight.shape[1] |
|
config.vision_tokenizer_config.lang_embedding_dim = overwrite |
|
print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.") |
|
|
|
vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model |
|
|
|
self.vlm = KosmosInstruct( |
|
vision_encoder=vision_encoder, |
|
vision_tokenizer=vision_tokenizer, |
|
lang_model=language_model, |
|
initial_tokenizer_len = config.text_config.initial_tokenizer_len, |
|
pad_token_id = config.text_config.pad_token_id, |
|
image_aspect_ratio = config.vision_encoder_config.image_aspect_ratio, |
|
anyres_patch_sampling = config.vision_encoder_config.anyres_patch_sampling |
|
) |
|
|
|
self.post_init() |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
**generate_kwargs, |
|
) -> torch.LongTensor: |
|
self.vlm = self.vlm.eval() |
|
return self.vlm.generate( |
|
vision_x = pixel_values, |
|
lang_x = input_ids, |
|
attention_mask = attention_mask, |
|
**generate_kwargs) |
|
|
|
def update_special_tokens(self, tokenizer): |
|
tokenizer.add_special_tokens( |
|
{"additional_special_tokens": list(self.vlm.special_tokens.values())} |
|
) |
|
self.vlm.lang_model.config.vocab_size = len(tokenizer) |
|
self.vlm.set_special_token_ids( |
|
{ |
|
v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values() |
|
} |
|
) |
|
return tokenizer |
|
|
|
|