amaye15 commited on
Commit
1ce0391
·
verified ·
1 Parent(s): cdfccc4

Update modeling_aimv2.py

Browse files
Files changed (1) hide show
  1. modeling_aimv2.py +21 -0
modeling_aimv2.py CHANGED
@@ -315,10 +315,14 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
315
 
316
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
317
  def __init__(self, config: AIMv2Config):
 
318
  super().__init__(config)
319
 
320
  self.num_labels = config.num_labels
 
 
321
  self.aimv2 = AIMv2Model(config)
 
322
 
323
  # Classifier head
324
  self.classifier = (
@@ -326,9 +330,11 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
326
  if config.num_labels > 0
327
  else nn.Identity()
328
  )
 
329
 
330
  # Initialize weights and apply final processing
331
  self.post_init()
 
332
 
333
  def forward(
334
  self,
@@ -338,33 +344,48 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
338
  output_hidden_states: Optional[bool] = None,
339
  return_dict: Optional[bool] = None,
340
  ) -> Union[tuple, ImageClassifierOutput]:
 
341
 
342
  return_dict = (
343
  return_dict if return_dict is not None else self.config.use_return_dict
344
  )
 
345
 
346
  # Call base model
 
347
  outputs = self.aimv2(
348
  pixel_values,
349
  mask=head_mask,
350
  output_hidden_states=output_hidden_states,
351
  return_dict=return_dict,
352
  )
 
 
353
  sequence_output = outputs[0]
 
 
354
  # Classifier head
355
  logits = self.classifier(sequence_output[:, 0, :])
 
356
 
357
  loss = None
358
  if labels is not None:
 
 
359
  labels = labels.to(logits.device)
 
 
360
  # Always use cross-entropy loss
361
  loss_fct = CrossEntropyLoss()
362
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
 
363
 
364
  if not return_dict:
365
  output = (logits,) + outputs[1:]
 
366
  return ((loss,) + output) if loss is not None else output
367
 
 
368
  return ImageClassifierOutput(
369
  loss=loss,
370
  logits=logits,
 
315
 
316
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
317
  def __init__(self, config: AIMv2Config):
318
+ print("Initializing AIMv2ForImageClassification")
319
  super().__init__(config)
320
 
321
  self.num_labels = config.num_labels
322
+ print(f"Number of labels: {self.num_labels}")
323
+
324
  self.aimv2 = AIMv2Model(config)
325
+ print("Initialized AIMv2 base model")
326
 
327
  # Classifier head
328
  self.classifier = (
 
330
  if config.num_labels > 0
331
  else nn.Identity()
332
  )
333
+ print(f"Initialized classifier: {self.classifier}")
334
 
335
  # Initialize weights and apply final processing
336
  self.post_init()
337
+ print("Weights initialized and final processing applied")
338
 
339
  def forward(
340
  self,
 
344
  output_hidden_states: Optional[bool] = None,
345
  return_dict: Optional[bool] = None,
346
  ) -> Union[tuple, ImageClassifierOutput]:
347
+ print("Forward pass started")
348
 
349
  return_dict = (
350
  return_dict if return_dict is not None else self.config.use_return_dict
351
  )
352
+ print(f"return_dict: {return_dict}")
353
 
354
  # Call base model
355
+ print("Calling AIMv2 base model")
356
  outputs = self.aimv2(
357
  pixel_values,
358
  mask=head_mask,
359
  output_hidden_states=output_hidden_states,
360
  return_dict=return_dict,
361
  )
362
+ print(f"AIMv2 outputs received: {outputs}")
363
+
364
  sequence_output = outputs[0]
365
+ print(f"Shape of sequence_output: {sequence_output.shape}")
366
+
367
  # Classifier head
368
  logits = self.classifier(sequence_output[:, 0, :])
369
+ print(f"Logits calculated: {logits.shape}")
370
 
371
  loss = None
372
  if labels is not None:
373
+ print(labels)
374
+ print(f"Labels provided: {labels.shape}")
375
  labels = labels.to(logits.device)
376
+ print(f"Labels moved to device: {labels.device}")
377
+
378
  # Always use cross-entropy loss
379
  loss_fct = CrossEntropyLoss()
380
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
381
+ print(f"Loss calculated: {loss.item()}")
382
 
383
  if not return_dict:
384
  output = (logits,) + outputs[1:]
385
+ print("Output without return_dict")
386
  return ((loss,) + output) if loss is not None else output
387
 
388
+ print("Returning ImageClassifierOutput")
389
  return ImageClassifierOutput(
390
  loss=loss,
391
  logits=logits,