|
from transformers import PretrainedConfig |
|
from typing import List |
|
import json |
|
|
|
from peft import PeftConfig |
|
from .configuration_vision import Idefics2VisionConfig |
|
from .configuration_internlm2 import InternLM2Config |
|
from .configuration_projector import ProjectorConfig |
|
from .configuration_connector import Idefics2ConnectorConfig |
|
from .image_processor import Idefics2ImageProcessor |
|
from .configuration_downsampler import DownsamplerConfig |
|
|
|
class WeMMConfig(PretrainedConfig): |
|
model_type = "wemm_hf" |
|
|
|
def __init__( |
|
self, |
|
vision_config = None, |
|
text_config = None, |
|
projector_config = None, |
|
connector_config = None, |
|
adapter_path = None, |
|
image_processor = None, |
|
do_image_splitting = False, |
|
spliter_emb_config = None, |
|
downsampler_config = None, |
|
tokenizer_config = None, |
|
**kwargs |
|
): |
|
|
|
if vision_config is not None: |
|
self.vision_config = Idefics2VisionConfig(**vision_config) |
|
|
|
|
|
|
|
if text_config is not None: |
|
self.text_config = InternLM2Config(**text_config) |
|
|
|
|
|
if projector_config is not None: |
|
self.projector_config = ProjectorConfig(**projector_config) |
|
|
|
|
|
if connector_config is not None: |
|
self.connector_config = Idefics2ConnectorConfig(**connector_config) |
|
|
|
if image_processor is not None: |
|
self.image_processor = image_processor |
|
|
|
|
|
if adapter_path is not None: |
|
self.adapter_path = adapter_path |
|
|
|
self.do_image_splitting = do_image_splitting |
|
|
|
if spliter_emb_config is not None: |
|
self.spliter_emb_config = spliter_emb_config |
|
|
|
if downsampler_config is not None: |
|
self.downsampler_config = DownsamplerConfig(**downsampler_config) |
|
|
|
if tokenizer_config is not None: |
|
self.tokenizer_config = tokenizer_config |
|
|
|
super().__init__(**kwargs) |
|
|
|
if __name__=="__main__": |
|
wemm_config_path = "/mnt/csp/mmvision/home/feipengma/projects/wemm_evaluation/WeMM/config.json" |
|
wemm_config = WeMMConfig.from_pretrained(wemm_config_path) |
|
print(wemm_config.connector_config) |