aa / ovis /model /configuration_ovis.py
root
Ajout du module Ovis
b4942cf
raw
history blame
1.99 kB
from typing import Union, Optional
from transformers import PretrainedConfig, AutoConfig
class OvisConfig(PretrainedConfig):
model_type = "ovis"
def __init__(
self,
llm_config: Optional[Union[PretrainedConfig, dict]] = None,
visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None,
multimodal_max_length=8192,
hidden_size=None,
conversation_formatter_class=None,
llm_attn_implementation=None,
disable_tie_weight=False,
**kwargs
):
super().__init__(**kwargs)
if llm_config is not None:
assert isinstance(llm_config, (PretrainedConfig, dict)), \
f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type"
if not isinstance(llm_config, PretrainedConfig):
model_type = llm_config['model_type']
llm_config.pop('model_type')
llm_config = AutoConfig.for_model(model_type, **llm_config)
self.llm_config = llm_config
if visual_tokenizer_config is not None:
assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \
f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type"
if not isinstance(visual_tokenizer_config, PretrainedConfig):
model_type = visual_tokenizer_config['model_type']
visual_tokenizer_config.pop('model_type')
visual_tokenizer_config = AutoConfig.for_model(model_type, **visual_tokenizer_config)
self.visual_tokenizer_config = visual_tokenizer_config
self.multimodal_max_length = multimodal_max_length
self.hidden_size = hidden_size
self.conversation_formatter_class = conversation_formatter_class
self.llm_attn_implementation = llm_attn_implementation
self.disable_tie_weight = disable_tie_weight