hysts commited on
Commit
e89b524
1 Parent(s): f77ad3f
Files changed (2) hide show
  1. app.py +10 -5
  2. model.py +9 -6
app.py CHANGED
@@ -57,6 +57,7 @@ def main():
57
  input_image = gr.Image(label='Input Pose Image',
58
  type='pil',
59
  elem_id='input-image')
 
60
  with gr.Row():
61
  paths = sorted(pathlib.Path('pose_images').glob('*.png'))
62
  example_images = gr.Dataset(components=[input_image],
@@ -112,18 +113,22 @@ Note: Currently, only 5 types of textures are supported, i.e., pure color, strip
112
  gr.Markdown(FOOTER)
113
 
114
  input_image.change(fn=model.process_pose_image,
115
- inputs=[input_image],
116
- outputs=None)
117
  generate_label_button.click(fn=model.generate_label_image,
118
- inputs=[shape_text],
119
- outputs=[label_image])
 
 
 
120
  generate_human_button.click(fn=model.generate_human,
121
  inputs=[
 
122
  texture_text,
123
  sample_steps,
124
  seed,
125
  ],
126
- outputs=[result])
127
  example_images.click(fn=set_example_image,
128
  inputs=example_images,
129
  outputs=example_images.components)
57
  input_image = gr.Image(label='Input Pose Image',
58
  type='pil',
59
  elem_id='input-image')
60
+ pose_data = gr.Variable()
61
  with gr.Row():
62
  paths = sorted(pathlib.Path('pose_images').glob('*.png'))
63
  example_images = gr.Dataset(components=[input_image],
113
  gr.Markdown(FOOTER)
114
 
115
  input_image.change(fn=model.process_pose_image,
116
+ inputs=input_image,
117
+ outputs=pose_data)
118
  generate_label_button.click(fn=model.generate_label_image,
119
+ inputs=[
120
+ pose_data,
121
+ shape_text,
122
+ ],
123
+ outputs=label_image)
124
  generate_human_button.click(fn=model.generate_human,
125
  inputs=[
126
+ label_image,
127
  texture_text,
128
  sample_steps,
129
  seed,
130
  ],
131
+ outputs=result)
132
  example_images.click(fn=set_example_image,
133
  inputs=example_images,
134
  outputs=example_images.components)
model.py CHANGED
@@ -98,29 +98,32 @@ class Model:
98
  result = np.asarray(result[0, :, :, :], dtype=np.uint8)
99
  return result
100
 
101
- def process_pose_image(self, pose_image: PIL.Image.Image) -> None:
102
  if pose_image is None:
103
  return
104
  data = self.preprocess_pose_image(pose_image)
105
  self.model.feed_pose_data(data)
 
106
 
107
- def generate_label_image(self, shape_text: str) -> np.ndarray:
 
 
108
  shape_attributes = generate_shape_attributes(shape_text)
109
  shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
110
  self.model.feed_shape_attributes(shape_attributes)
111
  self.model.generate_parsing_map()
112
  self.model.generate_quantized_segm()
113
  colored_segm = self.model.palette_result(self.model.segm[0].cpu())
 
114
 
115
- mask = colored_segm.copy()
 
 
116
  seg_map = self.process_mask(mask)
117
  self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
118
  0).to(self.model.device)
119
  self.model.generate_quantized_segm()
120
- return colored_segm
121
 
122
- def generate_human(self, texture_text: str, sample_steps: int,
123
- seed: int) -> np.ndarray:
124
  set_random_seed(seed)
125
 
126
  texture_attributes = generate_texture_attributes(texture_text)
98
  result = np.asarray(result[0, :, :, :], dtype=np.uint8)
99
  return result
100
 
101
+ def process_pose_image(self, pose_image: PIL.Image.Image) -> torch.Tensor:
102
  if pose_image is None:
103
  return
104
  data = self.preprocess_pose_image(pose_image)
105
  self.model.feed_pose_data(data)
106
+ return data
107
 
108
+ def generate_label_image(self, pose_data: torch.Tensor,
109
+ shape_text: str) -> np.ndarray:
110
+ self.model.feed_pose_data(pose_data)
111
  shape_attributes = generate_shape_attributes(shape_text)
112
  shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
113
  self.model.feed_shape_attributes(shape_attributes)
114
  self.model.generate_parsing_map()
115
  self.model.generate_quantized_segm()
116
  colored_segm = self.model.palette_result(self.model.segm[0].cpu())
117
+ return colored_segm
118
 
119
+ def generate_human(self, label_image: np.ndarray, texture_text: str,
120
+ sample_steps: int, seed: int) -> np.ndarray:
121
+ mask = label_image.copy()
122
  seg_map = self.process_mask(mask)
123
  self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
124
  0).to(self.model.device)
125
  self.model.generate_quantized_segm()
 
126
 
 
 
127
  set_random_seed(seed)
128
 
129
  texture_attributes = generate_texture_attributes(texture_text)