liuhaotian commited on
Commit
42dc9ff
1 Parent(s): 4f97a73
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -56,21 +56,24 @@ class Instance:
56
  self.model_type = 'base'
57
  self.loaded_model_list = {}
58
  self.counter = Counter()
59
- self.counter['base'] = 0
60
  self.loaded_model_list['base'], self.common_instances = ckpt_load_helper(
61
  'gligen-generation-text-box',
62
  is_inpaint=False, is_style=False, common_instances=None
63
  )
64
  self.capacity = capacity
65
 
66
- def _log(self, batch_size, instruction, phrase_list):
 
 
67
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
68
- print(f'[{current_time}] ' + str(dict(self.counter)), f'samples: {batch_size}', f'prompt: {instruction}', f'phrases: {phrase_list}', sep=', ')
 
 
69
 
70
  def get_model(self, model_type, batch_size, instruction, phrase_list):
71
  if model_type in self.loaded_model_list:
72
- self.counter[model_type] += 1
73
- self._log(batch_size, instruction, phrase_list)
74
  return self.loaded_model_list[model_type]
75
 
76
  if self.capacity == len(self.loaded_model_list):
@@ -80,9 +83,8 @@ class Instance:
80
  gc.collect()
81
  torch.cuda.empty_cache()
82
 
83
- self.counter[model_type] = 1
84
  self.loaded_model_list[model_type] = self._get_model(model_type)
85
- self._log(batch_size, instruction, phrase_list)
86
  return self.loaded_model_list[model_type]
87
 
88
  def _get_model(self, model_type):
@@ -299,7 +301,8 @@ def generate(task, language_instruction, grounding_texts, sketch_pad,
299
  if len(boxes) != len(grounding_texts):
300
  if len(boxes) < len(grounding_texts):
301
  raise ValueError("""The number of boxes should be equal to the number of grounding objects.
302
- Number of boxes drawn: {}, number of grounding tokens: {}""".format(len(boxes), len(grounding_texts)))
 
303
  grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
304
 
305
  boxes = (np.asarray(boxes) / 512).tolist()
 
56
  self.model_type = 'base'
57
  self.loaded_model_list = {}
58
  self.counter = Counter()
59
+ self.global_counter = Counter()
60
  self.loaded_model_list['base'], self.common_instances = ckpt_load_helper(
61
  'gligen-generation-text-box',
62
  is_inpaint=False, is_style=False, common_instances=None
63
  )
64
  self.capacity = capacity
65
 
66
+ def _log(self, model_type, batch_size, instruction, phrase_list):
67
+ self.counter[model_type] += 1
68
+ self.global_counter[model_type] += 1
69
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
70
+ print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format(
71
+ current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list
72
+ ))
73
 
74
  def get_model(self, model_type, batch_size, instruction, phrase_list):
75
  if model_type in self.loaded_model_list:
76
+ self._log(model_type, batch_size, instruction, phrase_list)
 
77
  return self.loaded_model_list[model_type]
78
 
79
  if self.capacity == len(self.loaded_model_list):
 
83
  gc.collect()
84
  torch.cuda.empty_cache()
85
 
 
86
  self.loaded_model_list[model_type] = self._get_model(model_type)
87
+ self._log(model_type, batch_size, instruction, phrase_list)
88
  return self.loaded_model_list[model_type]
89
 
90
  def _get_model(self, model_type):
 
301
  if len(boxes) != len(grounding_texts):
302
  if len(boxes) < len(grounding_texts):
303
  raise ValueError("""The number of boxes should be equal to the number of grounding objects.
304
+ Number of boxes drawn: {}, number of grounding tokens: {}.
305
+ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
306
  grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
307
 
308
  boxes = (np.asarray(boxes) / 512).tolist()