DawnC commited on
Commit
7921180
1 Parent(s): 20aee3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -21
app.py CHANGED
@@ -143,6 +143,40 @@ def preprocess_image(image):
143
  def get_akc_breeds_link():
144
  return "https://www.akc.org/dog-breeds/"
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def predict(image):
147
  try:
148
  image_tensor = preprocess_image(image)
@@ -152,30 +186,52 @@ def predict(image):
152
  logits = output[0]
153
  else:
154
  logits = output
155
- _, predicted = torch.max(logits, 1) # predicted is the max value's index on dim=1
156
- breed = dog_breeds[predicted.item()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- description = get_dog_description(breed)
159
  akc_link = get_akc_breeds_link()
160
-
161
- if isinstance(description, dict):
162
- description_str = "\n\n".join([f"**{key}**: {value}" for key, value in description.items()])
163
- else:
164
- description_str = description
165
-
166
- # Add AKC link as an option
167
- description_str += f"\n\n**Want to learn more about dog breeds?** [Visit the AKC dog breeds page]({akc_link}) and search for {breed} to find detailed information."
168
-
169
- # Add disclaimer
170
- disclaimer = ("\n\n*Disclaimer: The external link provided leads to the American Kennel Club (AKC) dog breeds page. "
171
- "You may need to search for the specific breed on that page. "
172
- "I am not responsible for the content on external sites. "
173
- "Please refer to the AKC's terms of use and privacy policy.*")
174
- description_str += disclaimer
175
-
176
  return description_str
177
- except Exception as e:
178
- return f"An error occurred: {e}"
179
 
180
 
181
  iface = gr.Interface(
 
143
  def get_akc_breeds_link():
144
  return "https://www.akc.org/dog-breeds/"
145
 
146
+ # def predict(image):
147
+ # try:
148
+ # image_tensor = preprocess_image(image)
149
+ # with torch.no_grad():
150
+ # output = model(image_tensor)
151
+ # if isinstance(output, tuple):
152
+ # logits = output[0]
153
+ # else:
154
+ # logits = output
155
+ # _, predicted = torch.max(logits, 1) # predicted is the max value's index on dim=1
156
+ # breed = dog_breeds[predicted.item()]
157
+
158
+ # description = get_dog_description(breed)
159
+ # akc_link = get_akc_breeds_link()
160
+
161
+ # if isinstance(description, dict):
162
+ # description_str = "\n\n".join([f"**{key}**: {value}" for key, value in description.items()])
163
+ # else:
164
+ # description_str = description
165
+
166
+ # # Add AKC link as an option
167
+ # description_str += f"\n\n**Want to learn more about dog breeds?** [Visit the AKC dog breeds page]({akc_link}) and search for {breed} to find detailed information."
168
+
169
+ # # Add disclaimer
170
+ # disclaimer = ("\n\n*Disclaimer: The external link provided leads to the American Kennel Club (AKC) dog breeds page. "
171
+ # "You may need to search for the specific breed on that page. "
172
+ # "I am not responsible for the content on external sites. "
173
+ # "Please refer to the AKC's terms of use and privacy policy.*")
174
+ # description_str += disclaimer
175
+
176
+ # return description_str
177
+ # except Exception as e:
178
+ # return f"An error occurred: {e}"
179
+
180
  def predict(image):
181
  try:
182
  image_tensor = preprocess_image(image)
 
186
  logits = output[0]
187
  else:
188
  logits = output
189
+
190
+ # 計算預測的概率分佈
191
+ probabilities = F.softmax(logits, dim=1)
192
+
193
+ # 取得最高的預測分數以及對應的品種
194
+ top_confidence, top_index = torch.max(probabilities, 1)
195
+ top_confidence = top_confidence.item() # 轉成 Python 數值
196
+ top_breed = dog_breeds[top_index.item()]
197
+
198
+ # 如果最高預測分數大於等於 60%,直接返回該品種的資訊
199
+ if top_confidence >= 0.60:
200
+ description = get_dog_description(top_breed)
201
+ akc_link = get_akc_breeds_link()
202
+ description_str = f"**Breed**: {top_breed}\n\n**Description**: {description}\n"
203
+ description_str += f"\n\n**Want to learn more about dog breeds?** [Visit the AKC dog breeds page]({akc_link}) and search for {top_breed}."
204
+ return description_str
205
+
206
+ # 如果預測分數小於 60%,返回 Top-3 預測並讓用戶選擇
207
+ else:
208
+ top3_confidences, top3_indices = torch.topk(probabilities, 3, dim=1)
209
+ top3_breeds = [dog_breeds[idx] for idx in top3_indices.squeeze().tolist()]
210
+ top3_confidences = top3_confidences.squeeze().tolist()
211
+
212
+ return {
213
+ "top3_breeds": top3_breeds,
214
+ "top3_confidences": [f"{conf * 100:.2f}%" for conf in top3_confidences],
215
+ "selected_breed": None,
216
+ "message": "The confidence score is low. Please select the correct breed from the options or select 'None of the above' if none are correct."
217
+ }
218
+
219
+ except Exception as e:
220
+ return f"An error occurred: {e}"
221
+
222
+ # 處理用戶選擇的結果
223
+ def handle_user_selection(top3_breeds, selected_breed):
224
+ if selected_breed in top3_breeds:
225
+ breed_index = top3_breeds.index(selected_breed)
226
+ description = get_dog_description(selected_breed)
227
 
 
228
  akc_link = get_akc_breeds_link()
229
+ description_str = f"**Breed**: {selected_breed}\n\n**Description**: {description}\n"
230
+ description_str += f"\n\n**Want to learn more about dog breeds?** [Visit the AKC dog breeds page]({akc_link}) and search for {selected_breed}."
231
+
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  return description_str
233
+ else:
234
+ return "Sorry, the breed could not be identified. Please try uploading a clearer image or another breed."
235
 
236
 
237
  iface = gr.Interface(