Ricercar commited on
Commit
10c79ab
1 Parent(s): 5857783

prepare for archive

Browse files
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
 
 
3
  import gradio as gr
4
  import numpy as np
5
  import torch
@@ -31,12 +32,19 @@ LORA_TRIGGER_WORD = {
31
  'shinkai_makoto': ['shinkai makoto', 'kimi no na wa.', 'tenki no ko', 'kotonoha no niwa'],
32
  }
33
 
 
 
34
 
35
  class WebApp():
36
  def __init__(self, debug_mode=False):
 
 
 
 
 
37
  self.args_base = {
38
  "seed": 42,
39
- "device": "cuda",
40
  "output_dir": "output_demo",
41
  "caption_model_name": "blip-large",
42
  "clip_model_name": "ViT-L-14/openai",
@@ -60,7 +68,6 @@ class WebApp():
60
  self.args_input = {} # for gr.components only
61
  self.gr_loras = list(LORA_TRIGGER_WORD.keys())
62
 
63
- # fun fact: google analytics doesn't work in this space currently
64
  self.gtag = os.environ.get('GTag')
65
 
66
  self.ga_script = f"""
@@ -80,13 +87,13 @@ class WebApp():
80
  # self._preload_pipeline()
81
 
82
  self.debug_mode = debug_mode # turn off clip interrogator when debugging for faster building speed
83
- if not self.debug_mode:
84
  self.init_interrogator()
85
 
86
 
87
  def init_interrogator(self):
88
  cache_path = os.environ.get('HF_HOME')
89
- print(f"Intended cache dir: {cache_path}")
90
  config = Config()
91
  config.cache_path = cache_path
92
  config.clip_model_path = cache_path
@@ -96,7 +103,7 @@ class WebApp():
96
  self.ci.config.chunk_size = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
97
  self.ci.config.flavor_intermediate_count = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
98
 
99
- print(f"HF cache dir: {file_utils.default_cache_path}")
100
 
101
  def _preload_pipeline(self):
102
  for model in BASE_MODEL.values():
@@ -114,10 +121,10 @@ class WebApp():
114
  <h1 >Diffusion Cocktail 🍸: Fused Generation from Diffusion Models</h1>
115
  <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;">
116
  <a class="flex-item" href="https://arxiv.org/abs/2312.08873" target="_blank">
117
- <img src="https://img.shields.io/badge/arXiv-paper-darkred.svg" alt="arXiv Paper">
118
  </a>
119
  <a class="flex-item" href="https://MAPS-research.github.io/Ditail" target="_blank">
120
- <img src="https://img.shields.io/badge/Project_Page-Diffusion_Cocktail-yellow.svg" alt="Project Page">
121
  </a>
122
  <a class="flex-item" href="https://github.com/MAPS-research/Ditail" target="_blank">
123
  <img src="https://img.shields.io/badge/Github-Code-green.svg" alt="GitHub Code">
@@ -127,7 +134,20 @@ class WebApp():
127
  </div>
128
  """
129
  )
130
-
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  def get_image(self):
133
  self.args_input['img'] = gr.Image(label='content image', type='pil', show_share_button=False, elem_classes="input_image")
@@ -142,7 +162,7 @@ class WebApp():
142
 
143
 
144
  def _interrogate_image(self, image, generate_prompt):
145
- if hasattr(self, 'ci') and generate_prompt:
146
  return self.ci.interrogate_fast(image).split(',')[0].replace('arafed', '')
147
  else:
148
  return ''
@@ -153,8 +173,8 @@ class WebApp():
153
 
154
  def get_lora(self, num_cols=3):
155
  self.args_input['lora'] = gr.State('none')
156
- lora_gallery = gr.Gallery(label='target LoRA (optional)', columns=num_cols, value=[(os.path.join(self.args_base['lora_dir'], f"{lora}.jpeg"), lora) for lora in self.gr_loras], allow_preview=False, show_share_button=False, selected_index=0)
157
- lora_gallery.select(self._update_lora_selection, inputs=[], outputs=[self.args_input['lora']])
158
 
159
  def _update_lora_selection(self, selected_state: gr.SelectData):
160
  return self.gr_loras[selected_state.index]
@@ -180,7 +200,7 @@ class WebApp():
180
 
181
  def run_ditail(self, *values):
182
  gr_args = self.args_base.copy()
183
- print(self.args_input.keys())
184
  for k, v in zip(list(self.args_input.keys()), values):
185
  gr_args[k] = v
186
  # quick fix for example
@@ -195,9 +215,9 @@ class WebApp():
195
  seed_everything(gr_args['seed'])
196
  ditail = DitailDemo(gr_args)
197
 
198
- metadata_to_show = ['inv_model', 'spl_model', 'lora', 'lora_scale', 'inv_steps', 'spl_steps', 'pos_prompt', 'alpha', 'neg_prompt', 'beta', 'omega']
199
  args_to_show = {}
200
- for key in metadata_to_show:
201
  args_to_show[key] = gr_args[key]
202
 
203
  img = ditail.run_ditail()
@@ -207,8 +227,19 @@ class WebApp():
207
 
208
  return img, args_to_show
209
 
210
- def run_example(self, img, prompt, inv_model, spl_model, lora):
211
- return self.run_ditail(img, prompt, spl_model, gr.State(lora), inv_model)
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  def show_credits(self):
214
  gr.Markdown(
@@ -224,6 +255,10 @@ class WebApp():
224
  with gr.Blocks(css='.input_image img {object-fit: contain;}', head=self.ga_script) as demo:
225
 
226
  self.title()
 
 
 
 
227
  with gr.Row():
228
  self.get_image()
229
 
@@ -232,6 +267,8 @@ class WebApp():
232
  self.get_base_model()
233
  self.get_lora(num_cols=3)
234
  submit_btn = gr.Button("Generate", variant='primary')
 
 
235
 
236
  with gr.Accordion("advanced options", open=False):
237
  self.get_params()
@@ -250,12 +287,12 @@ class WebApp():
250
  with gr.Row():
251
  cache_examples = not self.debug_mode
252
  gr.Examples(
253
- examples=[[os.path.join(os.path.dirname(__file__), "example", "Lenna.png"), 'a woman called Lenna wearing a feathered hat', list(BASE_MODEL.keys())[1], list(BASE_MODEL.keys())[2], 'none']],
254
  inputs=[self.args_input['img'], self.args_input['pos_prompt'], self.args_input['inv_model'], self.args_input['spl_model'], gr.Textbox(label='LoRA', visible=False), ],
255
  fn = self.run_example,
256
  outputs=[output_image, metadata],
257
  run_on_click=True,
258
- cache_examples=cache_examples,
259
  )
260
 
261
  self.show_credits()
@@ -264,7 +301,7 @@ class WebApp():
264
  return demo
265
 
266
 
267
- app = WebApp(debug_mode=False)
268
  demo = app.ui()
269
 
270
 
 
1
  import os
2
 
3
+ import argparse
4
  import gradio as gr
5
  import numpy as np
6
  import torch
 
32
  'shinkai_makoto': ['shinkai makoto', 'kimi no na wa.', 'tenki no ko', 'kotonoha no niwa'],
33
  }
34
 
35
+ METADATA_TO_SHOW = ['inv_model', 'spl_model', 'lora', 'lora_scale', 'inv_steps', 'spl_steps', 'pos_prompt', 'alpha', 'neg_prompt', 'beta', 'omega']
36
+
37
 
38
  class WebApp():
39
  def __init__(self, debug_mode=False):
40
+ if torch.cuda.is_available():
41
+ self.device = "cuda"
42
+ else:
43
+ self.device = "cpu"
44
+
45
  self.args_base = {
46
  "seed": 42,
47
+ "device": self.device,
48
  "output_dir": "output_demo",
49
  "caption_model_name": "blip-large",
50
  "clip_model_name": "ViT-L-14/openai",
 
68
  self.args_input = {} # for gr.components only
69
  self.gr_loras = list(LORA_TRIGGER_WORD.keys())
70
 
 
71
  self.gtag = os.environ.get('GTag')
72
 
73
  self.ga_script = f"""
 
87
  # self._preload_pipeline()
88
 
89
  self.debug_mode = debug_mode # turn off clip interrogator when debugging for faster building speed
90
+ if not self.debug_mode and self.device=="cuda":
91
  self.init_interrogator()
92
 
93
 
94
  def init_interrogator(self):
95
  cache_path = os.environ.get('HF_HOME')
96
+ # print(f"Intended cache dir: {cache_path}")
97
  config = Config()
98
  config.cache_path = cache_path
99
  config.clip_model_path = cache_path
 
103
  self.ci.config.chunk_size = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
104
  self.ci.config.flavor_intermediate_count = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
105
 
106
+ # print(f"HF cache dir: {file_utils.default_cache_path}")
107
 
108
  def _preload_pipeline(self):
109
  for model in BASE_MODEL.values():
 
121
  <h1 >Diffusion Cocktail 🍸: Fused Generation from Diffusion Models</h1>
122
  <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;">
123
  <a class="flex-item" href="https://arxiv.org/abs/2312.08873" target="_blank">
124
+ <img src="https://img.shields.io/badge/arXiv-Paper-darkred.svg" alt="arXiv Paper">
125
  </a>
126
  <a class="flex-item" href="https://MAPS-research.github.io/Ditail" target="_blank">
127
+ <img src="https://img.shields.io/badge/Website-Ditail-yellow.svg" alt="Project Page">
128
  </a>
129
  <a class="flex-item" href="https://github.com/MAPS-research/Ditail" target="_blank">
130
  <img src="https://img.shields.io/badge/Github-Code-green.svg" alt="GitHub Code">
 
134
  </div>
135
  """
136
  )
137
+
138
+
139
+ def device_requirements(self):
140
+ gr.Markdown(
141
+ """
142
+ <center>
143
+ <h2>
144
+ Attention: The demo doesn't work in this space running on CPU only. \
145
+ Please duplicate and upgrade to a private "T4 medium" GPU.
146
+ </h2>
147
+ </center>
148
+ """
149
+ )
150
+ gr.DuplicateButton(size='lg', scale=1, variant='primary')
151
 
152
  def get_image(self):
153
  self.args_input['img'] = gr.Image(label='content image', type='pil', show_share_button=False, elem_classes="input_image")
 
162
 
163
 
164
  def _interrogate_image(self, image, generate_prompt):
165
+ if hasattr(self, 'ci') and image is not None and generate_prompt:
166
  return self.ci.interrogate_fast(image).split(',')[0].replace('arafed', '')
167
  else:
168
  return ''
 
173
 
174
  def get_lora(self, num_cols=3):
175
  self.args_input['lora'] = gr.State('none')
176
+ self.lora_gallery = gr.Gallery(label='target LoRA (optional)', columns=num_cols, value=[(os.path.join(self.args_base['lora_dir'], f"{lora}.jpeg"), lora) for lora in self.gr_loras], allow_preview=False, show_share_button=False)
177
+ self.lora_gallery.select(self._update_lora_selection, inputs=[], outputs=[self.args_input['lora']])
178
 
179
  def _update_lora_selection(self, selected_state: gr.SelectData):
180
  return self.gr_loras[selected_state.index]
 
200
 
201
  def run_ditail(self, *values):
202
  gr_args = self.args_base.copy()
203
+ # print(self.args_input.keys())
204
  for k, v in zip(list(self.args_input.keys()), values):
205
  gr_args[k] = v
206
  # quick fix for example
 
215
  seed_everything(gr_args['seed'])
216
  ditail = DitailDemo(gr_args)
217
 
218
+
219
  args_to_show = {}
220
+ for key in METADATA_TO_SHOW:
221
  args_to_show[key] = gr_args[key]
222
 
223
  img = ditail.run_ditail()
 
227
 
228
  return img, args_to_show
229
 
230
+ # def run_example(self, img, prompt, inv_model, spl_model, lora):
231
+ # return self.run_ditail(img, prompt, spl_model, gr.State(lora), inv_model)
232
+ def run_example(self, *values):
233
+ gr_args = self.args_base.copy()
234
+ for k, v in zip(['img', 'pos_prompt', 'inv_model', 'spl_model', 'lora'], values):
235
+ gr_args[k] = v
236
+ args_to_show = {}
237
+ for key in METADATA_TO_SHOW:
238
+ args_to_show[key] = gr_args[key]
239
+ img = os.path.join(os.path.dirname(__file__), "example", "Cocktail_impression.jpg")
240
+ # self.lora_gallery.selected_index = self.gr_loras.index(gr_args['lora'])
241
+ return img, args_to_show
242
+
243
 
244
  def show_credits(self):
245
  gr.Markdown(
 
255
  with gr.Blocks(css='.input_image img {object-fit: contain;}', head=self.ga_script) as demo:
256
 
257
  self.title()
258
+
259
+ if self.device == "cpu":
260
+ self.device_requirements()
261
+
262
  with gr.Row():
263
  self.get_image()
264
 
 
267
  self.get_base_model()
268
  self.get_lora(num_cols=3)
269
  submit_btn = gr.Button("Generate", variant='primary')
270
+ if self.device == 'cpu':
271
+ submit_btn.variant='secondary'
272
 
273
  with gr.Accordion("advanced options", open=False):
274
  self.get_params()
 
287
  with gr.Row():
288
  cache_examples = not self.debug_mode
289
  gr.Examples(
290
+ examples=[[os.path.join(os.path.dirname(__file__), "example", "Cocktail.jpg"), 'a glass of a cocktail with a lime wedge on it', list(BASE_MODEL.keys())[1], list(BASE_MODEL.keys())[1], 'impressionism']],
291
  inputs=[self.args_input['img'], self.args_input['pos_prompt'], self.args_input['inv_model'], self.args_input['spl_model'], gr.Textbox(label='LoRA', visible=False), ],
292
  fn = self.run_example,
293
  outputs=[output_image, metadata],
294
  run_on_click=True,
295
+ # cache_examples=cache_examples,
296
  )
297
 
298
  self.show_credits()
 
301
  return demo
302
 
303
 
304
+ app = WebApp(debug_mode=True)
305
  demo = app.ui()
306
 
307
 
ditail/src/ditail_demo.py CHANGED
@@ -72,11 +72,13 @@ class DitailDemo(nn.Module):
72
  padding='max_length',
73
  max_length=self.tokenizer.model_max_length
74
  )
75
-
76
  @torch.no_grad()
77
  def encode_image(self, image_pil):
78
  # image_pil = T.Resize(512)(img.convert('RGB'))
79
  image_pil = T.Resize(512)(image_pil)
 
 
80
  image = T.ToTensor()(image_pil).unsqueeze(0).to(self.device)
81
  with torch.autocast(device_type=self.device, dtype=torch.float32):
82
  image = 2 * image - 1
 
72
  padding='max_length',
73
  max_length=self.tokenizer.model_max_length
74
  )
75
+
76
  @torch.no_grad()
77
  def encode_image(self, image_pil):
78
  # image_pil = T.Resize(512)(img.convert('RGB'))
79
  image_pil = T.Resize(512)(image_pil)
80
+ width, height = image_pil.size
81
+
82
  image = T.ToTensor()(image_pil).unsqueeze(0).to(self.device)
83
  with torch.autocast(device_type=self.device, dtype=torch.float32):
84
  image = 2 * image - 1
example/Cocktail.jpg ADDED
example/Cocktail_impression.jpg ADDED