|
from transformers import PretrainedConfig |
|
from transformers import logging |
|
from transformers import CONFIG_MAPPING |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class XGenMMVisionEncoderConfig(PretrainedConfig): |
|
model_type = "xgenmm_vision_encoder" |
|
|
|
def __init__(self, |
|
model_name: str = 'ViT-H-14-378-quickgelu', |
|
force_image_size: int = 378, |
|
**kwargs): |
|
self.model_name = model_name |
|
self.force_image_size = force_image_size |
|
super().__init__(**kwargs) |
|
|
|
|
|
class XGenMMVisionTokenizerConfig(PretrainedConfig): |
|
model_type = "xgenmm_vision_tokenizer" |
|
|
|
def __init__(self, |
|
vis_feature_dim: int = 1280, |
|
lang_embedding_dim: int = 3072, |
|
num_vis_tokens: int = 128, |
|
image_aspect_ratio: str = 'anyres', |
|
repeat_latents: bool = False, |
|
**kwargs): |
|
self.vis_feature_dim = vis_feature_dim |
|
self.lang_embedding_dim = lang_embedding_dim |
|
self.num_vis_tokens = num_vis_tokens |
|
self.image_aspect_ratio = image_aspect_ratio |
|
self.repeat_latents = repeat_latents |
|
super().__init__(**kwargs) |
|
|
|
|
|
class XGenMMConfig(PretrainedConfig): |
|
model_type = "xgenmm" |
|
|
|
def __init__(self, |
|
vision_encoder_config: dict = None, |
|
vision_tokenizer_config: dict = None, |
|
text_config: dict = None, |
|
**kwargs): |
|
|
|
if vision_encoder_config is None: |
|
vision_encoder_config = {'image_aspect_ratio': 'anyres', 'anyres_patch_sampling': True} |
|
logger.info("vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values.") |
|
|
|
if vision_tokenizer_config is None: |
|
vision_tokenizer_config = {} |
|
logger.info("vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values.") |
|
|
|
if text_config is None: |
|
text_config = { |
|
'initial_tokenizer_len':32012, |
|
'pad_token_id':32011, |
|
'bos_token_id':1, |
|
'eos_token_id':32000, |
|
'vocab_size': 32064, |
|
'hidden_size': 3072, |
|
'intermediate_size': 8192, |
|
'num_hidden_layers': 32, |
|
'num_attention_heads': 32, |
|
'num_key_value_heads': 32, |
|
'resid_pdrop': 0.0, |
|
'embd_pdrop': 0.0, |
|
'attention_dropout': 0.0, |
|
'hidden_act': 'silu', |
|
'max_position_embeddings': 4096, |
|
'original_max_position_embeddings': 4096, |
|
'initializer_range': 0.02, |
|
'rms_norm_eps': 1e-05, |
|
'use_cache': True, |
|
'rope_theta': 10000.0, |
|
'rope_scaling': None, |
|
'sliding_window': 2047, |
|
'return_dict': True, |
|
'output_hidden_states': False, |
|
'output_attentions': False, |
|
'torchscript': False, |
|
'torch_dtype': 'bfloat16', |
|
'use_bfloat16': False, |
|
'tf_legacy_loss': False, |
|
'pruned_heads': {}, |
|
'tie_word_embeddings': False, |
|
'chunk_size_feed_forward': 0, |
|
'is_encoder_decoder': False, |
|
'is_decoder': False, |
|
'cross_attention_hidden_size': None, |
|
'add_cross_attention': False, |
|
'tie_encoder_decoder': False, |
|
'max_length': 20, |
|
'min_length': 0, |
|
'do_sample': False, |
|
'early_stopping': False, |
|
'num_beams': 1, |
|
'num_beam_groups': 1, |
|
'diversity_penalty': 0.0, |
|
'temperature': 1.0, |
|
'top_k': 50, |
|
'top_p': 1.0, |
|
'typical_p': 1.0, |
|
'repetition_penalty': 1.0, |
|
'length_penalty': 1.0, |
|
'no_repeat_ngram_size': 0, |
|
'encoder_no_repeat_ngram_size': 0, |
|
'bad_words_ids': None, |
|
'num_return_sequences': 1, |
|
'output_scores': False, |
|
'return_dict_in_generate': False, |
|
'forced_bos_token_id': None, |
|
'forced_eos_token_id': None, |
|
'remove_invalid_values': False, |
|
'exponential_decay_length_penalty': None, |
|
'suppress_tokens': None, |
|
'begin_suppress_tokens': None, |
|
'finetuning_task': None, |
|
'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, |
|
'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, |
|
'tokenizer_class': None, |
|
'prefix': None, |
|
'bos_token_id': 1, |
|
'pad_token_id': 32000, |
|
'eos_token_id': 32000, |
|
'sep_token_id': None, |
|
'decoder_start_token_id': None, |
|
'task_specific_params': None, |
|
'problem_type': None, |
|
'model_type': 'phi3' |
|
} |
|
logger.info("text_config is None. Initializing the text config with default values (`Phi3Config`).") |
|
|
|
self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config) |
|
|
|
self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(**vision_tokenizer_config) |
|
|
|
text_model_type = text_config["model_type"] if "model_type" in text_config else "phi3" |
|
self.text_config = CONFIG_MAPPING[text_model_type](**text_config) |
|
|
|
for key in ['initial_tokenizer_len', 'pad_token_id']: |
|
if key not in self.text_config.to_dict(): |
|
raise ValueError(f"The key `{key}` is missing in the text_config.") |
|
|
|
super().__init__(**kwargs) |
|
|
|
@classmethod |
|
def from_vision_encoder_vision_tokenizer_text_configs( |
|
cls, |
|
vision_encoder_config: XGenMMVisionEncoderConfig, |
|
vision_tokenizer_config: XGenMMVisionTokenizerConfig, |
|
text_config: PretrainedConfig, |
|
**kwargs): |
|
|
|
return cls( |
|
vision_encoder_config=vision_encoder_config.to_dict(), |
|
vision_tokenizer_config=vision_tokenizer_config.to_dict(), |
|
text_config=text_config.to_dict(), |
|
**kwargs, |
|
) |
|
|