Spaces:
Sleeping
Sleeping
Yahia battach
commited on
Commit
•
016de46
1
Parent(s):
7272ff8
edit app.py
Browse files
app.py
CHANGED
@@ -129,6 +129,53 @@ def format_name(taxon, common):
|
|
129 |
return f"{taxon} ({common})"
|
130 |
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
@torch.no_grad()
|
133 |
def open_domain_classification(img, rank: int, return_all=False):
|
134 |
"""
|
@@ -136,7 +183,6 @@ def open_domain_classification(img, rank: int, return_all=False):
|
|
136 |
If targeting a higher rank than species, then this function predicts among all
|
137 |
species, then sums up species-level probabilities for the given rank.
|
138 |
"""
|
139 |
-
|
140 |
logger.info(f"Starting open domain classification for rank: {rank}")
|
141 |
img = preprocess_img(img).to(device)
|
142 |
img_features = model.encode_image(img.unsqueeze(0))
|
@@ -148,15 +194,13 @@ def open_domain_classification(img, rank: int, return_all=False):
|
|
148 |
if rank + 1 == len(ranks):
|
149 |
topk = probs.topk(k)
|
150 |
prediction_dict = {
|
151 |
-
format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
|
152 |
}
|
153 |
logger.info(f"Top K predictions: {prediction_dict}")
|
154 |
-
|
155 |
-
logger.info(f"Top prediction name: {top_prediction_name}")
|
156 |
-
sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
|
157 |
if return_all:
|
158 |
-
return prediction_dict,
|
159 |
-
return prediction_dict
|
160 |
|
161 |
output = collections.defaultdict(float)
|
162 |
for i in torch.nonzero(probs > min_prob).squeeze():
|
@@ -165,18 +209,11 @@ def open_domain_classification(img, rank: int, return_all=False):
|
|
165 |
topk_names = heapq.nlargest(k, output, key=output.get)
|
166 |
prediction_dict = {name: output[name] for name in topk_names}
|
167 |
logger.info(f"Top K names for output: {topk_names}")
|
168 |
-
|
169 |
-
|
170 |
-
top_prediction_name = topk_names[0]
|
171 |
-
logger.info(f"Top prediction name: {top_prediction_name}")
|
172 |
-
sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
|
173 |
-
logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
|
174 |
-
|
175 |
if return_all:
|
176 |
-
return prediction_dict,
|
177 |
return prediction_dict
|
178 |
|
179 |
-
|
180 |
def change_output(choice):
|
181 |
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
|
182 |
|
@@ -310,12 +347,19 @@ if __name__ == "__main__":
|
|
310 |
fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
|
311 |
)
|
312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
open_domain_btn.click(
|
314 |
-
fn=lambda img, rank: open_domain_classification(img, rank, return_all=
|
315 |
inputs=[img_input, rank_dropdown],
|
316 |
outputs=[open_domain_output],
|
317 |
)
|
318 |
|
|
|
319 |
zero_shot_btn.click(
|
320 |
fn=zero_shot_classification,
|
321 |
inputs=[img_input_zs, classes_txt],
|
|
|
129 |
return f"{taxon} ({common})"
|
130 |
|
131 |
|
132 |
+
# @torch.no_grad()
|
133 |
+
# def open_domain_classification(img, rank: int, return_all=False):
|
134 |
+
# """
|
135 |
+
# Predicts from the entire tree of life.
|
136 |
+
# If targeting a higher rank than species, then this function predicts among all
|
137 |
+
# species, then sums up species-level probabilities for the given rank.
|
138 |
+
# """
|
139 |
+
|
140 |
+
# logger.info(f"Starting open domain classification for rank: {rank}")
|
141 |
+
# img = preprocess_img(img).to(device)
|
142 |
+
# img_features = model.encode_image(img.unsqueeze(0))
|
143 |
+
# img_features = F.normalize(img_features, dim=-1)
|
144 |
+
|
145 |
+
# logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
|
146 |
+
# probs = F.softmax(logits, dim=0)
|
147 |
+
|
148 |
+
# if rank + 1 == len(ranks):
|
149 |
+
# topk = probs.topk(k)
|
150 |
+
# prediction_dict = {
|
151 |
+
# format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
|
152 |
+
# }
|
153 |
+
# logger.info(f"Top K predictions: {prediction_dict}")
|
154 |
+
# top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
|
155 |
+
# logger.info(f"Top prediction name: {top_prediction_name}")
|
156 |
+
# sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
|
157 |
+
# if return_all:
|
158 |
+
# return prediction_dict, sample_img, taxon_url
|
159 |
+
# return prediction_dict
|
160 |
+
|
161 |
+
# output = collections.defaultdict(float)
|
162 |
+
# for i in torch.nonzero(probs > min_prob).squeeze():
|
163 |
+
# output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
|
164 |
+
|
165 |
+
# topk_names = heapq.nlargest(k, output, key=output.get)
|
166 |
+
# prediction_dict = {name: output[name] for name in topk_names}
|
167 |
+
# logger.info(f"Top K names for output: {topk_names}")
|
168 |
+
# logger.info(f"Prediction dictionary: {prediction_dict}")
|
169 |
+
|
170 |
+
# top_prediction_name = topk_names[0]
|
171 |
+
# logger.info(f"Top prediction name: {top_prediction_name}")
|
172 |
+
# sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
|
173 |
+
# logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
|
174 |
+
|
175 |
+
# if return_all:
|
176 |
+
# return prediction_dict, sample_img, taxon_url
|
177 |
+
# return prediction_dict
|
178 |
+
|
179 |
@torch.no_grad()
|
180 |
def open_domain_classification(img, rank: int, return_all=False):
|
181 |
"""
|
|
|
183 |
If targeting a higher rank than species, then this function predicts among all
|
184 |
species, then sums up species-level probabilities for the given rank.
|
185 |
"""
|
|
|
186 |
logger.info(f"Starting open domain classification for rank: {rank}")
|
187 |
img = preprocess_img(img).to(device)
|
188 |
img_features = model.encode_image(img.unsqueeze(0))
|
|
|
194 |
if rank + 1 == len(ranks):
|
195 |
topk = probs.topk(k)
|
196 |
prediction_dict = {
|
197 |
+
format_name(*txt_names[i]): prob.item() for i, prob in zip(topk.indices, topk.values)
|
198 |
}
|
199 |
logger.info(f"Top K predictions: {prediction_dict}")
|
200 |
+
|
|
|
|
|
201 |
if return_all:
|
202 |
+
return prediction_dict, None, None # Return dummy None values for unused parts
|
203 |
+
return prediction_dict # Only return the dictionary for the Label component
|
204 |
|
205 |
output = collections.defaultdict(float)
|
206 |
for i in torch.nonzero(probs > min_prob).squeeze():
|
|
|
209 |
topk_names = heapq.nlargest(k, output, key=output.get)
|
210 |
prediction_dict = {name: output[name] for name in topk_names}
|
211 |
logger.info(f"Top K names for output: {topk_names}")
|
212 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
if return_all:
|
214 |
+
return prediction_dict, None, None
|
215 |
return prediction_dict
|
216 |
|
|
|
217 |
def change_output(choice):
|
218 |
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
|
219 |
|
|
|
347 |
fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
|
348 |
)
|
349 |
|
350 |
+
# open_domain_btn.click(
|
351 |
+
# fn=lambda img, rank: open_domain_classification(img, rank, return_all=True),
|
352 |
+
# inputs=[img_input, rank_dropdown],
|
353 |
+
# outputs=[open_domain_output],
|
354 |
+
# )
|
355 |
+
|
356 |
open_domain_btn.click(
|
357 |
+
fn=lambda img, rank: open_domain_classification(img, rank, return_all=False),
|
358 |
inputs=[img_input, rank_dropdown],
|
359 |
outputs=[open_domain_output],
|
360 |
)
|
361 |
|
362 |
+
|
363 |
zero_shot_btn.click(
|
364 |
fn=zero_shot_classification,
|
365 |
inputs=[img_input_zs, classes_txt],
|