Ceyda Cinarel commited on
Commit
47cfe13
1 Parent(s): cb5f8d1

make demo prettier half way there

Browse files
.gitattributes CHANGED
@@ -27,3 +27,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
29
  *.faiss filter=lfs diff=lfs merge=lfs -text
 
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
29
  *.faiss filter=lfs diff=lfs merge=lfs -text
30
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,22 +1,25 @@
 
1
  import re
2
  import streamlit as st # HF spaces at v1.2.0
3
- from demo import load_model,generate,get_dataset,embed
4
-
 
5
  # TODOs
6
  # Add markdown short readme project intro
7
 
8
 
9
- st.sidebar.image("assets/logo.png", use_column_width=True)
 
 
10
  st.header("ButterflyGAN")
11
- st.caption("This butterfly does not exist! ")
12
- st.write("Demo prep still in progress!!")
13
 
14
 
15
  @st.experimental_singleton
16
- def load_model_intocache(model_name):
17
-
18
  # model_name='ceyda/butterfly_512_base'
19
- gan = load_model(model_name)
20
  return gan
21
 
22
  @st.experimental_singleton
@@ -25,33 +28,46 @@ def load_dataset():
25
  return dataset
26
 
27
  model_name='ceyda/butterfly_cropped_uniq1K_512'
28
- model=load_model_intocache(model_name)
 
 
29
  dataset=load_dataset()
30
 
31
- screen = st.sidebar.radio("Pick a destination",["Make butterflies","Take a latent walk", "See the data mosaic"])
 
 
 
32
 
33
- if screen == "Make butterflies":
34
 
 
35
 
36
-
37
- if 'ims' not in st.session_state:
38
- st.session_state['ims'] = None
39
-
40
- ims=st.session_state["ims"]
41
  batch_size=4 #generate 4 butterflies
 
42
  def run():
43
  with st.spinner("Generating..."):
44
  ims=generate(model,batch_size)
45
  st.session_state['ims'] = ims
46
-
 
 
 
 
47
  runb=st.button("Generate", on_click=run)
48
  if ims is not None:
49
- cols=st.columns(batch_size)
50
  picks=[False]*batch_size
51
- for i,im in enumerate(ims):
 
52
  cols[i].image(im)
53
- picks[i]=cols[i].button("Find Nearest",key="pick_"+str(i))
54
- # if picks[i]:
 
 
 
 
 
 
55
  # scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(im), k=5)
56
  # for r in retrieved_examples["image"]:
57
  # st.image(r)
@@ -66,13 +82,40 @@ if screen == "Make butterflies":
66
 
67
  st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}")
68
 
69
- elif screen == "Take a latent walk":
70
- st.write("Take a latent walk")
 
 
 
 
 
 
 
 
 
 
71
 
72
- elif screen == "Input data mosaic":
 
 
 
 
 
 
 
 
 
73
  st.markdown("Todo add explanation about data")
74
  st.image("assets/training_data_lowres.png")
75
 
76
 
77
  # footer stuff
78
- st.sidebar.info(f"Model {model_name} is loaded")
 
 
 
 
 
 
 
 
1
+ from pydoc import ModuleScanner
2
  import re
3
  import streamlit as st # HF spaces at v1.2.0
4
+ from demo import load_model,generate,get_dataset,embed,make_meme
5
+ from PIL import Image
6
+ import numpy as np
7
  # TODOs
8
  # Add markdown short readme project intro
9
 
10
 
11
+ st.sidebar.subheader("This butterfly does not exist! ")
12
+ st.sidebar.image("assets/logo.png", width=200)
13
+
14
  st.header("ButterflyGAN")
15
+
16
+ st.write("Demo prep still in progress!! Come back later")
17
 
18
 
19
  @st.experimental_singleton
20
+ def load_model_intocache(model_name,model_version):
 
21
  # model_name='ceyda/butterfly_512_base'
22
+ gan = load_model(model_name,model_version)
23
  return gan
24
 
25
  @st.experimental_singleton
28
  return dataset
29
 
30
  model_name='ceyda/butterfly_cropped_uniq1K_512'
31
+ # model_version='0edac54b81958b82ce9fd5c1f688c33ac8e4f223'
32
+ model_version=None ##TBD
33
+ model=load_model_intocache(model_name,model_version)
34
  dataset=load_dataset()
35
 
36
+ generate_menu="🦋 Make butterflies"
37
+ latent_walk_menu="🎧 Take a latent walk"
38
+ make_meme_menu="🐦 Make a meme"
39
+ mosaic_menu="👀 See the mosaic"
40
 
41
+ screen = st.sidebar.radio("Pick a destination",[generate_menu,latent_walk_menu,make_meme_menu,mosaic_menu])
42
 
43
+ if screen == generate_menu:
44
 
 
 
 
 
 
45
  batch_size=4 #generate 4 butterflies
46
+ col_num=4
47
  def run():
48
  with st.spinner("Generating..."):
49
  ims=generate(model,batch_size)
50
  st.session_state['ims'] = ims
51
+ if 'ims' not in st.session_state:
52
+ st.session_state['ims'] = None
53
+ run()
54
+ ims=st.session_state["ims"]
55
+
56
  runb=st.button("Generate", on_click=run)
57
  if ims is not None:
58
+ cols=st.columns(col_num)
59
  picks=[False]*batch_size
60
+ for j,im in enumerate(ims):
61
+ i=j%col_num
62
  cols[i].image(im)
63
+ picks[j]=cols[i].button("Find Nearest",key="pick_"+str(j))
64
+ # meme_it=cols[i].button("What is this?",key="meme_"+str(j))
65
+ # if meme_it:
66
+ # no_bg=st.checkbox("Remove background?",True)
67
+ # meme_text=st.text_input("Meme text","Is this a pigeon?")
68
+ # meme=make_meme(im,text=meme_text,show_text=True,remove_background=no_bg)
69
+ # st.image(meme)
70
+ # if picks[j]:
71
  # scores, retrieved_examples=dataset.get_nearest_examples('beit_embeddings', embed(im), k=5)
72
  # for r in retrieved_examples["image"]:
73
  # st.image(r)
82
 
83
  st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}")
84
 
85
+ elif screen == latent_walk_menu:
86
+ st.write("Take a latent walk :musical_note:")
87
+
88
+ cols=st.columns(3)
89
+
90
+ cols[0].video("assets/latent_walks/regular_walk.mp4")
91
+ cols[0].caption("Regular walk")
92
+ cols[1].video("assets/latent_walks/walk_happyrock.mp4")
93
+ cols[1].caption("walk with music :butterfly:")
94
+ cols[2].video("assets/latent_walks/walk_cute.mp4")
95
+ cols[2].caption(":musical_note: walk with cute butterflies")
96
+ cols[1].caption("Royalty Free Music from Bensound")
97
 
98
+
99
+ elif screen == make_meme_menu:
100
+ im = generate(model,1)[0]
101
+ no_bg=st.checkbox("Remove background?",True)
102
+ meme_text=st.text_input("Meme text","Is this a pigeon?")
103
+ meme=make_meme(im,text=meme_text,show_text=True,remove_background=no_bg)
104
+ st.image(meme)
105
+
106
+
107
+ elif screen == mosaic_menu:
108
  st.markdown("Todo add explanation about data")
109
  st.image("assets/training_data_lowres.png")
110
 
111
 
112
  # footer stuff
113
+ st.sidebar.caption(f"[Model](https://huggingface.co/ceyda/butterfly_cropped_uniq1K_512) & [Dataset](https://huggingface.co/huggan/smithsonian_butterflies_subset) used")
114
+ # Link project repo( scripts etc )
115
+
116
+ # Credits
117
+ st.sidebar.caption(f"Made during the [huggan](https://github.com/huggingface/community-events) hackathon")
118
+ st.sidebar.caption(f"Contributors:")
119
+ st.sidebar.caption(f"[Ceyda Cinarel](https://huggingface.co/ceyda) & [Jonathan Whitaker](https://datasciencecastnet.home.blog/)")
120
+
121
+ ## Feel free to add more & change stuff ^
assets/impact.ttf ADDED
Binary file (47.6 kB). View file
assets/latent_walks/regular_walk.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbf4a07057e77a05e3aa2acc5c219425f46758f09535fee44a0e6e48363d5078
3
+ size 1736391
assets/latent_walks/walk_cute.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:babff5726bfd81353959587c84ea8dab4d485c1853850b0119abc7a23ed12e11
3
+ size 7637184
assets/latent_walks/walk_happyrock.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5219340b5fe3e509f02e83a0f0c972bd8f0ecd76df4019cedb0abe373b0fb5e8
3
+ size 6594393
assets/mosaic_bg.png ADDED
assets/outputs/example_output.jpg ADDED
assets/outputs/output2.jpg ADDED
assets/pigeon_meme.jpg ADDED
assets/pigeon_meme_orig.jpg ADDED
demo.py CHANGED
@@ -1,6 +1,67 @@
1
  import torch
2
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
3
  from datasets import load_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"):
6
  dataset=load_dataset(dataset_name)
@@ -8,13 +69,13 @@ def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"):
8
  return dataset["train"]
9
 
10
  from transformers import BeitFeatureExtractor, BeitForImageClassification
11
- feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
12
- model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
13
  def embed(images):
14
- inputs = feature_extractor(images=images, return_tensors="pt")
15
- outputs = model(**inputs,output_hidden_states= True)
16
  last_hidden=outputs.hidden_states[-1]
17
- pooler=model.base_model.pooler
18
  final_emb=pooler(last_hidden).detach().numpy()
19
  return final_emb
20
 
@@ -29,15 +90,15 @@ def get_dataset():
29
  dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
30
  return dataset
31
 
32
- def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512'):
33
- gan = LightweightGAN.from_pretrained(model_name)
34
  gan.eval()
35
  return gan
36
 
37
  def generate(gan,batch_size=1):
38
  with torch.no_grad():
39
- ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)
40
- ims = ims.permute(0,2,3,1).detach().cpu().numpy()
41
  return ims
42
 
43
  def interpolate():
1
  import torch
2
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
3
  from datasets import load_dataset
4
+ from PIL import Image
5
+ import numpy as np
6
+ import paddlehub as hub
7
+ import random
8
+ from PIL import ImageDraw,ImageFont
9
+
10
+ import streamlit as st
11
+
12
+ @st.experimental_singleton
13
+ def load_bg_model():
14
+ bg_model = hub.Module(name='U2NetP', directory='assets/models/')
15
+ return bg_model
16
+
17
+
18
+ bg_model = load_bg_model()
19
+ def remove_bg(img):
20
+ result = bg_model.Segmentation(
21
+ images=[np.array(img)[:,:,::-1]],
22
+ paths=None,
23
+ batch_size=1,
24
+ input_size=320,
25
+ output_dir=None,
26
+ visualization=False)
27
+ output = result[0]
28
+ mask=Image.fromarray(output['mask'])
29
+ front=Image.fromarray(output['front'][:,:,::-1]).convert("RGBA")
30
+ front.putalpha(mask)
31
+ return front
32
+
33
+ meme_template=Image.open("./assets/pigeon_meme.jpg").convert("RGBA")
34
+ def make_meme(pigeon,text="Is this a pigeon?",show_text=True,remove_background=True):
35
+
36
+ meme=meme_template.copy()
37
+ approx_butterfly_center=(850,30)
38
+
39
+ if remove_background:
40
+ pigeon=remove_bg(pigeon)
41
+ meme=meme.convert("RGBA")
42
+
43
+ random_rotate=random.randint(-30,30)
44
+ random_size=random.randint(150,200)
45
+ pigeon=pigeon.resize((random_size,random_size)).rotate(random_rotate,expand=True)
46
+
47
+ meme.alpha_composite(pigeon, approx_butterfly_center)
48
+
49
+ #ref: https://blog.lipsumarium.com/caption-memes-in-python/
50
+ def drawTextWithOutline(text, x, y):
51
+ draw.text((x-2, y-2), text,(0,0,0),font=font)
52
+ draw.text((x+2, y-2), text,(0,0,0),font=font)
53
+ draw.text((x+2, y+2), text,(0,0,0),font=font)
54
+ draw.text((x-2, y+2), text,(0,0,0),font=font)
55
+ draw.text((x, y), text, (255,255,255), font=font)
56
+
57
+ if show_text:
58
+ draw = ImageDraw.Draw(meme)
59
+ font_size=52
60
+ font = ImageFont.truetype("assets/impact.ttf", font_size)
61
+ w, h = draw.textsize(text, font) # measure the size the text will take
62
+ drawTextWithOutline(text, meme.width/2 - w/2, meme.height - font_size*2)
63
+ meme = meme.convert("RGB")
64
+ return meme
65
 
66
  def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"):
67
  dataset=load_dataset(dataset_name)
69
  return dataset["train"]
70
 
71
  from transformers import BeitFeatureExtractor, BeitForImageClassification
72
+ emb_feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
73
+ emb_model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
74
  def embed(images):
75
+ inputs = emb_feature_extractor(images=images, return_tensors="pt")
76
+ outputs = emb_model(**inputs,output_hidden_states= True)
77
  last_hidden=outputs.hidden_states[-1]
78
+ pooler=emb_model.base_model.pooler
79
  final_emb=pooler(last_hidden).detach().numpy()
80
  return final_emb
81
 
90
  dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
91
  return dataset
92
 
93
+ def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512',model_version="95a9596a1e47e2419c9bd5252d809eecb14fdcf4"):
94
+ gan = LightweightGAN.from_pretrained(model_name,version=model_version)
95
  gan.eval()
96
  return gan
97
 
98
  def generate(gan,batch_size=1):
99
  with torch.no_grad():
100
+ ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255
101
+ ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
102
  return ims
103
 
104
  def interpolate():
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  git+https://github.com/huggingface/community-events.git@3fea10c5d5a50c69f509e34cd580fe9139905d04#egg=huggan
2
  transformers
3
- faiss-cpu
 
 
1
  git+https://github.com/huggingface/community-events.git@3fea10c5d5a50c69f509e34cd580fe9139905d04#egg=huggan
2
  transformers
3
+ faiss-cpu
4
+ paddlehub
5
+ paddlepaddle