Upload 17 files
Browse files- .gitattributes +2 -9
- .gitignore +1 -0
- .gitmodules +3 -0
- .pre-commit-config.yaml +36 -0
- .style.yapf +5 -0
- README.md +6 -5
- app.py +140 -0
- model.py +145 -0
- patch +169 -0
- pose_images/000.png +0 -0
- pose_images/001.png +0 -0
- pose_images/002.png +0 -0
- pose_images/003.png +0 -0
- pose_images/004.png +0 -0
- pose_images/005.png +0 -0
- requirements.txt +12 -0
- style.css +16 -0
.gitattributes
CHANGED
@@ -1,35 +1,28 @@
|
|
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
2 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
3 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
4 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
5 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
|
11 |
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
13 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
17 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
20 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
22 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
25 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
28 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pretrained_models
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "Text2Human"]
|
2 |
+
path = Text2Human
|
3 |
+
url = https://github.com/yumingj/Text2Human
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
exclude: ^(Text2Human|patch)
|
2 |
+
repos:
|
3 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
4 |
+
rev: v4.2.0
|
5 |
+
hooks:
|
6 |
+
- id: check-executables-have-shebangs
|
7 |
+
- id: check-json
|
8 |
+
- id: check-merge-conflict
|
9 |
+
- id: check-shebang-scripts-are-executable
|
10 |
+
- id: check-toml
|
11 |
+
- id: check-yaml
|
12 |
+
- id: double-quote-string-fixer
|
13 |
+
- id: end-of-file-fixer
|
14 |
+
- id: mixed-line-ending
|
15 |
+
args: ['--fix=lf']
|
16 |
+
- id: requirements-txt-fixer
|
17 |
+
- id: trailing-whitespace
|
18 |
+
- repo: https://github.com/myint/docformatter
|
19 |
+
rev: v1.4
|
20 |
+
hooks:
|
21 |
+
- id: docformatter
|
22 |
+
args: ['--in-place']
|
23 |
+
- repo: https://github.com/pycqa/isort
|
24 |
+
rev: 5.12.0
|
25 |
+
hooks:
|
26 |
+
- id: isort
|
27 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
28 |
+
rev: v0.991
|
29 |
+
hooks:
|
30 |
+
- id: mypy
|
31 |
+
args: ['--ignore-missing-imports']
|
32 |
+
- repo: https://github.com/google/yapf
|
33 |
+
rev: v0.32.0
|
34 |
+
hooks:
|
35 |
+
- id: yapf
|
36 |
+
args: ['--parallel', '--in-place']
|
.style.yapf
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[style]
|
2 |
+
based_on_style = pep8
|
3 |
+
blank_line_before_nested_class_or_def = false
|
4 |
+
spaces_before_comment = 2
|
5 |
+
split_before_logical_operator = true
|
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces
|
|
|
1 |
---
|
2 |
+
title: Text2Human
|
3 |
+
emoji: π
|
4 |
colorFrom: purple
|
5 |
+
colorTo: gray
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
suggested_hardware: t4-small
|
11 |
---
|
12 |
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
app.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import os
|
6 |
+
import pathlib
|
7 |
+
import random
|
8 |
+
import shlex
|
9 |
+
import subprocess
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
if os.getenv('SYSTEM') == 'spaces':
|
15 |
+
import mim
|
16 |
+
|
17 |
+
mim.uninstall('mmcv-full', confirm_yes=True)
|
18 |
+
mim.install('mmcv-full==1.5.2', is_yes=True)
|
19 |
+
|
20 |
+
with open('patch') as f:
|
21 |
+
subprocess.run(shlex.split('patch -p1'), cwd='Text2Human', stdin=f)
|
22 |
+
|
23 |
+
from model import Model
|
24 |
+
|
25 |
+
DESCRIPTION = '''# [Text2Human](https://github.com/yumingj/Text2Human)
|
26 |
+
|
27 |
+
You can modify sample steps and seeds. By varying seeds, you can sample different human images under the same pose, shape description, and texture description. The larger the sample steps, the better quality of the generated images. (The default value of sample steps is 256 in the original repo.)
|
28 |
+
|
29 |
+
Label image generation step can be skipped. However, in that case, the input label image must be 512x256 in size and must contain only the specified colors.
|
30 |
+
'''
|
31 |
+
|
32 |
+
MAX_SEED = np.iinfo(np.int32).max
|
33 |
+
|
34 |
+
|
35 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
36 |
+
if randomize_seed:
|
37 |
+
seed = random.randint(0, MAX_SEED)
|
38 |
+
return seed
|
39 |
+
|
40 |
+
|
41 |
+
model = Model()
|
42 |
+
|
43 |
+
with gr.Blocks(css='style.css') as demo:
|
44 |
+
gr.Markdown(DESCRIPTION)
|
45 |
+
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column():
|
48 |
+
with gr.Row():
|
49 |
+
input_image = gr.Image(label='Input Pose Image',
|
50 |
+
type='pil',
|
51 |
+
elem_id='input-image')
|
52 |
+
pose_data = gr.State()
|
53 |
+
with gr.Row():
|
54 |
+
paths = sorted(pathlib.Path('pose_images').glob('*.png'))
|
55 |
+
gr.Examples(examples=[[path.as_posix()] for path in paths],
|
56 |
+
inputs=input_image)
|
57 |
+
|
58 |
+
with gr.Row():
|
59 |
+
shape_text = gr.Textbox(
|
60 |
+
label='Shape Description',
|
61 |
+
placeholder=
|
62 |
+
'''<gender>, <sleeve length>, <length of lower clothing>, <outer clothing type>, <other accessories1>, ...
|
63 |
+
Note: The outer clothing type and accessories can be omitted.''')
|
64 |
+
with gr.Row():
|
65 |
+
gr.Examples(
|
66 |
+
examples=[['man, sleeveless T-shirt, long pants'],
|
67 |
+
['woman, short-sleeve T-shirt, short jeans']],
|
68 |
+
inputs=shape_text)
|
69 |
+
with gr.Row():
|
70 |
+
generate_label_button = gr.Button('Generate Label Image')
|
71 |
+
|
72 |
+
with gr.Column():
|
73 |
+
with gr.Row():
|
74 |
+
label_image = gr.Image(label='Label Image',
|
75 |
+
type='numpy',
|
76 |
+
elem_id='label-image')
|
77 |
+
|
78 |
+
with gr.Row():
|
79 |
+
texture_text = gr.Textbox(
|
80 |
+
label='Texture Description',
|
81 |
+
placeholder=
|
82 |
+
'''<upper clothing texture>, <lower clothing texture>, <outer clothing texture>
|
83 |
+
Note: Currently, only 5 types of textures are supported, i.e., pure color, stripe/spline, plaid/lattice, floral, denim.'''
|
84 |
+
)
|
85 |
+
with gr.Row():
|
86 |
+
gr.Examples(examples=[
|
87 |
+
['pure color, denim'],
|
88 |
+
['floral, stripe'],
|
89 |
+
],
|
90 |
+
inputs=texture_text)
|
91 |
+
with gr.Row():
|
92 |
+
sample_steps = gr.Slider(label='Sample Steps',
|
93 |
+
minimum=10,
|
94 |
+
maximum=300,
|
95 |
+
step=1,
|
96 |
+
value=256)
|
97 |
+
with gr.Row():
|
98 |
+
seed = gr.Slider(label='Seed',
|
99 |
+
minimum=0,
|
100 |
+
maximum=MAX_SEED,
|
101 |
+
step=1,
|
102 |
+
value=0)
|
103 |
+
randomize_seed = gr.Checkbox(label='Randomize seed',
|
104 |
+
value=True)
|
105 |
+
with gr.Row():
|
106 |
+
generate_human_button = gr.Button('Generate Human')
|
107 |
+
|
108 |
+
with gr.Column():
|
109 |
+
with gr.Row():
|
110 |
+
result = gr.Image(label='Result',
|
111 |
+
type='numpy',
|
112 |
+
elem_id='result-image')
|
113 |
+
|
114 |
+
input_image.change(
|
115 |
+
fn=model.process_pose_image,
|
116 |
+
inputs=input_image,
|
117 |
+
outputs=pose_data,
|
118 |
+
)
|
119 |
+
generate_label_button.click(
|
120 |
+
fn=model.generate_label_image,
|
121 |
+
inputs=[
|
122 |
+
pose_data,
|
123 |
+
shape_text,
|
124 |
+
],
|
125 |
+
outputs=label_image,
|
126 |
+
)
|
127 |
+
generate_human_button.click(fn=randomize_seed_fn,
|
128 |
+
inputs=[seed, randomize_seed],
|
129 |
+
outputs=seed,
|
130 |
+
queue=False).then(
|
131 |
+
fn=model.generate_human,
|
132 |
+
inputs=[
|
133 |
+
label_image,
|
134 |
+
texture_text,
|
135 |
+
sample_steps,
|
136 |
+
seed,
|
137 |
+
],
|
138 |
+
outputs=result,
|
139 |
+
)
|
140 |
+
demo.queue(max_size=10).launch()
|
model.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import pathlib
|
4 |
+
import sys
|
5 |
+
import zipfile
|
6 |
+
|
7 |
+
import huggingface_hub
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
import torch
|
11 |
+
|
12 |
+
sys.path.insert(0, 'Text2Human')
|
13 |
+
|
14 |
+
from models.sample_model import SampleFromPoseModel
|
15 |
+
from utils.language_utils import (generate_shape_attributes,
|
16 |
+
generate_texture_attributes)
|
17 |
+
from utils.options import dict_to_nonedict, parse
|
18 |
+
from utils.util import set_random_seed
|
19 |
+
|
20 |
+
COLOR_LIST = [
|
21 |
+
(0, 0, 0),
|
22 |
+
(255, 250, 250),
|
23 |
+
(220, 220, 220),
|
24 |
+
(250, 235, 215),
|
25 |
+
(255, 250, 205),
|
26 |
+
(211, 211, 211),
|
27 |
+
(70, 130, 180),
|
28 |
+
(127, 255, 212),
|
29 |
+
(0, 100, 0),
|
30 |
+
(50, 205, 50),
|
31 |
+
(255, 255, 0),
|
32 |
+
(245, 222, 179),
|
33 |
+
(255, 140, 0),
|
34 |
+
(255, 0, 0),
|
35 |
+
(16, 78, 139),
|
36 |
+
(144, 238, 144),
|
37 |
+
(50, 205, 174),
|
38 |
+
(50, 155, 250),
|
39 |
+
(160, 140, 88),
|
40 |
+
(213, 140, 88),
|
41 |
+
(90, 140, 90),
|
42 |
+
(185, 210, 205),
|
43 |
+
(130, 165, 180),
|
44 |
+
(225, 141, 151),
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
class Model:
|
49 |
+
def __init__(self):
|
50 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
51 |
+
self.config = self._load_config()
|
52 |
+
self.config['device'] = device.type
|
53 |
+
self._download_models()
|
54 |
+
self.model = SampleFromPoseModel(self.config)
|
55 |
+
self.model.batch_size = 1
|
56 |
+
|
57 |
+
def _load_config(self) -> dict:
|
58 |
+
path = 'Text2Human/configs/sample_from_pose.yml'
|
59 |
+
config = parse(path, is_train=False)
|
60 |
+
config = dict_to_nonedict(config)
|
61 |
+
return config
|
62 |
+
|
63 |
+
def _download_models(self) -> None:
|
64 |
+
model_dir = pathlib.Path('pretrained_models')
|
65 |
+
if model_dir.exists():
|
66 |
+
return
|
67 |
+
path = huggingface_hub.hf_hub_download('yumingj/Text2Human_SSHQ',
|
68 |
+
'pretrained_models.zip')
|
69 |
+
model_dir.mkdir()
|
70 |
+
with zipfile.ZipFile(path) as f:
|
71 |
+
f.extractall(model_dir)
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def preprocess_pose_image(image: PIL.Image.Image) -> torch.Tensor:
|
75 |
+
image = np.array(
|
76 |
+
image.resize(
|
77 |
+
size=(256, 512),
|
78 |
+
resample=PIL.Image.Resampling.LANCZOS))[:, :, 2:].transpose(
|
79 |
+
2, 0, 1).astype(np.float32)
|
80 |
+
image = image / 12. - 1
|
81 |
+
data = torch.from_numpy(image).unsqueeze(1)
|
82 |
+
return data
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def process_mask(mask: np.ndarray) -> np.ndarray:
|
86 |
+
if mask.shape != (512, 256, 3):
|
87 |
+
return None
|
88 |
+
seg_map = np.full(mask.shape[:-1], -1)
|
89 |
+
for index, color in enumerate(COLOR_LIST):
|
90 |
+
seg_map[np.sum(mask == color, axis=2) == 3] = index
|
91 |
+
if not (seg_map != -1).all():
|
92 |
+
return None
|
93 |
+
return seg_map
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def postprocess(result: torch.Tensor) -> np.ndarray:
|
97 |
+
result = result.permute(0, 2, 3, 1)
|
98 |
+
result = result.detach().cpu().numpy()
|
99 |
+
result = result * 255
|
100 |
+
result = np.asarray(result[0, :, :, :], dtype=np.uint8)
|
101 |
+
return result
|
102 |
+
|
103 |
+
def process_pose_image(self, pose_image: PIL.Image.Image) -> torch.Tensor:
|
104 |
+
if pose_image is None:
|
105 |
+
return
|
106 |
+
data = self.preprocess_pose_image(pose_image)
|
107 |
+
self.model.feed_pose_data(data)
|
108 |
+
return data
|
109 |
+
|
110 |
+
def generate_label_image(self, pose_data: torch.Tensor,
|
111 |
+
shape_text: str) -> np.ndarray:
|
112 |
+
if pose_data is None:
|
113 |
+
return
|
114 |
+
self.model.feed_pose_data(pose_data)
|
115 |
+
shape_attributes = generate_shape_attributes(shape_text)
|
116 |
+
shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
|
117 |
+
self.model.feed_shape_attributes(shape_attributes)
|
118 |
+
self.model.generate_parsing_map()
|
119 |
+
self.model.generate_quantized_segm()
|
120 |
+
colored_segm = self.model.palette_result(self.model.segm[0].cpu())
|
121 |
+
return colored_segm
|
122 |
+
|
123 |
+
def generate_human(self, label_image: np.ndarray, texture_text: str,
|
124 |
+
sample_steps: int, seed: int) -> np.ndarray:
|
125 |
+
if label_image is None:
|
126 |
+
return
|
127 |
+
mask = label_image.copy()
|
128 |
+
seg_map = self.process_mask(mask)
|
129 |
+
if seg_map is None:
|
130 |
+
return
|
131 |
+
self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
|
132 |
+
0).to(self.model.device)
|
133 |
+
self.model.generate_quantized_segm()
|
134 |
+
|
135 |
+
set_random_seed(seed)
|
136 |
+
|
137 |
+
texture_attributes = generate_texture_attributes(texture_text)
|
138 |
+
texture_attributes = torch.LongTensor(texture_attributes)
|
139 |
+
self.model.feed_texture_attributes(texture_attributes)
|
140 |
+
self.model.generate_texture_map()
|
141 |
+
|
142 |
+
self.model.sample_steps = sample_steps
|
143 |
+
out = self.model.sample_and_refine()
|
144 |
+
res = self.postprocess(out)
|
145 |
+
return res
|
patch
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/models/hierarchy_inference_model.py b/models/hierarchy_inference_model.py
|
2 |
+
index 3116307..5de661d 100644
|
3 |
+
--- a/models/hierarchy_inference_model.py
|
4 |
+
+++ b/models/hierarchy_inference_model.py
|
5 |
+
@@ -21,7 +21,7 @@ class VQGANTextureAwareSpatialHierarchyInferenceModel():
|
6 |
+
|
7 |
+
def __init__(self, opt):
|
8 |
+
self.opt = opt
|
9 |
+
- self.device = torch.device('cuda')
|
10 |
+
+ self.device = torch.device(opt['device'])
|
11 |
+
self.is_train = opt['is_train']
|
12 |
+
|
13 |
+
self.top_encoder = Encoder(
|
14 |
+
diff --git a/models/hierarchy_vqgan_model.py b/models/hierarchy_vqgan_model.py
|
15 |
+
index 4b0d657..0bf4712 100644
|
16 |
+
--- a/models/hierarchy_vqgan_model.py
|
17 |
+
+++ b/models/hierarchy_vqgan_model.py
|
18 |
+
@@ -20,7 +20,7 @@ class HierarchyVQSpatialTextureAwareModel():
|
19 |
+
|
20 |
+
def __init__(self, opt):
|
21 |
+
self.opt = opt
|
22 |
+
- self.device = torch.device('cuda')
|
23 |
+
+ self.device = torch.device(opt['device'])
|
24 |
+
self.top_encoder = Encoder(
|
25 |
+
ch=opt['top_ch'],
|
26 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
27 |
+
diff --git a/models/parsing_gen_model.py b/models/parsing_gen_model.py
|
28 |
+
index 9440345..15a1ecb 100644
|
29 |
+
--- a/models/parsing_gen_model.py
|
30 |
+
+++ b/models/parsing_gen_model.py
|
31 |
+
@@ -22,7 +22,7 @@ class ParsingGenModel():
|
32 |
+
|
33 |
+
def __init__(self, opt):
|
34 |
+
self.opt = opt
|
35 |
+
- self.device = torch.device('cuda')
|
36 |
+
+ self.device = torch.device(opt['device'])
|
37 |
+
self.is_train = opt['is_train']
|
38 |
+
|
39 |
+
self.attr_embedder = ShapeAttrEmbedding(
|
40 |
+
diff --git a/models/sample_model.py b/models/sample_model.py
|
41 |
+
index 4c60e3f..5265cd0 100644
|
42 |
+
--- a/models/sample_model.py
|
43 |
+
+++ b/models/sample_model.py
|
44 |
+
@@ -23,7 +23,7 @@ class BaseSampleModel():
|
45 |
+
|
46 |
+
def __init__(self, opt):
|
47 |
+
self.opt = opt
|
48 |
+
- self.device = torch.device('cuda')
|
49 |
+
+ self.device = torch.device(opt['device'])
|
50 |
+
|
51 |
+
# hierarchical VQVAE
|
52 |
+
self.decoder = Decoder(
|
53 |
+
@@ -123,7 +123,7 @@ class BaseSampleModel():
|
54 |
+
|
55 |
+
def load_top_pretrain_models(self):
|
56 |
+
# load pretrained vqgan
|
57 |
+
- top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
58 |
+
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device)
|
59 |
+
|
60 |
+
self.decoder.load_state_dict(
|
61 |
+
top_vae_checkpoint['decoder'], strict=True)
|
62 |
+
@@ -137,7 +137,7 @@ class BaseSampleModel():
|
63 |
+
self.top_post_quant_conv.eval()
|
64 |
+
|
65 |
+
def load_bot_pretrain_network(self):
|
66 |
+
- checkpoint = torch.load(self.opt['bot_vae_path'])
|
67 |
+
+ checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device)
|
68 |
+
self.bot_decoder_res.load_state_dict(
|
69 |
+
checkpoint['bot_decoder_res'], strict=True)
|
70 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
71 |
+
@@ -153,7 +153,7 @@ class BaseSampleModel():
|
72 |
+
|
73 |
+
def load_pretrained_segm_token(self):
|
74 |
+
# load pretrained vqgan for segmentation mask
|
75 |
+
- segm_token_checkpoint = torch.load(self.opt['segm_token_path'])
|
76 |
+
+ segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device)
|
77 |
+
self.segm_encoder.load_state_dict(
|
78 |
+
segm_token_checkpoint['encoder'], strict=True)
|
79 |
+
self.segm_quantizer.load_state_dict(
|
80 |
+
@@ -166,7 +166,7 @@ class BaseSampleModel():
|
81 |
+
self.segm_quant_conv.eval()
|
82 |
+
|
83 |
+
def load_index_pred_network(self):
|
84 |
+
- checkpoint = torch.load(self.opt['pretrained_index_network'])
|
85 |
+
+ checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device)
|
86 |
+
self.index_pred_guidance_encoder.load_state_dict(
|
87 |
+
checkpoint['guidance_encoder'], strict=True)
|
88 |
+
self.index_pred_decoder.load_state_dict(
|
89 |
+
@@ -176,7 +176,7 @@ class BaseSampleModel():
|
90 |
+
self.index_pred_decoder.eval()
|
91 |
+
|
92 |
+
def load_sampler_pretrained_network(self):
|
93 |
+
- checkpoint = torch.load(self.opt['pretrained_sampler'])
|
94 |
+
+ checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device)
|
95 |
+
self.sampler_fn.load_state_dict(checkpoint, strict=True)
|
96 |
+
self.sampler_fn.eval()
|
97 |
+
|
98 |
+
@@ -397,7 +397,7 @@ class SampleFromPoseModel(BaseSampleModel):
|
99 |
+
[185, 210, 205], [130, 165, 180], [225, 141, 151]]
|
100 |
+
|
101 |
+
def load_shape_generation_models(self):
|
102 |
+
- checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
|
103 |
+
+ checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device)
|
104 |
+
|
105 |
+
self.shape_attr_embedder.load_state_dict(
|
106 |
+
checkpoint['embedder'], strict=True)
|
107 |
+
diff --git a/models/transformer_model.py b/models/transformer_model.py
|
108 |
+
index 7db0f3e..4523d17 100644
|
109 |
+
--- a/models/transformer_model.py
|
110 |
+
+++ b/models/transformer_model.py
|
111 |
+
@@ -21,7 +21,7 @@ class TransformerTextureAwareModel():
|
112 |
+
|
113 |
+
def __init__(self, opt):
|
114 |
+
self.opt = opt
|
115 |
+
- self.device = torch.device('cuda')
|
116 |
+
+ self.device = torch.device(opt['device'])
|
117 |
+
self.is_train = opt['is_train']
|
118 |
+
|
119 |
+
# VQVAE for image
|
120 |
+
@@ -317,10 +317,10 @@ class TransformerTextureAwareModel():
|
121 |
+
def sample_fn(self, temp=1.0, sample_steps=None):
|
122 |
+
self._denoise_fn.eval()
|
123 |
+
|
124 |
+
- b, device = self.image.size(0), 'cuda'
|
125 |
+
+ b = self.image.size(0)
|
126 |
+
x_t = torch.ones(
|
127 |
+
- (b, np.prod(self.shape)), device=device).long() * self.mask_id
|
128 |
+
- unmasked = torch.zeros_like(x_t, device=device).bool()
|
129 |
+
+ (b, np.prod(self.shape)), device=self.device).long() * self.mask_id
|
130 |
+
+ unmasked = torch.zeros_like(x_t, device=self.device).bool()
|
131 |
+
sample_steps = list(range(1, sample_steps + 1))
|
132 |
+
|
133 |
+
texture_mask_flatten = self.texture_tokens.view(-1)
|
134 |
+
@@ -336,11 +336,11 @@ class TransformerTextureAwareModel():
|
135 |
+
|
136 |
+
for t in reversed(sample_steps):
|
137 |
+
print(f'Sample timestep {t:4d}', end='\r')
|
138 |
+
- t = torch.full((b, ), t, device=device, dtype=torch.long)
|
139 |
+
+ t = torch.full((b, ), t, device=self.device, dtype=torch.long)
|
140 |
+
|
141 |
+
# where to unmask
|
142 |
+
changes = torch.rand(
|
143 |
+
- x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
|
144 |
+
+ x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
|
145 |
+
# don't unmask somewhere already unmasked
|
146 |
+
changes = torch.bitwise_xor(changes,
|
147 |
+
torch.bitwise_and(changes, unmasked))
|
148 |
+
diff --git a/models/vqgan_model.py b/models/vqgan_model.py
|
149 |
+
index 13a2e70..9c840f1 100644
|
150 |
+
--- a/models/vqgan_model.py
|
151 |
+
+++ b/models/vqgan_model.py
|
152 |
+
@@ -20,7 +20,7 @@ class VQModel():
|
153 |
+
def __init__(self, opt):
|
154 |
+
super().__init__()
|
155 |
+
self.opt = opt
|
156 |
+
- self.device = torch.device('cuda')
|
157 |
+
+ self.device = torch.device(opt['device'])
|
158 |
+
self.encoder = Encoder(
|
159 |
+
ch=opt['ch'],
|
160 |
+
num_res_blocks=opt['num_res_blocks'],
|
161 |
+
@@ -390,7 +390,7 @@ class VQImageSegmTextureModel(VQImageModel):
|
162 |
+
|
163 |
+
def __init__(self, opt):
|
164 |
+
self.opt = opt
|
165 |
+
- self.device = torch.device('cuda')
|
166 |
+
+ self.device = torch.device(opt['device'])
|
167 |
+
self.encoder = Encoder(
|
168 |
+
ch=opt['ch'],
|
169 |
+
num_res_blocks=opt['num_res_blocks'],
|
pose_images/000.png
ADDED
pose_images/001.png
ADDED
pose_images/002.png
ADDED
pose_images/003.png
ADDED
pose_images/004.png
ADDED
pose_images/005.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops==0.6.1
|
2 |
+
lpips==0.1.4
|
3 |
+
mmcv-full==1.5.2
|
4 |
+
mmsegmentation==0.24.1
|
5 |
+
numpy==1.23.5
|
6 |
+
openmim==0.1.5
|
7 |
+
Pillow==9.5.0
|
8 |
+
sentence-transformers==2.2.2
|
9 |
+
tokenizers==0.13.3
|
10 |
+
torch==1.11.0
|
11 |
+
torchvision==0.12.0
|
12 |
+
transformers==4.30.2
|
style.css
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
4 |
+
#input-image {
|
5 |
+
max-height: 300px;
|
6 |
+
}
|
7 |
+
#label-image {
|
8 |
+
height: 300px;
|
9 |
+
}
|
10 |
+
#result-image {
|
11 |
+
height: 300px;
|
12 |
+
}
|
13 |
+
img#visitor-badge {
|
14 |
+
display: block;
|
15 |
+
margin: auto;
|
16 |
+
}
|