|
|
|
|
|
|
|
|
|
@@ -21,7 +21,7 @@ class VQGANTextureAwareSpatialHierarchyInferenceModel(): |
|
|
|
def __init__(self, opt): |
|
self.opt = opt |
|
- self.device = torch.device('cuda') |
|
+ self.device = torch.device(opt['device']) |
|
self.is_train = opt['is_train'] |
|
|
|
self.top_encoder = Encoder( |
|
|
|
|
|
|
|
|
|
@@ -20,7 +20,7 @@ class HierarchyVQSpatialTextureAwareModel(): |
|
|
|
def __init__(self, opt): |
|
self.opt = opt |
|
- self.device = torch.device('cuda') |
|
+ self.device = torch.device(opt['device']) |
|
self.top_encoder = Encoder( |
|
ch=opt['top_ch'], |
|
num_res_blocks=opt['top_num_res_blocks'], |
|
|
|
|
|
|
|
|
|
@@ -22,7 +22,7 @@ class ParsingGenModel(): |
|
|
|
def __init__(self, opt): |
|
self.opt = opt |
|
- self.device = torch.device('cuda') |
|
+ self.device = torch.device(opt['device']) |
|
self.is_train = opt['is_train'] |
|
|
|
self.attr_embedder = ShapeAttrEmbedding( |
|
|
|
|
|
|
|
|
|
@@ -23,7 +23,7 @@ class BaseSampleModel(): |
|
|
|
def __init__(self, opt): |
|
self.opt = opt |
|
- self.device = torch.device('cuda') |
|
+ self.device = torch.device(opt['device']) |
|
|
|
# hierarchical VQVAE |
|
self.decoder = Decoder( |
|
@@ -123,7 +123,7 @@ class BaseSampleModel(): |
|
|
|
def load_top_pretrain_models(self): |
|
# load pretrained vqgan |
|
- top_vae_checkpoint = torch.load(self.opt['top_vae_path']) |
|
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device) |
|
|
|
self.decoder.load_state_dict( |
|
top_vae_checkpoint['decoder'], strict=True) |
|
@@ -137,7 +137,7 @@ class BaseSampleModel(): |
|
self.top_post_quant_conv.eval() |
|
|
|
def load_bot_pretrain_network(self): |
|
- checkpoint = torch.load(self.opt['bot_vae_path']) |
|
+ checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device) |
|
self.bot_decoder_res.load_state_dict( |
|
checkpoint['bot_decoder_res'], strict=True) |
|
self.decoder.load_state_dict(checkpoint['decoder'], strict=True) |
|
@@ -153,7 +153,7 @@ class BaseSampleModel(): |
|
|
|
def load_pretrained_segm_token(self): |
|
# load pretrained vqgan for segmentation mask |
|
- segm_token_checkpoint = torch.load(self.opt['segm_token_path']) |
|
+ segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device) |
|
self.segm_encoder.load_state_dict( |
|
segm_token_checkpoint['encoder'], strict=True) |
|
self.segm_quantizer.load_state_dict( |
|
@@ -166,7 +166,7 @@ class BaseSampleModel(): |
|
self.segm_quant_conv.eval() |
|
|
|
def load_index_pred_network(self): |
|
- checkpoint = torch.load(self.opt['pretrained_index_network']) |
|
+ checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device) |
|
self.index_pred_guidance_encoder.load_state_dict( |
|
checkpoint['guidance_encoder'], strict=True) |
|
self.index_pred_decoder.load_state_dict( |
|
@@ -176,7 +176,7 @@ class BaseSampleModel(): |
|
self.index_pred_decoder.eval() |
|
|
|
def load_sampler_pretrained_network(self): |
|
- checkpoint = torch.load(self.opt['pretrained_sampler']) |
|
+ checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device) |
|
self.sampler_fn.load_state_dict(checkpoint, strict=True) |
|
self.sampler_fn.eval() |
|
|
|
@@ -397,7 +397,7 @@ class SampleFromPoseModel(BaseSampleModel): |
|
[185, 210, 205], [130, 165, 180], [225, 141, 151]] |
|
|
|
def load_shape_generation_models(self): |
|
- checkpoint = torch.load(self.opt['pretrained_parsing_gen']) |
|
+ checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device) |
|
|
|
self.shape_attr_embedder.load_state_dict( |
|
checkpoint['embedder'], strict=True) |
|
|
|
|
|
|
|
|
|
@@ -21,7 +21,7 @@ class TransformerTextureAwareModel(): |
|
|
|
def __init__(self, opt): |
|
self.opt = opt |
|
- self.device = torch.device('cuda') |
|
+ self.device = torch.device(opt['device']) |
|
self.is_train = opt['is_train'] |
|
|
|
# VQVAE for image |
|
@@ -317,10 +317,10 @@ class TransformerTextureAwareModel(): |
|
def sample_fn(self, temp=1.0, sample_steps=None): |
|
self._denoise_fn.eval() |
|
|
|
- b, device = self.image.size(0), 'cuda' |
|
+ b = self.image.size(0) |
|
x_t = torch.ones( |
|
- (b, np.prod(self.shape)), device=device).long() * self.mask_id |
|
- unmasked = torch.zeros_like(x_t, device=device).bool() |
|
+ (b, np.prod(self.shape)), device=self.device).long() * self.mask_id |
|
+ unmasked = torch.zeros_like(x_t, device=self.device).bool() |
|
sample_steps = list(range(1, sample_steps + 1)) |
|
|
|
texture_mask_flatten = self.texture_tokens.view(-1) |
|
@@ -336,11 +336,11 @@ class TransformerTextureAwareModel(): |
|
|
|
for t in reversed(sample_steps): |
|
print(f'Sample timestep {t:4d}', end='\r') |
|
- t = torch.full((b, ), t, device=device, dtype=torch.long) |
|
+ t = torch.full((b, ), t, device=self.device, dtype=torch.long) |
|
|
|
# where to unmask |
|
changes = torch.rand( |
|
- x_t.shape, device=device) < 1 / t.float().unsqueeze(-1) |
|
+ x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1) |
|
# don't unmask somewhere already unmasked |
|
changes = torch.bitwise_xor(changes, |
|
torch.bitwise_and(changes, unmasked)) |
|
|
|
|
|
|
|
|
|
@@ -20,7 +20,7 @@ class VQModel(): |
|
def __init__(self, opt): |
|
super().__init__() |
|
self.opt = opt |
|
- self.device = torch.device('cuda') |
|
+ self.device = torch.device(opt['device']) |
|
self.encoder = Encoder( |
|
ch=opt['ch'], |
|
num_res_blocks=opt['num_res_blocks'], |
|
@@ -390,7 +390,7 @@ class VQImageSegmTextureModel(VQImageModel): |
|
|
|
def __init__(self, opt): |
|
self.opt = opt |
|
- self.device = torch.device('cuda') |
|
+ self.device = torch.device(opt['device']) |
|
self.encoder = Encoder( |
|
ch=opt['ch'], |
|
num_res_blocks=opt['num_res_blocks'], |
|
|