File size: 1,914 Bytes
9fbe234
b0b9e1f
9fbe234
b0b9e1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fbe234
 
 
 
 
b0b9e1f
 
9fbe234
b0b9e1f
 
 
 
9fbe234
 
 
 
 
 
b0b9e1f
 
9fbe234
b0b9e1f
9fbe234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0b9e1f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import re
import streamlit as st # HF spaces at v1.2.0
from demo import load_model,generate,get_dataset,embed

# TODOs
# Add markdown short readme project intro
# project setup:

# git clone https://github.com/huggingface/community-events.git
# cd community-events
# pip install .
 
st.title("ButterflyGAN")
st.write("## This butterfly does not exist! ")
st.write("Demo prep still in progress!!")

@st.experimental_singleton
def load_model_intocache(model_name):
    
    # model_name='ceyda/butterfly_512_base'
    gan = load_model(model_name)

    return gan

@st.experimental_singleton
def load_dataset():
    dataset=get_dataset()
    return dataset

model_name='ceyda/butterfly_cropped_uniq1K_512'
model=load_model_intocache(model_name)
dataset=load_dataset()

st.write(f"Model {model_name} is loaded")
st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}")

if 'ims' not in st.session_state:
    st.session_state['ims'] = None

ims=st.session_state["ims"]
batch_size=4 #generate 4 butterflies 
def run():
    with st.spinner("Generating..."):
        ims=generate(model,batch_size)
        st.session_state['ims'] = ims

runb=st.button("Generate", on_click=run)
if ims is not None:
    cols=st.columns(batch_size)
    picks=[False]*batch_size
    for i,im in enumerate(ims):
        cols[i].image(im)
        picks[i]=cols[i].button("Find Nearest",key="pick_"+str(i))
        # if picks[i]:
        #     scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(im), k=5)
        #     for r in retrieved_examples["image"]:
        #         st.image(r)
    
if any(picks):
    # st.write("Nearest butterflies:")
    for i,pick in enumerate(picks):
        if pick:
            scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
            for r in retrieved_examples["image"]:
                cols[i].image(r)