geninhu commited on
Commit
6baf171
1 Parent(s): 4bb924d

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -140
app.py DELETED
@@ -1,140 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
- import streamlit as st
4
-
5
- from models import Generator, Discriminrator
6
- from utils import image_to_base64
7
- import torch
8
- import torchvision.transforms as T
9
- from torchvision.utils import make_grid
10
- from PIL import Image
11
-
12
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
-
14
-
15
- model_name = {
16
- "aurora": 'huggan/fastgan-few-shot-aurora-bs8',
17
- "painting": 'huggan/fastgan-few-shot-painting-bs8',
18
- "shell": 'huggan/fastgan-few-shot-shells',
19
- "fauvism": 'huggan/fastgan-few-shot-fauvism-still-life',
20
- }
21
-
22
- #@st.cache(allow_output_mutation=True)
23
- def load_generator(model_name_or_path):
24
- generator = Generator(in_channels=256, out_channels=3)
25
- generator = generator.from_pretrained(model_name_or_path, in_channels=256, out_channels=3)
26
- _ = generator.to('cuda')
27
- _ = generator.eval()
28
-
29
- return generator
30
-
31
- def _denormalize(input: torch.Tensor) -> torch.Tensor:
32
- return (input * 127.5) + 127.5
33
-
34
-
35
- def generate_images(generator, number_imgs):
36
- noise = torch.zeros(number_imgs, 256, 1, 1, device='cuda').normal_(0.0, 1.0)
37
- with torch.no_grad():
38
- gan_images, _ = generator(noise)
39
-
40
- gan_images = _denormalize(gan_images.detach()).cpu()
41
- gan_images = make_grid(gan_images, nrow=number_imgs, normalize=True)
42
- gan_images = gan_images.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
43
- gan_images = Image.fromarray(gan_images)
44
- return gan_images
45
-
46
-
47
- def main():
48
-
49
- st.set_page_config(
50
- page_title="FastGAN Generator",
51
- page_icon="🖥️",
52
- layout="wide",
53
- initial_sidebar_state="expanded"
54
- )
55
-
56
- # st.sidebar.markdown(
57
- # """
58
- # <style>
59
- # .aligncenter {
60
- # text-align: center;
61
- # }
62
- # </style>
63
- # <p class="aligncenter">
64
- # <img src="https://e7.pngegg.com/pngimages/510/121/png-clipart-machine-learning-deep-learning-artificial-intelligence-algorithm-machine-learning-angle-text.png"/>
65
- # </p>
66
- # """,
67
- # unsafe_allow_html=True,
68
- # )
69
- st.sidebar.markdown(
70
- """
71
- ___
72
- <p style='text-align: center'>
73
- FastGAN is an few-shot GAN model that generates images of several types!
74
- </p>
75
- <p style='text-align: center'>
76
- Model training and Space creation by
77
- <br/>
78
- <a href="https://huggingface.co/vumichien" target="_blank">Chien Vu</a> | <a href="https://huggingface.co/geninhu" target="_blank">Nhu Hoang</a>
79
- <br/>
80
- </p>
81
-
82
- <p style='text-align: center'>
83
- <a href="https://github.com/silentz/Towards-Faster-And-Stabilized-GAN-Training-For-High-Fidelity-Few-Shot-Image-Synthesis" target="_blank">based on FastGAN model</a> | <a href="https://arxiv.org/abs/2101.04775" target="_blank">Article</a>
84
- </p>
85
- """,
86
- unsafe_allow_html=True,
87
- )
88
-
89
- st.header("Welcome to FastGAN")
90
-
91
- col1, col2, col3, col4 = st.columns([3,3,3,3])
92
- with col1:
93
- st.markdown('Fauvism GAN [model](https://huggingface.co/huggan/fastgan-few-shot-fauvism-still-life)', unsafe_allow_html=True)
94
- st.image('fauvism.png', width=300)
95
-
96
- with col2:
97
- st.markdown('Aurora GAN [model](https://huggingface.co/huggan/fastgan-few-shot-aurora-bs8)', unsafe_allow_html=True)
98
- st.image('aurora.png', width=300)
99
-
100
- with col3:
101
- st.markdown('Painting GAN [model](https://huggingface.co/huggan/fastgan-few-shot-painting-bs8)', unsafe_allow_html=True)
102
- st.image('painting.png', width=300)
103
- with col4:
104
- st.markdown('Shell GAN [model](https://huggingface.co/huggan/fastgan-few-shot-shells)', unsafe_allow_html=True)
105
- st.image('shell.png', width=300)
106
-
107
- # Choose generator
108
- col11, col12, col13 = st.columns([4,4,2])
109
- with col11:
110
- st.markdown('Choose type of image to generate', unsafe_allow_html=True)
111
- img_type = st.selectbox("", index=0, options=["shell", "aurora", "painting", "fauvism"])
112
-
113
- with col12:
114
- number_imgs = st.number_input('How many images you want to generate ?', min_value=1, max_value=5)
115
- if number_imgs is None:
116
- st.write('Invalid number ! Please insert number of images to generate !')
117
- raise ValueError('Invalid number ! Please insert number of images to generate !')
118
- with col13:
119
- generate_button = st.button('Get Image!')
120
-
121
- # row2 = st.columns([10])
122
- # with row2:
123
- if generate_button:
124
- st.markdown("""
125
- <small><i>Predictions may take up to 1mn under high load. Please stand by.</i></small>
126
- """,
127
- unsafe_allow_html=True,)
128
- generator = load_generator(model_name[img_type])
129
- gan_images = generate_images(generator, number_imgs)
130
- # margin = 0.1 # for better position of zoom in arrow
131
- # n_columns = 2
132
- # cols = st.columns([1] + [margin, 1] * (n_columns - 1))
133
- # for i, img in enumerate(gan_images):
134
- # cols[(i % n_columns) * 2].image(img)
135
-
136
- st.image(gan_images, width=200*number_imgs)
137
-
138
-
139
- if __name__ == '__main__':
140
- main()