from transformers import CLIPVisionConfig, FlaxCLIPVisionPreTrainedModel from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule import jax.numpy as jnp from flax import linen as nn import jax from transformers.modeling_flax_outputs import FlaxSequenceClassifierOutput class FlaxCLIPForImageClassificationModule(nn.Module): config: CLIPVisionConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.vit = FlaxCLIPVisionModule(config=self.config, dtype=self.dtype) self.classifier = nn.Dense( self.config.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.variance_scaling( self.config.initializer_range ** 2, "fan_in", "truncated_normal" ), ) def __call__( self, pixel_values=None, deterministic: bool = True, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.vit( pixel_values, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.classifier(hidden_states[:, 0, :]) if not return_dict: output = (logits,) + outputs[2:] return output return FlaxSequenceClassifierOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FlaxCLIPForImageClassification(FlaxCLIPVisionPreTrainedModel): module_class = FlaxCLIPForImageClassificationModule