Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -591,13 +591,14 @@ class BaseModel(nn.Module):
|
|
591 |
|
592 |
class ModelManager:
|
593 |
"""
|
594 |
-
模型管理器:負責AI
|
595 |
-
|
596 |
"""
|
597 |
_instance = None
|
598 |
_initialized = False
|
599 |
_yolo_model = None
|
600 |
_breed_model = None
|
|
|
601 |
|
602 |
def __new__(cls):
|
603 |
if cls._instance is None:
|
@@ -607,8 +608,20 @@ class ModelManager:
|
|
607 |
def __init__(self):
|
608 |
# 避免重複初始化
|
609 |
if not ModelManager._initialized:
|
|
|
|
|
610 |
ModelManager._initialized = True
|
611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
612 |
@property
|
613 |
def yolo_model(self):
|
614 |
"""
|
@@ -623,18 +636,23 @@ class ModelManager:
|
|
623 |
def breed_model(self):
|
624 |
"""
|
625 |
延遲初始化品種分類模型
|
626 |
-
|
627 |
"""
|
628 |
if self._breed_model is None:
|
629 |
-
self._breed_model = BaseModel(
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
|
|
|
|
|
|
|
|
635 |
self._breed_model.eval()
|
636 |
return self._breed_model
|
637 |
|
|
|
638 |
model_manager = ModelManager()
|
639 |
|
640 |
|
@@ -663,7 +681,7 @@ def predict_single_dog(image):
|
|
663 |
tuple: (top1_prob, topk_breeds, relative_probs)
|
664 |
"""
|
665 |
|
666 |
-
image_tensor = preprocess_image(image).to(device)
|
667 |
|
668 |
with torch.no_grad():
|
669 |
# Get model outputs (只使用logits,不需要features)
|
|
|
591 |
|
592 |
class ModelManager:
|
593 |
"""
|
594 |
+
模型管理器:負責AI模型的初始化、設備管理和資源控制
|
595 |
+
使用單例模式確保整個應用程序中只有一個實例
|
596 |
"""
|
597 |
_instance = None
|
598 |
_initialized = False
|
599 |
_yolo_model = None
|
600 |
_breed_model = None
|
601 |
+
_device = None
|
602 |
|
603 |
def __new__(cls):
|
604 |
if cls._instance is None:
|
|
|
608 |
def __init__(self):
|
609 |
# 避免重複初始化
|
610 |
if not ModelManager._initialized:
|
611 |
+
# 初始化設備,這會在第一次創建實例時執行
|
612 |
+
self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
613 |
ModelManager._initialized = True
|
614 |
|
615 |
+
@property
|
616 |
+
def device(self):
|
617 |
+
"""
|
618 |
+
提供對設備的訪問
|
619 |
+
確保在需要時設備已經被初始化
|
620 |
+
"""
|
621 |
+
if self._device is None:
|
622 |
+
self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
623 |
+
return self._device
|
624 |
+
|
625 |
@property
|
626 |
def yolo_model(self):
|
627 |
"""
|
|
|
636 |
def breed_model(self):
|
637 |
"""
|
638 |
延遲初始化品種分類模型
|
639 |
+
只有在第一次使用時才會創建實例並移動到正確的設備上
|
640 |
"""
|
641 |
if self._breed_model is None:
|
642 |
+
self._breed_model = BaseModel(
|
643 |
+
num_classes=len(dog_breeds),
|
644 |
+
device=self.device # 使用我們的device屬性
|
645 |
+
).to(self.device)
|
646 |
+
|
647 |
+
checkpoint = torch.load(
|
648 |
+
'124_best_model_dog.pth',
|
649 |
+
map_location=self.device # 確保checkpoint加載到正確的設備
|
650 |
+
)
|
651 |
+
self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)
|
652 |
self._breed_model.eval()
|
653 |
return self._breed_model
|
654 |
|
655 |
+
|
656 |
model_manager = ModelManager()
|
657 |
|
658 |
|
|
|
681 |
tuple: (top1_prob, topk_breeds, relative_probs)
|
682 |
"""
|
683 |
|
684 |
+
image_tensor = preprocess_image(image).to(model_manager.device)
|
685 |
|
686 |
with torch.no_grad():
|
687 |
# Get model outputs (只使用logits,不需要features)
|