hysts HF staff commited on
Commit
fa7b0cc
1 Parent(s): 16ea01f
Files changed (6) hide show
  1. .pre-commit-config.yaml +60 -35
  2. .style.yapf +0 -5
  3. README.md +1 -1
  4. app.py +59 -66
  5. model.py +16 -21
  6. style.css +1 -4
.pre-commit-config.yaml CHANGED
@@ -1,36 +1,61 @@
1
- exclude: ^(Text2Human|patch)
2
  repos:
3
- - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
- hooks:
6
- - id: check-executables-have-shebangs
7
- - id: check-json
8
- - id: check-merge-conflict
9
- - id: check-shebang-scripts-are-executable
10
- - id: check-toml
11
- - id: check-yaml
12
- - id: double-quote-string-fixer
13
- - id: end-of-file-fixer
14
- - id: mixed-line-ending
15
- args: ['--fix=lf']
16
- - id: requirements-txt-fixer
17
- - id: trailing-whitespace
18
- - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
- hooks:
21
- - id: docformatter
22
- args: ['--in-place']
23
- - repo: https://github.com/pycqa/isort
24
- rev: 5.12.0
25
- hooks:
26
- - id: isort
27
- - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
- hooks:
30
- - id: mypy
31
- args: ['--ignore-missing-imports']
32
- - repo: https://github.com/google/yapf
33
- rev: v0.32.0
34
- hooks:
35
- - id: yapf
36
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^patch
2
  repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.6.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ["--fix=lf"]
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.7.5
19
+ hooks:
20
+ - id: docformatter
21
+ args: ["--in-place"]
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.13.2
24
+ hooks:
25
+ - id: isort
26
+ args: ["--profile", "black"]
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v1.10.0
29
+ hooks:
30
+ - id: mypy
31
+ args: ["--ignore-missing-imports"]
32
+ additional_dependencies:
33
+ [
34
+ "types-python-slugify",
35
+ "types-requests",
36
+ "types-PyYAML",
37
+ "types-pytz",
38
+ ]
39
+ - repo: https://github.com/psf/black
40
+ rev: 24.4.2
41
+ hooks:
42
+ - id: black
43
+ language_version: python3.10
44
+ args: ["--line-length", "119"]
45
+ - repo: https://github.com/kynan/nbstripout
46
+ rev: 0.7.1
47
+ hooks:
48
+ - id: nbstripout
49
+ args:
50
+ [
51
+ "--extra-keys",
52
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
53
+ ]
54
+ - repo: https://github.com/nbQA-dev/nbQA
55
+ rev: 1.8.5
56
+ hooks:
57
+ - id: nbqa-black
58
+ - id: nbqa-pyupgrade
59
+ args: ["--py37-plus"]
60
+ - id: nbqa-isort
61
+ args: ["--float-to-top"]
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🏃
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
 
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
app.py CHANGED
@@ -8,26 +8,30 @@ import random
8
  import shlex
9
  import subprocess
10
 
11
- import gradio as gr
12
- import numpy as np
 
13
 
14
- if os.getenv('SYSTEM') == 'spaces':
15
  import mim
16
 
17
- mim.uninstall('mmcv-full', confirm_yes=True)
18
- mim.install('mmcv-full==1.5.2', is_yes=True)
19
 
20
- with open('patch') as f:
21
- subprocess.run(shlex.split('patch -p1'), cwd='Text2Human', stdin=f)
 
 
 
 
22
 
23
  from model import Model
24
 
25
- DESCRIPTION = '''# [Text2Human](https://github.com/yumingj/Text2Human)
26
 
27
  You can modify sample steps and seeds. By varying seeds, you can sample different human images under the same pose, shape description, and texture description. The larger the sample steps, the better quality of the generated images. (The default value of sample steps is 256 in the original repo.)
28
 
29
  Label image generation step can be skipped. However, in that case, the input label image must be 512x256 in size and must contain only the specified colors.
30
- '''
31
 
32
  MAX_SEED = np.iinfo(np.int32).max
33
 
@@ -40,76 +44,61 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
40
 
41
  model = Model()
42
 
43
- with gr.Blocks(css='style.css') as demo:
44
  gr.Markdown(DESCRIPTION)
45
 
46
  with gr.Row():
47
  with gr.Column():
48
  with gr.Row():
49
- input_image = gr.Image(label='Input Pose Image',
50
- type='pil',
51
- elem_id='input-image')
52
  pose_data = gr.State()
53
  with gr.Row():
54
- paths = sorted(pathlib.Path('pose_images').glob('*.png'))
55
- gr.Examples(examples=[[path.as_posix()] for path in paths],
56
- inputs=input_image)
57
 
58
  with gr.Row():
59
  shape_text = gr.Textbox(
60
- label='Shape Description',
61
- placeholder=
62
- '''<gender>, <sleeve length>, <length of lower clothing>, <outer clothing type>, <other accessories1>, ...
63
- Note: The outer clothing type and accessories can be omitted.''')
64
  with gr.Row():
65
  gr.Examples(
66
- examples=[['man, sleeveless T-shirt, long pants'],
67
- ['woman, short-sleeve T-shirt, short jeans']],
68
- inputs=shape_text)
69
  with gr.Row():
70
- generate_label_button = gr.Button('Generate Label Image')
71
 
72
  with gr.Column():
73
  with gr.Row():
74
- label_image = gr.Image(label='Label Image',
75
- type='numpy',
76
- elem_id='label-image')
77
 
78
  with gr.Row():
79
  texture_text = gr.Textbox(
80
- label='Texture Description',
81
- placeholder=
82
- '''<upper clothing texture>, <lower clothing texture>, <outer clothing texture>
83
- Note: Currently, only 5 types of textures are supported, i.e., pure color, stripe/spline, plaid/lattice, floral, denim.'''
84
  )
85
  with gr.Row():
86
- gr.Examples(examples=[
87
- ['pure color, denim'],
88
- ['floral, stripe'],
89
- ],
90
- inputs=texture_text)
 
 
91
  with gr.Row():
92
- sample_steps = gr.Slider(label='Sample Steps',
93
- minimum=10,
94
- maximum=300,
95
- step=1,
96
- value=256)
97
  with gr.Row():
98
- seed = gr.Slider(label='Seed',
99
- minimum=0,
100
- maximum=MAX_SEED,
101
- step=1,
102
- value=0)
103
- randomize_seed = gr.Checkbox(label='Randomize seed',
104
- value=True)
105
  with gr.Row():
106
- generate_human_button = gr.Button('Generate Human')
107
 
108
  with gr.Column():
109
  with gr.Row():
110
- result = gr.Image(label='Result',
111
- type='numpy',
112
- elem_id='result-image')
113
 
114
  input_image.change(
115
  fn=model.process_pose_image,
@@ -124,17 +113,21 @@ Note: Currently, only 5 types of textures are supported, i.e., pure color, strip
124
  ],
125
  outputs=label_image,
126
  )
127
- generate_human_button.click(fn=randomize_seed_fn,
128
- inputs=[seed, randomize_seed],
129
- outputs=seed,
130
- queue=False).then(
131
- fn=model.generate_human,
132
- inputs=[
133
- label_image,
134
- texture_text,
135
- sample_steps,
136
- seed,
137
- ],
138
- outputs=result,
139
- )
140
- demo.queue(max_size=10).launch()
 
 
 
 
 
8
  import shlex
9
  import subprocess
10
 
11
+ if os.getenv("SYSTEM") == "spaces":
12
+ subprocess.run(shlex.split("pip install click==7.1.2"))
13
+ subprocess.run(shlex.split("pip install typer==0.9.4"))
14
 
 
15
  import mim
16
 
17
+ mim.uninstall("mmcv-full", confirm_yes=True)
18
+ mim.install("mmcv-full==1.5.2", is_yes=True)
19
 
20
+ with open("patch") as f:
21
+ subprocess.run(shlex.split("patch -p1"), cwd="Text2Human", stdin=f)
22
+
23
+
24
+ import gradio as gr
25
+ import numpy as np
26
 
27
  from model import Model
28
 
29
+ DESCRIPTION = """# [Text2Human](https://github.com/yumingj/Text2Human)
30
 
31
  You can modify sample steps and seeds. By varying seeds, you can sample different human images under the same pose, shape description, and texture description. The larger the sample steps, the better quality of the generated images. (The default value of sample steps is 256 in the original repo.)
32
 
33
  Label image generation step can be skipped. However, in that case, the input label image must be 512x256 in size and must contain only the specified colors.
34
+ """
35
 
36
  MAX_SEED = np.iinfo(np.int32).max
37
 
 
44
 
45
  model = Model()
46
 
47
+ with gr.Blocks(css="style.css") as demo:
48
  gr.Markdown(DESCRIPTION)
49
 
50
  with gr.Row():
51
  with gr.Column():
52
  with gr.Row():
53
+ input_image = gr.Image(label="Input Pose Image", type="pil", elem_id="input-image")
 
 
54
  pose_data = gr.State()
55
  with gr.Row():
56
+ paths = sorted(pathlib.Path("pose_images").glob("*.png"))
57
+ gr.Examples(examples=[[path.as_posix()] for path in paths], inputs=input_image)
 
58
 
59
  with gr.Row():
60
  shape_text = gr.Textbox(
61
+ label="Shape Description",
62
+ placeholder="""<gender>, <sleeve length>, <length of lower clothing>, <outer clothing type>, <other accessories1>, ...
63
+ Note: The outer clothing type and accessories can be omitted.""",
64
+ )
65
  with gr.Row():
66
  gr.Examples(
67
+ examples=[["man, sleeveless T-shirt, long pants"], ["woman, short-sleeve T-shirt, short jeans"]],
68
+ inputs=shape_text,
69
+ )
70
  with gr.Row():
71
+ generate_label_button = gr.Button("Generate Label Image")
72
 
73
  with gr.Column():
74
  with gr.Row():
75
+ label_image = gr.Image(label="Label Image", type="numpy", elem_id="label-image")
 
 
76
 
77
  with gr.Row():
78
  texture_text = gr.Textbox(
79
+ label="Texture Description",
80
+ placeholder="""<upper clothing texture>, <lower clothing texture>, <outer clothing texture>
81
+ Note: Currently, only 5 types of textures are supported, i.e., pure color, stripe/spline, plaid/lattice, floral, denim.""",
 
82
  )
83
  with gr.Row():
84
+ gr.Examples(
85
+ examples=[
86
+ ["pure color, denim"],
87
+ ["floral, stripe"],
88
+ ],
89
+ inputs=texture_text,
90
+ )
91
  with gr.Row():
92
+ sample_steps = gr.Slider(label="Sample Steps", minimum=10, maximum=300, step=1, value=256)
 
 
 
 
93
  with gr.Row():
94
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
95
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
96
  with gr.Row():
97
+ generate_human_button = gr.Button("Generate Human")
98
 
99
  with gr.Column():
100
  with gr.Row():
101
+ result = gr.Image(label="Result", type="numpy", elem_id="result-image")
 
 
102
 
103
  input_image.change(
104
  fn=model.process_pose_image,
 
113
  ],
114
  outputs=label_image,
115
  )
116
+ generate_human_button.click(
117
+ fn=randomize_seed_fn,
118
+ inputs=[seed, randomize_seed],
119
+ outputs=seed,
120
+ queue=False,
121
+ ).then(
122
+ fn=model.generate_human,
123
+ inputs=[
124
+ label_image,
125
+ texture_text,
126
+ sample_steps,
127
+ seed,
128
+ ],
129
+ outputs=result,
130
+ )
131
+
132
+ if __name__ == "__main__":
133
+ demo.queue(max_size=10).launch()
model.py CHANGED
@@ -9,11 +9,10 @@ import numpy as np
9
  import PIL.Image
10
  import torch
11
 
12
- sys.path.insert(0, 'Text2Human')
13
 
14
  from models.sample_model import SampleFromPoseModel
15
- from utils.language_utils import (generate_shape_attributes,
16
- generate_texture_attributes)
17
  from utils.options import dict_to_nonedict, parse
18
  from utils.util import set_random_seed
19
 
@@ -47,37 +46,36 @@ COLOR_LIST = [
47
 
48
  class Model:
49
  def __init__(self):
50
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
51
  self.config = self._load_config()
52
- self.config['device'] = device.type
53
  self._download_models()
54
  self.model = SampleFromPoseModel(self.config)
55
  self.model.batch_size = 1
56
 
57
  def _load_config(self) -> dict:
58
- path = 'Text2Human/configs/sample_from_pose.yml'
59
  config = parse(path, is_train=False)
60
  config = dict_to_nonedict(config)
61
  return config
62
 
63
  def _download_models(self) -> None:
64
- model_dir = pathlib.Path('pretrained_models')
65
  if model_dir.exists():
66
  return
67
- path = huggingface_hub.hf_hub_download('yumingj/Text2Human_SSHQ',
68
- 'pretrained_models.zip')
69
  model_dir.mkdir()
70
  with zipfile.ZipFile(path) as f:
71
  f.extractall(model_dir)
72
 
73
  @staticmethod
74
  def preprocess_pose_image(image: PIL.Image.Image) -> torch.Tensor:
75
- image = np.array(
76
- image.resize(
77
- size=(256, 512),
78
- resample=PIL.Image.Resampling.LANCZOS))[:, :, 2:].transpose(
79
- 2, 0, 1).astype(np.float32)
80
- image = image / 12. - 1
81
  data = torch.from_numpy(image).unsqueeze(1)
82
  return data
83
 
@@ -107,8 +105,7 @@ class Model:
107
  self.model.feed_pose_data(data)
108
  return data
109
 
110
- def generate_label_image(self, pose_data: torch.Tensor,
111
- shape_text: str) -> np.ndarray:
112
  if pose_data is None:
113
  return
114
  self.model.feed_pose_data(pose_data)
@@ -120,16 +117,14 @@ class Model:
120
  colored_segm = self.model.palette_result(self.model.segm[0].cpu())
121
  return colored_segm
122
 
123
- def generate_human(self, label_image: np.ndarray, texture_text: str,
124
- sample_steps: int, seed: int) -> np.ndarray:
125
  if label_image is None:
126
  return
127
  mask = label_image.copy()
128
  seg_map = self.process_mask(mask)
129
  if seg_map is None:
130
  return
131
- self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
132
- 0).to(self.model.device)
133
  self.model.generate_quantized_segm()
134
 
135
  set_random_seed(seed)
 
9
  import PIL.Image
10
  import torch
11
 
12
+ sys.path.insert(0, "Text2Human")
13
 
14
  from models.sample_model import SampleFromPoseModel
15
+ from utils.language_utils import generate_shape_attributes, generate_texture_attributes
 
16
  from utils.options import dict_to_nonedict, parse
17
  from utils.util import set_random_seed
18
 
 
46
 
47
  class Model:
48
  def __init__(self):
49
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
  self.config = self._load_config()
51
+ self.config["device"] = device.type
52
  self._download_models()
53
  self.model = SampleFromPoseModel(self.config)
54
  self.model.batch_size = 1
55
 
56
  def _load_config(self) -> dict:
57
+ path = "Text2Human/configs/sample_from_pose.yml"
58
  config = parse(path, is_train=False)
59
  config = dict_to_nonedict(config)
60
  return config
61
 
62
  def _download_models(self) -> None:
63
+ model_dir = pathlib.Path("pretrained_models")
64
  if model_dir.exists():
65
  return
66
+ path = huggingface_hub.hf_hub_download("yumingj/Text2Human_SSHQ", "pretrained_models.zip")
 
67
  model_dir.mkdir()
68
  with zipfile.ZipFile(path) as f:
69
  f.extractall(model_dir)
70
 
71
  @staticmethod
72
  def preprocess_pose_image(image: PIL.Image.Image) -> torch.Tensor:
73
+ image = (
74
+ np.array(image.resize(size=(256, 512), resample=PIL.Image.Resampling.LANCZOS))[:, :, 2:]
75
+ .transpose(2, 0, 1)
76
+ .astype(np.float32)
77
+ )
78
+ image = image / 12.0 - 1
79
  data = torch.from_numpy(image).unsqueeze(1)
80
  return data
81
 
 
105
  self.model.feed_pose_data(data)
106
  return data
107
 
108
+ def generate_label_image(self, pose_data: torch.Tensor, shape_text: str) -> np.ndarray:
 
109
  if pose_data is None:
110
  return
111
  self.model.feed_pose_data(pose_data)
 
117
  colored_segm = self.model.palette_result(self.model.segm[0].cpu())
118
  return colored_segm
119
 
120
+ def generate_human(self, label_image: np.ndarray, texture_text: str, sample_steps: int, seed: int) -> np.ndarray:
 
121
  if label_image is None:
122
  return
123
  mask = label_image.copy()
124
  seg_map = self.process_mask(mask)
125
  if seg_map is None:
126
  return
127
+ self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(0).to(self.model.device)
 
128
  self.model.generate_quantized_segm()
129
 
130
  set_random_seed(seed)
style.css CHANGED
@@ -1,5 +1,6 @@
1
  h1 {
2
  text-align: center;
 
3
  }
4
  #input-image {
5
  max-height: 300px;
@@ -10,7 +11,3 @@ h1 {
10
  #result-image {
11
  height: 300px;
12
  }
13
- img#visitor-badge {
14
- display: block;
15
- margin: auto;
16
- }
 
1
  h1 {
2
  text-align: center;
3
+ display: block;
4
  }
5
  #input-image {
6
  max-height: 300px;
 
11
  #result-image {
12
  height: 300px;
13
  }