elias3446 commited on
Commit
30a5522
β€’
1 Parent(s): 6297d7c

Upload 17 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,28 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
24
  *.wasm filter=lfs diff=lfs merge=lfs -text
25
  *.xz filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ pretrained_models
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "Text2Human"]
2
+ path = Text2Human
3
+ url = https://github.com/yumingj/Text2Human
.pre-commit-config.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^(Text2Human|patch)
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Outfits
3
- emoji: πŸ“ˆ
4
  colorFrom: purple
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.41.2
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Text2Human
3
+ emoji: πŸƒ
4
  colorFrom: purple
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
+ suggested_hardware: t4-small
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import pathlib
7
+ import random
8
+ import shlex
9
+ import subprocess
10
+
11
+ import gradio as gr
12
+ import numpy as np
13
+
14
+ if os.getenv('SYSTEM') == 'spaces':
15
+ import mim
16
+
17
+ mim.uninstall('mmcv-full', confirm_yes=True)
18
+ mim.install('mmcv-full==1.5.2', is_yes=True)
19
+
20
+ with open('patch') as f:
21
+ subprocess.run(shlex.split('patch -p1'), cwd='Text2Human', stdin=f)
22
+
23
+ from model import Model
24
+
25
+ DESCRIPTION = '''# [Text2Human](https://github.com/yumingj/Text2Human)
26
+
27
+ You can modify sample steps and seeds. By varying seeds, you can sample different human images under the same pose, shape description, and texture description. The larger the sample steps, the better quality of the generated images. (The default value of sample steps is 256 in the original repo.)
28
+
29
+ Label image generation step can be skipped. However, in that case, the input label image must be 512x256 in size and must contain only the specified colors.
30
+ '''
31
+
32
+ MAX_SEED = np.iinfo(np.int32).max
33
+
34
+
35
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
36
+ if randomize_seed:
37
+ seed = random.randint(0, MAX_SEED)
38
+ return seed
39
+
40
+
41
+ model = Model()
42
+
43
+ with gr.Blocks(css='style.css') as demo:
44
+ gr.Markdown(DESCRIPTION)
45
+
46
+ with gr.Row():
47
+ with gr.Column():
48
+ with gr.Row():
49
+ input_image = gr.Image(label='Input Pose Image',
50
+ type='pil',
51
+ elem_id='input-image')
52
+ pose_data = gr.State()
53
+ with gr.Row():
54
+ paths = sorted(pathlib.Path('pose_images').glob('*.png'))
55
+ gr.Examples(examples=[[path.as_posix()] for path in paths],
56
+ inputs=input_image)
57
+
58
+ with gr.Row():
59
+ shape_text = gr.Textbox(
60
+ label='Shape Description',
61
+ placeholder=
62
+ '''<gender>, <sleeve length>, <length of lower clothing>, <outer clothing type>, <other accessories1>, ...
63
+ Note: The outer clothing type and accessories can be omitted.''')
64
+ with gr.Row():
65
+ gr.Examples(
66
+ examples=[['man, sleeveless T-shirt, long pants'],
67
+ ['woman, short-sleeve T-shirt, short jeans']],
68
+ inputs=shape_text)
69
+ with gr.Row():
70
+ generate_label_button = gr.Button('Generate Label Image')
71
+
72
+ with gr.Column():
73
+ with gr.Row():
74
+ label_image = gr.Image(label='Label Image',
75
+ type='numpy',
76
+ elem_id='label-image')
77
+
78
+ with gr.Row():
79
+ texture_text = gr.Textbox(
80
+ label='Texture Description',
81
+ placeholder=
82
+ '''<upper clothing texture>, <lower clothing texture>, <outer clothing texture>
83
+ Note: Currently, only 5 types of textures are supported, i.e., pure color, stripe/spline, plaid/lattice, floral, denim.'''
84
+ )
85
+ with gr.Row():
86
+ gr.Examples(examples=[
87
+ ['pure color, denim'],
88
+ ['floral, stripe'],
89
+ ],
90
+ inputs=texture_text)
91
+ with gr.Row():
92
+ sample_steps = gr.Slider(label='Sample Steps',
93
+ minimum=10,
94
+ maximum=300,
95
+ step=1,
96
+ value=256)
97
+ with gr.Row():
98
+ seed = gr.Slider(label='Seed',
99
+ minimum=0,
100
+ maximum=MAX_SEED,
101
+ step=1,
102
+ value=0)
103
+ randomize_seed = gr.Checkbox(label='Randomize seed',
104
+ value=True)
105
+ with gr.Row():
106
+ generate_human_button = gr.Button('Generate Human')
107
+
108
+ with gr.Column():
109
+ with gr.Row():
110
+ result = gr.Image(label='Result',
111
+ type='numpy',
112
+ elem_id='result-image')
113
+
114
+ input_image.change(
115
+ fn=model.process_pose_image,
116
+ inputs=input_image,
117
+ outputs=pose_data,
118
+ )
119
+ generate_label_button.click(
120
+ fn=model.generate_label_image,
121
+ inputs=[
122
+ pose_data,
123
+ shape_text,
124
+ ],
125
+ outputs=label_image,
126
+ )
127
+ generate_human_button.click(fn=randomize_seed_fn,
128
+ inputs=[seed, randomize_seed],
129
+ outputs=seed,
130
+ queue=False).then(
131
+ fn=model.generate_human,
132
+ inputs=[
133
+ label_image,
134
+ texture_text,
135
+ sample_steps,
136
+ seed,
137
+ ],
138
+ outputs=result,
139
+ )
140
+ demo.queue(max_size=10).launch()
model.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+ import sys
5
+ import zipfile
6
+
7
+ import huggingface_hub
8
+ import numpy as np
9
+ import PIL.Image
10
+ import torch
11
+
12
+ sys.path.insert(0, 'Text2Human')
13
+
14
+ from models.sample_model import SampleFromPoseModel
15
+ from utils.language_utils import (generate_shape_attributes,
16
+ generate_texture_attributes)
17
+ from utils.options import dict_to_nonedict, parse
18
+ from utils.util import set_random_seed
19
+
20
+ COLOR_LIST = [
21
+ (0, 0, 0),
22
+ (255, 250, 250),
23
+ (220, 220, 220),
24
+ (250, 235, 215),
25
+ (255, 250, 205),
26
+ (211, 211, 211),
27
+ (70, 130, 180),
28
+ (127, 255, 212),
29
+ (0, 100, 0),
30
+ (50, 205, 50),
31
+ (255, 255, 0),
32
+ (245, 222, 179),
33
+ (255, 140, 0),
34
+ (255, 0, 0),
35
+ (16, 78, 139),
36
+ (144, 238, 144),
37
+ (50, 205, 174),
38
+ (50, 155, 250),
39
+ (160, 140, 88),
40
+ (213, 140, 88),
41
+ (90, 140, 90),
42
+ (185, 210, 205),
43
+ (130, 165, 180),
44
+ (225, 141, 151),
45
+ ]
46
+
47
+
48
+ class Model:
49
+ def __init__(self):
50
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
51
+ self.config = self._load_config()
52
+ self.config['device'] = device.type
53
+ self._download_models()
54
+ self.model = SampleFromPoseModel(self.config)
55
+ self.model.batch_size = 1
56
+
57
+ def _load_config(self) -> dict:
58
+ path = 'Text2Human/configs/sample_from_pose.yml'
59
+ config = parse(path, is_train=False)
60
+ config = dict_to_nonedict(config)
61
+ return config
62
+
63
+ def _download_models(self) -> None:
64
+ model_dir = pathlib.Path('pretrained_models')
65
+ if model_dir.exists():
66
+ return
67
+ path = huggingface_hub.hf_hub_download('yumingj/Text2Human_SSHQ',
68
+ 'pretrained_models.zip')
69
+ model_dir.mkdir()
70
+ with zipfile.ZipFile(path) as f:
71
+ f.extractall(model_dir)
72
+
73
+ @staticmethod
74
+ def preprocess_pose_image(image: PIL.Image.Image) -> torch.Tensor:
75
+ image = np.array(
76
+ image.resize(
77
+ size=(256, 512),
78
+ resample=PIL.Image.Resampling.LANCZOS))[:, :, 2:].transpose(
79
+ 2, 0, 1).astype(np.float32)
80
+ image = image / 12. - 1
81
+ data = torch.from_numpy(image).unsqueeze(1)
82
+ return data
83
+
84
+ @staticmethod
85
+ def process_mask(mask: np.ndarray) -> np.ndarray:
86
+ if mask.shape != (512, 256, 3):
87
+ return None
88
+ seg_map = np.full(mask.shape[:-1], -1)
89
+ for index, color in enumerate(COLOR_LIST):
90
+ seg_map[np.sum(mask == color, axis=2) == 3] = index
91
+ if not (seg_map != -1).all():
92
+ return None
93
+ return seg_map
94
+
95
+ @staticmethod
96
+ def postprocess(result: torch.Tensor) -> np.ndarray:
97
+ result = result.permute(0, 2, 3, 1)
98
+ result = result.detach().cpu().numpy()
99
+ result = result * 255
100
+ result = np.asarray(result[0, :, :, :], dtype=np.uint8)
101
+ return result
102
+
103
+ def process_pose_image(self, pose_image: PIL.Image.Image) -> torch.Tensor:
104
+ if pose_image is None:
105
+ return
106
+ data = self.preprocess_pose_image(pose_image)
107
+ self.model.feed_pose_data(data)
108
+ return data
109
+
110
+ def generate_label_image(self, pose_data: torch.Tensor,
111
+ shape_text: str) -> np.ndarray:
112
+ if pose_data is None:
113
+ return
114
+ self.model.feed_pose_data(pose_data)
115
+ shape_attributes = generate_shape_attributes(shape_text)
116
+ shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
117
+ self.model.feed_shape_attributes(shape_attributes)
118
+ self.model.generate_parsing_map()
119
+ self.model.generate_quantized_segm()
120
+ colored_segm = self.model.palette_result(self.model.segm[0].cpu())
121
+ return colored_segm
122
+
123
+ def generate_human(self, label_image: np.ndarray, texture_text: str,
124
+ sample_steps: int, seed: int) -> np.ndarray:
125
+ if label_image is None:
126
+ return
127
+ mask = label_image.copy()
128
+ seg_map = self.process_mask(mask)
129
+ if seg_map is None:
130
+ return
131
+ self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
132
+ 0).to(self.model.device)
133
+ self.model.generate_quantized_segm()
134
+
135
+ set_random_seed(seed)
136
+
137
+ texture_attributes = generate_texture_attributes(texture_text)
138
+ texture_attributes = torch.LongTensor(texture_attributes)
139
+ self.model.feed_texture_attributes(texture_attributes)
140
+ self.model.generate_texture_map()
141
+
142
+ self.model.sample_steps = sample_steps
143
+ out = self.model.sample_and_refine()
144
+ res = self.postprocess(out)
145
+ return res
patch ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/models/hierarchy_inference_model.py b/models/hierarchy_inference_model.py
2
+ index 3116307..5de661d 100644
3
+ --- a/models/hierarchy_inference_model.py
4
+ +++ b/models/hierarchy_inference_model.py
5
+ @@ -21,7 +21,7 @@ class VQGANTextureAwareSpatialHierarchyInferenceModel():
6
+
7
+ def __init__(self, opt):
8
+ self.opt = opt
9
+ - self.device = torch.device('cuda')
10
+ + self.device = torch.device(opt['device'])
11
+ self.is_train = opt['is_train']
12
+
13
+ self.top_encoder = Encoder(
14
+ diff --git a/models/hierarchy_vqgan_model.py b/models/hierarchy_vqgan_model.py
15
+ index 4b0d657..0bf4712 100644
16
+ --- a/models/hierarchy_vqgan_model.py
17
+ +++ b/models/hierarchy_vqgan_model.py
18
+ @@ -20,7 +20,7 @@ class HierarchyVQSpatialTextureAwareModel():
19
+
20
+ def __init__(self, opt):
21
+ self.opt = opt
22
+ - self.device = torch.device('cuda')
23
+ + self.device = torch.device(opt['device'])
24
+ self.top_encoder = Encoder(
25
+ ch=opt['top_ch'],
26
+ num_res_blocks=opt['top_num_res_blocks'],
27
+ diff --git a/models/parsing_gen_model.py b/models/parsing_gen_model.py
28
+ index 9440345..15a1ecb 100644
29
+ --- a/models/parsing_gen_model.py
30
+ +++ b/models/parsing_gen_model.py
31
+ @@ -22,7 +22,7 @@ class ParsingGenModel():
32
+
33
+ def __init__(self, opt):
34
+ self.opt = opt
35
+ - self.device = torch.device('cuda')
36
+ + self.device = torch.device(opt['device'])
37
+ self.is_train = opt['is_train']
38
+
39
+ self.attr_embedder = ShapeAttrEmbedding(
40
+ diff --git a/models/sample_model.py b/models/sample_model.py
41
+ index 4c60e3f..5265cd0 100644
42
+ --- a/models/sample_model.py
43
+ +++ b/models/sample_model.py
44
+ @@ -23,7 +23,7 @@ class BaseSampleModel():
45
+
46
+ def __init__(self, opt):
47
+ self.opt = opt
48
+ - self.device = torch.device('cuda')
49
+ + self.device = torch.device(opt['device'])
50
+
51
+ # hierarchical VQVAE
52
+ self.decoder = Decoder(
53
+ @@ -123,7 +123,7 @@ class BaseSampleModel():
54
+
55
+ def load_top_pretrain_models(self):
56
+ # load pretrained vqgan
57
+ - top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
58
+ + top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device)
59
+
60
+ self.decoder.load_state_dict(
61
+ top_vae_checkpoint['decoder'], strict=True)
62
+ @@ -137,7 +137,7 @@ class BaseSampleModel():
63
+ self.top_post_quant_conv.eval()
64
+
65
+ def load_bot_pretrain_network(self):
66
+ - checkpoint = torch.load(self.opt['bot_vae_path'])
67
+ + checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device)
68
+ self.bot_decoder_res.load_state_dict(
69
+ checkpoint['bot_decoder_res'], strict=True)
70
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
71
+ @@ -153,7 +153,7 @@ class BaseSampleModel():
72
+
73
+ def load_pretrained_segm_token(self):
74
+ # load pretrained vqgan for segmentation mask
75
+ - segm_token_checkpoint = torch.load(self.opt['segm_token_path'])
76
+ + segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device)
77
+ self.segm_encoder.load_state_dict(
78
+ segm_token_checkpoint['encoder'], strict=True)
79
+ self.segm_quantizer.load_state_dict(
80
+ @@ -166,7 +166,7 @@ class BaseSampleModel():
81
+ self.segm_quant_conv.eval()
82
+
83
+ def load_index_pred_network(self):
84
+ - checkpoint = torch.load(self.opt['pretrained_index_network'])
85
+ + checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device)
86
+ self.index_pred_guidance_encoder.load_state_dict(
87
+ checkpoint['guidance_encoder'], strict=True)
88
+ self.index_pred_decoder.load_state_dict(
89
+ @@ -176,7 +176,7 @@ class BaseSampleModel():
90
+ self.index_pred_decoder.eval()
91
+
92
+ def load_sampler_pretrained_network(self):
93
+ - checkpoint = torch.load(self.opt['pretrained_sampler'])
94
+ + checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device)
95
+ self.sampler_fn.load_state_dict(checkpoint, strict=True)
96
+ self.sampler_fn.eval()
97
+
98
+ @@ -397,7 +397,7 @@ class SampleFromPoseModel(BaseSampleModel):
99
+ [185, 210, 205], [130, 165, 180], [225, 141, 151]]
100
+
101
+ def load_shape_generation_models(self):
102
+ - checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
103
+ + checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device)
104
+
105
+ self.shape_attr_embedder.load_state_dict(
106
+ checkpoint['embedder'], strict=True)
107
+ diff --git a/models/transformer_model.py b/models/transformer_model.py
108
+ index 7db0f3e..4523d17 100644
109
+ --- a/models/transformer_model.py
110
+ +++ b/models/transformer_model.py
111
+ @@ -21,7 +21,7 @@ class TransformerTextureAwareModel():
112
+
113
+ def __init__(self, opt):
114
+ self.opt = opt
115
+ - self.device = torch.device('cuda')
116
+ + self.device = torch.device(opt['device'])
117
+ self.is_train = opt['is_train']
118
+
119
+ # VQVAE for image
120
+ @@ -317,10 +317,10 @@ class TransformerTextureAwareModel():
121
+ def sample_fn(self, temp=1.0, sample_steps=None):
122
+ self._denoise_fn.eval()
123
+
124
+ - b, device = self.image.size(0), 'cuda'
125
+ + b = self.image.size(0)
126
+ x_t = torch.ones(
127
+ - (b, np.prod(self.shape)), device=device).long() * self.mask_id
128
+ - unmasked = torch.zeros_like(x_t, device=device).bool()
129
+ + (b, np.prod(self.shape)), device=self.device).long() * self.mask_id
130
+ + unmasked = torch.zeros_like(x_t, device=self.device).bool()
131
+ sample_steps = list(range(1, sample_steps + 1))
132
+
133
+ texture_mask_flatten = self.texture_tokens.view(-1)
134
+ @@ -336,11 +336,11 @@ class TransformerTextureAwareModel():
135
+
136
+ for t in reversed(sample_steps):
137
+ print(f'Sample timestep {t:4d}', end='\r')
138
+ - t = torch.full((b, ), t, device=device, dtype=torch.long)
139
+ + t = torch.full((b, ), t, device=self.device, dtype=torch.long)
140
+
141
+ # where to unmask
142
+ changes = torch.rand(
143
+ - x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
144
+ + x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
145
+ # don't unmask somewhere already unmasked
146
+ changes = torch.bitwise_xor(changes,
147
+ torch.bitwise_and(changes, unmasked))
148
+ diff --git a/models/vqgan_model.py b/models/vqgan_model.py
149
+ index 13a2e70..9c840f1 100644
150
+ --- a/models/vqgan_model.py
151
+ +++ b/models/vqgan_model.py
152
+ @@ -20,7 +20,7 @@ class VQModel():
153
+ def __init__(self, opt):
154
+ super().__init__()
155
+ self.opt = opt
156
+ - self.device = torch.device('cuda')
157
+ + self.device = torch.device(opt['device'])
158
+ self.encoder = Encoder(
159
+ ch=opt['ch'],
160
+ num_res_blocks=opt['num_res_blocks'],
161
+ @@ -390,7 +390,7 @@ class VQImageSegmTextureModel(VQImageModel):
162
+
163
+ def __init__(self, opt):
164
+ self.opt = opt
165
+ - self.device = torch.device('cuda')
166
+ + self.device = torch.device(opt['device'])
167
+ self.encoder = Encoder(
168
+ ch=opt['ch'],
169
+ num_res_blocks=opt['num_res_blocks'],
pose_images/000.png ADDED
pose_images/001.png ADDED
pose_images/002.png ADDED
pose_images/003.png ADDED
pose_images/004.png ADDED
pose_images/005.png ADDED
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.6.1
2
+ lpips==0.1.4
3
+ mmcv-full==1.5.2
4
+ mmsegmentation==0.24.1
5
+ numpy==1.23.5
6
+ openmim==0.1.5
7
+ Pillow==9.5.0
8
+ sentence-transformers==2.2.2
9
+ tokenizers==0.13.3
10
+ torch==1.11.0
11
+ torchvision==0.12.0
12
+ transformers==4.30.2
style.css ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ #input-image {
5
+ max-height: 300px;
6
+ }
7
+ #label-image {
8
+ height: 300px;
9
+ }
10
+ #result-image {
11
+ height: 300px;
12
+ }
13
+ img#visitor-badge {
14
+ display: block;
15
+ margin: auto;
16
+ }