|
from typing import Union, Optional |
|
|
|
from transformers import PretrainedConfig, AutoConfig |
|
from .visual_tokenizer import ClipVisualTokenizerConfig |
|
|
|
class OvisConfig(PretrainedConfig): |
|
model_type = "ovis" |
|
|
|
def __init__(self, |
|
llm_config: Optional[Union[PretrainedConfig, dict]] = None, |
|
visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None, |
|
multimodal_max_length=2048, |
|
hidden_size=None, |
|
conversation_formatter_class=None, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
if llm_config is not None: |
|
assert isinstance(llm_config, (PretrainedConfig, dict)), \ |
|
f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type" |
|
if not isinstance(llm_config, PretrainedConfig): |
|
model_type = llm_config['model_type'] |
|
llm_config.pop('model_type') |
|
llm_config = AutoConfig.for_model(model_type, **llm_config) |
|
self.llm_config = llm_config |
|
if visual_tokenizer_config is not None: |
|
assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \ |
|
f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type" |
|
if not isinstance(visual_tokenizer_config, PretrainedConfig): |
|
model_type = visual_tokenizer_config['model_type'] |
|
visual_tokenizer_config.pop('model_type') |
|
visual_tokenizer_config = AutoConfig.for_model(model_type, **visual_tokenizer_config) |
|
self.visual_tokenizer_config = visual_tokenizer_config |
|
self.multimodal_max_length = multimodal_max_length |
|
self.hidden_size = hidden_size |
|
self.conversation_formatter_class = conversation_formatter_class |
|
|