koclip / koclip /model.py
jaketae's picture
feature: add streamlit backbone
f1d50b1
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
from transformers.utils import logging
from .config import HybridCLIPConfig
logger = logging.get_logger(__name__)
class FlaxHybridCLIPModule(nn.Module):
config: HybridCLIPConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
text_config = self.config.text_config
vision_config = self.config.vision_config
self.projection_dim = self.config.projection_dim
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
vision_module = FLAX_MODEL_MAPPING.get(
self.config.vision_config.__class__, FlaxCLIPVisionModel
).module_class
self.text_model = text_module(text_config, dtype=self.dtype)
self.vision_model = vision_module(vision_config, dtype=self.dtype)
self.visual_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
use_bias=False,
)
self.text_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
use_bias=False,
)
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
def __call__(
self,
input_ids=None,
pixel_values=None,
attention_mask=None,
position_ids=None,
token_type_ids=None,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / jnp.linalg.norm(
image_embeds, axis=-1, keepdims=True
)
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
# cosine similarity as logits
logit_scale = jnp.exp(self.logit_scale)
logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
logits_per_image = logits_per_text.T
if not return_dict:
return (
logits_per_image,
logits_per_text,
text_embeds,
image_embeds,
text_outputs,
vision_outputs,
)
return FlaxCLIPOutput(
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
class FlaxHybridCLIP(FlaxPreTrainedModel):
config_class = HybridCLIPConfig
module_class = FlaxHybridCLIPModule
def __init__(
self,
config: HybridCLIPConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
if input_shape is None:
input_shape = (
(1, 1),
(
1,
config.vision_config.image_size,
config.vision_config.image_size,
3,
),
)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(
config, module, input_shape=input_shape, seed=seed, dtype=dtype
)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensor
input_ids = jnp.zeros(input_shape[0], dtype="i4")
position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]
)
token_type_ids = jnp.ones_like(input_ids)
attention_mask = jnp.ones_like(input_ids)
pixel_values = jax.random.normal(rng, input_shape[1])
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids
)["params"]
def __call__(
self,
input_ids,
pixel_values,
attention_mask=None,
position_ids=None,
token_type_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
)
if position_ids is None:
position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
)
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(pixel_values, dtype=jnp.float32),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
def get_text_features(
self,
input_ids,
attention_mask=None,
position_ids=None,
token_type_ids=None,
dropout_rng: jax.random.PRNGKey = None,
train=False,
):
r"""
Args:
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
`What are input IDs? <../glossary.html#input-ids>`__
Returns:
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
obtained by applying the projection layer to the pooled output of text model.
"""
if position_ids is None:
position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
)
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _get_features(
module,
input_ids,
attention_mask,
position_ids,
token_type_ids,
deterministic,
):
text_outputs = module.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
token_type_ids=token_type_ids,
deterministic=deterministic,
)
pooled_output = text_outputs[1]
text_features = module.text_projection(pooled_output)
return text_features
return self.module.apply(
{"params": self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
not train,
method=_get_features,
rngs=rngs,
)
def get_image_features(
self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False
):
r"""
Args:
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
using :class:`~transformers.ImageFeatureExtractionMixin`. See
:meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
Returns:
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
obtained by applying the projection layer to the pooled output of vision model.
"""
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _get_features(module, pixel_values, deterministic):
vision_outputs = module.vision_model(
pixel_values=pixel_values, deterministic=deterministic
)
pooled_output = vision_outputs[1] # pooled_output
image_features = module.visual_projection(pooled_output)
return image_features
return self.module.apply(
{"params": self.params},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
method=_get_features,
rngs=rngs,
)
@classmethod
def from_text_vision_pretrained(
cls,
text_model_name_or_path: str = None,
vision_model_name_or_path: str = None,
*model_args,
**kwargs,
) -> FlaxPreTrainedModel:
"""
Params:
text_model_name_or_path (:obj: `str`, `optional`):
Information necessary to initiate the text model. Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
a Flax model using the provided conversion scripts and loading the Flax model afterwards.
vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
Information necessary to initiate the vision model. Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
a Flax model using the provided conversion scripts and loading the Flax model afterwards.
model_args (remaining positional arguments, `optional`):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`).
- To update the text configuration, use the prefix `text_` for each configuration parameter.
- To update the vision configuration, use the prefix `vision_` for each configuration parameter.
- To update the parent model configuration, do not use a prefix for each configuration parameter.
Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
Example::
>>> from transformers import FlaxHybridCLIP
>>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
>>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
>>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
>>> # saving model after fine-tuning
>>> model.save_pretrained("./bert-clip")
>>> # load fine-tuned model
>>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
"""
kwargs_text = {
argument[len("text_") :]: value
for argument, value in kwargs.items()
if argument.startswith("text_")
}
kwargs_vision = {
argument[len("vision_") :]: value
for argument, value in kwargs.items()
if argument.startswith("vision_")
}
# remove text, vision kwargs from kwargs
for key in kwargs_text.keys():
del kwargs["text_" + key]
for key in kwargs_vision.keys():
del kwargs["vision_" + key]
# Load and initialize the text and vision model
text_model = kwargs_text.pop("model", None)
if text_model is None:
assert (
text_model_name_or_path is not None
), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
from transformers import FlaxAutoModel
if "config" not in kwargs_text:
from transformers import AutoConfig
text_config = AutoConfig.from_pretrained(text_model_name_or_path)
kwargs_text["config"] = text_config
text_model = FlaxAutoModel.from_pretrained(
text_model_name_or_path, *model_args, **kwargs_text
)
vision_model = kwargs_vision.pop("model", None)
if vision_model is None:
assert (
vision_model_name_or_path is not None
), "If `model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
from transformers import FlaxAutoModel
if "config" not in kwargs_vision:
from transformers import AutoConfig
vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
kwargs_vision["config"] = vision_config
vision_model = FlaxAutoModel.from_pretrained(
vision_model_name_or_path, *model_args, **kwargs_vision
)
# instantiate config with corresponding kwargs
dtype = kwargs.pop("dtype", jnp.float32)
config = HybridCLIPConfig.from_text_vision_configs(
text_model.config, vision_model.config, **kwargs
)
# init model
model = cls(config, *model_args, dtype=dtype, **kwargs)
if vision_config.model_type == "clip":
model.params["vision_model"]["vision_model"] = vision_model.params[
"vision_model"
]
model.params["visual_projection"]["kernel"] = vision_model.params[
"visual_projection"
]["kernel"]
else:
model.params["vision_model"] = vision_model.params
model.params["text_model"] = text_model.params
return model