geninhu commited on
Commit
5e5faf7
1 Parent(s): 3f10a30

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import streamlit as st
4
+
5
+ from models import Generator, Discriminrator
6
+ from utils import image_to_base64
7
+ import torch
8
+ import torchvision.transforms as T
9
+ from torchvision.utils import make_grid
10
+ from PIL import Image
11
+
12
+ from streamlit_lottie import st_lottie
13
+ import requests
14
+
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+
17
+
18
+ model_name = {
19
+ "aurora": 'huggan/fastgan-few-shot-aurora',
20
+ "painting": 'huggan/fastgan-few-shot-painting',
21
+ "shell": 'huggan/fastgan-few-shot-shells',
22
+ "fauvism": 'huggan/fastgan-few-shot-fauvism-still-life',
23
+ }
24
+
25
+ #@st.cache(allow_output_mutation=True)
26
+ def load_generator(model_name_or_path):
27
+ generator = Generator(in_channels=256, out_channels=3)
28
+ generator = generator.from_pretrained(model_name_or_path, in_channels=256, out_channels=3)
29
+ _ = generator.to('cuda')
30
+ _ = generator.eval()
31
+
32
+ return generator
33
+
34
+ def _denormalize(input: torch.Tensor) -> torch.Tensor:
35
+ return (input * 127.5) + 127.5
36
+
37
+
38
+ def generate_images(generator, number_imgs):
39
+ noise = torch.zeros(number_imgs, 256, 1, 1, device='cuda').normal_(0.0, 1.0)
40
+ with torch.no_grad():
41
+ gan_images, _ = generator(noise)
42
+
43
+ gan_images = _denormalize(gan_images.detach()).cpu()
44
+ gan_images = [i for i in gan_images]
45
+ gan_images = [make_grid(i, nrow=1, normalize=True) for i in gan_images]
46
+ gan_images = [i.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() for i in gan_images]
47
+ gan_images = [Image.fromarray(i) for i in gan_images]
48
+ return gan_images
49
+
50
+ def load_lottieurl(url: str):
51
+ r = requests.get(url)
52
+ if r.status_code != 200:
53
+ return None
54
+ return r.json()
55
+
56
+
57
+ def main():
58
+
59
+ st.set_page_config(
60
+ page_title="FastGAN Generator",
61
+ page_icon="🖥️",
62
+ layout="wide",
63
+ initial_sidebar_state="expanded"
64
+ )
65
+
66
+ lottie_penguin = load_lottieurl('https://assets7.lottiefiles.com/packages/lf20_mm4bsl3l.json')
67
+
68
+ with st.sidebar:
69
+ st_lottie(lottie_penguin, height=200)
70
+ st.sidebar.markdown(
71
+ """
72
+ ___
73
+ <p style='text-align: center'>
74
+ FastGAN is an few-shot GAN model that generates images of several types!
75
+ </p>
76
+ <p style='text-align: center'>
77
+ Model training and Space creation by
78
+ <br/>
79
+ <a href="https://huggingface.co/vumichien" target="_blank">Chien Vu</a> | <a href="https://huggingface.co/geninhu" target="_blank">Nhu Hoang</a>
80
+ <br/>
81
+ </p>
82
+
83
+ <p style='text-align: center'>
84
+ based on
85
+ <br/>
86
+ <a href="https://github.com/silentz/Towards-Faster-And-Stabilized-GAN-Training-For-High-Fidelity-Few-Shot-Image-Synthesis" target="_blank">FastGAN model</a> | <a href="https://arxiv.org/abs/2101.04775" target="_blank">Article</a>
87
+ </p>
88
+ """,
89
+ unsafe_allow_html=True,
90
+ )
91
+
92
+ st.header("Welcome to FastGAN")
93
+
94
+ col1, col2, col3, col4 = st.columns([3,3,3,3])
95
+ with col1:
96
+ st.markdown('Fauvism GAN [model](https://huggingface.co/huggan/fastgan-few-shot-fauvism-still-life)', unsafe_allow_html=True)
97
+ st.image('fauvism.png', width=250)
98
+
99
+ with col2:
100
+ st.markdown('Aurora GAN [model](https://huggingface.co/huggan/fastgan-few-shot-aurora-bs8)', unsafe_allow_html=True)
101
+ st.image('aurora.png', width=250)
102
+
103
+ with col3:
104
+ st.markdown('Painting GAN [model](https://huggingface.co/huggan/fastgan-few-shot-painting-bs8)', unsafe_allow_html=True)
105
+ st.image('painting.png', width=250)
106
+ with col4:
107
+ st.markdown('Shell GAN [model](https://huggingface.co/huggan/fastgan-few-shot-shells)', unsafe_allow_html=True)
108
+ st.image('shell.png', width=250)
109
+
110
+ st.markdown('___')
111
+ if st.checkbox('Click if you want to create one of your own !'):
112
+
113
+ col11, col12, col13 = st.columns([3,3,3])
114
+ with col11:
115
+ img_type = st.selectbox("Choose type of image to generate", index=0, options=["aurora", "painting", "fauvism", "shell"])
116
+ # with col12:
117
+ number_imgs = st.slider('How many images you want to generate ?', min_value=1, max_value=5)
118
+ if number_imgs is None:
119
+ st.write('Invalid number ! Please insert number of images to generate !')
120
+ raise ValueError('Invalid number ! Please insert number of images to generate !')
121
+
122
+ generate_button = st.button('Get Image!')
123
+ if generate_button:
124
+ st.markdown("""
125
+ <small><i>Predictions may take up to 1 minute under high load.</i></small>
126
+ </br>
127
+ <small><i>Please stand by.</i></small>
128
+ """,
129
+ unsafe_allow_html=True,)
130
+
131
+ if generate_button:
132
+ generator = load_generator(model_name[img_type])
133
+ gan_images = generate_images(generator, number_imgs)
134
+ with col12:
135
+ st.image(gan_images[0], width=400)
136
+ if len(gan_images) > 1:
137
+ with col13:
138
+ if len(gan_images) <= 2:
139
+ st.image(gan_images[1], width=400)
140
+ else:
141
+ st.image(gan_images[1:], width=200)
142
+
143
+
144
+
145
+ if __name__ == '__main__':
146
+ main()