Ceyda Cinarel commited on
Commit
b0b9e1f
·
1 Parent(s): 4e918e8

Add demo start

Browse files
Files changed (4) hide show
  1. app.py +41 -0
  2. demo.py +28 -0
  3. packages.txt +0 -0
  4. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st # HF spaces at v1.2.0
2
+ from demo import load_model,generate
3
+
4
+ # TODOs
5
+ # Add markdown short readme project intro
6
+ # project setup:
7
+
8
+ # git clone https://github.com/huggingface/community-events.git
9
+ # cd community-events
10
+ # pip install .
11
+
12
+ st.title("ButterflyGAN")
13
+ st.write("## This butterfly does not exist! ")
14
+ st.write("Demo prep still in progress!!")
15
+
16
+ @st.experimental_singleton
17
+ def load_model_intocache(model_name):
18
+
19
+ # model_name='ceyda/butterfly_512_base'
20
+ gan = load_model(model_name)
21
+
22
+ return gan
23
+
24
+ model_name='ceyda/butterfly_cropped_uniq1K_512'
25
+ model=load_model_intocache(model_name)
26
+
27
+ st.write(f"Model {model_name} is loaded")
28
+ st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}")
29
+
30
+ run=st.button("Generate")
31
+ if run:
32
+ with st.spinner("Generating..."):
33
+
34
+ batch_size=4 #generate 4 butterflies
35
+ ims=generate(model,batch_size)
36
+
37
+ cols=st.columns(batch_size)
38
+ for i,im in enumerate(ims):
39
+ cols[i].image(im)
40
+
41
+
demo.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
3
+ from datasets import load_dataset
4
+
5
+ def get_train_data(dataset_name="ceyda/smithsonian_butterflies_transparent_cropped",data_limit=1000):
6
+ dataset=load_dataset(dataset_name)
7
+ dataset=dataset.sort("sim_score")
8
+ score_thresh = dataset["train"][data_limit]['sim_score']
9
+ dataset = dataset.filter(lambda x: x['sim_score'] < score_thresh)
10
+
11
+ dataset = dataset.map(lambda x: x.convert("RGB"))
12
+ return dataset["train"]
13
+
14
+
15
+
16
+ def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512'):
17
+ gan = LightweightGAN.from_pretrained(model_name)
18
+ gan.eval();
19
+ return gan
20
+
21
+ def generate(gan,batch_size=1):
22
+ with torch.no_grad():
23
+ ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)
24
+ ims = ims.permute(0,2,3,1).detach().cpu().numpy()
25
+ return ims
26
+
27
+ def interpolate():
28
+ pass
packages.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git+https://github.com/huggingface/community-events.git@3fea10c5d5a50c69f509e34cd580fe9139905d04#egg=huggan