Manli's picture
initial commit
0a0d29d
raw
history blame
4.65 kB
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel
import torch
import open_clip
from typing import List, Optional, Tuple, Union
from utils import check_embedding_fns
from vlm import PerceiverResampler, XGenMMPerceiver
from configuration_xgenmm import XGenMMVisionEncoderConfig, XGenMMVisionTokenizerConfig, XGenMMConfig
class XGenMMVisionEncoder(PreTrainedModel):
main_input_name = "pixel_values"
config_class = XGenMMVisionEncoderConfig
def __init__(self, config: XGenMMVisionEncoderConfig):
super().__init__(config)
if config.model_name != 'google/siglip-so400m-patch14-384':
raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.")
self.model = AutoModel.from_pretrained(config.model_name)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
# assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
return self.model.encode_image(pixel_values)
# vision tokenizer
class XGenMMVisionTokenizer(PreTrainedModel):
config_class = XGenMMVisionTokenizerConfig
def __init__(self, config: XGenMMVisionTokenizerConfig):
super().__init__(config)
self.model = PerceiverResampler(
dim=config.vis_feature_dim,
dim_inner=config.lang_embedding_dim,
num_latents=config.num_vis_tokens,
)
def forward(self,
vision_features: torch.Tensor,
vision_attn_masks: torch.Tensor):
return self.model(vision_features, vision_attn_masks)
# XGenMM model
class XGenMMModelForConditionalGeneration(PreTrainedModel):
config_class = XGenMMConfig
def __init__(self, config: XGenMMConfig):
super().__init__(config)
# vision encoder initialization
vision_encoder = AutoModel.from_pretrained(config.vision_encoder_config.model_name).vision_model
# language model initialization
language_model = AutoModelForCausalLM.from_config(config.text_config)
check_embedding_fns(language_model)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
# vision tokenizer initialization
if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]:
overwrite = language_model.get_input_embeddings().weight.shape[1]
config.vision_tokenizer_config.lang_embedding_dim = overwrite
print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.")
vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
self.vlm = XGenMMPerceiver(
vision_encoder=vision_encoder,
vision_tokenizer=vision_tokenizer,
lang_model=language_model,
initial_tokenizer_len = config.text_config.initial_tokenizer_len,
pad_token_id = config.text_config.pad_token_id,
image_aspect_ratio = config.vision_encoder_config.image_aspect_ratio,
anyres_patch_sampling = config.vision_encoder_config.anyres_patch_sampling,
anyres_grids = config.vision_encoder_config.anyres_grids
)
# Initialize weights and apply final processing
self.post_init()
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**generate_kwargs,
) -> torch.LongTensor:
self.vlm = self.vlm.eval()
return self.vlm.generate(
vision_x = pixel_values,
lang_x = input_ids,
attention_mask = attention_mask,
**generate_kwargs)
def update_special_tokens(self, tokenizer):
tokenizer.add_special_tokens(
{"additional_special_tokens": list(self.vlm.special_tokens.values())}
)
self.vlm.lang_model.config.vocab_size = len(tokenizer)
self.vlm.set_special_token_ids(
{
v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values()
}
)
return tokenizer