liuyizhang
add transformers_4_35_0
1ce5e18
# coding=utf-8
# Copyright 2021 The OpenAI Team Authors, The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple, Union
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import ModelOutput, add_start_docstrings, logging
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
logger = logging.get_logger(__name__)
CLIP_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
CLIP_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
CLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
CLIP_INPUTS_DOCSTRING = r"""
Args:
input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@flax.struct.dataclass
class FlaxCLIPTextModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`jnp.ndarray` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of
[`FlaxCLIPTextModel`].
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
text_embeds: jnp.ndarray = None
last_hidden_state: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxCLIPOutput(ModelOutput):
"""
Args:
logits_per_image:(`jnp.ndarray` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text:(`jnp.ndarray` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds(`jnp.ndarray` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of
[`FlaxCLIPTextModel`].
image_embeds(`jnp.ndarray` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of
[`FlaxCLIPVisionModel`].
text_model_output(`FlaxBaseModelOutputWithPooling`):
The output of the [`FlaxCLIPTextModel`].
vision_model_output(`FlaxBaseModelOutputWithPooling`):
The output of the [`FlaxCLIPVisionModel`].
"""
logits_per_image: jnp.ndarray = None
logits_per_text: jnp.ndarray = None
text_embeds: jnp.ndarray = None
image_embeds: jnp.ndarray = None
text_model_output: FlaxBaseModelOutputWithPooling = None
vision_model_output: FlaxBaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
class FlaxCLIPVisionEmbeddings(nn.Module):
config: CLIPVisionConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
embed_dim = self.config.hidden_size
image_size = self.config.image_size
patch_size = self.config.patch_size
self.class_embedding = self.param("class_embedding", jax.nn.initializers.normal(stddev=0.02), (embed_dim,))
self.patch_embedding = nn.Conv(
embed_dim,
kernel_size=(patch_size, patch_size),
strides=(patch_size, patch_size),
padding="VALID",
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(),
)
self.num_patches = (image_size // patch_size) ** 2
num_positions = self.num_patches + 1
self.position_embedding = nn.Embed(num_positions, embed_dim, embedding_init=jax.nn.initializers.normal())
self.position_ids = jnp.expand_dims(jnp.arange(0, num_positions, dtype="i4"), axis=0)
def __call__(self, pixel_values):
patch_embeds = self.patch_embedding(pixel_values)
batch_size, height, width, channels = patch_embeds.shape
patch_embeds = jnp.reshape(patch_embeds, (batch_size, height * width, channels))
class_embeds = jnp.expand_dims(self.class_embedding, axis=(0, 1))
class_embeds = jnp.tile(class_embeds, (batch_size, 1, 1))
embeddings = jnp.concatenate([class_embeds, patch_embeds], axis=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class FlaxCLIPTextEmbeddings(nn.Module):
config: CLIPTextConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
embed_dim = self.config.hidden_size
self.token_embedding = nn.Embed(self.config.vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal())
self.position_embedding = nn.Embed(
self.config.max_position_embeddings, embed_dim, embedding_init=jax.nn.initializers.normal()
)
self.position_ids = jnp.expand_dims(
jnp.arange(0, self.config.max_position_embeddings, dtype="i4"), axis=(0, 1)
)
def __call__(self, input_ids, position_ids):
input_embeds = self.token_embedding(input_ids.astype("i4"))
position_embeds = self.position_embedding(position_ids.astype("i4"))
embeddings = input_embeds + position_embeds
return embeddings
class FlaxCLIPAttention(nn.Module):
config: Union[CLIPTextConfig, CLIPVisionConfig]
dtype: jnp.dtype = jnp.float32
def setup(self):
self.embed_dim = self.config.hidden_size
self.num_heads = self.config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = self.config.attention_dropout
self.k_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
self.v_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
self.q_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
self.out_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
self.causal = isinstance(self.config, CLIPTextConfig)
if self.causal:
self.causal_mask = make_causal_mask(jnp.ones((1, self.config.max_position_embeddings), dtype="i4"))
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
def __call__(
self,
hidden_states,
attention_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
):
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = self._split_heads(query)
key = self._split_heads(key)
value = self._split_heads(value)
causal_attention_mask = None
if self.causal:
query_length, key_length = query.shape[1], key.shape[1]
causal_attention_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length]
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_mask = combine_masks(attention_mask, causal_attention_mask, dtype="i4")
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
if attention_mask is not None:
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query,
key,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
attn_output = self._merge_heads(attn_output)
attn_output = self.out_proj(attn_output)
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
class FlaxCLIPMLP(nn.Module):
config: Union[CLIPTextConfig, CLIPVisionConfig]
dtype: jnp.dtype = jnp.float32
def setup(self):
self.activation_fn = ACT2FN[self.config.hidden_act]
self.fc1 = nn.Dense(
self.config.intermediate_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.01),
)
self.fc2 = nn.Dense(self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
def __call__(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class FlaxCLIPEncoderLayer(nn.Module):
config: Union[CLIPTextConfig, CLIPVisionConfig]
dtype: jnp.dtype = jnp.float32
def setup(self):
self.self_attn = FlaxCLIPAttention(self.config, dtype=self.dtype)
self.layer_norm1 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.mlp = FlaxCLIPMLP(self.config, dtype=self.dtype)
self.layer_norm2 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
attn_outputs = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
)
hidden_states = attn_outputs[0]
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += attn_outputs[1:]
return outputs
class FlaxCLIPLayerCollection(nn.Module):
config: Union[CLIPTextConfig, CLIPVisionConfig]
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = [
FlaxCLIPEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
attention_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class FlaxCLIPEncoder(nn.Module):
config: Union[CLIPTextConfig, CLIPVisionConfig]
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = FlaxCLIPLayerCollection(self.config, dtype=self.dtype)
def __call__(
self,
inputs_embeds,
attention_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.layers(
hidden_states=inputs_embeds,
attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class FlaxCLIPTextTransformer(nn.Module):
config: CLIPTextConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.embeddings = FlaxCLIPTextEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
# For `pooled_output` computation
self.eos_token_id = self.config.eos_token_id
def __call__(
self,
input_ids,
attention_mask,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the EOS embedding (eos_token_id is the highest number in each sequence)
pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)]
else:
# (no need to cast from bool to int after comparing to `eos_token_id`)
pooled_output = last_hidden_state[
jnp.arange(last_hidden_state.shape[0]), (input_ids == self.eos_token_id).argmax(axis=-1)
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return FlaxBaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class FlaxCLIPVisionTransformer(nn.Module):
config: CLIPVisionConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.embeddings = FlaxCLIPVisionEmbeddings(self.config, dtype=self.dtype)
self.pre_layrnorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype)
self.post_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(
self,
pixel_values=None,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
return_dict: bool = True,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return FlaxBaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
config_class = CLIPTextConfig
module_class: nn.Module = None
def __init__(
self,
config: CLIPTextConfig,
input_shape=(1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
input_ids = jnp.zeros(input_shape, dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
attention_mask = jnp.ones_like(input_ids)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
module_class: nn.Module = None
def __init__(
self,
config: CLIPVisionConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
pixel_values = jax.random.normal(rng, input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(rngs, pixel_values)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__(
self,
pixel_values,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
config_class = CLIPConfig
module_class: nn.Module = None
def __init__(
self,
config: CLIPConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
if input_shape is None:
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
input_ids = jnp.zeros(input_shape[0], dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
attention_mask = jnp.ones_like(input_ids)
pixel_values = jax.random.normal(rng, input_shape[1])
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def __call__(
self,
input_ids,
pixel_values,
attention_mask=None,
position_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(pixel_values, dtype=jnp.float32),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
def get_text_features(
self,
input_ids,
attention_mask=None,
position_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train=False,
):
r"""
Args:
input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
Returns:
text_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The text embeddings obtained by applying
the projection layer to the pooled output of [`FlaxCLIPTextModel`].
Examples:
```python
>>> from transformers import AutoTokenizer, FlaxCLIPModel
>>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")
>>> text_features = model.get_text_features(**inputs)
```"""
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _get_features(module, input_ids, attention_mask, position_ids, deterministic):
text_outputs = module.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
)
pooled_output = text_outputs[1]
text_features = module.text_projection(pooled_output)
return text_features
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
method=_get_features,
rngs=rngs,
)
def get_image_features(
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
):
r"""
Args:
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
Returns:
image_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`FlaxCLIPVisionModel`]
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, FlaxCLIPModel
>>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="np")
>>> image_features = model.get_image_features(**inputs)
```"""
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _get_features(module, pixel_values, deterministic):
vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
pooled_output = vision_outputs[1] # pooled_output
image_features = module.visual_projection(pooled_output)
return image_features
return self.module.apply(
{"params": params or self.params},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
method=_get_features,
rngs=rngs,
)
class FlaxCLIPTextModule(nn.Module):
config: CLIPTextConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class FlaxCLIPTextModel(FlaxCLIPTextPreTrainedModel):
module_class = FlaxCLIPTextModule
FLAX_CLIP_TEXT_MODEL_DOCSTRING = """
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, FlaxCLIPTextModel
>>> model = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooler_output = outputs.pooler_output # pooled (EOS token) states
```
"""
overwrite_call_docstring(FlaxCLIPTextModel, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_DOCSTRING)
append_replace_return_docstrings(
FlaxCLIPTextModel, output_type=FlaxBaseModelOutputWithPooling, config_class=CLIPTextConfig
)
class FlaxCLIPTextModelWithProjectionModule(nn.Module):
config: CLIPTextConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype)
self.text_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
text_embeds = self.text_projection(pooled_output)
if not return_dict:
return (text_embeds, text_outputs[0]) + text_outputs[2:]
return FlaxCLIPTextModelOutput(
text_embeds=text_embeds,
last_hidden_state=text_outputs.last_hidden_state,
hidden_states=text_outputs.hidden_states,
attentions=text_outputs.attentions,
)
class FlaxCLIPTextModelWithProjection(FlaxCLIPTextPreTrainedModel):
module_class = FlaxCLIPTextModelWithProjectionModule
FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING = """
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, FlaxCLIPTextModelWithProjection
>>> model = FlaxCLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")
>>> outputs = model(**inputs)
>>> text_embeds = outputs.text_embeds
```
"""
overwrite_call_docstring(
FlaxCLIPTextModelWithProjection, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRING
)
append_replace_return_docstrings(
FlaxCLIPTextModelWithProjection, output_type=FlaxCLIPTextModelOutput, config_class=CLIPTextConfig
)
class FlaxCLIPVisionModule(nn.Module):
config: CLIPVisionConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.vision_model = FlaxCLIPVisionTransformer(self.config, dtype=self.dtype)
def __call__(
self,
pixel_values,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.vision_model(
pixel_values=pixel_values,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class FlaxCLIPVisionModel(FlaxCLIPVisionPreTrainedModel):
module_class = FlaxCLIPVisionModule
FLAX_CLIP_VISION_MODEL_DOCSTRING = """
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, FlaxCLIPVisionModel
>>> model = FlaxCLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooler_output = outputs.pooler_output # pooled CLS states
```
"""
overwrite_call_docstring(FlaxCLIPVisionModel, CLIP_VISION_INPUTS_DOCSTRING + FLAX_CLIP_VISION_MODEL_DOCSTRING)
append_replace_return_docstrings(
FlaxCLIPVisionModel, output_type=FlaxBaseModelOutputWithPooling, config_class=CLIPVisionConfig
)
class FlaxCLIPModule(nn.Module):
config: CLIPConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
text_config = self.config.text_config
vision_config = self.config.vision_config
self.projection_dim = self.config.projection_dim
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
self.text_model = FlaxCLIPTextTransformer(text_config, dtype=self.dtype)
self.vision_model = FlaxCLIPVisionTransformer(vision_config, dtype=self.dtype)
self.visual_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.text_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.logit_scale = self.param(
"logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, []
)
def __call__(
self,
input_ids=None,
pixel_values=None,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
# cosine similarity as logits
logit_scale = jnp.exp(self.logit_scale)
logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
logits_per_image = logits_per_text.T
if not return_dict:
return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return FlaxCLIPOutput(
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
@add_start_docstrings(CLIP_START_DOCSTRING)
class FlaxCLIPModel(FlaxCLIPPreTrainedModel):
module_class = FlaxCLIPModule
FLAX_CLIP_MODEL_DOCSTRING = """
Returns:
Example:
```python
>>> import jax
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, FlaxCLIPModel
>>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="np", padding=True
... )
>>> outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
>>> probs = jax.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities
```
"""
overwrite_call_docstring(FlaxCLIPModel, CLIP_INPUTS_DOCSTRING + FLAX_CLIP_MODEL_DOCSTRING)
append_replace_return_docstrings(FlaxCLIPModel, output_type=FlaxCLIPOutput, config_class=CLIPConfig)