from dataclasses import dataclass import torch import torch.nn as nn from transformers import SiglipVisionModel, SiglipPreTrainedModel, SiglipVisionConfig from transformers.utils import ModelOutput @dataclass class SiglipForImageClassifierOutput(ModelOutput): loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None pooler_output: torch.FloatTensor | None = None hidden_states: tuple[torch.FloatTensor, ...] | None = None attentions: tuple[torch.FloatTensor, ...] | None = None class SiglipForImageClassification(SiglipPreTrainedModel): config_class = SiglipVisionConfig main_input_name = "pixel_values" def __init__( self, config, ): super().__init__(config) self.num_labels = config.num_labels self.siglip = SiglipVisionModel(config) # Classifier head self.classifier = ( nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() ) # Initialize weights and apply final processing self.post_init() def forward( self, pixel_values: torch.FloatTensor, labels: torch.LongTensor | None = None ): outputs = self.siglip(pixel_values) pooler_output = outputs.pooler_output logits = self.classifier(pooler_output) loss = None return SiglipForImageClassifierOutput( loss=loss, logits=logits, pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )