Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
|
|
6 |
import time
|
7 |
import traceback
|
8 |
import spaces
|
9 |
-
from torchvision.models import
|
10 |
from torchvision.ops import nms, box_iou
|
11 |
import torch.nn.functional as F
|
12 |
from torchvision import transforms
|
@@ -98,29 +98,61 @@ class MultiHeadAttention(nn.Module):
|
|
98 |
return out
|
99 |
|
100 |
class BaseModel(nn.Module):
|
|
|
101 |
def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
102 |
super().__init__()
|
103 |
self.device = device
|
104 |
-
self.backbone = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
|
105 |
-
self.feature_dim = self.backbone.classifier[1].in_features
|
106 |
-
self.backbone.classifier = nn.Identity()
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
self.num_heads = max(1, min(8, self.feature_dim // 64))
|
109 |
self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
|
110 |
|
|
|
111 |
self.classifier = nn.Sequential(
|
112 |
nn.LayerNorm(self.feature_dim),
|
113 |
nn.Dropout(0.3),
|
114 |
nn.Linear(self.feature_dim, num_classes)
|
115 |
)
|
116 |
|
117 |
-
self.to(device)
|
118 |
-
|
119 |
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
x = x.to(self.device)
|
|
|
|
|
121 |
features = self.backbone(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
attended_features = self.attention(features)
|
|
|
|
|
123 |
logits = self.classifier(attended_features)
|
|
|
124 |
return logits, attended_features
|
125 |
|
126 |
|
@@ -179,7 +211,7 @@ class ModelManager:
|
|
179 |
).to(self.device)
|
180 |
|
181 |
checkpoint = torch.load(
|
182 |
-
'
|
183 |
map_location=self.device # 確保checkpoint加載到正確的設備
|
184 |
)
|
185 |
self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)
|
|
|
6 |
import time
|
7 |
import traceback
|
8 |
import spaces
|
9 |
+
from torchvision.models import convnext_base, ConvNeXt_Base_Weights
|
10 |
from torchvision.ops import nms, box_iou
|
11 |
import torch.nn.functional as F
|
12 |
from torchvision import transforms
|
|
|
98 |
return out
|
99 |
|
100 |
class BaseModel(nn.Module):
|
101 |
+
|
102 |
def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
103 |
super().__init__()
|
104 |
self.device = device
|
|
|
|
|
|
|
105 |
|
106 |
+
# 1. 初始化 backbone
|
107 |
+
self.backbone = convnext_base(weights=ConvNeXt_Base_Weights.IMAGENET1K_V1)
|
108 |
+
self.backbone.classifier = nn.Identity() # 移除原始分類器
|
109 |
+
|
110 |
+
# 2. 使用測試數據確定實際的特徵維度
|
111 |
+
with torch.no_grad(): # 不需要計算梯度
|
112 |
+
dummy_input = torch.randn(1, 3, 224, 224) # 創建示例輸入
|
113 |
+
features = self.backbone(dummy_input)
|
114 |
+
if len(features.shape) > 2: # 如果特徵是多維的
|
115 |
+
features = features.mean([-2, -1]) # 進行全局平均池化
|
116 |
+
self.feature_dim = features.shape[1] # 獲取正確的特徵維度
|
117 |
+
|
118 |
+
print(f"Feature Dim: {self.feature_dim}") # 幫助調試
|
119 |
+
|
120 |
+
# 3. 設置多頭注意力層
|
121 |
self.num_heads = max(1, min(8, self.feature_dim // 64))
|
122 |
self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
|
123 |
|
124 |
+
# 4. 設置分類器
|
125 |
self.classifier = nn.Sequential(
|
126 |
nn.LayerNorm(self.feature_dim),
|
127 |
nn.Dropout(0.3),
|
128 |
nn.Linear(self.feature_dim, num_classes)
|
129 |
)
|
130 |
|
|
|
|
|
131 |
def forward(self, x):
|
132 |
+
"""
|
133 |
+
模型的前向傳播過程
|
134 |
+
Args:
|
135 |
+
x (Tensor): 輸入圖像張量,形狀為 [batch_size, channels, height, width]
|
136 |
+
Returns:
|
137 |
+
Tuple[Tensor, Tensor]: 分類邏輯值和注意力特徵
|
138 |
+
"""
|
139 |
x = x.to(self.device)
|
140 |
+
|
141 |
+
# 1. 提取基礎特徵
|
142 |
features = self.backbone(x)
|
143 |
+
|
144 |
+
# 2. 處理特徵維度
|
145 |
+
if len(features.shape) > 2:
|
146 |
+
# 如果特徵維度是 [batch_size, channels, height, width]
|
147 |
+
# 轉換為 [batch_size, channels]
|
148 |
+
features = features.mean([-2, -1]) # 使用全局平均池化
|
149 |
+
|
150 |
+
# 3. 應用注意力機制
|
151 |
attended_features = self.attention(features)
|
152 |
+
|
153 |
+
# 4. 最終分類
|
154 |
logits = self.classifier(attended_features)
|
155 |
+
|
156 |
return logits, attended_features
|
157 |
|
158 |
|
|
|
211 |
).to(self.device)
|
212 |
|
213 |
checkpoint = torch.load(
|
214 |
+
'ConvNextBase_best_model_dog.pth',
|
215 |
map_location=self.device # 確保checkpoint加載到正確的設備
|
216 |
)
|
217 |
self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)
|