butterfly-gan / app.py
ceyda's picture
Update app.py
e3c61c8
from distutils.command.build import build
import streamlit as st # HF spaces at v1.2.0
from demo import load_model,generate,get_dataset,embed,make_meme
import streamlit.components.v1 as components
import io
import os
root_dir=os.path.dirname(os.path.abspath(__file__))
build_dir = os.path.join(root_dir, "custom_component/frontend/build")
_component_func = components.declare_component("release_butterflies", path=build_dir)
def release_butterflies(name, key=None):
component_value = _component_func(name=name, key=key, default=0)
return component_value
st.sidebar.subheader("This butterfly does not exist! ")
st.sidebar.image("assets/logo.png", width=200)
st.title("ButterflyGAN")
@st.experimental_singleton
def load_model_intocache(model_name,model_version):
# model_name='ceyda/butterfly_512_base'
gan = load_model(model_name,model_version)
return gan
@st.experimental_singleton
def load_dataset():
dataset=get_dataset()
return dataset
@st.experimental_singleton
def load_variables():# Don't want to open read files over and over. not sure if it makes a diff
latent_walk_code=open("assets/code_snippets/latent_walk.py").read()
latent_walk_code_music=open("assets/code_snippets/latent_walk_music.py").read()
return latent_walk_code,latent_walk_code_music
def img2download(image):
imgByteArr = io.BytesIO()
image.save(imgByteArr, format="JPEG")
imgByteArr = imgByteArr.getvalue()
return imgByteArr
model_name='ceyda/butterfly_cropped_uniq1K_512'
model_version='57d36a15546909557d9f967f47713236c8288838'
# model_version=None
model=load_model_intocache(model_name,model_version)
dataset=load_dataset()
latent_walk_code, latent_walk_code_music=load_variables()
generate_menu="πŸ¦‹ Make butterflies"
latent_walk_menu="🎧 Take a latent walk"
make_meme_menu="🐦 Make a meme"
mosaic_menu="πŸ‘€ See the mosaic"
fun_menu="πŸ™Œ Release the butterflies"
screen = st.sidebar.radio("Pick a destination",[generate_menu,latent_walk_menu,make_meme_menu,mosaic_menu,fun_menu])
if screen == generate_menu:
batch_size=4 #generate 4 butterflies
col_num=4
def run():
with st.spinner("Generating..."):
ims=generate(model,batch_size)
st.session_state['ims'] = ims
if 'ims' not in st.session_state:
st.session_state['ims'] = None
run()
ims=st.session_state["ims"]
st.write("Light-GAN model trained on 1000 butterfly images taken from the Smithsonian Museum collection. \n \
Based on [paper:](https://openreview.net/forum?id=1Fqg133qRaI) *Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis*")
runb=st.button("Generate", on_click=run ,help="generated on the fly maybe slow")
if ims is not None:
cols=st.columns(col_num)
picks=[False]*batch_size
for j,im in enumerate(ims):
i=j%col_num
cols[i].image(im, use_column_width=True)
picks[j]=cols[i].button("Find Nearest",key="pick_"+str(j))
if any(picks):
# st.write("Nearest butterflies:")
for i,pick in enumerate(picks):
if pick:
scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
for r in retrieved_examples["image"]:
cols[i].image(r, use_column_width=True)
st.write("Nearest neighbors found in the training set according to L2 distance on 'microsoft/beit-base-patch16-224' embeddings")
st.write(f"Latent dimension: {model.latent_dim}, image size:{model.image_size}")
elif screen == latent_walk_menu:
st.write("Take a latent walk :musical_note: with cute butterflies")
cols=st.columns(3)
cols[0].caption("A regular walk (no music)")
cols[0].video("assets/latent_walks/regular_walk.mp4")
cols[1].caption("Walk with music :butterfly:")
cols[1].video("assets/latent_walks/walk_happyrock.mp4")
cols[2].caption("Walk with music :butterfly:")
cols[2].video("assets/latent_walks/walk_cute.mp4")
st.caption("Royalty Free Music from Bensound")
st.write("🎧Did those butterflies seem to be dancing to the music?!Here is the secret:")
with st.expander("See the Code Snippets"):
st.write("A regular latent walk:")
st.code(latent_walk_code, language='python')
st.write(":musical_note: latent walk with music:")
st.code(latent_walk_code_music, language='python')
elif screen == make_meme_menu:
if "pigeon" not in st.session_state:
st.session_state['pigeon'] = generate(model,1)[0]
def get_pigeon():
st.session_state['pigeon'] = generate(model,1)[0]
cols= st.columns(2)
cols[0].button("change pigeon",on_click=get_pigeon)
no_bg=cols[1].checkbox("Remove background?",True,help="Remove the background from pigeon")
show_text=cols[1].checkbox("Show text?",True)
meme_text=st.text_input("Enter text","Is this a pigeon?")
meme=make_meme(st.session_state['pigeon'],text=meme_text,show_text=show_text,remove_background=no_bg)
st.image(meme)
coly=st.columns(2)
coly[0].download_button("Download", img2download(meme),mime="image/jpeg")
coly[1].write("Made a cool one? [Share](https://twitter.com/intent/tweet?text=Check%20out%20the%20demo%20for%20Butterfly%20GAN%20%F0%9F%A6%8Bhttps%3A//huggingface.co/spaces/huggan/butterfly-gan%0Amade%20by%20%40ceyda_cinarel%20%26%20%40johnowhitaker%20) on Twitter")
elif screen == mosaic_menu:
cols=st.columns(2)
cols[0].markdown("These are all the butterflies in our [training set](https://huggingface.co/huggan/smithsonian_butterflies_subset)")
cols[0].image("assets/train_data_mosaic_lowres.jpg")
cols[0].write("πŸ”Ž view the high-res version [here](https://www.easyzoom.com/imageaccess/0c77e0e716f14ea7bc235447e5a4c397)")
cols[1].markdown("These are the butterflies our model generated.")
cols[1].image("assets/gen_mosaic_lowres.jpg")
cols[1].write("πŸ”Ž view the high-res version [here](https://www.easyzoom.com/imageaccess/cbb04e81106c4c54a9d9f9dbfb236eab)")
elif screen == fun_menu:
cols=st.columns([1,2])
cols[0].write("While working on this project")
cols[0].image("assets/butterflies_everywhere.jpg")
with cols[1]:
release_butterflies("Hello World")
# footer stuff
st.sidebar.caption(f"[Model](https://huggingface.co/ceyda/butterfly_cropped_uniq1K_512) & [Dataset](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) used")
# Link project repo( scripts etc )
# Credits
st.sidebar.caption(f"Made during the [huggan](https://github.com/huggingface/community-events) hackathon")
st.sidebar.caption(f"Contributors:")
st.sidebar.caption(f"[Ceyda Cinarel](https://github.com/cceyda) & [Jonathan Whitaker](https://datasciencecastnet.home.blog/)")
## Feel free to add more & change stuff ^