Spaces:
Running
Running
hysts
commited on
Commit
·
b85284b
1
Parent(s):
e4d1395
Add files
Browse files- .gitattributes +1 -0
- .gitignore +1 -0
- .gitmodules +3 -0
- .pre-commit-config.yaml +46 -0
- .style.yapf +5 -0
- Text2Human +1 -0
- app.py +157 -0
- model.py +134 -0
- patch +169 -0
- pose_images/000.png +3 -0
- pose_images/001.png +3 -0
- pose_images/002.png +3 -0
- pose_images/003.png +3 -0
- pose_images/004.png +3 -0
- pose_images/005.png +3 -0
- requirements.txt +11 -0
.gitattributes
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
2 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
3 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
4 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pretrained_models
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "Text2Human"]
|
2 |
+
path = Text2Human
|
3 |
+
url = https://github.com/yumingj/Text2Human
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
exclude: ^(Text2Human|patch)
|
2 |
+
repos:
|
3 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
4 |
+
rev: v4.2.0
|
5 |
+
hooks:
|
6 |
+
- id: check-executables-have-shebangs
|
7 |
+
- id: check-json
|
8 |
+
- id: check-merge-conflict
|
9 |
+
- id: check-shebang-scripts-are-executable
|
10 |
+
- id: check-toml
|
11 |
+
- id: check-yaml
|
12 |
+
- id: double-quote-string-fixer
|
13 |
+
- id: end-of-file-fixer
|
14 |
+
- id: mixed-line-ending
|
15 |
+
args: ['--fix=lf']
|
16 |
+
- id: requirements-txt-fixer
|
17 |
+
- id: trailing-whitespace
|
18 |
+
- repo: https://github.com/myint/docformatter
|
19 |
+
rev: v1.4
|
20 |
+
hooks:
|
21 |
+
- id: docformatter
|
22 |
+
args: ['--in-place']
|
23 |
+
- repo: https://github.com/pycqa/isort
|
24 |
+
rev: 5.10.1
|
25 |
+
hooks:
|
26 |
+
- id: isort
|
27 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
28 |
+
rev: v0.812
|
29 |
+
hooks:
|
30 |
+
- id: mypy
|
31 |
+
args: ['--ignore-missing-imports']
|
32 |
+
- repo: https://github.com/google/yapf
|
33 |
+
rev: v0.32.0
|
34 |
+
hooks:
|
35 |
+
- id: yapf
|
36 |
+
args: ['--parallel', '--in-place']
|
37 |
+
- repo: https://github.com/kynan/nbstripout
|
38 |
+
rev: 0.5.0
|
39 |
+
hooks:
|
40 |
+
- id: nbstripout
|
41 |
+
args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
|
42 |
+
- repo: https://github.com/nbQA-dev/nbQA
|
43 |
+
rev: 1.3.1
|
44 |
+
hooks:
|
45 |
+
- id: nbqa-isort
|
46 |
+
- id: nbqa-yapf
|
.style.yapf
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[style]
|
2 |
+
based_on_style = pep8
|
3 |
+
blank_line_before_nested_class_or_def = false
|
4 |
+
spaces_before_comment = 2
|
5 |
+
split_before_logical_operator = true
|
Text2Human
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 6d38607df89651704000d0e6571bfc640d185a77
|
app.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import pathlib
|
8 |
+
import subprocess
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
if os.getenv('SYSTEM') == 'spaces':
|
13 |
+
subprocess.call('pip uninstall -y mmcv-full'.split())
|
14 |
+
subprocess.call('pip install mmcv-full==1.5.2'.split())
|
15 |
+
subprocess.call('git apply ../patch'.split(), cwd='Text2Human')
|
16 |
+
|
17 |
+
from model import Model
|
18 |
+
|
19 |
+
|
20 |
+
def parse_args() -> argparse.Namespace:
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument('--device', type=str, default='cpu')
|
23 |
+
parser.add_argument('--theme', type=str)
|
24 |
+
parser.add_argument('--share', action='store_true')
|
25 |
+
parser.add_argument('--port', type=int)
|
26 |
+
parser.add_argument('--disable-queue',
|
27 |
+
dest='enable_queue',
|
28 |
+
action='store_false')
|
29 |
+
return parser.parse_args()
|
30 |
+
|
31 |
+
|
32 |
+
def set_example_image(example: list) -> dict:
|
33 |
+
return gr.Image.update(value=example[0])
|
34 |
+
|
35 |
+
|
36 |
+
def set_example_text(example: list) -> dict:
|
37 |
+
return gr.Textbox.update(value=example[0])
|
38 |
+
|
39 |
+
|
40 |
+
def main():
|
41 |
+
args = parse_args()
|
42 |
+
model = Model(args.device)
|
43 |
+
|
44 |
+
css = '''
|
45 |
+
h1#title {
|
46 |
+
text-align: center;
|
47 |
+
}
|
48 |
+
#input-image {
|
49 |
+
max-height: 300px;
|
50 |
+
}
|
51 |
+
#label-image {
|
52 |
+
height: 300px;
|
53 |
+
}
|
54 |
+
#result-image {
|
55 |
+
height: 300px;
|
56 |
+
}
|
57 |
+
'''
|
58 |
+
|
59 |
+
with gr.Blocks(theme=args.theme, css=css) as demo:
|
60 |
+
gr.Markdown('''<h1 id="title">Text2Human</h1>
|
61 |
+
|
62 |
+
This is an unofficial demo for <a href="https://github.com/yumingj/Text2Human">https://github.com/yumingj/Text2Human</a>.
|
63 |
+
''')
|
64 |
+
with gr.Row():
|
65 |
+
with gr.Column():
|
66 |
+
with gr.Row():
|
67 |
+
input_image = gr.Image(label='Input Pose Image',
|
68 |
+
type='pil',
|
69 |
+
elem_id='input-image')
|
70 |
+
with gr.Row():
|
71 |
+
paths = sorted(pathlib.Path('pose_images').glob('*.png'))
|
72 |
+
example_images = gr.Dataset(components=[input_image],
|
73 |
+
samples=[[path.as_posix()]
|
74 |
+
for path in paths])
|
75 |
+
|
76 |
+
with gr.Column():
|
77 |
+
with gr.Row():
|
78 |
+
label_image = gr.Image(label='Label Image',
|
79 |
+
type='numpy',
|
80 |
+
elem_id='label-image')
|
81 |
+
with gr.Row():
|
82 |
+
shape_text = gr.Textbox(
|
83 |
+
label='Shape Description',
|
84 |
+
placeholder=
|
85 |
+
'''<gender>, <sleeve length>, <length of lower clothing>, <outer clothing type>, <other accessories1>, ...
|
86 |
+
Note: The outer clothing type and accessories can be omitted.''')
|
87 |
+
with gr.Row():
|
88 |
+
shape_example_texts = gr.Dataset(
|
89 |
+
components=[shape_text],
|
90 |
+
samples=[['man, sleeveless T-shirt, long pants'],
|
91 |
+
['woman, short-sleeve T-shirt, short jeans']])
|
92 |
+
with gr.Row():
|
93 |
+
generate_label_button = gr.Button('Generate Label Image')
|
94 |
+
|
95 |
+
with gr.Column():
|
96 |
+
with gr.Row():
|
97 |
+
result = gr.Image(label='Result',
|
98 |
+
type='numpy',
|
99 |
+
elem_id='result-image')
|
100 |
+
with gr.Row():
|
101 |
+
texture_text = gr.Textbox(
|
102 |
+
label='Texture Description',
|
103 |
+
placeholder=
|
104 |
+
'''<upper clothing texture>, <lower clothing texture>, <outer clothing texture>
|
105 |
+
Note: Currently, only 5 types of textures are supported, i.e., pure color, stripe/spline, plaid/lattice, floral, denim.'''
|
106 |
+
)
|
107 |
+
with gr.Row():
|
108 |
+
texture_example_texts = gr.Dataset(
|
109 |
+
components=[texture_text],
|
110 |
+
samples=[['pure color, denim'], ['floral, stripe']])
|
111 |
+
with gr.Row():
|
112 |
+
sample_steps = gr.Slider(10,
|
113 |
+
300,
|
114 |
+
value=10,
|
115 |
+
step=10,
|
116 |
+
label='Sample Steps')
|
117 |
+
with gr.Row():
|
118 |
+
seed = gr.Slider(0, 1000000, value=0, step=1, label='Seed')
|
119 |
+
with gr.Row():
|
120 |
+
generate_human_button = gr.Button('Generate Human')
|
121 |
+
|
122 |
+
gr.Markdown(
|
123 |
+
'<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.text2human" alt="visitor badge"/></center>'
|
124 |
+
)
|
125 |
+
|
126 |
+
input_image.change(fn=model.process_pose_image,
|
127 |
+
inputs=[input_image],
|
128 |
+
outputs=None)
|
129 |
+
generate_label_button.click(fn=model.generate_label_image,
|
130 |
+
inputs=[shape_text],
|
131 |
+
outputs=[label_image])
|
132 |
+
generate_human_button.click(fn=model.generate_human,
|
133 |
+
inputs=[
|
134 |
+
texture_text,
|
135 |
+
sample_steps,
|
136 |
+
seed,
|
137 |
+
],
|
138 |
+
outputs=[result])
|
139 |
+
example_images.click(fn=set_example_image,
|
140 |
+
inputs=example_images,
|
141 |
+
outputs=example_images.components)
|
142 |
+
shape_example_texts.click(fn=set_example_text,
|
143 |
+
inputs=shape_example_texts,
|
144 |
+
outputs=shape_example_texts.components)
|
145 |
+
texture_example_texts.click(fn=set_example_text,
|
146 |
+
inputs=texture_example_texts,
|
147 |
+
outputs=texture_example_texts.components)
|
148 |
+
|
149 |
+
demo.launch(
|
150 |
+
enable_queue=args.enable_queue,
|
151 |
+
server_port=args.port,
|
152 |
+
share=args.share,
|
153 |
+
)
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == '__main__':
|
157 |
+
main()
|
model.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import sys
|
6 |
+
import zipfile
|
7 |
+
|
8 |
+
import huggingface_hub
|
9 |
+
import numpy as np
|
10 |
+
import PIL.Image
|
11 |
+
import torch
|
12 |
+
|
13 |
+
sys.path.insert(0, 'Text2Human')
|
14 |
+
|
15 |
+
from models.sample_model import SampleFromPoseModel
|
16 |
+
from utils.language_utils import (generate_shape_attributes,
|
17 |
+
generate_texture_attributes)
|
18 |
+
from utils.options import dict_to_nonedict, parse
|
19 |
+
from utils.util import set_random_seed
|
20 |
+
|
21 |
+
COLOR_LIST = [
|
22 |
+
(0, 0, 0),
|
23 |
+
(255, 250, 250),
|
24 |
+
(220, 220, 220),
|
25 |
+
(250, 235, 215),
|
26 |
+
(255, 250, 205),
|
27 |
+
(211, 211, 211),
|
28 |
+
(70, 130, 180),
|
29 |
+
(127, 255, 212),
|
30 |
+
(0, 100, 0),
|
31 |
+
(50, 205, 50),
|
32 |
+
(255, 255, 0),
|
33 |
+
(245, 222, 179),
|
34 |
+
(255, 140, 0),
|
35 |
+
(255, 0, 0),
|
36 |
+
(16, 78, 139),
|
37 |
+
(144, 238, 144),
|
38 |
+
(50, 205, 174),
|
39 |
+
(50, 155, 250),
|
40 |
+
(160, 140, 88),
|
41 |
+
(213, 140, 88),
|
42 |
+
(90, 140, 90),
|
43 |
+
(185, 210, 205),
|
44 |
+
(130, 165, 180),
|
45 |
+
(225, 141, 151),
|
46 |
+
]
|
47 |
+
|
48 |
+
|
49 |
+
class Model:
|
50 |
+
def __init__(self, device: str):
|
51 |
+
self.config = self._load_config()
|
52 |
+
self.config['device'] = device
|
53 |
+
self._download_models()
|
54 |
+
self.model = SampleFromPoseModel(self.config)
|
55 |
+
|
56 |
+
def _load_config(self) -> dict:
|
57 |
+
path = 'Text2Human/configs/sample_from_pose.yml'
|
58 |
+
config = parse(path, is_train=False)
|
59 |
+
config = dict_to_nonedict(config)
|
60 |
+
return config
|
61 |
+
|
62 |
+
def _download_models(self) -> None:
|
63 |
+
model_dir = pathlib.Path('pretrained_models')
|
64 |
+
if model_dir.exists():
|
65 |
+
return
|
66 |
+
token = os.getenv('HF_TOKEN')
|
67 |
+
path = huggingface_hub.hf_hub_download('hysts/Text2Human',
|
68 |
+
'orig/pretrained_models.zip',
|
69 |
+
use_auth_token=token)
|
70 |
+
model_dir.mkdir()
|
71 |
+
with zipfile.ZipFile(path) as f:
|
72 |
+
f.extractall(model_dir)
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def preprocess_pose_image(image: PIL.Image.Image) -> torch.Tensor:
|
76 |
+
image = np.array(
|
77 |
+
image.resize(
|
78 |
+
size=(256, 512),
|
79 |
+
resample=PIL.Image.Resampling.LANCZOS))[:, :, 2:].transpose(
|
80 |
+
2, 0, 1).astype(np.float32)
|
81 |
+
image = image / 12. - 1
|
82 |
+
data = torch.from_numpy(image).unsqueeze(1)
|
83 |
+
return data
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def process_mask(mask: torch.Tensor) -> torch.Tensor:
|
87 |
+
seg_map = np.full(mask.shape[:-1], -1)
|
88 |
+
for index, color in enumerate(COLOR_LIST):
|
89 |
+
seg_map[np.sum(mask == color, axis=2) == 3] = index
|
90 |
+
assert (seg_map != -1).all()
|
91 |
+
return seg_map
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def postprocess(result: torch.Tensor) -> np.ndarray:
|
95 |
+
result = result.permute(0, 2, 3, 1)
|
96 |
+
result = result.detach().cpu().numpy()
|
97 |
+
result = result * 255
|
98 |
+
result = np.asarray(result[0, :, :, :], dtype=np.uint8)
|
99 |
+
return result
|
100 |
+
|
101 |
+
def process_pose_image(self, pose_image: PIL.Image.Image) -> None:
|
102 |
+
if pose_image is None:
|
103 |
+
return
|
104 |
+
data = self.preprocess_pose_image(pose_image)
|
105 |
+
self.model.feed_pose_data(data)
|
106 |
+
|
107 |
+
def generate_label_image(self, shape_text: str) -> np.ndarray:
|
108 |
+
shape_attributes = generate_shape_attributes(shape_text)
|
109 |
+
shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
|
110 |
+
self.model.feed_shape_attributes(shape_attributes)
|
111 |
+
self.model.generate_parsing_map()
|
112 |
+
self.model.generate_quantized_segm()
|
113 |
+
colored_segm = self.model.palette_result(self.model.segm[0].cpu())
|
114 |
+
|
115 |
+
mask = colored_segm.copy()
|
116 |
+
seg_map = self.process_mask(mask)
|
117 |
+
self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
|
118 |
+
0).to(self.model.device)
|
119 |
+
self.model.generate_quantized_segm()
|
120 |
+
return colored_segm
|
121 |
+
|
122 |
+
def generate_human(self, texture_text: str, sample_steps: int,
|
123 |
+
seed: int) -> np.ndarray:
|
124 |
+
set_random_seed(seed)
|
125 |
+
|
126 |
+
texture_attributes = generate_texture_attributes(texture_text)
|
127 |
+
texture_attributes = torch.LongTensor(texture_attributes)
|
128 |
+
self.model.feed_texture_attributes(texture_attributes)
|
129 |
+
self.model.generate_texture_map()
|
130 |
+
|
131 |
+
self.model.sample_steps = sample_steps
|
132 |
+
out = self.model.sample_and_refine()
|
133 |
+
res = self.postprocess(out)
|
134 |
+
return res
|
patch
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/models/hierarchy_inference_model.py b/models/hierarchy_inference_model.py
|
2 |
+
index 3116307..5de661d 100644
|
3 |
+
--- a/models/hierarchy_inference_model.py
|
4 |
+
+++ b/models/hierarchy_inference_model.py
|
5 |
+
@@ -21,7 +21,7 @@ class VQGANTextureAwareSpatialHierarchyInferenceModel():
|
6 |
+
|
7 |
+
def __init__(self, opt):
|
8 |
+
self.opt = opt
|
9 |
+
- self.device = torch.device('cuda')
|
10 |
+
+ self.device = torch.device(opt['device'])
|
11 |
+
self.is_train = opt['is_train']
|
12 |
+
|
13 |
+
self.top_encoder = Encoder(
|
14 |
+
diff --git a/models/hierarchy_vqgan_model.py b/models/hierarchy_vqgan_model.py
|
15 |
+
index 4b0d657..0bf4712 100644
|
16 |
+
--- a/models/hierarchy_vqgan_model.py
|
17 |
+
+++ b/models/hierarchy_vqgan_model.py
|
18 |
+
@@ -20,7 +20,7 @@ class HierarchyVQSpatialTextureAwareModel():
|
19 |
+
|
20 |
+
def __init__(self, opt):
|
21 |
+
self.opt = opt
|
22 |
+
- self.device = torch.device('cuda')
|
23 |
+
+ self.device = torch.device(opt['device'])
|
24 |
+
self.top_encoder = Encoder(
|
25 |
+
ch=opt['top_ch'],
|
26 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
27 |
+
diff --git a/models/parsing_gen_model.py b/models/parsing_gen_model.py
|
28 |
+
index 9440345..15a1ecb 100644
|
29 |
+
--- a/models/parsing_gen_model.py
|
30 |
+
+++ b/models/parsing_gen_model.py
|
31 |
+
@@ -22,7 +22,7 @@ class ParsingGenModel():
|
32 |
+
|
33 |
+
def __init__(self, opt):
|
34 |
+
self.opt = opt
|
35 |
+
- self.device = torch.device('cuda')
|
36 |
+
+ self.device = torch.device(opt['device'])
|
37 |
+
self.is_train = opt['is_train']
|
38 |
+
|
39 |
+
self.attr_embedder = ShapeAttrEmbedding(
|
40 |
+
diff --git a/models/sample_model.py b/models/sample_model.py
|
41 |
+
index 4c60e3f..5265cd0 100644
|
42 |
+
--- a/models/sample_model.py
|
43 |
+
+++ b/models/sample_model.py
|
44 |
+
@@ -23,7 +23,7 @@ class BaseSampleModel():
|
45 |
+
|
46 |
+
def __init__(self, opt):
|
47 |
+
self.opt = opt
|
48 |
+
- self.device = torch.device('cuda')
|
49 |
+
+ self.device = torch.device(opt['device'])
|
50 |
+
|
51 |
+
# hierarchical VQVAE
|
52 |
+
self.decoder = Decoder(
|
53 |
+
@@ -123,7 +123,7 @@ class BaseSampleModel():
|
54 |
+
|
55 |
+
def load_top_pretrain_models(self):
|
56 |
+
# load pretrained vqgan
|
57 |
+
- top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
58 |
+
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device)
|
59 |
+
|
60 |
+
self.decoder.load_state_dict(
|
61 |
+
top_vae_checkpoint['decoder'], strict=True)
|
62 |
+
@@ -137,7 +137,7 @@ class BaseSampleModel():
|
63 |
+
self.top_post_quant_conv.eval()
|
64 |
+
|
65 |
+
def load_bot_pretrain_network(self):
|
66 |
+
- checkpoint = torch.load(self.opt['bot_vae_path'])
|
67 |
+
+ checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device)
|
68 |
+
self.bot_decoder_res.load_state_dict(
|
69 |
+
checkpoint['bot_decoder_res'], strict=True)
|
70 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
71 |
+
@@ -153,7 +153,7 @@ class BaseSampleModel():
|
72 |
+
|
73 |
+
def load_pretrained_segm_token(self):
|
74 |
+
# load pretrained vqgan for segmentation mask
|
75 |
+
- segm_token_checkpoint = torch.load(self.opt['segm_token_path'])
|
76 |
+
+ segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device)
|
77 |
+
self.segm_encoder.load_state_dict(
|
78 |
+
segm_token_checkpoint['encoder'], strict=True)
|
79 |
+
self.segm_quantizer.load_state_dict(
|
80 |
+
@@ -166,7 +166,7 @@ class BaseSampleModel():
|
81 |
+
self.segm_quant_conv.eval()
|
82 |
+
|
83 |
+
def load_index_pred_network(self):
|
84 |
+
- checkpoint = torch.load(self.opt['pretrained_index_network'])
|
85 |
+
+ checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device)
|
86 |
+
self.index_pred_guidance_encoder.load_state_dict(
|
87 |
+
checkpoint['guidance_encoder'], strict=True)
|
88 |
+
self.index_pred_decoder.load_state_dict(
|
89 |
+
@@ -176,7 +176,7 @@ class BaseSampleModel():
|
90 |
+
self.index_pred_decoder.eval()
|
91 |
+
|
92 |
+
def load_sampler_pretrained_network(self):
|
93 |
+
- checkpoint = torch.load(self.opt['pretrained_sampler'])
|
94 |
+
+ checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device)
|
95 |
+
self.sampler_fn.load_state_dict(checkpoint, strict=True)
|
96 |
+
self.sampler_fn.eval()
|
97 |
+
|
98 |
+
@@ -397,7 +397,7 @@ class SampleFromPoseModel(BaseSampleModel):
|
99 |
+
[185, 210, 205], [130, 165, 180], [225, 141, 151]]
|
100 |
+
|
101 |
+
def load_shape_generation_models(self):
|
102 |
+
- checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
|
103 |
+
+ checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device)
|
104 |
+
|
105 |
+
self.shape_attr_embedder.load_state_dict(
|
106 |
+
checkpoint['embedder'], strict=True)
|
107 |
+
diff --git a/models/transformer_model.py b/models/transformer_model.py
|
108 |
+
index 7db0f3e..4523d17 100644
|
109 |
+
--- a/models/transformer_model.py
|
110 |
+
+++ b/models/transformer_model.py
|
111 |
+
@@ -21,7 +21,7 @@ class TransformerTextureAwareModel():
|
112 |
+
|
113 |
+
def __init__(self, opt):
|
114 |
+
self.opt = opt
|
115 |
+
- self.device = torch.device('cuda')
|
116 |
+
+ self.device = torch.device(opt['device'])
|
117 |
+
self.is_train = opt['is_train']
|
118 |
+
|
119 |
+
# VQVAE for image
|
120 |
+
@@ -317,10 +317,10 @@ class TransformerTextureAwareModel():
|
121 |
+
def sample_fn(self, temp=1.0, sample_steps=None):
|
122 |
+
self._denoise_fn.eval()
|
123 |
+
|
124 |
+
- b, device = self.image.size(0), 'cuda'
|
125 |
+
+ b = self.image.size(0)
|
126 |
+
x_t = torch.ones(
|
127 |
+
- (b, np.prod(self.shape)), device=device).long() * self.mask_id
|
128 |
+
- unmasked = torch.zeros_like(x_t, device=device).bool()
|
129 |
+
+ (b, np.prod(self.shape)), device=self.device).long() * self.mask_id
|
130 |
+
+ unmasked = torch.zeros_like(x_t, device=self.device).bool()
|
131 |
+
sample_steps = list(range(1, sample_steps + 1))
|
132 |
+
|
133 |
+
texture_mask_flatten = self.texture_tokens.view(-1)
|
134 |
+
@@ -336,11 +336,11 @@ class TransformerTextureAwareModel():
|
135 |
+
|
136 |
+
for t in reversed(sample_steps):
|
137 |
+
print(f'Sample timestep {t:4d}', end='\r')
|
138 |
+
- t = torch.full((b, ), t, device=device, dtype=torch.long)
|
139 |
+
+ t = torch.full((b, ), t, device=self.device, dtype=torch.long)
|
140 |
+
|
141 |
+
# where to unmask
|
142 |
+
changes = torch.rand(
|
143 |
+
- x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
|
144 |
+
+ x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
|
145 |
+
# don't unmask somewhere already unmasked
|
146 |
+
changes = torch.bitwise_xor(changes,
|
147 |
+
torch.bitwise_and(changes, unmasked))
|
148 |
+
diff --git a/models/vqgan_model.py b/models/vqgan_model.py
|
149 |
+
index 13a2e70..9c840f1 100644
|
150 |
+
--- a/models/vqgan_model.py
|
151 |
+
+++ b/models/vqgan_model.py
|
152 |
+
@@ -20,7 +20,7 @@ class VQModel():
|
153 |
+
def __init__(self, opt):
|
154 |
+
super().__init__()
|
155 |
+
self.opt = opt
|
156 |
+
- self.device = torch.device('cuda')
|
157 |
+
+ self.device = torch.device(opt['device'])
|
158 |
+
self.encoder = Encoder(
|
159 |
+
ch=opt['ch'],
|
160 |
+
num_res_blocks=opt['num_res_blocks'],
|
161 |
+
@@ -390,7 +390,7 @@ class VQImageSegmTextureModel(VQImageModel):
|
162 |
+
|
163 |
+
def __init__(self, opt):
|
164 |
+
self.opt = opt
|
165 |
+
- self.device = torch.device('cuda')
|
166 |
+
+ self.device = torch.device(opt['device'])
|
167 |
+
self.encoder = Encoder(
|
168 |
+
ch=opt['ch'],
|
169 |
+
num_res_blocks=opt['num_res_blocks'],
|
pose_images/000.png
ADDED
Git LFS Details
|
pose_images/001.png
ADDED
Git LFS Details
|
pose_images/002.png
ADDED
Git LFS Details
|
pose_images/003.png
ADDED
Git LFS Details
|
pose_images/004.png
ADDED
Git LFS Details
|
pose_images/005.png
ADDED
Git LFS Details
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops==0.4.1
|
2 |
+
lpips==0.1.4
|
3 |
+
mmcv-full==1.5.2
|
4 |
+
mmsegmentation==0.24.1
|
5 |
+
numpy==1.22.3
|
6 |
+
Pillow==9.1.1
|
7 |
+
sentence-transformers==2.2.0
|
8 |
+
tokenizers==0.12.1
|
9 |
+
torch==1.11.0
|
10 |
+
torchvision==0.12.0
|
11 |
+
transformers==4.19.2
|