espejelomar commited on
Commit
48fb736
1 Parent(s): 51d63fe

first draft

Browse files
Files changed (6) hide show
  1. README.md +8 -8
  2. app.py +178 -0
  3. assets/logo.png +0 -0
  4. packages.txt +1 -0
  5. requirements.txt +5 -0
  6. utils.py +107 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Platzi Curso Streamlit Butterfly Gan
3
- emoji: 📉
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: streamlit
7
- sdk_version: 1.10.0
8
  app_file: app.py
9
- pinned: false
10
- license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Butterfly GAN
3
+ emoji: 🦋
4
+ colorFrom: blue
5
+ colorTo: yellow
6
  sdk: streamlit
7
+ sdk_version: 1.2.0
8
  app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from distutils.command.build import build
2
+ import streamlit as st # HF spaces at v1.2.0
3
+ # from utils import load_model,generate,get_dataset,embed,make_meme
4
+ from utils import load_model,generate
5
+ # import streamlit.components.v1 as components
6
+ # import io
7
+ # import os
8
+
9
+ # root_dir=os.path.dirname(os.path.abspath(__file__))
10
+ # build_dir = os.path.join(root_dir, "custom_component/frontend/build")
11
+ # _component_func = components.declare_component("release_butterflies", path=build_dir)
12
+ # def release_butterflies(name, key=None):
13
+ # component_value = _component_func(name=name, key=key, default=0)
14
+ # return component_value
15
+
16
+
17
+ ## Configuración de nuestro demo
18
+ st.title("ButterflyGAN")
19
+ st.write("Light-GAN model trained on 1000 butterfly images taken from the Smithsonian Museum collection. \n \
20
+ Based on [paper:](https://openreview.net/forum?id=1Fqg133qRaI) *Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis*")
21
+
22
+ st.sidebar.subheader("This butterfly does not exist! ")
23
+ st.sidebar.image("assets/logo.png", width=200)
24
+ st.sidebar.caption(f"[Model](https://huggingface.co/ceyda/butterfly_cropped_uniq1K_512) & [Dataset](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) used")
25
+ st.sidebar.caption(f"Made during the [huggan](https://github.com/huggingface/community-events) hackathon")
26
+ st.sidebar.caption(f"Contributors:")
27
+ st.sidebar.caption(f"[Ceyda Cinarel](https://github.com/cceyda) & [Jonathan Whitaker](https://datasciencecastnet.home.blog/)")
28
+
29
+
30
+ ## Cargamos modelo
31
+ repo_id = 'ceyda/butterfly_cropped_uniq1K_512'
32
+ version_modelo ='57d36a15546909557d9f967f47713236c8288838'
33
+ modelo_gan = load_model(repo_id, version_modelo)
34
+
35
+ # @st.experimental_singleton
36
+ # def load_model_intocache(model_name,model_version):
37
+ # # model_name='ceyda/butterfly_512_base'
38
+ # gan = load_model(model_name,model_version)
39
+ # return gan
40
+
41
+ # @st.experimental_singleton
42
+ # def load_dataset():
43
+ # dataset=get_dataset()
44
+ # return dataset
45
+
46
+ # @st.experimental_singleton
47
+ # def load_variables():# Don't want to open read files over and over. not sure if it makes a diff
48
+ # latent_walk_code=open("assets/code_snippets/latent_walk.py").read()
49
+ # latent_walk_code_music=open("assets/code_snippets/latent_walk_music.py").read()
50
+ # return latent_walk_code,latent_walk_code_music
51
+
52
+ # def img2download(image):
53
+ # imgByteArr = io.BytesIO()
54
+ # image.save(imgByteArr, format="JPEG")
55
+ # imgByteArr = imgByteArr.getvalue()
56
+ # return imgByteArr
57
+
58
+ # model_name='ceyda/butterfly_cropped_uniq1K_512'
59
+ # model_version='57d36a15546909557d9f967f47713236c8288838'
60
+ # model_version=None
61
+
62
+
63
+ # model=load_model_intocache(model_name,model_version)
64
+ # dataset=loadk_dataset()
65
+ # latent_walk_code, latent_walk_code_music=load_variables()
66
+
67
+ # generate_menu="🦋 Make butterflies"
68
+ # latent_walk_menu="🎧 Take a latent walk"
69
+ # make_meme_menu="🐦 Make a meme"
70
+ # mosaic_menu="👀 See the mosaic"
71
+ # fun_menu="🙌 Release the butterflies"
72
+
73
+ # screen = st.sidebar.radio("Pick a destination",[generate_menu,latent_walk_menu,make_meme_menu,mosaic_menu,fun_menu])
74
+
75
+ ## Generamos 4 mariposas
76
+ n_mariposas =4
77
+
78
+ ## Función que genera mariposas y lo guarda como un estado de la sesión
79
+ def corre():
80
+ with st.spinner("Generando, espera un poco..."):
81
+ ims=generate(modelo_gan,n_mariposas)
82
+ st.session_state['ims'] = ims
83
+
84
+ ## Si no hay una imagen generada entonces generala
85
+ if 'ims' not in st.session_state:
86
+ st.session_state['ims'] = None
87
+ corre()
88
+
89
+ ## ims contiene las imágenes generadas
90
+ ims=st.session_state["ims"]
91
+
92
+ ## Si la usuaria da click en el botón entonces corremos la función genera()
93
+ runb=st.button("Genera mariposas por favor", on_click=corre ,help="generated on the fly maybe slow")
94
+
95
+ if ims is not None:
96
+ cols=st.columns(n_mariposas)
97
+ # picks=[False]*n_mariposas
98
+ for j,im in enumerate(ims):
99
+ i=j%n_mariposas
100
+ cols[i].image(im, use_column_width=True)
101
+ # picks[j]=cols[i].button("Find Nearest",key="pick_"+str(j))
102
+
103
+ # if any(picks):
104
+ # # st.write("Nearest butterflies:")
105
+ # for i,pick in enumerate(picks):
106
+ # if pick:
107
+ # scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(ims[i]), k=5)
108
+ # for r in retrieved_examples["image"]:
109
+ # cols[i].image(r, use_column_width=True)
110
+ # st.write("Nearest neighbors found in the training set according to L2 distance on 'microsoft/beit-base-patch16-224' embeddings")
111
+ # st.write(f"Latent dimension: {model.latent_dim}, image size:{model.image_size}")
112
+
113
+ # elif screen == latent_walk_menu:
114
+
115
+ # st.write("Take a latent walk :musical_note: with cute butterflies")
116
+
117
+ # cols=st.columns(3)
118
+
119
+ # cols[0].caption("A regular walk (no music)")
120
+ # cols[0].video("assets/latent_walks/regular_walk.mp4")
121
+
122
+ # cols[1].caption("Walk with music :butterfly:")
123
+ # cols[1].video("assets/latent_walks/walk_happyrock.mp4")
124
+ # cols[2].caption("Walk with music :butterfly:")
125
+ # cols[2].video("assets/latent_walks/walk_cute.mp4")
126
+
127
+ # st.caption("Royalty Free Music from Bensound")
128
+ # st.write("🎧Did those butterflies seem to be dancing to the music?!Here is the secret:")
129
+ # with st.expander("See the Code Snippets"):
130
+ # st.write("A regular latent walk:")
131
+ # st.code(latent_walk_code, language='python')
132
+ # st.write(":musical_note: latent walk with music:")
133
+ # st.code(latent_walk_code_music, language='python')
134
+
135
+
136
+ # elif screen == make_meme_menu:
137
+ # if "pigeon" not in st.session_state:
138
+ # st.session_state['pigeon'] = generate(model,1)[0]
139
+
140
+ # def get_pigeon():
141
+ # st.session_state['pigeon'] = generate(model,1)[0]
142
+
143
+ # cols= st.columns(2)
144
+ # cols[0].button("change pigeon",on_click=get_pigeon)
145
+ # no_bg=cols[1].checkbox("Remove background?",True,help="Remove the background from pigeon")
146
+ # show_text=cols[1].checkbox("Show text?",True)
147
+
148
+ # meme_text=st.text_input("Enter text","Is this a pigeon?")
149
+
150
+
151
+ # meme=make_meme(st.session_state['pigeon'],text=meme_text,show_text=show_text,remove_background=no_bg)
152
+ # st.image(meme)
153
+ # coly=st.columns(2)
154
+ # coly[0].download_button("Download", img2download(meme),mime="image/jpeg")
155
+ # coly[1].write("Made a cool one? [Share](https://twitter.com/intent/tweet?text=Check%20out%20the%20demo%20for%20Butterfly%20GAN%20%F0%9F%A6%8Bhttps%3A//huggingface.co/spaces/huggan/butterfly-gan%0Amade%20by%20%40ceyda_cinarel%20%26%20%40johnowhitaker%20) on Twitter")
156
+
157
+
158
+ # elif screen == mosaic_menu:
159
+ # cols=st.columns(2)
160
+ # cols[0].markdown("These are all the butterflies in our [training set](https://huggingface.co/huggan/smithsonian_butterflies_subset)")
161
+ # cols[0].image("assets/train_data_mosaic_lowres.jpg")
162
+ # cols[0].write("🔎 view the high-res version [here](https://www.easyzoom.com/imageaccess/0c77e0e716f14ea7bc235447e5a4c397)")
163
+
164
+ # cols[1].markdown("These are the butterflies our model generated.")
165
+ # cols[1].image("assets/gen_mosaic_lowres.jpg")
166
+ # cols[1].write("🔎 view the high-res version [here](https://www.easyzoom.com/imageaccess/cbb04e81106c4c54a9d9f9dbfb236eab)")
167
+
168
+ # elif screen == fun_menu:
169
+
170
+ # cols=st.columns([1,2])
171
+ # cols[0].write("While working on this project")
172
+ # cols[0].image("assets/butterflies_everywhere.jpg")
173
+
174
+ # with cols[1]:
175
+ # release_butterflies("Hello World")
176
+
177
+
178
+ ## Feel free to add more & change stuff ^
assets/logo.png ADDED
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libgl1
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/community-events.git@3fea10c5d5a50c69f509e34cd580fe9139905d04#egg=huggan
2
+ transformers
3
+ faiss-cpu
4
+ paddlehub
5
+ paddlepaddle
utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
3
+ from datasets import load_dataset
4
+ from PIL import Image
5
+ import numpy as np
6
+ import paddlehub as hub
7
+ import random
8
+ from PIL import ImageDraw,ImageFont
9
+
10
+ import streamlit as st
11
+
12
+ @st.experimental_singleton
13
+ def load_bg_model():
14
+ bg_model = hub.Module(name='U2NetP', directory='assets/models/')
15
+ return bg_model
16
+
17
+
18
+ bg_model = load_bg_model()
19
+ def remove_bg(img):
20
+ result = bg_model.Segmentation(
21
+ images=[np.array(img)[:,:,::-1]],
22
+ paths=None,
23
+ batch_size=1,
24
+ input_size=320,
25
+ output_dir=None,
26
+ visualization=False)
27
+ output = result[0]
28
+ mask=Image.fromarray(output['mask'])
29
+ front=Image.fromarray(output['front'][:,:,::-1]).convert("RGBA")
30
+ front.putalpha(mask)
31
+ return front
32
+
33
+ meme_template=Image.open("./assets/pigeon_meme.jpg").convert("RGBA")
34
+ def make_meme(pigeon,text="Is this a pigeon?",show_text=True,remove_background=True):
35
+
36
+ meme=meme_template.copy()
37
+ approx_butterfly_center=(850,30)
38
+
39
+ if remove_background:
40
+ pigeon=remove_bg(pigeon)
41
+
42
+ else:
43
+ pigeon=Image.fromarray(pigeon).convert("RGBA")
44
+
45
+ random_rotate=random.randint(-30,30)
46
+ random_size=random.randint(150,200)
47
+ pigeon=pigeon.resize((random_size,random_size)).rotate(random_rotate,expand=True)
48
+
49
+ meme.alpha_composite(pigeon, approx_butterfly_center)
50
+
51
+ #ref: https://blog.lipsumarium.com/caption-memes-in-python/
52
+ def drawTextWithOutline(text, x, y):
53
+ draw.text((x-2, y-2), text,(0,0,0),font=font)
54
+ draw.text((x+2, y-2), text,(0,0,0),font=font)
55
+ draw.text((x+2, y+2), text,(0,0,0),font=font)
56
+ draw.text((x-2, y+2), text,(0,0,0),font=font)
57
+ draw.text((x, y), text, (255,255,255), font=font)
58
+
59
+ if show_text:
60
+ draw = ImageDraw.Draw(meme)
61
+ font_size=52
62
+ font = ImageFont.truetype("assets/impact.ttf", font_size)
63
+ w, h = draw.textsize(text, font) # measure the size the text will take
64
+ drawTextWithOutline(text, meme.width/2 - w/2, meme.height - font_size*2)
65
+ meme = meme.convert("RGB")
66
+ return meme
67
+
68
+ def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"):
69
+ dataset=load_dataset(dataset_name)
70
+ dataset=dataset.sort("sim_score")
71
+ return dataset["train"]
72
+
73
+ from transformers import BeitFeatureExtractor, BeitForImageClassification
74
+ emb_feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
75
+ emb_model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
76
+ def embed(images):
77
+ inputs = emb_feature_extractor(images=images, return_tensors="pt")
78
+ outputs = emb_model(**inputs,output_hidden_states= True)
79
+ last_hidden=outputs.hidden_states[-1]
80
+ pooler=emb_model.base_model.pooler
81
+ final_emb=pooler(last_hidden).detach().numpy()
82
+ return final_emb
83
+
84
+ def build_index():
85
+ dataset=get_train_data()
86
+ ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20)
87
+ ds_with_embeddings.add_faiss_index(column='beit_embeddings')
88
+ ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss')
89
+
90
+ def get_dataset():
91
+ dataset=get_train_data()
92
+ dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
93
+ return dataset
94
+
95
+ def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512',model_version=None):
96
+ gan = LightweightGAN.from_pretrained(model_name,version=model_version)
97
+ gan.eval()
98
+ return gan
99
+
100
+ def generate(gan,batch_size=1):
101
+ with torch.no_grad():
102
+ ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255
103
+ ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
104
+ return ims
105
+
106
+ def interpolate():
107
+ pass