apollonia-7b / mm_connector /configuration_connector.py
nisten's picture
Add files using upload-large-folder tool
deb6397 verified
raw
history blame
1.31 kB
import torch
import torch.nn as nn
from typing import Dict, List, Union
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
import torch.nn.functional as F
import json, os
class ConnectorConfig(PretrainedConfig):
model_type = "mm_connector"
def __init__(
self,
vision_hidden_size: List[int] = [],
text_hidden_size: int = 0,
num_patches: int = 24,
rms_norm_eps: float = 1e-4,
token_input_shape: List[int] = [],
**kwargs,
):
super().__init__(**kwargs)
self.vision_hidden_size = vision_hidden_size
self.text_hidden_size = text_hidden_size
self.num_patches = num_patches
self.rms_norm_eps=rms_norm_eps
self.token_input_shape = token_input_shape
@classmethod
def load_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "ConnectorConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_from_json(pretrained_model_name_or_path, **kwargs)
return cls.from_dict(config_dict, **kwargs)
@classmethod
def get_config_from_json(cls, config_file, **kwargs):
with open(config_file, 'r') as file:
config_data = json.load(file)
return config_data, kwargs