butterfly-gan / app.py
Ceyda Cinarel
make demo prettier half way there
47cfe13
raw
history blame
4.31 kB
from pydoc import ModuleScanner
import re
import streamlit as st # HF spaces at v1.2.0
from demo import load_model,generate,get_dataset,embed,make_meme
from PIL import Image
import numpy as np
# TODOs
# Add markdown short readme project intro
st.sidebar.subheader("This butterfly does not exist! ")
st.sidebar.image("assets/logo.png", width=200)
st.header("ButterflyGAN")
st.write("Demo prep still in progress!! Come back later")
@st.experimental_singleton
def load_model_intocache(model_name,model_version):
# model_name='ceyda/butterfly_512_base'
gan = load_model(model_name,model_version)
return gan
@st.experimental_singleton
def load_dataset():
dataset=get_dataset()
return dataset
model_name='ceyda/butterfly_cropped_uniq1K_512'
# model_version='0edac54b81958b82ce9fd5c1f688c33ac8e4f223'
model_version=None ##TBD
model=load_model_intocache(model_name,model_version)
dataset=load_dataset()
generate_menu="πŸ¦‹ Make butterflies"
latent_walk_menu="🎧 Take a latent walk"
make_meme_menu="🐦 Make a meme"
mosaic_menu="πŸ‘€ See the mosaic"
screen = st.sidebar.radio("Pick a destination",[generate_menu,latent_walk_menu,make_meme_menu,mosaic_menu])
if screen == generate_menu:
batch_size=4 #generate 4 butterflies
col_num=4
def run():
with st.spinner("Generating..."):
ims=generate(model,batch_size)
st.session_state['ims'] = ims
if 'ims' not in st.session_state:
st.session_state['ims'] = None
run()
ims=st.session_state["ims"]
runb=st.button("Generate", on_click=run)
if ims is not None:
cols=st.columns(col_num)
picks=[False]*batch_size
for j,im in enumerate(ims):
i=j%col_num
cols[i].image(im)
picks[j]=cols[i].button("Find Nearest",key="pick_"+str(j))
# meme_it=cols[i].button("What is this?",key="meme_"+str(j))
# if meme_it:
# no_bg=st.checkbox("Remove background?",True)
# meme_text=st.text_input("Meme text","Is this a pigeon?")
# meme=make_meme(im,text=meme_text,show_text=True,remove_background=no_bg)
# st.image(meme)
# if picks[j]:
# scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(im), k=5)
# for r in retrieved_examples["image"]:
# st.image(r)
if any(picks):
# st.write("Nearest butterflies:")
for i,pick in enumerate(picks):
if pick:
scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
for r in retrieved_examples["image"]:
cols[i].image(r)
st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}")
elif screen == latent_walk_menu:
st.write("Take a latent walk :musical_note:")
cols=st.columns(3)
cols[0].video("assets/latent_walks/regular_walk.mp4")
cols[0].caption("Regular walk")
cols[1].video("assets/latent_walks/walk_happyrock.mp4")
cols[1].caption("walk with music :butterfly:")
cols[2].video("assets/latent_walks/walk_cute.mp4")
cols[2].caption(":musical_note: walk with cute butterflies")
cols[1].caption("Royalty Free Music from Bensound")
elif screen == make_meme_menu:
im = generate(model,1)[0]
no_bg=st.checkbox("Remove background?",True)
meme_text=st.text_input("Meme text","Is this a pigeon?")
meme=make_meme(im,text=meme_text,show_text=True,remove_background=no_bg)
st.image(meme)
elif screen == mosaic_menu:
st.markdown("Todo add explanation about data")
st.image("assets/training_data_lowres.png")
# footer stuff
st.sidebar.caption(f"[Model](https://huggingface.co/ceyda/butterfly_cropped_uniq1K_512) & [Dataset](https://huggingface.co/huggan/smithsonian_butterflies_subset) used")
# Link project repo( scripts etc )
# Credits
st.sidebar.caption(f"Made during the [huggan](https://github.com/huggingface/community-events) hackathon")
st.sidebar.caption(f"Contributors:")
st.sidebar.caption(f"[Ceyda Cinarel](https://huggingface.co/ceyda) & [Jonathan Whitaker](https://datasciencecastnet.home.blog/)")
## Feel free to add more & change stuff ^