hysts commited on
Commit
e2d3c58
1 Parent(s): 0fb46e4
Files changed (2) hide show
  1. app.py +1 -2
  2. model.py +14 -2
app.py CHANGED
@@ -68,8 +68,7 @@ def main():
68
  with gr.Row():
69
  label_image = gr.Image(label='Label Image',
70
  type='numpy',
71
- elem_id='label-image',
72
- interactive=False)
73
  with gr.Row():
74
  shape_text = gr.Textbox(
75
  label='Shape Description',
68
  with gr.Row():
69
  label_image = gr.Image(label='Label Image',
70
  type='numpy',
71
+ elem_id='label-image')
 
72
  with gr.Row():
73
  shape_text = gr.Textbox(
74
  label='Shape Description',
model.py CHANGED
@@ -64,6 +64,7 @@ class Model:
64
  self.config['device'] = device
65
  self._download_models()
66
  self.model = SampleFromPoseModel(self.config)
 
67
 
68
  def _load_config(self) -> dict:
69
  path = 'Text2Human/configs/sample_from_pose.yml'
@@ -96,11 +97,20 @@ class Model:
96
  return data
97
 
98
  @staticmethod
99
- def process_mask(mask: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
100
  seg_map = np.full(mask.shape[:-1], -1)
101
  for index, color in enumerate(COLOR_LIST):
102
  seg_map[np.sum(mask == color, axis=2) == 3] = index
103
- assert (seg_map != -1).all()
 
104
  return seg_map
105
 
106
  @staticmethod
@@ -140,6 +150,8 @@ class Model:
140
  logger.debug(f'{sample_steps=}')
141
  mask = label_image.copy()
142
  seg_map = self.process_mask(mask)
 
 
143
  self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
144
  0).to(self.model.device)
145
  self.model.generate_quantized_segm()
64
  self.config['device'] = device
65
  self._download_models()
66
  self.model = SampleFromPoseModel(self.config)
67
+ self.model.batch_size = 1
68
 
69
  def _load_config(self) -> dict:
70
  path = 'Text2Human/configs/sample_from_pose.yml'
97
  return data
98
 
99
  @staticmethod
100
+ def process_mask(mask: np.ndarray) -> np.ndarray:
101
+ logger.debug(f'{mask.shape=}')
102
+ if mask.shape != (512, 256, 3):
103
+ return None
104
+ colors = np.unique(mask.reshape(-1, 3), axis=0)
105
+ colors = set(map(tuple, colors.tolist()))
106
+ logger.debug(f'{colors=}')
107
+ logger.debug(f'{colors - set(COLOR_LIST)=}')
108
+
109
  seg_map = np.full(mask.shape[:-1], -1)
110
  for index, color in enumerate(COLOR_LIST):
111
  seg_map[np.sum(mask == color, axis=2) == 3] = index
112
+ if not (seg_map != -1).all():
113
+ return None
114
  return seg_map
115
 
116
  @staticmethod
150
  logger.debug(f'{sample_steps=}')
151
  mask = label_image.copy()
152
  seg_map = self.process_mask(mask)
153
+ if seg_map is None:
154
+ return
155
  self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
156
  0).to(self.model.device)
157
  self.model.generate_quantized_segm()