diff --git a/models/hierarchy_inference_model.py b/models/hierarchy_inference_model.py index 3116307..5de661d 100644 --- a/models/hierarchy_inference_model.py +++ b/models/hierarchy_inference_model.py @@ -21,7 +21,7 @@ class VQGANTextureAwareSpatialHierarchyInferenceModel(): def __init__(self, opt): self.opt = opt - self.device = torch.device('cuda') + self.device = torch.device(opt['device']) self.is_train = opt['is_train'] self.top_encoder = Encoder( diff --git a/models/hierarchy_vqgan_model.py b/models/hierarchy_vqgan_model.py index 4b0d657..0bf4712 100644 --- a/models/hierarchy_vqgan_model.py +++ b/models/hierarchy_vqgan_model.py @@ -20,7 +20,7 @@ class HierarchyVQSpatialTextureAwareModel(): def __init__(self, opt): self.opt = opt - self.device = torch.device('cuda') + self.device = torch.device(opt['device']) self.top_encoder = Encoder( ch=opt['top_ch'], num_res_blocks=opt['top_num_res_blocks'], diff --git a/models/parsing_gen_model.py b/models/parsing_gen_model.py index 9440345..15a1ecb 100644 --- a/models/parsing_gen_model.py +++ b/models/parsing_gen_model.py @@ -22,7 +22,7 @@ class ParsingGenModel(): def __init__(self, opt): self.opt = opt - self.device = torch.device('cuda') + self.device = torch.device(opt['device']) self.is_train = opt['is_train'] self.attr_embedder = ShapeAttrEmbedding( diff --git a/models/sample_model.py b/models/sample_model.py index 4c60e3f..5265cd0 100644 --- a/models/sample_model.py +++ b/models/sample_model.py @@ -23,7 +23,7 @@ class BaseSampleModel(): def __init__(self, opt): self.opt = opt - self.device = torch.device('cuda') + self.device = torch.device(opt['device']) # hierarchical VQVAE self.decoder = Decoder( @@ -123,7 +123,7 @@ class BaseSampleModel(): def load_top_pretrain_models(self): # load pretrained vqgan - top_vae_checkpoint = torch.load(self.opt['top_vae_path']) + top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device) self.decoder.load_state_dict( top_vae_checkpoint['decoder'], strict=True) @@ -137,7 +137,7 @@ class BaseSampleModel(): self.top_post_quant_conv.eval() def load_bot_pretrain_network(self): - checkpoint = torch.load(self.opt['bot_vae_path']) + checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device) self.bot_decoder_res.load_state_dict( checkpoint['bot_decoder_res'], strict=True) self.decoder.load_state_dict(checkpoint['decoder'], strict=True) @@ -153,7 +153,7 @@ class BaseSampleModel(): def load_pretrained_segm_token(self): # load pretrained vqgan for segmentation mask - segm_token_checkpoint = torch.load(self.opt['segm_token_path']) + segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device) self.segm_encoder.load_state_dict( segm_token_checkpoint['encoder'], strict=True) self.segm_quantizer.load_state_dict( @@ -166,7 +166,7 @@ class BaseSampleModel(): self.segm_quant_conv.eval() def load_index_pred_network(self): - checkpoint = torch.load(self.opt['pretrained_index_network']) + checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device) self.index_pred_guidance_encoder.load_state_dict( checkpoint['guidance_encoder'], strict=True) self.index_pred_decoder.load_state_dict( @@ -176,7 +176,7 @@ class BaseSampleModel(): self.index_pred_decoder.eval() def load_sampler_pretrained_network(self): - checkpoint = torch.load(self.opt['pretrained_sampler']) + checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device) self.sampler_fn.load_state_dict(checkpoint, strict=True) self.sampler_fn.eval() @@ -397,7 +397,7 @@ class SampleFromPoseModel(BaseSampleModel): [185, 210, 205], [130, 165, 180], [225, 141, 151]] def load_shape_generation_models(self): - checkpoint = torch.load(self.opt['pretrained_parsing_gen']) + checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device) self.shape_attr_embedder.load_state_dict( checkpoint['embedder'], strict=True) diff --git a/models/transformer_model.py b/models/transformer_model.py index 7db0f3e..4523d17 100644 --- a/models/transformer_model.py +++ b/models/transformer_model.py @@ -21,7 +21,7 @@ class TransformerTextureAwareModel(): def __init__(self, opt): self.opt = opt - self.device = torch.device('cuda') + self.device = torch.device(opt['device']) self.is_train = opt['is_train'] # VQVAE for image @@ -317,10 +317,10 @@ class TransformerTextureAwareModel(): def sample_fn(self, temp=1.0, sample_steps=None): self._denoise_fn.eval() - b, device = self.image.size(0), 'cuda' + b = self.image.size(0) x_t = torch.ones( - (b, np.prod(self.shape)), device=device).long() * self.mask_id - unmasked = torch.zeros_like(x_t, device=device).bool() + (b, np.prod(self.shape)), device=self.device).long() * self.mask_id + unmasked = torch.zeros_like(x_t, device=self.device).bool() sample_steps = list(range(1, sample_steps + 1)) texture_mask_flatten = self.texture_tokens.view(-1) @@ -336,11 +336,11 @@ class TransformerTextureAwareModel(): for t in reversed(sample_steps): print(f'Sample timestep {t:4d}', end='\r') - t = torch.full((b, ), t, device=device, dtype=torch.long) + t = torch.full((b, ), t, device=self.device, dtype=torch.long) # where to unmask changes = torch.rand( - x_t.shape, device=device) < 1 / t.float().unsqueeze(-1) + x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1) # don't unmask somewhere already unmasked changes = torch.bitwise_xor(changes, torch.bitwise_and(changes, unmasked)) diff --git a/models/vqgan_model.py b/models/vqgan_model.py index 13a2e70..9c840f1 100644 --- a/models/vqgan_model.py +++ b/models/vqgan_model.py @@ -20,7 +20,7 @@ class VQModel(): def __init__(self, opt): super().__init__() self.opt = opt - self.device = torch.device('cuda') + self.device = torch.device(opt['device']) self.encoder = Encoder( ch=opt['ch'], num_res_blocks=opt['num_res_blocks'], @@ -390,7 +390,7 @@ class VQImageSegmTextureModel(VQImageModel): def __init__(self, opt): self.opt = opt - self.device = torch.device('cuda') + self.device = torch.device(opt['device']) self.encoder = Encoder( ch=opt['ch'], num_res_blocks=opt['num_res_blocks'],