sashavor commited on
Commit
0d69242
1 Parent(s): c45624f

adding new index and changing app a bit

Browse files
Files changed (3) hide show
  1. app.py +8 -8
  2. index_768.pickle +3 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -8,23 +8,23 @@ from transformers import AutoModel, AutoFeatureExtractor
8
  seed = 42
9
 
10
  # Only runs once when the script is first run.
11
- with open("index.pickle", "rb") as handle:
12
  index = pickle.load(handle)
13
 
14
  # Load model for computing embeddings.
15
- feature_extractor = AutoFeatureExtractor.from_pretrained("abhishek/autotrain-butterflies-new-17716425")
16
- model = AutoModel.from_pretrained("abhishek/autotrain-butterflies-new-17716425")
17
 
18
  # Candidate images.
19
- dataset = load_dataset("sasha/butterflies_names_multiple")
20
  ds = dataset["train"]
21
 
22
 
23
- def query(image, top_k):
24
  inputs = feature_extractor(image, return_tensors="pt")
25
  model_output = model(**inputs)
26
  embedding = model_output.pooler_output.detach()
27
- results = index.query(embedding)
28
  inx = results[0][0].tolist()
29
  images = ds.select(inx)["image"]
30
  return images
@@ -37,8 +37,8 @@ description = "This Space demos an image similarity system. You can refer to [th
37
  # Not sure what the best for this demo is.
38
  gr.Interface(
39
  query,
40
- inputs=[gr.Image(type="pil"), gr.Slider(value=5, minimum=1, maximum=10, step=1)],
41
- outputs=gr.Gallery().style(grid=[3], height="auto"),
42
  # Filenames denote the integer labels. Know here: https://hf.co/datasets/beans
43
  title=title,
44
  description=description,
 
8
  seed = 42
9
 
10
  # Only runs once when the script is first run.
11
+ with open("index_768.pickle", "rb") as handle:
12
  index = pickle.load(handle)
13
 
14
  # Load model for computing embeddings.
15
+ feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")
16
+ model = AutoModel.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")
17
 
18
  # Candidate images.
19
+ dataset = load_dataset("sasha/butterflies_10k_names_multiple")
20
  ds = dataset["train"]
21
 
22
 
23
+ def query(image, top_k=4):
24
  inputs = feature_extractor(image, return_tensors="pt")
25
  model_output = model(**inputs)
26
  embedding = model_output.pooler_output.detach()
27
+ results = index.query(embedding, k=top_k)
28
  inx = results[0][0].tolist()
29
  images = ds.select(inx)["image"]
30
  return images
 
37
  # Not sure what the best for this demo is.
38
  gr.Interface(
39
  query,
40
+ inputs=[gr.Image(type="pil")],
41
+ outputs=gr.Gallery().style(grid=[2], height="auto"),
42
  # Filenames denote the integer labels. Know here: https://hf.co/datasets/beans
43
  title=title,
44
  description=description,
index_768.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eccd83bb743f6de5eaf05886f948d95daceeede60805ad01cdba0baddd1a60cc
3
+ size 53317256
requirements.txt CHANGED
@@ -2,4 +2,5 @@ transformers==4.25.1
2
  datasets==2.7.1
3
  numpy==1.21.6
4
  torch==1.12.1
5
- torchvision
 
 
2
  datasets==2.7.1
3
  numpy==1.21.6
4
  torch==1.12.1
5
+ torchvision
6
+ pynndescent