Spaces:
Runtime error
Runtime error
Ceyda Cinarel
commited on
Commit
·
b0b9e1f
1
Parent(s):
4e918e8
Add demo start
Browse files- app.py +41 -0
- demo.py +28 -0
- packages.txt +0 -0
- requirements.txt +1 -0
app.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st # HF spaces at v1.2.0
|
2 |
+
from demo import load_model,generate
|
3 |
+
|
4 |
+
# TODOs
|
5 |
+
# Add markdown short readme project intro
|
6 |
+
# project setup:
|
7 |
+
|
8 |
+
# git clone https://github.com/huggingface/community-events.git
|
9 |
+
# cd community-events
|
10 |
+
# pip install .
|
11 |
+
|
12 |
+
st.title("ButterflyGAN")
|
13 |
+
st.write("## This butterfly does not exist! ")
|
14 |
+
st.write("Demo prep still in progress!!")
|
15 |
+
|
16 |
+
@st.experimental_singleton
|
17 |
+
def load_model_intocache(model_name):
|
18 |
+
|
19 |
+
# model_name='ceyda/butterfly_512_base'
|
20 |
+
gan = load_model(model_name)
|
21 |
+
|
22 |
+
return gan
|
23 |
+
|
24 |
+
model_name='ceyda/butterfly_cropped_uniq1K_512'
|
25 |
+
model=load_model_intocache(model_name)
|
26 |
+
|
27 |
+
st.write(f"Model {model_name} is loaded")
|
28 |
+
st.write(f"Latent dimension: {model.latent_dim}, Image size:{model.image_size}")
|
29 |
+
|
30 |
+
run=st.button("Generate")
|
31 |
+
if run:
|
32 |
+
with st.spinner("Generating..."):
|
33 |
+
|
34 |
+
batch_size=4 #generate 4 butterflies
|
35 |
+
ims=generate(model,batch_size)
|
36 |
+
|
37 |
+
cols=st.columns(batch_size)
|
38 |
+
for i,im in enumerate(ims):
|
39 |
+
cols[i].image(im)
|
40 |
+
|
41 |
+
|
demo.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
|
3 |
+
from datasets import load_dataset
|
4 |
+
|
5 |
+
def get_train_data(dataset_name="ceyda/smithsonian_butterflies_transparent_cropped",data_limit=1000):
|
6 |
+
dataset=load_dataset(dataset_name)
|
7 |
+
dataset=dataset.sort("sim_score")
|
8 |
+
score_thresh = dataset["train"][data_limit]['sim_score']
|
9 |
+
dataset = dataset.filter(lambda x: x['sim_score'] < score_thresh)
|
10 |
+
|
11 |
+
dataset = dataset.map(lambda x: x.convert("RGB"))
|
12 |
+
return dataset["train"]
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512'):
|
17 |
+
gan = LightweightGAN.from_pretrained(model_name)
|
18 |
+
gan.eval();
|
19 |
+
return gan
|
20 |
+
|
21 |
+
def generate(gan,batch_size=1):
|
22 |
+
with torch.no_grad():
|
23 |
+
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)
|
24 |
+
ims = ims.permute(0,2,3,1).detach().cpu().numpy()
|
25 |
+
return ims
|
26 |
+
|
27 |
+
def interpolate():
|
28 |
+
pass
|
packages.txt
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
git+https://github.com/huggingface/community-events.git@3fea10c5d5a50c69f509e34cd580fe9139905d04#egg=huggan
|