japanese-clip-vit-h-14-bert-wider / modeling_custom_clip.py
bsyx001's picture
Upload model
7211a2e verified
raw
history blame
1.42 kB
"""
Subclasses VisionTextDualEncoderModel to customize text pooler.
"""
from typing import Optional
import torch
from transformers import AutoModel, VisionTextDualEncoderModel
from .configuration_custom_clip import CustomCLIPConfig, get_text_model_pooler
# @add_start_docstrings(CUSTOM_CLIP_START_DOCSTRING)
class CustomCLIPModel(VisionTextDualEncoderModel):
config_class = CustomCLIPConfig
DEFAULT_TEXT_MODEL_POOLER_TYPE: torch.nn.Module = get_text_model_pooler(
CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_STR
)
DEFAULT_TEXT_MODEL_POOLER_KWARGS: dict = (
CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_KWARGS
)
def __init__(
self, config: Optional[CustomCLIPConfig.__base__] = None, *args, **kwargs
):
config = config if config is None else CustomCLIPConfig.from_base(config)
super().__init__(
config, # surprisingly, `super` is unnecessary, possibly due to implementation of CustomCLIPConfig.__init__?
*args,
**kwargs,
)
self.text_model.pooler = (
(self.DEFAULT_TEXT_MODEL_POOLER_TYPE)(
**self.DEFAULT_TEXT_MODEL_POOLER_KWARGS
)
if config is None
else get_text_model_pooler(config.text_model_pooler)(
**config.text_model_pooler_kwargs
)
)
AutoModel.register(CustomCLIPConfig, CustomCLIPModel)