DawnC commited on
Commit
3d323ba
1 Parent(s): d96e417

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -90
app.py CHANGED
@@ -3,15 +3,15 @@ import numpy as np
3
  import torch
4
  import torch.nn as nn
5
  import gradio as gr
6
- from dataclasses import dataclass
7
  from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
8
  from torchvision.ops import nms, box_iou
9
  import torch.nn.functional as F
10
  from torchvision import transforms
11
  from PIL import Image, ImageDraw, ImageFont, ImageFilter
12
- from dog_database import get_dog_description
13
  from breed_health_info import breed_health_info
14
  from breed_noise_info import breed_noise_info
 
15
  from scoring_calculation_system import UserPreferences
16
  from recommendation_html_format import format_recommendation_html, get_breed_recommendations
17
  from history_manager import UserHistoryManager
@@ -42,19 +42,19 @@ model_yolo = YOLO('yolov8l.pt')
42
  history_manager = UserHistoryManager()
43
 
44
  dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staffordshire_Terrier",
45
- "Appenzeller", "Australian_Terrier", "Bedlington_Terrier", "Bernese_Mountain_Dog",
46
  "Blenheim_Spaniel", "Border_Collie", "Border_Terrier", "Boston_Bull", "Bouvier_Des_Flandres",
47
  "Brabancon_Griffon", "Brittany_Spaniel", "Cardigan", "Chesapeake_Bay_Retriever",
48
- "Chihuahua", "Dandie_Dinmont", "Doberman", "English_Foxhound", "English_Setter",
49
  "English_Springer", "EntleBucher", "Eskimo_Dog", "French_Bulldog", "German_Shepherd",
50
  "German_Short-Haired_Pointer", "Gordon_Setter", "Great_Dane", "Great_Pyrenees",
51
- "Greater_Swiss_Mountain_Dog", "Ibizan_Hound", "Irish_Setter", "Irish_Terrier",
52
  "Irish_Water_Spaniel", "Irish_Wolfhound", "Italian_Greyhound", "Japanese_Spaniel",
53
  "Kerry_Blue_Terrier", "Labrador_Retriever", "Lakeland_Terrier", "Leonberg", "Lhasa",
54
  "Maltese_Dog", "Mexican_Hairless", "Newfoundland", "Norfolk_Terrier", "Norwegian_Elkhound",
55
  "Norwich_Terrier", "Old_English_Sheepdog", "Pekinese", "Pembroke", "Pomeranian",
56
  "Rhodesian_Ridgeback", "Rottweiler", "Saint_Bernard", "Saluki", "Samoyed",
57
- "Scotch_Terrier", "Scottish_Deerhound", "Sealyham_Terrier", "Shetland_Sheepdog",
58
  "Shih-Tzu", "Siberian_Husky", "Staffordshire_Bullterrier", "Sussex_Spaniel",
59
  "Tibetan_Mastiff", "Tibetan_Terrier", "Walker_Hound", "Weimaraner",
60
  "Welsh_Springer_Spaniel", "West_Highland_White_Terrier", "Yorkshire_Terrier",
@@ -68,6 +68,7 @@ dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staff
68
  "Standard_Schnauzer", "Toy_Poodle", "Toy_Terrier", "Vizsla", "Whippet",
69
  "Wire-Haired_Fox_Terrier"]
70
 
 
71
  class MultiHeadAttention(nn.Module):
72
 
73
  def __init__(self, in_dim, num_heads=8):
@@ -122,15 +123,19 @@ class BaseModel(nn.Module):
122
  logits = self.classifier(attended_features)
123
  return logits, attended_features
124
 
125
-
126
- num_classes = 120
127
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
128
- model = BaseModel(num_classes=num_classes, device=device)
129
 
130
- checkpoint = torch.load('best_model_81_dog.pth', map_location=torch.device('cpu'))
131
- model.load_state_dict(checkpoint['model_state_dict'])
132
 
133
- # evaluation mode
 
 
 
 
 
134
  model.eval()
135
 
136
  # Image preprocessing function
@@ -149,24 +154,38 @@ def preprocess_image(image):
149
  return transform(image).unsqueeze(0)
150
 
151
  async def predict_single_dog(image):
152
- image_tensor = preprocess_image(image)
 
 
 
 
 
 
 
 
153
  with torch.no_grad():
154
- output = model(image_tensor)
155
- logits = output[0] if isinstance(output, tuple) else output
156
- probabilities = F.softmax(logits, dim=1)
157
- topk_probs, topk_indices = torch.topk(probabilities, k=3)
158
- top1_prob = topk_probs[0][0].item()
159
- topk_breeds = [dog_breeds[idx.item()] for idx in topk_indices[0]]
160
-
161
- # Calculate relative probabilities for display
162
- raw_probs = [prob.item() for prob in topk_probs[0]]
163
- sum_probs = sum(raw_probs)
164
- relative_probs = [f"{(prob/sum_probs * 100):.2f}%" for prob in raw_probs]
165
-
166
- return top1_prob, topk_breeds, relative_probs
167
-
168
-
169
- async def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.45):
 
 
 
 
 
 
170
  results = model_yolo(image, conf=conf_threshold, iou=iou_threshold)[0]
171
  dogs = []
172
  boxes = []
@@ -193,7 +212,6 @@ async def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.45):
193
 
194
  return dogs
195
 
196
-
197
  def non_max_suppression(boxes, iou_threshold):
198
  keep = []
199
  boxes = sorted(boxes, key=lambda x: x[1], reverse=True)
@@ -218,52 +236,6 @@ def calculate_iou(box1, box2):
218
  return iou
219
 
220
 
221
- async def process_single_dog(image):
222
- """Process a single dog image and return breed predictions and HTML output."""
223
- top1_prob, topk_breeds, relative_probs = await predict_single_dog(image)
224
-
225
- # Case 1: Low confidence - unclear image or breed not in dataset
226
- if top1_prob < 0.2:
227
- error_message = format_warning_html(
228
- 'The image is unclear or the breed is not in the dataset. Please upload a clearer image of a dog.'
229
- )
230
- initial_state = {
231
- "explanation": error_message,
232
- "image": None,
233
- "is_multi_dog": False
234
- }
235
- return error_message, None, initial_state
236
-
237
- breed = topk_breeds[0]
238
-
239
- # Case 2: High confidence - single breed result
240
- if top1_prob >= 0.45:
241
- description = get_dog_description(breed)
242
- html_content = format_single_dog_result(breed, description)
243
- initial_state = {
244
- "explanation": html_content,
245
- "image": image,
246
- "is_multi_dog": False
247
- }
248
- return html_content, image, initial_state
249
-
250
- # Case 3: Medium confidence - show top 3 breeds with relative probabilities
251
- description = get_dog_description(breed)
252
- breeds_html = format_multiple_breeds_result(
253
- topk_breeds=topk_breeds,
254
- relative_probs=relative_probs,
255
- color='#34C759', # 使用單狗顏色
256
- index=1, # 因為是單狗處理,所以index為1
257
- get_dog_description=get_dog_description
258
- )
259
-
260
- initial_state = {
261
- "explanation": breeds_html,
262
- "image": image,
263
- "is_multi_dog": False
264
- }
265
- return breeds_html, image, initial_state
266
-
267
 
268
  def create_breed_comparison(breed1: str, breed2: str) -> dict:
269
  breed1_info = get_dog_description(breed1)
@@ -353,21 +325,46 @@ async def predict(image):
353
  top1_prob, topk_breeds, relative_probs = await predict_single_dog(cropped_image)
354
  combined_confidence = detection_confidence * top1_prob
355
 
356
- # Format results based on confidence
357
- if combined_confidence < 0.2:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  dogs_info += format_error_message(color, i+1)
359
- elif top1_prob >= 0.45:
360
- breed = topk_breeds[0]
361
- description = get_dog_description(breed)
362
- dogs_info += format_single_dog_result(breed, description, color)
363
- else:
364
- dogs_info += format_multiple_breeds_result(
365
- topk_breeds,
366
- relative_probs,
367
- color,
368
- i+1,
369
- get_dog_description
370
- )
371
 
372
  # Wrap final HTML output
373
  html_output = format_multi_dog_container(dogs_info)
@@ -422,6 +419,7 @@ def show_details_html(choice, previous_output, initial_state):
422
  def main():
423
  with gr.Blocks(css=get_css_styles()) as iface:
424
  # Header HTML
 
425
  gr.HTML("""
426
  <header style='text-align: center; padding: 20px; margin-bottom: 20px;'>
427
  <h1 style='font-size: 2.5em; margin-bottom: 10px; color: #2D3748;'>
@@ -467,6 +465,7 @@ def main():
467
  history_component=history_component
468
  )
469
 
 
470
  # 4. 最後創建歷史記錄標籤頁
471
  create_history_tab(history_component)
472
 
 
3
  import torch
4
  import torch.nn as nn
5
  import gradio as gr
6
+ import time
7
  from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
8
  from torchvision.ops import nms, box_iou
9
  import torch.nn.functional as F
10
  from torchvision import transforms
11
  from PIL import Image, ImageDraw, ImageFont, ImageFilter
 
12
  from breed_health_info import breed_health_info
13
  from breed_noise_info import breed_noise_info
14
+ from dog_database import get_dog_description
15
  from scoring_calculation_system import UserPreferences
16
  from recommendation_html_format import format_recommendation_html, get_breed_recommendations
17
  from history_manager import UserHistoryManager
 
42
  history_manager = UserHistoryManager()
43
 
44
  dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staffordshire_Terrier",
45
+ "Appenzeller", "Australian_Terrier", "Bedlington_Terrier", "Bernese_Mountain_Dog", "Bichon_Frise",
46
  "Blenheim_Spaniel", "Border_Collie", "Border_Terrier", "Boston_Bull", "Bouvier_Des_Flandres",
47
  "Brabancon_Griffon", "Brittany_Spaniel", "Cardigan", "Chesapeake_Bay_Retriever",
48
+ "Chihuahua", "Dachshund", "Dandie_Dinmont", "Doberman", "English_Foxhound", "English_Setter",
49
  "English_Springer", "EntleBucher", "Eskimo_Dog", "French_Bulldog", "German_Shepherd",
50
  "German_Short-Haired_Pointer", "Gordon_Setter", "Great_Dane", "Great_Pyrenees",
51
+ "Greater_Swiss_Mountain_Dog","Havanese", "Ibizan_Hound", "Irish_Setter", "Irish_Terrier",
52
  "Irish_Water_Spaniel", "Irish_Wolfhound", "Italian_Greyhound", "Japanese_Spaniel",
53
  "Kerry_Blue_Terrier", "Labrador_Retriever", "Lakeland_Terrier", "Leonberg", "Lhasa",
54
  "Maltese_Dog", "Mexican_Hairless", "Newfoundland", "Norfolk_Terrier", "Norwegian_Elkhound",
55
  "Norwich_Terrier", "Old_English_Sheepdog", "Pekinese", "Pembroke", "Pomeranian",
56
  "Rhodesian_Ridgeback", "Rottweiler", "Saint_Bernard", "Saluki", "Samoyed",
57
+ "Scotch_Terrier", "Scottish_Deerhound", "Sealyham_Terrier", "Shetland_Sheepdog", "Shiba_Inu",
58
  "Shih-Tzu", "Siberian_Husky", "Staffordshire_Bullterrier", "Sussex_Spaniel",
59
  "Tibetan_Mastiff", "Tibetan_Terrier", "Walker_Hound", "Weimaraner",
60
  "Welsh_Springer_Spaniel", "West_Highland_White_Terrier", "Yorkshire_Terrier",
 
68
  "Standard_Schnauzer", "Toy_Poodle", "Toy_Terrier", "Vizsla", "Whippet",
69
  "Wire-Haired_Fox_Terrier"]
70
 
71
+
72
  class MultiHeadAttention(nn.Module):
73
 
74
  def __init__(self, in_dim, num_heads=8):
 
123
  logits = self.classifier(attended_features)
124
  return logits, attended_features
125
 
126
+ # Initialize model
127
+ num_classes = len(dog_breeds)
128
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
129
 
130
+ # Initialize base model
131
+ model = BaseModel(num_classes=num_classes, device=device).to(device)
132
 
133
+ # Load model path
134
+ model_path = "124_best_model_dog.pth"
135
+ checkpoint = torch.load(model_path, map_location=device)
136
+
137
+ # Load model state
138
+ model.load_state_dict(checkpoint["base_model"], strict=False)
139
  model.eval()
140
 
141
  # Image preprocessing function
 
154
  return transform(image).unsqueeze(0)
155
 
156
  async def predict_single_dog(image):
157
+ """
158
+ Predicts the dog breed using only the classifier.
159
+ Args:
160
+ image: PIL Image or numpy array
161
+ Returns:
162
+ tuple: (top1_prob, topk_breeds, relative_probs)
163
+ """
164
+ image_tensor = preprocess_image(image).to(device)
165
+
166
  with torch.no_grad():
167
+ # Get model outputs (只使用logits,不需要features)
168
+ logits = model(image_tensor)[0] # 如果model仍返回tuple,取第一個元素
169
+ probs = F.softmax(logits, dim=1)
170
+
171
+ # Classifier prediction
172
+ top5_prob, top5_idx = torch.topk(probs, k=5)
173
+ breeds = [dog_breeds[idx.item()] for idx in top5_idx[0]]
174
+ probabilities = [prob.item() for prob in top5_prob[0]]
175
+
176
+ # Calculate relative probabilities
177
+ sum_probs = sum(probabilities[:3]) # 只取前三個來計算相對概率
178
+ relative_probs = [f"{(prob/sum_probs * 100):.2f}%" for prob in probabilities[:3]]
179
+
180
+ # Debug output
181
+ print("\nClassifier Predictions:")
182
+ for breed, prob in zip(breeds[:5], probabilities[:5]):
183
+ print(f"{breed}: {prob:.4f}")
184
+
185
+ return probabilities[0], breeds[:3], relative_probs
186
+
187
+
188
+ async def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
189
  results = model_yolo(image, conf=conf_threshold, iou=iou_threshold)[0]
190
  dogs = []
191
  boxes = []
 
212
 
213
  return dogs
214
 
 
215
  def non_max_suppression(boxes, iou_threshold):
216
  keep = []
217
  boxes = sorted(boxes, key=lambda x: x[1], reverse=True)
 
236
  return iou
237
 
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  def create_breed_comparison(breed1: str, breed2: str) -> dict:
241
  breed1_info = get_dog_description(breed1)
 
325
  top1_prob, topk_breeds, relative_probs = await predict_single_dog(cropped_image)
326
  combined_confidence = detection_confidence * top1_prob
327
 
328
+ # Format results based on confidence with error handling
329
+ try:
330
+ if combined_confidence < 0.2:
331
+ dogs_info += format_error_message(color, i+1)
332
+ elif top1_prob >= 0.45:
333
+ breed = topk_breeds[0]
334
+ description = get_dog_description(breed)
335
+ # Handle missing breed description
336
+ if description is None:
337
+ # 如果沒有描述,創建一個基本描述
338
+ description = {
339
+ "Name": breed,
340
+ "Size": "Unknown",
341
+ "Exercise Needs": "Unknown",
342
+ "Grooming Needs": "Unknown",
343
+ "Care Level": "Unknown",
344
+ "Good with Children": "Unknown",
345
+ "Description": f"Identified as {breed.replace('_', ' ')}"
346
+ }
347
+ dogs_info += format_single_dog_result(breed, description, color)
348
+ else:
349
+ # 修改format_multiple_breeds_result的調用,包含錯誤處理
350
+ dogs_info += format_multiple_breeds_result(
351
+ topk_breeds,
352
+ relative_probs,
353
+ color,
354
+ i+1,
355
+ lambda breed: get_dog_description(breed) or {
356
+ "Name": breed,
357
+ "Size": "Unknown",
358
+ "Exercise Needs": "Unknown",
359
+ "Grooming Needs": "Unknown",
360
+ "Care Level": "Unknown",
361
+ "Good with Children": "Unknown",
362
+ "Description": f"Identified as {breed.replace('_', ' ')}"
363
+ }
364
+ )
365
+ except Exception as e:
366
+ print(f"Error formatting results for dog {i+1}: {str(e)}")
367
  dogs_info += format_error_message(color, i+1)
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
  # Wrap final HTML output
370
  html_output = format_multi_dog_container(dogs_info)
 
419
  def main():
420
  with gr.Blocks(css=get_css_styles()) as iface:
421
  # Header HTML
422
+
423
  gr.HTML("""
424
  <header style='text-align: center; padding: 20px; margin-bottom: 20px;'>
425
  <h1 style='font-size: 2.5em; margin-bottom: 10px; color: #2D3748;'>
 
465
  history_component=history_component
466
  )
467
 
468
+
469
  # 4. 最後創建歷史記錄標籤頁
470
  create_history_tab(history_component)
471