Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -43,7 +43,7 @@ dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staff
|
|
43 |
"Miniature_Poodle", "Miniature_Schnauzer", "Otterhound", "Papillon", "Pug", "Redbone",
|
44 |
"Schipperke", "Silky_Terrier", "Soft-Coated_Wheaten_Terrier", "Standard_Poodle",
|
45 |
"Standard_Schnauzer", "Toy_Poodle", "Toy_Terrier", "Vizsla", "Whippet",
|
46 |
-
"Wire-Haired_Fox_Terrier"
|
47 |
|
48 |
class MultiHeadAttention(nn.Module):
|
49 |
|
@@ -100,11 +100,11 @@ class BaseModel(nn.Module):
|
|
100 |
return logits, attended_features
|
101 |
|
102 |
|
103 |
-
num_classes =
|
104 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
105 |
model = BaseModel(num_classes=num_classes, device=device)
|
106 |
|
107 |
-
checkpoint = torch.load('
|
108 |
model.load_state_dict(checkpoint['model_state_dict'])
|
109 |
|
110 |
# evaluation mode
|
@@ -207,40 +207,67 @@ async def process_single_dog(image):
|
|
207 |
|
208 |
# Case 1: Low confidence - unclear image or breed not in dataset
|
209 |
if top1_prob < 0.15:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
initial_state = {
|
211 |
-
"explanation":
|
212 |
"image": None,
|
213 |
"is_multi_dog": False
|
214 |
}
|
215 |
-
return
|
216 |
|
217 |
breed = topk_breeds[0]
|
218 |
|
219 |
# Case 2: High confidence - single breed result
|
220 |
if top1_prob >= 0.45:
|
221 |
description = get_dog_description(breed)
|
222 |
-
formatted_description =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
initial_state = {
|
224 |
-
"explanation":
|
225 |
"image": image,
|
226 |
"is_multi_dog": False
|
227 |
}
|
228 |
-
return
|
229 |
|
230 |
# Case 3: Medium confidence - show top 3 breeds with relative probabilities
|
231 |
else:
|
232 |
-
|
233 |
for i, (breed, prob) in enumerate(zip(topk_breeds, relative_probs)):
|
234 |
description = get_dog_description(breed)
|
235 |
-
formatted_description =
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
initial_state = {
|
239 |
-
"explanation":
|
240 |
"image": image,
|
241 |
"is_multi_dog": False
|
242 |
}
|
243 |
-
return
|
244 |
|
245 |
|
246 |
async def predict(image):
|
|
|
43 |
"Miniature_Poodle", "Miniature_Schnauzer", "Otterhound", "Papillon", "Pug", "Redbone",
|
44 |
"Schipperke", "Silky_Terrier", "Soft-Coated_Wheaten_Terrier", "Standard_Poodle",
|
45 |
"Standard_Schnauzer", "Toy_Poodle", "Toy_Terrier", "Vizsla", "Whippet",
|
46 |
+
"Wire-Haired_Fox_Terrier"]
|
47 |
|
48 |
class MultiHeadAttention(nn.Module):
|
49 |
|
|
|
100 |
return logits, attended_features
|
101 |
|
102 |
|
103 |
+
num_classes = 120
|
104 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
105 |
model = BaseModel(num_classes=num_classes, device=device)
|
106 |
|
107 |
+
checkpoint = torch.load('best_model_81_dog.pth', map_location=torch.device('cpu'))
|
108 |
model.load_state_dict(checkpoint['model_state_dict'])
|
109 |
|
110 |
# evaluation mode
|
|
|
207 |
|
208 |
# Case 1: Low confidence - unclear image or breed not in dataset
|
209 |
if top1_prob < 0.15:
|
210 |
+
error_message = '''
|
211 |
+
<div class="dog-info-card">
|
212 |
+
<div class="breed-info">
|
213 |
+
<p class="warning-message">
|
214 |
+
<span class="icon">⚠️</span>
|
215 |
+
The image is unclear or the breed is not in the dataset. Please upload a clearer image of a dog.
|
216 |
+
</p>
|
217 |
+
</div>
|
218 |
+
</div>
|
219 |
+
'''
|
220 |
initial_state = {
|
221 |
+
"explanation": error_message,
|
222 |
"image": None,
|
223 |
"is_multi_dog": False
|
224 |
}
|
225 |
+
return error_message, None, initial_state
|
226 |
|
227 |
breed = topk_breeds[0]
|
228 |
|
229 |
# Case 2: High confidence - single breed result
|
230 |
if top1_prob >= 0.45:
|
231 |
description = get_dog_description(breed)
|
232 |
+
formatted_description = format_description_html(description, breed) # 使用 format_description_html
|
233 |
+
html_content = f'''
|
234 |
+
<div class="dog-info-card">
|
235 |
+
<div class="breed-info">
|
236 |
+
{formatted_description}
|
237 |
+
</div>
|
238 |
+
</div>
|
239 |
+
'''
|
240 |
initial_state = {
|
241 |
+
"explanation": html_content,
|
242 |
"image": image,
|
243 |
"is_multi_dog": False
|
244 |
}
|
245 |
+
return html_content, image, initial_state
|
246 |
|
247 |
# Case 3: Medium confidence - show top 3 breeds with relative probabilities
|
248 |
else:
|
249 |
+
breeds_html = ""
|
250 |
for i, (breed, prob) in enumerate(zip(topk_breeds, relative_probs)):
|
251 |
description = get_dog_description(breed)
|
252 |
+
formatted_description = format_description_html(description, breed) # 使用 format_description_html
|
253 |
+
breeds_html += f'''
|
254 |
+
<div class="dog-info-card">
|
255 |
+
<div class="breed-info">
|
256 |
+
<div class="breed-header">
|
257 |
+
<span class="breed-name">Breed {i+1}: {breed}</span>
|
258 |
+
<span class="confidence-badge">Confidence: {prob}</span>
|
259 |
+
</div>
|
260 |
+
{formatted_description}
|
261 |
+
</div>
|
262 |
+
</div>
|
263 |
+
'''
|
264 |
|
265 |
initial_state = {
|
266 |
+
"explanation": breeds_html,
|
267 |
"image": image,
|
268 |
"is_multi_dog": False
|
269 |
}
|
270 |
+
return breeds_html, image, initial_state
|
271 |
|
272 |
|
273 |
async def predict(image):
|