amaye15 commited on
Commit
c8eff26
·
verified ·
1 Parent(s): 29d8c52

Update modeling_aimv2.py

Browse files
Files changed (1) hide show
  1. modeling_aimv2.py +62 -1
modeling_aimv2.py CHANGED
@@ -222,7 +222,7 @@ class AIMv2Model(AIMv2PretrainedModel):
222
  hidden_states=hidden_states,
223
  )
224
 
225
-
226
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
227
  def __init__(self, config: AIMv2Config):
228
  super().__init__(config)
@@ -306,3 +306,64 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
306
  hidden_states=outputs.hidden_states,
307
  # attentions=outputs.attentions,
308
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  hidden_states=hidden_states,
223
  )
224
 
225
+ '''
226
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
227
  def __init__(self, config: AIMv2Config):
228
  super().__init__(config)
 
306
  hidden_states=outputs.hidden_states,
307
  # attentions=outputs.attentions,
308
  )
309
+ '''
310
+
311
+
312
+ class AIMv2ForImageClassification(AIMv2PretrainedModel):
313
+ def __init__(self, config: AIMv2Config):
314
+ super().__init__(config)
315
+
316
+ self.num_labels = config.num_labels
317
+ self.aimv2 = AIMv2Model(config)
318
+
319
+ # Classifier head
320
+ self.classifier = (
321
+ nn.Linear(config.hidden_size, config.num_labels)
322
+ if config.num_labels > 0
323
+ else nn.Identity()
324
+ )
325
+
326
+ # Initialize weights and apply final processing
327
+ self.post_init()
328
+
329
+ def forward(
330
+ self,
331
+ pixel_values: Optional[torch.Tensor] = None,
332
+ head_mask: Optional[torch.Tensor] = None,
333
+ labels: Optional[torch.Tensor] = None,
334
+ output_hidden_states: Optional[bool] = None,
335
+ return_dict: Optional[bool] = None,
336
+ ) -> Union[tuple, ImageClassifierOutput]:
337
+
338
+ return_dict = (
339
+ return_dict if return_dict is not None else self.config.use_return_dict
340
+ )
341
+
342
+ outputs = self.aimv2(
343
+ pixel_values,
344
+ mask=head_mask,
345
+ output_hidden_states=output_hidden_states,
346
+ return_dict=return_dict,
347
+ )
348
+
349
+ sequence_output = outputs[0]
350
+
351
+ logits = self.classifier(sequence_output[:, 0, :])
352
+
353
+ loss = None
354
+ if labels is not None:
355
+ labels = labels.to(logits.device)
356
+
357
+ # Always use cross-entropy loss
358
+ loss_fct = CrossEntropyLoss()
359
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
360
+
361
+ if not return_dict:
362
+ output = (logits,) + outputs[1:]
363
+ return ((loss,) + output) if loss is not None else output
364
+
365
+ return ImageClassifierOutput(
366
+ loss=loss,
367
+ logits=logits,
368
+ hidden_states=outputs.hidden_states,
369
+ )