DawnC commited on
Commit
45a406b
·
1 Parent(s): 2721825

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -13
app.py CHANGED
@@ -625,30 +625,21 @@ def preprocess_image(image):
625
 
626
  @device_handler
627
  async def predict_single_dog(image):
628
- """
629
- Predicts the dog breed using only the classifier.
630
- Args:
631
- image: PIL Image or numpy array
632
- Returns:
633
- tuple: (top1_prob, topk_breeds, relative_probs)
634
- """
635
  image_tensor = preprocess_image(image)
636
 
637
  with torch.no_grad():
638
- # Get model outputs (只使用logits,不需要features)
639
- logits = model(image_tensor)[0] # 如果model仍返回tuple,取第一個元素
640
  probs = F.softmax(logits, dim=1)
641
 
642
- # Classifier prediction
643
  top5_prob, top5_idx = torch.topk(probs, k=5)
644
  breeds = [dog_breeds[idx.item()] for idx in top5_idx[0]]
645
  probabilities = [prob.item() for prob in top5_prob[0]]
646
 
647
- # Calculate relative probabilities
648
- sum_probs = sum(probabilities[:3]) # 只取前三個來計算相對概率
649
  relative_probs = [f"{(prob/sum_probs * 100):.2f}%" for prob in probabilities[:3]]
650
 
651
- # Debug output
652
  print("\nClassifier Predictions:")
653
  for breed, prob in zip(breeds[:5], probabilities[:5]):
654
  print(f"{breed}: {prob:.4f}")
 
625
 
626
  @device_handler
627
  async def predict_single_dog(image):
 
 
 
 
 
 
 
628
  image_tensor = preprocess_image(image)
629
 
630
  with torch.no_grad():
631
+ outputs = model(image_tensor)
632
+ logits = outputs[0] if isinstance(outputs, tuple) else outputs
633
  probs = F.softmax(logits, dim=1)
634
 
635
+ # 其餘代碼保持不變
636
  top5_prob, top5_idx = torch.topk(probs, k=5)
637
  breeds = [dog_breeds[idx.item()] for idx in top5_idx[0]]
638
  probabilities = [prob.item() for prob in top5_prob[0]]
639
 
640
+ sum_probs = sum(probabilities[:3])
 
641
  relative_probs = [f"{(prob/sum_probs * 100):.2f}%" for prob in probabilities[:3]]
642
 
 
643
  print("\nClassifier Predictions:")
644
  for breed, prob in zip(breeds[:5], probabilities[:5]):
645
  print(f"{breed}: {prob:.4f}")