liuhaotian commited on
Commit
e6da15b
1 Parent(s): 99cdea0
Files changed (1) hide show
  1. app.py +56 -18
app.py CHANGED
@@ -7,7 +7,9 @@ import json
7
  import numpy as np
8
  from PIL import Image, ImageDraw, ImageFont
9
  from functools import partial
 
10
  import math
 
11
 
12
  from gradio import processing_utils
13
  from typing import Optional
@@ -42,20 +44,56 @@ def ckpt_load_helper(modality, is_inpaint, is_style, common_instances=None):
42
  return loaded_model_list, common_instances
43
 
44
 
45
- loaded_model_list, common_instances = ckpt_load_helper(
46
- 'gligen-generation-text-box',
47
- is_inpaint=False, is_style=False, common_instances=None
48
- )
49
-
50
- loaded_model_list_inpaint = ckpt_load_helper(
51
- 'gligen-inpainting-text-box',
52
- is_inpaint=True, is_style=False, common_instances=common_instances
53
- )[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- loaded_model_list_style = ckpt_load_helper(
56
- 'gligen-generation-text-image-box',
57
- is_inpaint=False, is_style=True, common_instances=common_instances
58
- )[0]
59
 
60
 
61
  def load_clip_model():
@@ -143,7 +181,7 @@ def inference(task, language_instruction, grounding_instruction, inpainting_boxe
143
  image_list = [placeholder_image] * len(phrase_list) # placeholder input for visual prompt, which is disabled
144
 
145
  batch_size = int(batch_size)
146
- if not 1 <= batch_size <= 2:
147
  batch_size = 2
148
 
149
  if style_image == None:
@@ -183,13 +221,13 @@ def inference(task, language_instruction, grounding_instruction, inpainting_boxe
183
  with torch.autocast(device_type='cuda', dtype=torch.float16):
184
  if task == 'Grounded Generation':
185
  if style_image == None:
186
- return grounded_generation_box(loaded_model_list, instruction, *args, **kwargs)
187
  else:
188
- return grounded_generation_box(loaded_model_list_style, instruction, *args, **kwargs)
189
  elif task == 'Grounded Inpainting':
190
  assert image is not None
191
  instruction['input_image'] = image.convert("RGB")
192
- return grounded_generation_box(loaded_model_list_inpaint, instruction, *args, **kwargs)
193
 
194
 
195
  def draw_box(boxes=[], texts=[], img=None):
@@ -498,7 +536,7 @@ with Blocks(
498
  with gr.Column():
499
  alpha_sample = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (τ)")
500
  guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
501
- batch_size = gr.Slider(minimum=1, maximum=2, step=1, value=2, label="Number of Samples")
502
  append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption")
503
  use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False)
504
  with gr.Row():
 
7
  import numpy as np
8
  from PIL import Image, ImageDraw, ImageFont
9
  from functools import partial
10
+ from collections import Counter
11
  import math
12
+ import gc
13
 
14
  from gradio import processing_utils
15
  from typing import Optional
 
44
  return loaded_model_list, common_instances
45
 
46
 
47
+ class Instance:
48
+ def __init__(self, capacity = 2):
49
+ self.model_type = 'base'
50
+ self.loaded_model_list = {}
51
+ self.counter = Counter()
52
+ self.counter['base'] = 0
53
+ self.loaded_model_list['base'], self.common_instances = ckpt_load_helper(
54
+ 'gligen-generation-text-box',
55
+ is_inpaint=False, is_style=False, common_instances=None
56
+ )
57
+ self.capacity = capacity
58
+
59
+ def get_model(self, model_type):
60
+ if model_type in self.loaded_model_list:
61
+ self.counter[model_type] += 1
62
+ print(self.counter)
63
+ return self.loaded_model_list[model_type]
64
+
65
+ if self.capacity == len(self.loaded_model_list):
66
+ least_used_type = self.counter.most_common()[-1][0]
67
+ del self.loaded_model_list[least_used_type]
68
+ del self.counter[least_used_type]
69
+ gc.collect()
70
+ torch.cuda.empty_cache()
71
+
72
+ self.counter[model_type] = 1
73
+ self.loaded_model_list[model_type] = self._get_model(model_type)
74
+ print(self.counter)
75
+ return self.loaded_model_list[model_type]
76
+
77
+ def _get_model(self, model_type):
78
+ if model_type == 'base':
79
+ return ckpt_load_helper(
80
+ 'gligen-generation-text-box',
81
+ is_inpaint=False, is_style=False, common_instances=self.common_instances
82
+ )[0]
83
+ elif model_type == 'inpaint':
84
+ return ckpt_load_helper(
85
+ 'gligen-inpainting-text-box',
86
+ is_inpaint=True, is_style=False, common_instances=self.common_instances
87
+ )[0]
88
+ elif model_type == 'style':
89
+ return ckpt_load_helper(
90
+ 'gligen-generation-text-image-box',
91
+ is_inpaint=False, is_style=True, common_instances=self.common_instances
92
+ )[0]
93
+
94
+ assert False
95
 
96
+ instance = Instance()
 
 
 
97
 
98
 
99
  def load_clip_model():
 
181
  image_list = [placeholder_image] * len(phrase_list) # placeholder input for visual prompt, which is disabled
182
 
183
  batch_size = int(batch_size)
184
+ if not 1 <= batch_size <= 4:
185
  batch_size = 2
186
 
187
  if style_image == None:
 
221
  with torch.autocast(device_type='cuda', dtype=torch.float16):
222
  if task == 'Grounded Generation':
223
  if style_image == None:
224
+ return grounded_generation_box(instance.get_model('base'), instruction, *args, **kwargs)
225
  else:
226
+ return grounded_generation_box(instance.get_model('style'), instruction, *args, **kwargs)
227
  elif task == 'Grounded Inpainting':
228
  assert image is not None
229
  instruction['input_image'] = image.convert("RGB")
230
+ return grounded_generation_box(instance.get_model('inpaint'), instruction, *args, **kwargs)
231
 
232
 
233
  def draw_box(boxes=[], texts=[], img=None):
 
536
  with gr.Column():
537
  alpha_sample = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (τ)")
538
  guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
539
+ batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=2, label="Number of Samples")
540
  append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption")
541
  use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False)
542
  with gr.Row():