Spaces:
Runtime error
Runtime error
Ceyda Cinarel
commited on
Commit
•
9fbe234
1
Parent(s):
b0b9e1f
add nearest neighbor
Browse files- .gitattributes +1 -0
- app.py +34 -8
- beit_index.faiss +3 -0
- demo.py +23 -4
.gitattributes
CHANGED
@@ -26,3 +26,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
26 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
27 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
28 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
26 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
27 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
28 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.faiss filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
|
|
1 |
import streamlit as st # HF spaces at v1.2.0
|
2 |
-
from demo import load_model,generate
|
3 |
|
4 |
# TODOs
|
5 |
# Add markdown short readme project intro
|
@@ -21,21 +22,46 @@ def load_model_intocache(model_name):
|
|
21 |
|
22 |
return gan
|
23 |
|
|
|
|
|
|
|
|
|
|
|
24 |
model_name='ceyda/butterfly_cropped_uniq1K_512'
|
25 |
model=load_model_intocache(model_name)
|
|
|
26 |
|
27 |
st.write(f"Model {model_name} is loaded")
|
28 |
st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}")
|
29 |
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
with st.spinner("Generating..."):
|
33 |
-
|
34 |
-
batch_size=4 #generate 4 butterflies
|
35 |
ims=generate(model,batch_size)
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
|
|
1 |
+
import re
|
2 |
import streamlit as st # HF spaces at v1.2.0
|
3 |
+
from demo import load_model,generate,get_dataset,embed
|
4 |
|
5 |
# TODOs
|
6 |
# Add markdown short readme project intro
|
|
|
22 |
|
23 |
return gan
|
24 |
|
25 |
+
@st.experimental_singleton
|
26 |
+
def load_dataset():
|
27 |
+
dataset=get_dataset()
|
28 |
+
return dataset
|
29 |
+
|
30 |
model_name='ceyda/butterfly_cropped_uniq1K_512'
|
31 |
model=load_model_intocache(model_name)
|
32 |
+
dataset=load_dataset()
|
33 |
|
34 |
st.write(f"Model {model_name} is loaded")
|
35 |
st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}")
|
36 |
|
37 |
+
if 'ims' not in st.session_state:
|
38 |
+
st.session_state['ims'] = None
|
39 |
+
|
40 |
+
ims=st.session_state["ims"]
|
41 |
+
batch_size=4 #generate 4 butterflies
|
42 |
+
def run():
|
43 |
with st.spinner("Generating..."):
|
|
|
|
|
44 |
ims=generate(model,batch_size)
|
45 |
+
st.session_state['ims'] = ims
|
46 |
|
47 |
+
runb=st.button("Generate", on_click=run)
|
48 |
+
if ims is not None:
|
49 |
+
cols=st.columns(batch_size)
|
50 |
+
picks=[False]*batch_size
|
51 |
+
for i,im in enumerate(ims):
|
52 |
+
cols[i].image(im)
|
53 |
+
picks[i]=cols[i].button("Find Nearest",key="pick_"+str(i))
|
54 |
+
# if picks[i]:
|
55 |
+
# scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(im), k=5)
|
56 |
+
# for r in retrieved_examples["image"]:
|
57 |
+
# st.image(r)
|
58 |
+
|
59 |
+
if any(picks):
|
60 |
+
# st.write("Nearest butterflies:")
|
61 |
+
for i,pick in enumerate(picks):
|
62 |
+
if pick:
|
63 |
+
scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
|
64 |
+
for r in retrieved_examples["image"]:
|
65 |
+
cols[i].image(r)
|
66 |
|
67 |
|
beit_index.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d56496f69d06d78867ab39298a5354c0419056000824d82b06db343449c4518d
|
3 |
+
size 3072045
|
demo.py
CHANGED
@@ -7,15 +7,34 @@ def get_train_data(dataset_name="ceyda/smithsonian_butterflies_transparent_cropp
|
|
7 |
dataset=dataset.sort("sim_score")
|
8 |
score_thresh = dataset["train"][data_limit]['sim_score']
|
9 |
dataset = dataset.filter(lambda x: x['sim_score'] < score_thresh)
|
10 |
-
|
11 |
-
dataset = dataset.map(lambda x: x.convert("RGB"))
|
12 |
return dataset["train"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512'):
|
17 |
gan = LightweightGAN.from_pretrained(model_name)
|
18 |
-
gan.eval()
|
19 |
return gan
|
20 |
|
21 |
def generate(gan,batch_size=1):
|
|
|
7 |
dataset=dataset.sort("sim_score")
|
8 |
score_thresh = dataset["train"][data_limit]['sim_score']
|
9 |
dataset = dataset.filter(lambda x: x['sim_score'] < score_thresh)
|
10 |
+
dataset = dataset.map(lambda x: {'image' : x['image'].convert("RGB")})
|
|
|
11 |
return dataset["train"]
|
12 |
+
|
13 |
+
from transformers import BeitFeatureExtractor, BeitForImageClassification
|
14 |
+
feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
|
15 |
+
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
|
16 |
+
def embed(images):
|
17 |
+
inputs = feature_extractor(images=images, return_tensors="pt")
|
18 |
+
outputs = model(**inputs,output_hidden_states= True)
|
19 |
+
last_hidden=outputs.hidden_states[-1]
|
20 |
+
pooler=model.base_model.pooler
|
21 |
+
final_emb=pooler(last_hidden).detach().numpy()
|
22 |
+
return final_emb
|
23 |
|
24 |
+
def build_index():
|
25 |
+
dataset=get_train_data()
|
26 |
+
ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20)
|
27 |
+
ds_with_embeddings.add_faiss_index(column='beit_embeddings')
|
28 |
+
ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss')
|
29 |
+
|
30 |
+
def get_dataset():
|
31 |
+
dataset=get_train_data()
|
32 |
+
dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
|
33 |
+
return dataset
|
34 |
|
35 |
def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512'):
|
36 |
gan = LightweightGAN.from_pretrained(model_name)
|
37 |
+
gan.eval()
|
38 |
return gan
|
39 |
|
40 |
def generate(gan,batch_size=1):
|