sashavor commited on
Commit
78cdedf
·
1 Parent(s): 9b28e54
Files changed (1) hide show
  1. app.py +10 -29
app.py CHANGED
@@ -1,10 +1,8 @@
1
-
2
-
3
  import pickle
4
 
5
  import gradio as gr
6
  from datasets import load_dataset
7
- from transformers import AutoModel
8
 
9
 
10
  seed = 42
@@ -18,38 +16,21 @@ feature_extractor = AutoFeatureExtractor.from_pretrained("abhishek/autotrain-but
18
  model = AutoModel.from_pretrained("abhishek/autotrain-butterflies-new-17716425")
19
 
20
  # Candidate images.
21
- dataset = load_dataset("huggan/inat_butterflies_top10k")
22
- candidate_dataset = dataset["train"]
23
 
24
 
25
  def query(image, top_k):
26
  inputs = feature_extractor(image, return_tensors="pt")
27
  model_output = model(**inputs)
28
- embedding = model_output.pooler_output
29
  results = index.query(embedding)
30
-
31
- # Should be a list of string file paths for gr.Gallery to work
32
- images = []
33
- # List of labels for each image in the gallery
34
- labels = []
35
-
36
- candidates = []
37
-
38
- for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
39
- if idx == top_k:
40
- break
41
- image_id, label = r.split("_")[0], r.split("_")[1]
42
- candidates.append(candidate_dataset[int(image_id)]["image"])
43
- labels.append(f"Label: {label}")
44
-
45
- for i, candidate in enumerate(candidates):
46
- filename = f"similar_{i}.png"
47
- candidate.save(filename)
48
- images.append(filename)
49
-
50
- # The gallery component can be a list of tuples, where the first element is a path to a file
51
- # and the second element is an optional caption for that image
52
- return list(zip(images, labels))
53
 
54
 
55
  title = "Find my Butterfly 🦋"
 
 
 
1
  import pickle
2
 
3
  import gradio as gr
4
  from datasets import load_dataset
5
+ from transformers import AutoModel, AutoFeatureExtractor
6
 
7
 
8
  seed = 42
 
16
  model = AutoModel.from_pretrained("abhishek/autotrain-butterflies-new-17716425")
17
 
18
  # Candidate images.
19
+ dataset = load_dataset("sasha/butterflies_names_multiple")
20
+ ds = dataset["train"]
21
 
22
 
23
  def query(image, top_k):
24
  inputs = feature_extractor(image, return_tensors="pt")
25
  model_output = model(**inputs)
26
+ embedding = model_output.pooler_output.detach()
27
  results = index.query(embedding)
28
+ images=[]
29
+ for i in results[0].tolist():
30
+ print(i)
31
+ print(type(i))
32
+ images.append(ds.select(i)["image"])
33
+ return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  title = "Find my Butterfly 🦋"