kaushikbar commited on
Commit
0df795e
·
1 Parent(s): 895e936

min dall-e

Browse files
Files changed (2) hide show
  1. app.py +206 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import gradio
3
+ import subprocess
4
+ from PIL import Image
5
+ import torch, torch.backends.cudnn, torch.backends.cuda
6
+ from min_dalle import MinDalle
7
+ from emoji import demojize
8
+ import string
9
+
10
+ def filename_from_text(text: str) -> str:
11
+ text = demojize(text, delimiters=['', ''])
12
+ text = text.lower().encode('ascii', errors='ignore').decode()
13
+ allowed_chars = string.ascii_lowercase + ' '
14
+ text = ''.join(i for i in text.lower() if i in allowed_chars)
15
+ text = text[:64]
16
+ text = '-'.join(text.strip().split())
17
+ if len(text) == 0: text = 'blank'
18
+ return text
19
+
20
+ def log_gpu_memory():
21
+ print("Date:{}, GPU memory:{}".format(str(datetime.datetime.now()), subprocess.check_output('nvidia-smi').decode('utf-8')))
22
+
23
+ log_gpu_memory()
24
+
25
+ model = MinDalle(
26
+ is_mega=True,
27
+ is_reusable=True,
28
+ device='cuda',
29
+ dtype=torch.float32
30
+ )
31
+
32
+ log_gpu_memory()
33
+
34
+ def run_model(
35
+ text: str,
36
+ grid_size: int,
37
+ is_seamless: bool,
38
+ save_as_png: bool,
39
+ temperature: float,
40
+ supercondition: str,
41
+ top_k: str
42
+ ) -> str:
43
+ torch.set_grad_enabled(False)
44
+ torch.backends.cudnn.enabled = True
45
+ torch.backends.cudnn.deterministic = False
46
+ torch.backends.cudnn.benchmark = True
47
+ torch.backends.cuda.matmul.allow_tf32 = True
48
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
49
+
50
+ print("Date:{}".format(str(datetime.datetime.now())))
51
+ print('text:', text)
52
+ print('grid_size:', grid_size)
53
+ print('is_seamless:', is_seamless)
54
+ print('temperature:', temperature)
55
+ print('supercondition:', supercondition)
56
+ print('top_k:', top_k)
57
+
58
+ try:
59
+ temperature = float(temperature)
60
+ assert(temperature > 1e-6)
61
+ except:
62
+ raise Exception('Temperature must be a positive nonzero number')
63
+ try:
64
+ grid_size = int(grid_size)
65
+ assert(grid_size <= 5)
66
+ assert(grid_size >= 1)
67
+ except:
68
+ raise Exception('Grid size must be between 1 and 5')
69
+ try:
70
+ top_k = int(top_k)
71
+ assert(top_k <= 16384)
72
+ assert(top_k >= 1)
73
+ except:
74
+ raise Exception('Top k must be between 1 and 16384')
75
+
76
+ with torch.no_grad():
77
+ image = model.generate_image(
78
+ text = text,
79
+ seed = -1,
80
+ grid_size = grid_size,
81
+ is_seamless = bool(is_seamless),
82
+ temperature = temperature,
83
+ supercondition_factor = float(supercondition),
84
+ top_k = top_k,
85
+ is_verbose = True
86
+ )
87
+
88
+ log_gpu_memory()
89
+
90
+ ext = 'png' if bool(save_as_png) else 'jpg'
91
+ filename = filename_from_text(text)
92
+ image_path = '{}.{}'.format(filename, ext)
93
+ image.save(image_path)
94
+
95
+ return image_path
96
+
97
+ demo = gradio.Blocks(analytics_enabled=True)
98
+
99
+ with demo:
100
+ with gradio.Row():
101
+ with gradio.Column():
102
+ input_text = gradio.Textbox(
103
+ label='Input Text',
104
+ value='Moai statue giving a TED Talk',
105
+ lines=3
106
+ )
107
+ run_button = gradio.Button(value='Generate Image').style(full_width=True)
108
+ #output_image = gradio.Image(
109
+ # value='examples/moai-statue.jpg',
110
+ # label='Output Image',
111
+ # type='file',
112
+ # interactive=False
113
+ )
114
+
115
+ with gradio.Column():
116
+ gradio.Markdown('## Settings')
117
+ with gradio.Row():
118
+ grid_size = gradio.Slider(
119
+ label='Grid Size',
120
+ value=5,
121
+ minimum=1,
122
+ maximum=5,
123
+ step=1
124
+ )
125
+ save_as_png = gradio.Checkbox(
126
+ label='Output PNG',
127
+ value=False
128
+ )
129
+ is_seamless = gradio.Checkbox(
130
+ label='Seamless',
131
+ value=False
132
+ )
133
+ gradio.Markdown('#### Advanced')
134
+ with gradio.Row():
135
+ temperature = gradio.Number(
136
+ label='Temperature',
137
+ value=1
138
+ )
139
+ top_k = gradio.Dropdown(
140
+ label='Top-k',
141
+ choices=[str(2 ** i) for i in range(15)],
142
+ value='128'
143
+ )
144
+ supercondition = gradio.Dropdown(
145
+ label='Super Condition',
146
+ choices=[str(2 ** i) for i in range(2, 7)],
147
+ value='16'
148
+ )
149
+
150
+ gradio.Markdown(
151
+ """
152
+ ####
153
+ - **Input Text**: For long prompts, only the first 64 text tokens will be used to generate the image.
154
+ - **Grid Size**: Size of the image grid. 3x3 takes about 15 seconds.
155
+ - **Seamless**: Tile images in image token space instead of pixel space.
156
+ - **Temperature**: High temperature increases the probability of sampling low scoring image tokens.
157
+ - **Top-k**: Each image token is sampled from the top-k scoring tokens.
158
+ - **Super Condition**: Higher values can result in better agreement with the text.
159
+ """
160
+ )
161
+
162
+ gradio.Examples(
163
+ examples=[
164
+ #['Rusty Iron Man suit found abandoned in the woods being reclaimed by nature', 3, 'examples/rusty-iron-man.jpg'],
165
+ #['Moai statue giving a TED Talk', 5, 'examples/moai-statue.jpg'],
166
+ #['Court sketch of Godzilla on trial', 5, 'examples/godzilla-trial.jpg'],
167
+ #['lofi nuclear war to relax and study to', 5, 'examples/lofi-nuclear-war.jpg'],
168
+ #['Karl Marx slimed at Kids Choice Awards', 4, 'examples/marx-slimed.jpg'],
169
+ #['Scientists trying to rhyme orange with banana', 4, 'examples/scientists-rhyme.jpg'],
170
+ #['Jesus turning water into wine on Americas Got Talent', 5, 'examples/jesus-talent.jpg'],
171
+ #['Elmo in a street riot throwing a Molotov cocktail, hyperrealistic', 5, 'examples/elmo-riot.jpg'],
172
+ #['Trail cam footage of gollum eating watermelon', 4, 'examples/gollum.jpg'],
173
+ #['Funeral at Whole Foods', 4, 'examples/funeral-whole-foods.jpg'],
174
+ #['Singularity, hyperrealism', 5, 'examples/singularity.jpg'],
175
+ #['Astronaut riding a horse hyperrealistic', 5, 'examples/astronaut-horse.jpg'],
176
+ ['Astronaut riding a horse hyperrealistic', 1],
177
+ #['An astronaut walking on Mars next to a Starship rocket, realistic', 5, 'examples/astronaut-mars.jpg'],
178
+ #['Nuclear explosion broccoli', 4, 'examples/nuclear-broccoli.jpg'],
179
+ #['Dali painting of WALL·E', 5, 'examples/dali-walle.jpg'],
180
+ #['Cleopatra checking her iPhone', 4, 'examples/cleopatra-iphone.jpg'],
181
+ ],
182
+ inputs=[
183
+ input_text,
184
+ grid_size,
185
+ #output_image
186
+ ],
187
+ examples_per_page=20
188
+ )
189
+
190
+ run_button.click(
191
+ fn=run_model,
192
+ inputs=[
193
+ input_text,
194
+ grid_size,
195
+ is_seamless,
196
+ save_as_png,
197
+ temperature,
198
+ supercondition,
199
+ top_k
200
+ ],
201
+ outputs=[
202
+ output_image
203
+ ]
204
+ )
205
+
206
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ min-dalle==0.4.6
2
+ emoji==1.7.0
3
+
4
+ #--find-links https://download.pytorch.org/whl/torch_stable.html
5
+ torch==1.12.1+cu116
6
+