clip-vit-l-224-patch14-datacomp-image-classification / clip_for_image_classification.py
Thouph's picture
Upload 3 files
60c6529
from transformers import CLIPVisionConfig, FlaxCLIPVisionPreTrainedModel
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
import jax.numpy as jnp
from flax import linen as nn
import jax
from transformers.modeling_flax_outputs import FlaxSequenceClassifierOutput
class FlaxCLIPForImageClassificationModule(nn.Module):
config: CLIPVisionConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.vit = FlaxCLIPVisionModule(config=self.config, dtype=self.dtype)
self.classifier = nn.Dense(
self.config.num_labels,
dtype=self.dtype,
kernel_init=jax.nn.initializers.variance_scaling(
self.config.initializer_range ** 2, "fan_in", "truncated_normal"
),
)
def __call__(
self,
pixel_values=None,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vit(
pixel_values,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.classifier(hidden_states[:, 0, :])
if not return_dict:
output = (logits,) + outputs[2:]
return output
return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class FlaxCLIPForImageClassification(FlaxCLIPVisionPreTrainedModel):
module_class = FlaxCLIPForImageClassificationModule