Ceyda Cinarel commited on
Commit
a3be375
1 Parent(s): da76319

change to small dataset

Browse files
Files changed (2) hide show
  1. app.py +7 -7
  2. demo.py +1 -4
app.py CHANGED
@@ -56,12 +56,12 @@ if ims is not None:
56
  # for r in retrieved_examples["image"]:
57
  # st.image(r)
58
 
59
- if any(picks):
60
- # st.write("Nearest butterflies:")
61
- for i,pick in enumerate(picks):
62
- if pick:
63
- scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
64
- for r in retrieved_examples["image"]:
65
- cols[i].image(r)
66
 
67
 
 
56
  # for r in retrieved_examples["image"]:
57
  # st.image(r)
58
 
59
+ if any(picks):
60
+ # st.write("Nearest butterflies:")
61
+ for i,pick in enumerate(picks):
62
+ if pick:
63
+ scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
64
+ for r in retrieved_examples["image"]:
65
+ cols[i].image(r)
66
 
67
 
demo.py CHANGED
@@ -2,12 +2,9 @@ import torch
2
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
3
  from datasets import load_dataset
4
 
5
- def get_train_data(dataset_name="ceyda/smithsonian_butterflies_transparent_cropped",data_limit=1000):
6
  dataset=load_dataset(dataset_name)
7
  dataset=dataset.sort("sim_score")
8
- score_thresh = dataset["train"][data_limit]['sim_score']
9
- dataset = dataset.filter(lambda x: x['sim_score'] < score_thresh)
10
- dataset = dataset.map(lambda x: {'image' : x['image'].convert("RGB")})
11
  return dataset["train"]
12
 
13
  from transformers import BeitFeatureExtractor, BeitForImageClassification
 
2
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
3
  from datasets import load_dataset
4
 
5
+ def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"):
6
  dataset=load_dataset(dataset_name)
7
  dataset=dataset.sort("sim_score")
 
 
 
8
  return dataset["train"]
9
 
10
  from transformers import BeitFeatureExtractor, BeitForImageClassification