Spaces:
Runtime error
Runtime error
| diff --git a/models/hierarchy_inference_model.py b/models/hierarchy_inference_model.py | |
| index 3116307..5de661d 100644 | |
| --- a/models/hierarchy_inference_model.py | |
| +++ b/models/hierarchy_inference_model.py | |
| class VQGANTextureAwareSpatialHierarchyInferenceModel(): | |
| def __init__(self, opt): | |
| self.opt = opt | |
| - self.device = torch.device('cuda') | |
| + self.device = torch.device(opt['device']) | |
| self.is_train = opt['is_train'] | |
| self.top_encoder = Encoder( | |
| diff --git a/models/hierarchy_vqgan_model.py b/models/hierarchy_vqgan_model.py | |
| index 4b0d657..0bf4712 100644 | |
| --- a/models/hierarchy_vqgan_model.py | |
| +++ b/models/hierarchy_vqgan_model.py | |
| class HierarchyVQSpatialTextureAwareModel(): | |
| def __init__(self, opt): | |
| self.opt = opt | |
| - self.device = torch.device('cuda') | |
| + self.device = torch.device(opt['device']) | |
| self.top_encoder = Encoder( | |
| ch=opt['top_ch'], | |
| num_res_blocks=opt['top_num_res_blocks'], | |
| diff --git a/models/parsing_gen_model.py b/models/parsing_gen_model.py | |
| index 9440345..15a1ecb 100644 | |
| --- a/models/parsing_gen_model.py | |
| +++ b/models/parsing_gen_model.py | |
| class ParsingGenModel(): | |
| def __init__(self, opt): | |
| self.opt = opt | |
| - self.device = torch.device('cuda') | |
| + self.device = torch.device(opt['device']) | |
| self.is_train = opt['is_train'] | |
| self.attr_embedder = ShapeAttrEmbedding( | |
| diff --git a/models/sample_model.py b/models/sample_model.py | |
| index 4c60e3f..5265cd0 100644 | |
| --- a/models/sample_model.py | |
| +++ b/models/sample_model.py | |
| class BaseSampleModel(): | |
| def __init__(self, opt): | |
| self.opt = opt | |
| - self.device = torch.device('cuda') | |
| + self.device = torch.device(opt['device']) | |
| # hierarchical VQVAE | |
| self.decoder = Decoder( | |
| class BaseSampleModel(): | |
| def load_top_pretrain_models(self): | |
| # load pretrained vqgan | |
| - top_vae_checkpoint = torch.load(self.opt['top_vae_path']) | |
| + top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device) | |
| self.decoder.load_state_dict( | |
| top_vae_checkpoint['decoder'], strict=True) | |
| class BaseSampleModel(): | |
| self.top_post_quant_conv.eval() | |
| def load_bot_pretrain_network(self): | |
| - checkpoint = torch.load(self.opt['bot_vae_path']) | |
| + checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device) | |
| self.bot_decoder_res.load_state_dict( | |
| checkpoint['bot_decoder_res'], strict=True) | |
| self.decoder.load_state_dict(checkpoint['decoder'], strict=True) | |
| class BaseSampleModel(): | |
| def load_pretrained_segm_token(self): | |
| # load pretrained vqgan for segmentation mask | |
| - segm_token_checkpoint = torch.load(self.opt['segm_token_path']) | |
| + segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device) | |
| self.segm_encoder.load_state_dict( | |
| segm_token_checkpoint['encoder'], strict=True) | |
| self.segm_quantizer.load_state_dict( | |
| class BaseSampleModel(): | |
| self.segm_quant_conv.eval() | |
| def load_index_pred_network(self): | |
| - checkpoint = torch.load(self.opt['pretrained_index_network']) | |
| + checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device) | |
| self.index_pred_guidance_encoder.load_state_dict( | |
| checkpoint['guidance_encoder'], strict=True) | |
| self.index_pred_decoder.load_state_dict( | |
| class BaseSampleModel(): | |
| self.index_pred_decoder.eval() | |
| def load_sampler_pretrained_network(self): | |
| - checkpoint = torch.load(self.opt['pretrained_sampler']) | |
| + checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device) | |
| self.sampler_fn.load_state_dict(checkpoint, strict=True) | |
| self.sampler_fn.eval() | |
| class SampleFromPoseModel(BaseSampleModel): | |
| [185, 210, 205], [130, 165, 180], [225, 141, 151]] | |
| def load_shape_generation_models(self): | |
| - checkpoint = torch.load(self.opt['pretrained_parsing_gen']) | |
| + checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device) | |
| self.shape_attr_embedder.load_state_dict( | |
| checkpoint['embedder'], strict=True) | |
| diff --git a/models/transformer_model.py b/models/transformer_model.py | |
| index 7db0f3e..4523d17 100644 | |
| --- a/models/transformer_model.py | |
| +++ b/models/transformer_model.py | |
| class TransformerTextureAwareModel(): | |
| def __init__(self, opt): | |
| self.opt = opt | |
| - self.device = torch.device('cuda') | |
| + self.device = torch.device(opt['device']) | |
| self.is_train = opt['is_train'] | |
| # VQVAE for image | |
| class TransformerTextureAwareModel(): | |
| def sample_fn(self, temp=1.0, sample_steps=None): | |
| self._denoise_fn.eval() | |
| - b, device = self.image.size(0), 'cuda' | |
| + b = self.image.size(0) | |
| x_t = torch.ones( | |
| - (b, np.prod(self.shape)), device=device).long() * self.mask_id | |
| - unmasked = torch.zeros_like(x_t, device=device).bool() | |
| + (b, np.prod(self.shape)), device=self.device).long() * self.mask_id | |
| + unmasked = torch.zeros_like(x_t, device=self.device).bool() | |
| sample_steps = list(range(1, sample_steps + 1)) | |
| texture_mask_flatten = self.texture_tokens.view(-1) | |
| class TransformerTextureAwareModel(): | |
| for t in reversed(sample_steps): | |
| print(f'Sample timestep {t:4d}', end='\r') | |
| - t = torch.full((b, ), t, device=device, dtype=torch.long) | |
| + t = torch.full((b, ), t, device=self.device, dtype=torch.long) | |
| # where to unmask | |
| changes = torch.rand( | |
| - x_t.shape, device=device) < 1 / t.float().unsqueeze(-1) | |
| + x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1) | |
| # don't unmask somewhere already unmasked | |
| changes = torch.bitwise_xor(changes, | |
| torch.bitwise_and(changes, unmasked)) | |
| diff --git a/models/vqgan_model.py b/models/vqgan_model.py | |
| index 13a2e70..9c840f1 100644 | |
| --- a/models/vqgan_model.py | |
| +++ b/models/vqgan_model.py | |
| class VQModel(): | |
| def __init__(self, opt): | |
| super().__init__() | |
| self.opt = opt | |
| - self.device = torch.device('cuda') | |
| + self.device = torch.device(opt['device']) | |
| self.encoder = Encoder( | |
| ch=opt['ch'], | |
| num_res_blocks=opt['num_res_blocks'], | |
| class VQImageSegmTextureModel(VQImageModel): | |
| def __init__(self, opt): | |
| self.opt = opt | |
| - self.device = torch.device('cuda') | |
| + self.device = torch.device(opt['device']) | |
| self.encoder = Encoder( | |
| ch=opt['ch'], | |
| num_res_blocks=opt['num_res_blocks'], | |