DawnC commited on
Commit
c2d5142
·
1 Parent(s): 20b4434

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -10
app.py CHANGED
@@ -587,15 +587,55 @@ class BaseModel(nn.Module):
587
  attended_features = self.attention(features)
588
  logits = self.classifier(attended_features)
589
  return logits, attended_features
590
-
591
 
592
- model_yolo = YOLO('yolov8l.pt')
593
- num_classes = len(dog_breeds)
594
- model = BaseModel(num_classes=num_classes)
595
- model_path = '124_best_model_dog.pth'
596
- checkpoint = torch.load(model_path)
597
- model.load_state_dict(checkpoint['base_model'], strict=False)
598
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
 
600
 
601
  # Image preprocessing function
@@ -627,7 +667,7 @@ async def predict_single_dog(image):
627
 
628
  with torch.no_grad():
629
  # Get model outputs (只使用logits,不需要features)
630
- logits = model(image_tensor)[0] # 如果model仍返回tuple,取第一個元素
631
  probs = F.softmax(logits, dim=1)
632
 
633
  # Classifier prediction
@@ -649,7 +689,9 @@ async def predict_single_dog(image):
649
  @spaces.GPU
650
  async def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
651
 
652
- results = model_yolo(image, conf=conf_threshold, iou=iou_threshold)[0]
 
 
653
  dogs = []
654
  boxes = []
655
  for box in results.boxes:
 
587
  attended_features = self.attention(features)
588
  logits = self.classifier(attended_features)
589
  return logits, attended_features
 
590
 
591
+
592
+ class ModelManager:
593
+ """
594
+ 模型管理器:負責AI模型的初始化和管理
595
+ 使用單例模式確保只有一個實例在管理所有模型
596
+ """
597
+ _instance = None
598
+ _initialized = False
599
+ _yolo_model = None
600
+ _breed_model = None
601
+
602
+ def __new__(cls):
603
+ if cls._instance is None:
604
+ cls._instance = super().__new__(cls)
605
+ return cls._instance
606
+
607
+ def __init__(self):
608
+ # 避免重複初始化
609
+ if not ModelManager._initialized:
610
+ ModelManager._initialized = True
611
+
612
+ @property
613
+ def yolo_model(self):
614
+ """
615
+ 延遲初始化YOLO模型
616
+ 只有在第一次使用時才會創建實例
617
+ """
618
+ if self._yolo_model is None:
619
+ self._yolo_model = YOLO('yolov8l.pt')
620
+ return self._yolo_model
621
+
622
+ @property
623
+ def breed_model(self):
624
+ """
625
+ 延遲初始化品種分類模型
626
+ 只有在第一次使用時才會創建實例
627
+ """
628
+ if self._breed_model is None:
629
+ self._breed_model = BaseModel(num_classes=len(dog_breeds),
630
+ device=device).to(device)
631
+ checkpoint = torch.load('124_best_model_dog.pth',
632
+ map_location=device)
633
+ self._breed_model.load_state_dict(checkpoint['base_model'],
634
+ strict=False)
635
+ self._breed_model.eval()
636
+ return self._breed_model
637
+
638
+ model_manager = ModelManager()
639
 
640
 
641
  # Image preprocessing function
 
667
 
668
  with torch.no_grad():
669
  # Get model outputs (只使用logits,不需要features)
670
+ logits = model_manager.breed_model(image_tensor)[0] # 如果model仍返回tuple,取第一個元素
671
  probs = F.softmax(logits, dim=1)
672
 
673
  # Classifier prediction
 
689
  @spaces.GPU
690
  async def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
691
 
692
+ results = model_manager.yolo_model(image, conf=conf_threshold,
693
+ iou=iou_threshold)[0]
694
+
695
  dogs = []
696
  boxes = []
697
  for box in results.boxes: