Spaces:
Running
Running
hysts
commited on
Commit
•
e2d3c58
1
Parent(s):
0fb46e4
debug
Browse files
app.py
CHANGED
@@ -68,8 +68,7 @@ def main():
|
|
68 |
with gr.Row():
|
69 |
label_image = gr.Image(label='Label Image',
|
70 |
type='numpy',
|
71 |
-
elem_id='label-image'
|
72 |
-
interactive=False)
|
73 |
with gr.Row():
|
74 |
shape_text = gr.Textbox(
|
75 |
label='Shape Description',
|
68 |
with gr.Row():
|
69 |
label_image = gr.Image(label='Label Image',
|
70 |
type='numpy',
|
71 |
+
elem_id='label-image')
|
|
|
72 |
with gr.Row():
|
73 |
shape_text = gr.Textbox(
|
74 |
label='Shape Description',
|
model.py
CHANGED
@@ -64,6 +64,7 @@ class Model:
|
|
64 |
self.config['device'] = device
|
65 |
self._download_models()
|
66 |
self.model = SampleFromPoseModel(self.config)
|
|
|
67 |
|
68 |
def _load_config(self) -> dict:
|
69 |
path = 'Text2Human/configs/sample_from_pose.yml'
|
@@ -96,11 +97,20 @@ class Model:
|
|
96 |
return data
|
97 |
|
98 |
@staticmethod
|
99 |
-
def process_mask(mask:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
seg_map = np.full(mask.shape[:-1], -1)
|
101 |
for index, color in enumerate(COLOR_LIST):
|
102 |
seg_map[np.sum(mask == color, axis=2) == 3] = index
|
103 |
-
|
|
|
104 |
return seg_map
|
105 |
|
106 |
@staticmethod
|
@@ -140,6 +150,8 @@ class Model:
|
|
140 |
logger.debug(f'{sample_steps=}')
|
141 |
mask = label_image.copy()
|
142 |
seg_map = self.process_mask(mask)
|
|
|
|
|
143 |
self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
|
144 |
0).to(self.model.device)
|
145 |
self.model.generate_quantized_segm()
|
64 |
self.config['device'] = device
|
65 |
self._download_models()
|
66 |
self.model = SampleFromPoseModel(self.config)
|
67 |
+
self.model.batch_size = 1
|
68 |
|
69 |
def _load_config(self) -> dict:
|
70 |
path = 'Text2Human/configs/sample_from_pose.yml'
|
97 |
return data
|
98 |
|
99 |
@staticmethod
|
100 |
+
def process_mask(mask: np.ndarray) -> np.ndarray:
|
101 |
+
logger.debug(f'{mask.shape=}')
|
102 |
+
if mask.shape != (512, 256, 3):
|
103 |
+
return None
|
104 |
+
colors = np.unique(mask.reshape(-1, 3), axis=0)
|
105 |
+
colors = set(map(tuple, colors.tolist()))
|
106 |
+
logger.debug(f'{colors=}')
|
107 |
+
logger.debug(f'{colors - set(COLOR_LIST)=}')
|
108 |
+
|
109 |
seg_map = np.full(mask.shape[:-1], -1)
|
110 |
for index, color in enumerate(COLOR_LIST):
|
111 |
seg_map[np.sum(mask == color, axis=2) == 3] = index
|
112 |
+
if not (seg_map != -1).all():
|
113 |
+
return None
|
114 |
return seg_map
|
115 |
|
116 |
@staticmethod
|
150 |
logger.debug(f'{sample_steps=}')
|
151 |
mask = label_image.copy()
|
152 |
seg_map = self.process_mask(mask)
|
153 |
+
if seg_map is None:
|
154 |
+
return
|
155 |
self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
|
156 |
0).to(self.model.device)
|
157 |
self.model.generate_quantized_segm()
|