Spaces:
Runtime error
Runtime error
espejelomar
commited on
Commit
•
48fb736
1
Parent(s):
51d63fe
first draft
Browse files- README.md +8 -8
- app.py +178 -0
- assets/logo.png +0 -0
- packages.txt +1 -0
- requirements.txt +5 -0
- utils.py +107 -0
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
-
license:
|
11 |
---
|
12 |
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces
|
|
|
1 |
---
|
2 |
+
title: Butterfly GAN
|
3 |
+
emoji: 🦋
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: yellow
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.2.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
app.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from distutils.command.build import build
|
2 |
+
import streamlit as st # HF spaces at v1.2.0
|
3 |
+
# from utils import load_model,generate,get_dataset,embed,make_meme
|
4 |
+
from utils import load_model,generate
|
5 |
+
# import streamlit.components.v1 as components
|
6 |
+
# import io
|
7 |
+
# import os
|
8 |
+
|
9 |
+
# root_dir=os.path.dirname(os.path.abspath(__file__))
|
10 |
+
# build_dir = os.path.join(root_dir, "custom_component/frontend/build")
|
11 |
+
# _component_func = components.declare_component("release_butterflies", path=build_dir)
|
12 |
+
# def release_butterflies(name, key=None):
|
13 |
+
# component_value = _component_func(name=name, key=key, default=0)
|
14 |
+
# return component_value
|
15 |
+
|
16 |
+
|
17 |
+
## Configuración de nuestro demo
|
18 |
+
st.title("ButterflyGAN")
|
19 |
+
st.write("Light-GAN model trained on 1000 butterfly images taken from the Smithsonian Museum collection. \n \
|
20 |
+
Based on [paper:](https://openreview.net/forum?id=1Fqg133qRaI) *Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis*")
|
21 |
+
|
22 |
+
st.sidebar.subheader("This butterfly does not exist! ")
|
23 |
+
st.sidebar.image("assets/logo.png", width=200)
|
24 |
+
st.sidebar.caption(f"[Model](https://huggingface.co/ceyda/butterfly_cropped_uniq1K_512) & [Dataset](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) used")
|
25 |
+
st.sidebar.caption(f"Made during the [huggan](https://github.com/huggingface/community-events) hackathon")
|
26 |
+
st.sidebar.caption(f"Contributors:")
|
27 |
+
st.sidebar.caption(f"[Ceyda Cinarel](https://github.com/cceyda) & [Jonathan Whitaker](https://datasciencecastnet.home.blog/)")
|
28 |
+
|
29 |
+
|
30 |
+
## Cargamos modelo
|
31 |
+
repo_id = 'ceyda/butterfly_cropped_uniq1K_512'
|
32 |
+
version_modelo ='57d36a15546909557d9f967f47713236c8288838'
|
33 |
+
modelo_gan = load_model(repo_id, version_modelo)
|
34 |
+
|
35 |
+
# @st.experimental_singleton
|
36 |
+
# def load_model_intocache(model_name,model_version):
|
37 |
+
# # model_name='ceyda/butterfly_512_base'
|
38 |
+
# gan = load_model(model_name,model_version)
|
39 |
+
# return gan
|
40 |
+
|
41 |
+
# @st.experimental_singleton
|
42 |
+
# def load_dataset():
|
43 |
+
# dataset=get_dataset()
|
44 |
+
# return dataset
|
45 |
+
|
46 |
+
# @st.experimental_singleton
|
47 |
+
# def load_variables():# Don't want to open read files over and over. not sure if it makes a diff
|
48 |
+
# latent_walk_code=open("assets/code_snippets/latent_walk.py").read()
|
49 |
+
# latent_walk_code_music=open("assets/code_snippets/latent_walk_music.py").read()
|
50 |
+
# return latent_walk_code,latent_walk_code_music
|
51 |
+
|
52 |
+
# def img2download(image):
|
53 |
+
# imgByteArr = io.BytesIO()
|
54 |
+
# image.save(imgByteArr, format="JPEG")
|
55 |
+
# imgByteArr = imgByteArr.getvalue()
|
56 |
+
# return imgByteArr
|
57 |
+
|
58 |
+
# model_name='ceyda/butterfly_cropped_uniq1K_512'
|
59 |
+
# model_version='57d36a15546909557d9f967f47713236c8288838'
|
60 |
+
# model_version=None
|
61 |
+
|
62 |
+
|
63 |
+
# model=load_model_intocache(model_name,model_version)
|
64 |
+
# dataset=loadk_dataset()
|
65 |
+
# latent_walk_code, latent_walk_code_music=load_variables()
|
66 |
+
|
67 |
+
# generate_menu="🦋 Make butterflies"
|
68 |
+
# latent_walk_menu="🎧 Take a latent walk"
|
69 |
+
# make_meme_menu="🐦 Make a meme"
|
70 |
+
# mosaic_menu="👀 See the mosaic"
|
71 |
+
# fun_menu="🙌 Release the butterflies"
|
72 |
+
|
73 |
+
# screen = st.sidebar.radio("Pick a destination",[generate_menu,latent_walk_menu,make_meme_menu,mosaic_menu,fun_menu])
|
74 |
+
|
75 |
+
## Generamos 4 mariposas
|
76 |
+
n_mariposas =4
|
77 |
+
|
78 |
+
## Función que genera mariposas y lo guarda como un estado de la sesión
|
79 |
+
def corre():
|
80 |
+
with st.spinner("Generando, espera un poco..."):
|
81 |
+
ims=generate(modelo_gan,n_mariposas)
|
82 |
+
st.session_state['ims'] = ims
|
83 |
+
|
84 |
+
## Si no hay una imagen generada entonces generala
|
85 |
+
if 'ims' not in st.session_state:
|
86 |
+
st.session_state['ims'] = None
|
87 |
+
corre()
|
88 |
+
|
89 |
+
## ims contiene las imágenes generadas
|
90 |
+
ims=st.session_state["ims"]
|
91 |
+
|
92 |
+
## Si la usuaria da click en el botón entonces corremos la función genera()
|
93 |
+
runb=st.button("Genera mariposas por favor", on_click=corre ,help="generated on the fly maybe slow")
|
94 |
+
|
95 |
+
if ims is not None:
|
96 |
+
cols=st.columns(n_mariposas)
|
97 |
+
# picks=[False]*n_mariposas
|
98 |
+
for j,im in enumerate(ims):
|
99 |
+
i=j%n_mariposas
|
100 |
+
cols[i].image(im, use_column_width=True)
|
101 |
+
# picks[j]=cols[i].button("Find Nearest",key="pick_"+str(j))
|
102 |
+
|
103 |
+
# if any(picks):
|
104 |
+
# # st.write("Nearest butterflies:")
|
105 |
+
# for i,pick in enumerate(picks):
|
106 |
+
# if pick:
|
107 |
+
# scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
|
108 |
+
# for r in retrieved_examples["image"]:
|
109 |
+
# cols[i].image(r, use_column_width=True)
|
110 |
+
# st.write("Nearest neighbors found in the training set according to L2 distance on 'microsoft/beit-base-patch16-224' embeddings")
|
111 |
+
# st.write(f"Latent dimension: {model.latent_dim}, image size:{model.image_size}")
|
112 |
+
|
113 |
+
# elif screen == latent_walk_menu:
|
114 |
+
|
115 |
+
# st.write("Take a latent walk :musical_note: with cute butterflies")
|
116 |
+
|
117 |
+
# cols=st.columns(3)
|
118 |
+
|
119 |
+
# cols[0].caption("A regular walk (no music)")
|
120 |
+
# cols[0].video("assets/latent_walks/regular_walk.mp4")
|
121 |
+
|
122 |
+
# cols[1].caption("Walk with music :butterfly:")
|
123 |
+
# cols[1].video("assets/latent_walks/walk_happyrock.mp4")
|
124 |
+
# cols[2].caption("Walk with music :butterfly:")
|
125 |
+
# cols[2].video("assets/latent_walks/walk_cute.mp4")
|
126 |
+
|
127 |
+
# st.caption("Royalty Free Music from Bensound")
|
128 |
+
# st.write("🎧Did those butterflies seem to be dancing to the music?!Here is the secret:")
|
129 |
+
# with st.expander("See the Code Snippets"):
|
130 |
+
# st.write("A regular latent walk:")
|
131 |
+
# st.code(latent_walk_code, language='python')
|
132 |
+
# st.write(":musical_note: latent walk with music:")
|
133 |
+
# st.code(latent_walk_code_music, language='python')
|
134 |
+
|
135 |
+
|
136 |
+
# elif screen == make_meme_menu:
|
137 |
+
# if "pigeon" not in st.session_state:
|
138 |
+
# st.session_state['pigeon'] = generate(model,1)[0]
|
139 |
+
|
140 |
+
# def get_pigeon():
|
141 |
+
# st.session_state['pigeon'] = generate(model,1)[0]
|
142 |
+
|
143 |
+
# cols= st.columns(2)
|
144 |
+
# cols[0].button("change pigeon",on_click=get_pigeon)
|
145 |
+
# no_bg=cols[1].checkbox("Remove background?",True,help="Remove the background from pigeon")
|
146 |
+
# show_text=cols[1].checkbox("Show text?",True)
|
147 |
+
|
148 |
+
# meme_text=st.text_input("Enter text","Is this a pigeon?")
|
149 |
+
|
150 |
+
|
151 |
+
# meme=make_meme(st.session_state['pigeon'],text=meme_text,show_text=show_text,remove_background=no_bg)
|
152 |
+
# st.image(meme)
|
153 |
+
# coly=st.columns(2)
|
154 |
+
# coly[0].download_button("Download", img2download(meme),mime="image/jpeg")
|
155 |
+
# coly[1].write("Made a cool one? [Share](https://twitter.com/intent/tweet?text=Check%20out%20the%20demo%20for%20Butterfly%20GAN%20%F0%9F%A6%8Bhttps%3A//huggingface.co/spaces/huggan/butterfly-gan%0Amade%20by%20%40ceyda_cinarel%20%26%20%40johnowhitaker%20) on Twitter")
|
156 |
+
|
157 |
+
|
158 |
+
# elif screen == mosaic_menu:
|
159 |
+
# cols=st.columns(2)
|
160 |
+
# cols[0].markdown("These are all the butterflies in our [training set](https://huggingface.co/huggan/smithsonian_butterflies_subset)")
|
161 |
+
# cols[0].image("assets/train_data_mosaic_lowres.jpg")
|
162 |
+
# cols[0].write("🔎 view the high-res version [here](https://www.easyzoom.com/imageaccess/0c77e0e716f14ea7bc235447e5a4c397)")
|
163 |
+
|
164 |
+
# cols[1].markdown("These are the butterflies our model generated.")
|
165 |
+
# cols[1].image("assets/gen_mosaic_lowres.jpg")
|
166 |
+
# cols[1].write("🔎 view the high-res version [here](https://www.easyzoom.com/imageaccess/cbb04e81106c4c54a9d9f9dbfb236eab)")
|
167 |
+
|
168 |
+
# elif screen == fun_menu:
|
169 |
+
|
170 |
+
# cols=st.columns([1,2])
|
171 |
+
# cols[0].write("While working on this project")
|
172 |
+
# cols[0].image("assets/butterflies_everywhere.jpg")
|
173 |
+
|
174 |
+
# with cols[1]:
|
175 |
+
# release_butterflies("Hello World")
|
176 |
+
|
177 |
+
|
178 |
+
## Feel free to add more & change stuff ^
|
assets/logo.png
ADDED
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
libgl1
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/huggingface/community-events.git@3fea10c5d5a50c69f509e34cd580fe9139905d04#egg=huggan
|
2 |
+
transformers
|
3 |
+
faiss-cpu
|
4 |
+
paddlehub
|
5 |
+
paddlepaddle
|
utils.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
|
3 |
+
from datasets import load_dataset
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import paddlehub as hub
|
7 |
+
import random
|
8 |
+
from PIL import ImageDraw,ImageFont
|
9 |
+
|
10 |
+
import streamlit as st
|
11 |
+
|
12 |
+
@st.experimental_singleton
|
13 |
+
def load_bg_model():
|
14 |
+
bg_model = hub.Module(name='U2NetP', directory='assets/models/')
|
15 |
+
return bg_model
|
16 |
+
|
17 |
+
|
18 |
+
bg_model = load_bg_model()
|
19 |
+
def remove_bg(img):
|
20 |
+
result = bg_model.Segmentation(
|
21 |
+
images=[np.array(img)[:,:,::-1]],
|
22 |
+
paths=None,
|
23 |
+
batch_size=1,
|
24 |
+
input_size=320,
|
25 |
+
output_dir=None,
|
26 |
+
visualization=False)
|
27 |
+
output = result[0]
|
28 |
+
mask=Image.fromarray(output['mask'])
|
29 |
+
front=Image.fromarray(output['front'][:,:,::-1]).convert("RGBA")
|
30 |
+
front.putalpha(mask)
|
31 |
+
return front
|
32 |
+
|
33 |
+
meme_template=Image.open("./assets/pigeon_meme.jpg").convert("RGBA")
|
34 |
+
def make_meme(pigeon,text="Is this a pigeon?",show_text=True,remove_background=True):
|
35 |
+
|
36 |
+
meme=meme_template.copy()
|
37 |
+
approx_butterfly_center=(850,30)
|
38 |
+
|
39 |
+
if remove_background:
|
40 |
+
pigeon=remove_bg(pigeon)
|
41 |
+
|
42 |
+
else:
|
43 |
+
pigeon=Image.fromarray(pigeon).convert("RGBA")
|
44 |
+
|
45 |
+
random_rotate=random.randint(-30,30)
|
46 |
+
random_size=random.randint(150,200)
|
47 |
+
pigeon=pigeon.resize((random_size,random_size)).rotate(random_rotate,expand=True)
|
48 |
+
|
49 |
+
meme.alpha_composite(pigeon, approx_butterfly_center)
|
50 |
+
|
51 |
+
#ref: https://blog.lipsumarium.com/caption-memes-in-python/
|
52 |
+
def drawTextWithOutline(text, x, y):
|
53 |
+
draw.text((x-2, y-2), text,(0,0,0),font=font)
|
54 |
+
draw.text((x+2, y-2), text,(0,0,0),font=font)
|
55 |
+
draw.text((x+2, y+2), text,(0,0,0),font=font)
|
56 |
+
draw.text((x-2, y+2), text,(0,0,0),font=font)
|
57 |
+
draw.text((x, y), text, (255,255,255), font=font)
|
58 |
+
|
59 |
+
if show_text:
|
60 |
+
draw = ImageDraw.Draw(meme)
|
61 |
+
font_size=52
|
62 |
+
font = ImageFont.truetype("assets/impact.ttf", font_size)
|
63 |
+
w, h = draw.textsize(text, font) # measure the size the text will take
|
64 |
+
drawTextWithOutline(text, meme.width/2 - w/2, meme.height - font_size*2)
|
65 |
+
meme = meme.convert("RGB")
|
66 |
+
return meme
|
67 |
+
|
68 |
+
def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"):
|
69 |
+
dataset=load_dataset(dataset_name)
|
70 |
+
dataset=dataset.sort("sim_score")
|
71 |
+
return dataset["train"]
|
72 |
+
|
73 |
+
from transformers import BeitFeatureExtractor, BeitForImageClassification
|
74 |
+
emb_feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
|
75 |
+
emb_model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
|
76 |
+
def embed(images):
|
77 |
+
inputs = emb_feature_extractor(images=images, return_tensors="pt")
|
78 |
+
outputs = emb_model(**inputs,output_hidden_states= True)
|
79 |
+
last_hidden=outputs.hidden_states[-1]
|
80 |
+
pooler=emb_model.base_model.pooler
|
81 |
+
final_emb=pooler(last_hidden).detach().numpy()
|
82 |
+
return final_emb
|
83 |
+
|
84 |
+
def build_index():
|
85 |
+
dataset=get_train_data()
|
86 |
+
ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20)
|
87 |
+
ds_with_embeddings.add_faiss_index(column='beit_embeddings')
|
88 |
+
ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss')
|
89 |
+
|
90 |
+
def get_dataset():
|
91 |
+
dataset=get_train_data()
|
92 |
+
dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
|
93 |
+
return dataset
|
94 |
+
|
95 |
+
def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512',model_version=None):
|
96 |
+
gan = LightweightGAN.from_pretrained(model_name,version=model_version)
|
97 |
+
gan.eval()
|
98 |
+
return gan
|
99 |
+
|
100 |
+
def generate(gan,batch_size=1):
|
101 |
+
with torch.no_grad():
|
102 |
+
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255
|
103 |
+
ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
|
104 |
+
return ims
|
105 |
+
|
106 |
+
def interpolate():
|
107 |
+
pass
|