Text2Human / patch
hysts
Add files
b85284b
raw
history blame
7.11 kB
diff --git a/models/hierarchy_inference_model.py b/models/hierarchy_inference_model.py
index 3116307..5de661d 100644
--- a/models/hierarchy_inference_model.py
+++ b/models/hierarchy_inference_model.py
@@ -21,7 +21,7 @@ class VQGANTextureAwareSpatialHierarchyInferenceModel():
def __init__(self, opt):
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
self.is_train = opt['is_train']
self.top_encoder = Encoder(
diff --git a/models/hierarchy_vqgan_model.py b/models/hierarchy_vqgan_model.py
index 4b0d657..0bf4712 100644
--- a/models/hierarchy_vqgan_model.py
+++ b/models/hierarchy_vqgan_model.py
@@ -20,7 +20,7 @@ class HierarchyVQSpatialTextureAwareModel():
def __init__(self, opt):
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
self.top_encoder = Encoder(
ch=opt['top_ch'],
num_res_blocks=opt['top_num_res_blocks'],
diff --git a/models/parsing_gen_model.py b/models/parsing_gen_model.py
index 9440345..15a1ecb 100644
--- a/models/parsing_gen_model.py
+++ b/models/parsing_gen_model.py
@@ -22,7 +22,7 @@ class ParsingGenModel():
def __init__(self, opt):
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
self.is_train = opt['is_train']
self.attr_embedder = ShapeAttrEmbedding(
diff --git a/models/sample_model.py b/models/sample_model.py
index 4c60e3f..5265cd0 100644
--- a/models/sample_model.py
+++ b/models/sample_model.py
@@ -23,7 +23,7 @@ class BaseSampleModel():
def __init__(self, opt):
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
# hierarchical VQVAE
self.decoder = Decoder(
@@ -123,7 +123,7 @@ class BaseSampleModel():
def load_top_pretrain_models(self):
# load pretrained vqgan
- top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device)
self.decoder.load_state_dict(
top_vae_checkpoint['decoder'], strict=True)
@@ -137,7 +137,7 @@ class BaseSampleModel():
self.top_post_quant_conv.eval()
def load_bot_pretrain_network(self):
- checkpoint = torch.load(self.opt['bot_vae_path'])
+ checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device)
self.bot_decoder_res.load_state_dict(
checkpoint['bot_decoder_res'], strict=True)
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
@@ -153,7 +153,7 @@ class BaseSampleModel():
def load_pretrained_segm_token(self):
# load pretrained vqgan for segmentation mask
- segm_token_checkpoint = torch.load(self.opt['segm_token_path'])
+ segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device)
self.segm_encoder.load_state_dict(
segm_token_checkpoint['encoder'], strict=True)
self.segm_quantizer.load_state_dict(
@@ -166,7 +166,7 @@ class BaseSampleModel():
self.segm_quant_conv.eval()
def load_index_pred_network(self):
- checkpoint = torch.load(self.opt['pretrained_index_network'])
+ checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device)
self.index_pred_guidance_encoder.load_state_dict(
checkpoint['guidance_encoder'], strict=True)
self.index_pred_decoder.load_state_dict(
@@ -176,7 +176,7 @@ class BaseSampleModel():
self.index_pred_decoder.eval()
def load_sampler_pretrained_network(self):
- checkpoint = torch.load(self.opt['pretrained_sampler'])
+ checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device)
self.sampler_fn.load_state_dict(checkpoint, strict=True)
self.sampler_fn.eval()
@@ -397,7 +397,7 @@ class SampleFromPoseModel(BaseSampleModel):
[185, 210, 205], [130, 165, 180], [225, 141, 151]]
def load_shape_generation_models(self):
- checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
+ checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device)
self.shape_attr_embedder.load_state_dict(
checkpoint['embedder'], strict=True)
diff --git a/models/transformer_model.py b/models/transformer_model.py
index 7db0f3e..4523d17 100644
--- a/models/transformer_model.py
+++ b/models/transformer_model.py
@@ -21,7 +21,7 @@ class TransformerTextureAwareModel():
def __init__(self, opt):
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
self.is_train = opt['is_train']
# VQVAE for image
@@ -317,10 +317,10 @@ class TransformerTextureAwareModel():
def sample_fn(self, temp=1.0, sample_steps=None):
self._denoise_fn.eval()
- b, device = self.image.size(0), 'cuda'
+ b = self.image.size(0)
x_t = torch.ones(
- (b, np.prod(self.shape)), device=device).long() * self.mask_id
- unmasked = torch.zeros_like(x_t, device=device).bool()
+ (b, np.prod(self.shape)), device=self.device).long() * self.mask_id
+ unmasked = torch.zeros_like(x_t, device=self.device).bool()
sample_steps = list(range(1, sample_steps + 1))
texture_mask_flatten = self.texture_tokens.view(-1)
@@ -336,11 +336,11 @@ class TransformerTextureAwareModel():
for t in reversed(sample_steps):
print(f'Sample timestep {t:4d}', end='\r')
- t = torch.full((b, ), t, device=device, dtype=torch.long)
+ t = torch.full((b, ), t, device=self.device, dtype=torch.long)
# where to unmask
changes = torch.rand(
- x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
+ x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
# don't unmask somewhere already unmasked
changes = torch.bitwise_xor(changes,
torch.bitwise_and(changes, unmasked))
diff --git a/models/vqgan_model.py b/models/vqgan_model.py
index 13a2e70..9c840f1 100644
--- a/models/vqgan_model.py
+++ b/models/vqgan_model.py
@@ -20,7 +20,7 @@ class VQModel():
def __init__(self, opt):
super().__init__()
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
self.encoder = Encoder(
ch=opt['ch'],
num_res_blocks=opt['num_res_blocks'],
@@ -390,7 +390,7 @@ class VQImageSegmTextureModel(VQImageModel):
def __init__(self, opt):
self.opt = opt
- self.device = torch.device('cuda')
+ self.device = torch.device(opt['device'])
self.encoder = Encoder(
ch=opt['ch'],
num_res_blocks=opt['num_res_blocks'],