|
from typing import Any |
|
|
|
from transformers import AutoConfig, PretrainedConfig, LlamaConfig |
|
|
|
|
|
class VisionConfig(PretrainedConfig): |
|
model_type: str = "vision_model" |
|
|
|
def __init__( |
|
self, |
|
**kwargs: Any, |
|
): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class ConnectorConfig(PretrainedConfig): |
|
model_type: str = "connector" |
|
|
|
def __init__( |
|
self, |
|
**kwargs: Any, |
|
): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class VLMConfig(LlamaConfig): |
|
model_type: str = "vlm" |
|
|
|
_sub_config_classes: dict[str, type[PretrainedConfig]] = { |
|
"vision_config": VisionConfig, |
|
"connector_config": ConnectorConfig, |
|
} |
|
|
|
def __init__( |
|
self, |
|
vision_config_args: dict[str, Any] = None, |
|
connector_config_args: dict[str, Any] = None, |
|
lazy_load: bool = False, |
|
**kwargs: Any, |
|
): |
|
final_vision_args = kwargs.pop("vision_config", vision_config_args) |
|
final_connector_args = kwargs.pop("connector_config", connector_config_args) |
|
|
|
self.vision_config: VisionConfig = VisionConfig(**(final_vision_args or {})) |
|
self.connector_config: ConnectorConfig = ConnectorConfig(**(final_connector_args or {})) |
|
self.lazy_load: bool = lazy_load |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
AutoConfig.register("vlm", VLMConfig) |