hysts commited on
Commit
b85284b
1 Parent(s): e4d1395
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
1
+ pretrained_models
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ [submodule "Text2Human"]
2
+ path = Text2Human
3
+ url = https://github.com/yumingj/Text2Human
.pre-commit-config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^(Text2Human|patch)
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.812
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
37
+ - repo: https://github.com/kynan/nbstripout
38
+ rev: 0.5.0
39
+ hooks:
40
+ - id: nbstripout
41
+ args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
+ - repo: https://github.com/nbQA-dev/nbQA
43
+ rev: 1.3.1
44
+ hooks:
45
+ - id: nbqa-isort
46
+ - id: nbqa-yapf
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
Text2Human ADDED
@@ -0,0 +1 @@
 
1
+ Subproject commit 6d38607df89651704000d0e6571bfc640d185a77
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import os
7
+ import pathlib
8
+ import subprocess
9
+
10
+ import gradio as gr
11
+
12
+ if os.getenv('SYSTEM') == 'spaces':
13
+ subprocess.call('pip uninstall -y mmcv-full'.split())
14
+ subprocess.call('pip install mmcv-full==1.5.2'.split())
15
+ subprocess.call('git apply ../patch'.split(), cwd='Text2Human')
16
+
17
+ from model import Model
18
+
19
+
20
+ def parse_args() -> argparse.Namespace:
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument('--device', type=str, default='cpu')
23
+ parser.add_argument('--theme', type=str)
24
+ parser.add_argument('--share', action='store_true')
25
+ parser.add_argument('--port', type=int)
26
+ parser.add_argument('--disable-queue',
27
+ dest='enable_queue',
28
+ action='store_false')
29
+ return parser.parse_args()
30
+
31
+
32
+ def set_example_image(example: list) -> dict:
33
+ return gr.Image.update(value=example[0])
34
+
35
+
36
+ def set_example_text(example: list) -> dict:
37
+ return gr.Textbox.update(value=example[0])
38
+
39
+
40
+ def main():
41
+ args = parse_args()
42
+ model = Model(args.device)
43
+
44
+ css = '''
45
+ h1#title {
46
+ text-align: center;
47
+ }
48
+ #input-image {
49
+ max-height: 300px;
50
+ }
51
+ #label-image {
52
+ height: 300px;
53
+ }
54
+ #result-image {
55
+ height: 300px;
56
+ }
57
+ '''
58
+
59
+ with gr.Blocks(theme=args.theme, css=css) as demo:
60
+ gr.Markdown('''<h1 id="title">Text2Human</h1>
61
+
62
+ This is an unofficial demo for <a href="https://github.com/yumingj/Text2Human">https://github.com/yumingj/Text2Human</a>.
63
+ ''')
64
+ with gr.Row():
65
+ with gr.Column():
66
+ with gr.Row():
67
+ input_image = gr.Image(label='Input Pose Image',
68
+ type='pil',
69
+ elem_id='input-image')
70
+ with gr.Row():
71
+ paths = sorted(pathlib.Path('pose_images').glob('*.png'))
72
+ example_images = gr.Dataset(components=[input_image],
73
+ samples=[[path.as_posix()]
74
+ for path in paths])
75
+
76
+ with gr.Column():
77
+ with gr.Row():
78
+ label_image = gr.Image(label='Label Image',
79
+ type='numpy',
80
+ elem_id='label-image')
81
+ with gr.Row():
82
+ shape_text = gr.Textbox(
83
+ label='Shape Description',
84
+ placeholder=
85
+ '''<gender>, <sleeve length>, <length of lower clothing>, <outer clothing type>, <other accessories1>, ...
86
+ Note: The outer clothing type and accessories can be omitted.''')
87
+ with gr.Row():
88
+ shape_example_texts = gr.Dataset(
89
+ components=[shape_text],
90
+ samples=[['man, sleeveless T-shirt, long pants'],
91
+ ['woman, short-sleeve T-shirt, short jeans']])
92
+ with gr.Row():
93
+ generate_label_button = gr.Button('Generate Label Image')
94
+
95
+ with gr.Column():
96
+ with gr.Row():
97
+ result = gr.Image(label='Result',
98
+ type='numpy',
99
+ elem_id='result-image')
100
+ with gr.Row():
101
+ texture_text = gr.Textbox(
102
+ label='Texture Description',
103
+ placeholder=
104
+ '''<upper clothing texture>, <lower clothing texture>, <outer clothing texture>
105
+ Note: Currently, only 5 types of textures are supported, i.e., pure color, stripe/spline, plaid/lattice, floral, denim.'''
106
+ )
107
+ with gr.Row():
108
+ texture_example_texts = gr.Dataset(
109
+ components=[texture_text],
110
+ samples=[['pure color, denim'], ['floral, stripe']])
111
+ with gr.Row():
112
+ sample_steps = gr.Slider(10,
113
+ 300,
114
+ value=10,
115
+ step=10,
116
+ label='Sample Steps')
117
+ with gr.Row():
118
+ seed = gr.Slider(0, 1000000, value=0, step=1, label='Seed')
119
+ with gr.Row():
120
+ generate_human_button = gr.Button('Generate Human')
121
+
122
+ gr.Markdown(
123
+ '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.text2human" alt="visitor badge"/></center>'
124
+ )
125
+
126
+ input_image.change(fn=model.process_pose_image,
127
+ inputs=[input_image],
128
+ outputs=None)
129
+ generate_label_button.click(fn=model.generate_label_image,
130
+ inputs=[shape_text],
131
+ outputs=[label_image])
132
+ generate_human_button.click(fn=model.generate_human,
133
+ inputs=[
134
+ texture_text,
135
+ sample_steps,
136
+ seed,
137
+ ],
138
+ outputs=[result])
139
+ example_images.click(fn=set_example_image,
140
+ inputs=example_images,
141
+ outputs=example_images.components)
142
+ shape_example_texts.click(fn=set_example_text,
143
+ inputs=shape_example_texts,
144
+ outputs=shape_example_texts.components)
145
+ texture_example_texts.click(fn=set_example_text,
146
+ inputs=texture_example_texts,
147
+ outputs=texture_example_texts.components)
148
+
149
+ demo.launch(
150
+ enable_queue=args.enable_queue,
151
+ server_port=args.port,
152
+ share=args.share,
153
+ )
154
+
155
+
156
+ if __name__ == '__main__':
157
+ main()
model.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import sys
6
+ import zipfile
7
+
8
+ import huggingface_hub
9
+ import numpy as np
10
+ import PIL.Image
11
+ import torch
12
+
13
+ sys.path.insert(0, 'Text2Human')
14
+
15
+ from models.sample_model import SampleFromPoseModel
16
+ from utils.language_utils import (generate_shape_attributes,
17
+ generate_texture_attributes)
18
+ from utils.options import dict_to_nonedict, parse
19
+ from utils.util import set_random_seed
20
+
21
+ COLOR_LIST = [
22
+ (0, 0, 0),
23
+ (255, 250, 250),
24
+ (220, 220, 220),
25
+ (250, 235, 215),
26
+ (255, 250, 205),
27
+ (211, 211, 211),
28
+ (70, 130, 180),
29
+ (127, 255, 212),
30
+ (0, 100, 0),
31
+ (50, 205, 50),
32
+ (255, 255, 0),
33
+ (245, 222, 179),
34
+ (255, 140, 0),
35
+ (255, 0, 0),
36
+ (16, 78, 139),
37
+ (144, 238, 144),
38
+ (50, 205, 174),
39
+ (50, 155, 250),
40
+ (160, 140, 88),
41
+ (213, 140, 88),
42
+ (90, 140, 90),
43
+ (185, 210, 205),
44
+ (130, 165, 180),
45
+ (225, 141, 151),
46
+ ]
47
+
48
+
49
+ class Model:
50
+ def __init__(self, device: str):
51
+ self.config = self._load_config()
52
+ self.config['device'] = device
53
+ self._download_models()
54
+ self.model = SampleFromPoseModel(self.config)
55
+
56
+ def _load_config(self) -> dict:
57
+ path = 'Text2Human/configs/sample_from_pose.yml'
58
+ config = parse(path, is_train=False)
59
+ config = dict_to_nonedict(config)
60
+ return config
61
+
62
+ def _download_models(self) -> None:
63
+ model_dir = pathlib.Path('pretrained_models')
64
+ if model_dir.exists():
65
+ return
66
+ token = os.getenv('HF_TOKEN')
67
+ path = huggingface_hub.hf_hub_download('hysts/Text2Human',
68
+ 'orig/pretrained_models.zip',
69
+ use_auth_token=token)
70
+ model_dir.mkdir()
71
+ with zipfile.ZipFile(path) as f:
72
+ f.extractall(model_dir)
73
+
74
+ @staticmethod
75
+ def preprocess_pose_image(image: PIL.Image.Image) -> torch.Tensor:
76
+ image = np.array(
77
+ image.resize(
78
+ size=(256, 512),
79
+ resample=PIL.Image.Resampling.LANCZOS))[:, :, 2:].transpose(
80
+ 2, 0, 1).astype(np.float32)
81
+ image = image / 12. - 1
82
+ data = torch.from_numpy(image).unsqueeze(1)
83
+ return data
84
+
85
+ @staticmethod
86
+ def process_mask(mask: torch.Tensor) -> torch.Tensor:
87
+ seg_map = np.full(mask.shape[:-1], -1)
88
+ for index, color in enumerate(COLOR_LIST):
89
+ seg_map[np.sum(mask == color, axis=2) == 3] = index
90
+ assert (seg_map != -1).all()
91
+ return seg_map
92
+
93
+ @staticmethod
94
+ def postprocess(result: torch.Tensor) -> np.ndarray:
95
+ result = result.permute(0, 2, 3, 1)
96
+ result = result.detach().cpu().numpy()
97
+ result = result * 255
98
+ result = np.asarray(result[0, :, :, :], dtype=np.uint8)
99
+ return result
100
+
101
+ def process_pose_image(self, pose_image: PIL.Image.Image) -> None:
102
+ if pose_image is None:
103
+ return
104
+ data = self.preprocess_pose_image(pose_image)
105
+ self.model.feed_pose_data(data)
106
+
107
+ def generate_label_image(self, shape_text: str) -> np.ndarray:
108
+ shape_attributes = generate_shape_attributes(shape_text)
109
+ shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
110
+ self.model.feed_shape_attributes(shape_attributes)
111
+ self.model.generate_parsing_map()
112
+ self.model.generate_quantized_segm()
113
+ colored_segm = self.model.palette_result(self.model.segm[0].cpu())
114
+
115
+ mask = colored_segm.copy()
116
+ seg_map = self.process_mask(mask)
117
+ self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
118
+ 0).to(self.model.device)
119
+ self.model.generate_quantized_segm()
120
+ return colored_segm
121
+
122
+ def generate_human(self, texture_text: str, sample_steps: int,
123
+ seed: int) -> np.ndarray:
124
+ set_random_seed(seed)
125
+
126
+ texture_attributes = generate_texture_attributes(texture_text)
127
+ texture_attributes = torch.LongTensor(texture_attributes)
128
+ self.model.feed_texture_attributes(texture_attributes)
129
+ self.model.generate_texture_map()
130
+
131
+ self.model.sample_steps = sample_steps
132
+ out = self.model.sample_and_refine()
133
+ res = self.postprocess(out)
134
+ return res
patch ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/models/hierarchy_inference_model.py b/models/hierarchy_inference_model.py
2
+ index 3116307..5de661d 100644
3
+ --- a/models/hierarchy_inference_model.py
4
+ +++ b/models/hierarchy_inference_model.py
5
+ @@ -21,7 +21,7 @@ class VQGANTextureAwareSpatialHierarchyInferenceModel():
6
+
7
+ def __init__(self, opt):
8
+ self.opt = opt
9
+ - self.device = torch.device('cuda')
10
+ + self.device = torch.device(opt['device'])
11
+ self.is_train = opt['is_train']
12
+
13
+ self.top_encoder = Encoder(
14
+ diff --git a/models/hierarchy_vqgan_model.py b/models/hierarchy_vqgan_model.py
15
+ index 4b0d657..0bf4712 100644
16
+ --- a/models/hierarchy_vqgan_model.py
17
+ +++ b/models/hierarchy_vqgan_model.py
18
+ @@ -20,7 +20,7 @@ class HierarchyVQSpatialTextureAwareModel():
19
+
20
+ def __init__(self, opt):
21
+ self.opt = opt
22
+ - self.device = torch.device('cuda')
23
+ + self.device = torch.device(opt['device'])
24
+ self.top_encoder = Encoder(
25
+ ch=opt['top_ch'],
26
+ num_res_blocks=opt['top_num_res_blocks'],
27
+ diff --git a/models/parsing_gen_model.py b/models/parsing_gen_model.py
28
+ index 9440345..15a1ecb 100644
29
+ --- a/models/parsing_gen_model.py
30
+ +++ b/models/parsing_gen_model.py
31
+ @@ -22,7 +22,7 @@ class ParsingGenModel():
32
+
33
+ def __init__(self, opt):
34
+ self.opt = opt
35
+ - self.device = torch.device('cuda')
36
+ + self.device = torch.device(opt['device'])
37
+ self.is_train = opt['is_train']
38
+
39
+ self.attr_embedder = ShapeAttrEmbedding(
40
+ diff --git a/models/sample_model.py b/models/sample_model.py
41
+ index 4c60e3f..5265cd0 100644
42
+ --- a/models/sample_model.py
43
+ +++ b/models/sample_model.py
44
+ @@ -23,7 +23,7 @@ class BaseSampleModel():
45
+
46
+ def __init__(self, opt):
47
+ self.opt = opt
48
+ - self.device = torch.device('cuda')
49
+ + self.device = torch.device(opt['device'])
50
+
51
+ # hierarchical VQVAE
52
+ self.decoder = Decoder(
53
+ @@ -123,7 +123,7 @@ class BaseSampleModel():
54
+
55
+ def load_top_pretrain_models(self):
56
+ # load pretrained vqgan
57
+ - top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
58
+ + top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device)
59
+
60
+ self.decoder.load_state_dict(
61
+ top_vae_checkpoint['decoder'], strict=True)
62
+ @@ -137,7 +137,7 @@ class BaseSampleModel():
63
+ self.top_post_quant_conv.eval()
64
+
65
+ def load_bot_pretrain_network(self):
66
+ - checkpoint = torch.load(self.opt['bot_vae_path'])
67
+ + checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device)
68
+ self.bot_decoder_res.load_state_dict(
69
+ checkpoint['bot_decoder_res'], strict=True)
70
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
71
+ @@ -153,7 +153,7 @@ class BaseSampleModel():
72
+
73
+ def load_pretrained_segm_token(self):
74
+ # load pretrained vqgan for segmentation mask
75
+ - segm_token_checkpoint = torch.load(self.opt['segm_token_path'])
76
+ + segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device)
77
+ self.segm_encoder.load_state_dict(
78
+ segm_token_checkpoint['encoder'], strict=True)
79
+ self.segm_quantizer.load_state_dict(
80
+ @@ -166,7 +166,7 @@ class BaseSampleModel():
81
+ self.segm_quant_conv.eval()
82
+
83
+ def load_index_pred_network(self):
84
+ - checkpoint = torch.load(self.opt['pretrained_index_network'])
85
+ + checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device)
86
+ self.index_pred_guidance_encoder.load_state_dict(
87
+ checkpoint['guidance_encoder'], strict=True)
88
+ self.index_pred_decoder.load_state_dict(
89
+ @@ -176,7 +176,7 @@ class BaseSampleModel():
90
+ self.index_pred_decoder.eval()
91
+
92
+ def load_sampler_pretrained_network(self):
93
+ - checkpoint = torch.load(self.opt['pretrained_sampler'])
94
+ + checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device)
95
+ self.sampler_fn.load_state_dict(checkpoint, strict=True)
96
+ self.sampler_fn.eval()
97
+
98
+ @@ -397,7 +397,7 @@ class SampleFromPoseModel(BaseSampleModel):
99
+ [185, 210, 205], [130, 165, 180], [225, 141, 151]]
100
+
101
+ def load_shape_generation_models(self):
102
+ - checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
103
+ + checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device)
104
+
105
+ self.shape_attr_embedder.load_state_dict(
106
+ checkpoint['embedder'], strict=True)
107
+ diff --git a/models/transformer_model.py b/models/transformer_model.py
108
+ index 7db0f3e..4523d17 100644
109
+ --- a/models/transformer_model.py
110
+ +++ b/models/transformer_model.py
111
+ @@ -21,7 +21,7 @@ class TransformerTextureAwareModel():
112
+
113
+ def __init__(self, opt):
114
+ self.opt = opt
115
+ - self.device = torch.device('cuda')
116
+ + self.device = torch.device(opt['device'])
117
+ self.is_train = opt['is_train']
118
+
119
+ # VQVAE for image
120
+ @@ -317,10 +317,10 @@ class TransformerTextureAwareModel():
121
+ def sample_fn(self, temp=1.0, sample_steps=None):
122
+ self._denoise_fn.eval()
123
+
124
+ - b, device = self.image.size(0), 'cuda'
125
+ + b = self.image.size(0)
126
+ x_t = torch.ones(
127
+ - (b, np.prod(self.shape)), device=device).long() * self.mask_id
128
+ - unmasked = torch.zeros_like(x_t, device=device).bool()
129
+ + (b, np.prod(self.shape)), device=self.device).long() * self.mask_id
130
+ + unmasked = torch.zeros_like(x_t, device=self.device).bool()
131
+ sample_steps = list(range(1, sample_steps + 1))
132
+
133
+ texture_mask_flatten = self.texture_tokens.view(-1)
134
+ @@ -336,11 +336,11 @@ class TransformerTextureAwareModel():
135
+
136
+ for t in reversed(sample_steps):
137
+ print(f'Sample timestep {t:4d}', end='\r')
138
+ - t = torch.full((b, ), t, device=device, dtype=torch.long)
139
+ + t = torch.full((b, ), t, device=self.device, dtype=torch.long)
140
+
141
+ # where to unmask
142
+ changes = torch.rand(
143
+ - x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
144
+ + x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
145
+ # don't unmask somewhere already unmasked
146
+ changes = torch.bitwise_xor(changes,
147
+ torch.bitwise_and(changes, unmasked))
148
+ diff --git a/models/vqgan_model.py b/models/vqgan_model.py
149
+ index 13a2e70..9c840f1 100644
150
+ --- a/models/vqgan_model.py
151
+ +++ b/models/vqgan_model.py
152
+ @@ -20,7 +20,7 @@ class VQModel():
153
+ def __init__(self, opt):
154
+ super().__init__()
155
+ self.opt = opt
156
+ - self.device = torch.device('cuda')
157
+ + self.device = torch.device(opt['device'])
158
+ self.encoder = Encoder(
159
+ ch=opt['ch'],
160
+ num_res_blocks=opt['num_res_blocks'],
161
+ @@ -390,7 +390,7 @@ class VQImageSegmTextureModel(VQImageModel):
162
+
163
+ def __init__(self, opt):
164
+ self.opt = opt
165
+ - self.device = torch.device('cuda')
166
+ + self.device = torch.device(opt['device'])
167
+ self.encoder = Encoder(
168
+ ch=opt['ch'],
169
+ num_res_blocks=opt['num_res_blocks'],
pose_images/000.png ADDED

Git LFS Details

  • SHA256: e109163ba1ebfe4c3323ac700e1e6dd9443d5d3cf7e468a3587de7fc40383fa8
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
pose_images/001.png ADDED

Git LFS Details

  • SHA256: 4656ad02618a7760a7214a1d494b73439f1a651df1ee9e0052b2417804614a56
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB
pose_images/002.png ADDED

Git LFS Details

  • SHA256: 9e493d8e9d17f601b47cf7124a91916c9370b5a6dc9b081749ca3116743e8b3f
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
pose_images/003.png ADDED

Git LFS Details

  • SHA256: bbdc5ba3553ed8d512061143db73beaf2adf13c55bc0fba291b5657e63ffbeb8
  • Pointer size: 130 Bytes
  • Size of remote file: 99 kB
pose_images/004.png ADDED

Git LFS Details

  • SHA256: 489a4c28711760b5c68f15b5bc94761c47c6f8fbb0fb307473736e6a08cf3991
  • Pointer size: 131 Bytes
  • Size of remote file: 149 kB
pose_images/005.png ADDED

Git LFS Details

  • SHA256: 03f08c831206f68beaa548c75272e932433f4ed65837698696805c92048e334c
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.4.1
2
+ lpips==0.1.4
3
+ mmcv-full==1.5.2
4
+ mmsegmentation==0.24.1
5
+ numpy==1.22.3
6
+ Pillow==9.1.1
7
+ sentence-transformers==2.2.0
8
+ tokenizers==0.12.1
9
+ torch==1.11.0
10
+ torchvision==0.12.0
11
+ transformers==4.19.2