diff --git a/src/gan_control/inference/controller.py b/src/gan_control/inference/controller.py index ee464ba..d1907dd 100644 --- a/src/gan_control/inference/controller.py +++ b/src/gan_control/inference/controller.py @@ -13,9 +13,9 @@ _log = get_logger(__name__) class Controller(Inference): - def __init__(self, controller_dir): + def __init__(self, controller_dir, device): _log.info('Init Controller class...') - super(Controller, self).__init__(os.path.join(controller_dir, 'generator')) + super(Controller, self).__init__(os.path.join(controller_dir, 'generator'), device) self.fc_controls = {} self.config_controls = {} for sub_group_name in self.batch_utils.sub_group_names: @@ -29,21 +29,21 @@ class Controller(Inference): @torch.no_grad() def gen_batch_by_controls(self, batch_size=1, latent=None, normalize=True, input_is_latent=False, static_noise=True, **kwargs): if latent is None: - latent = torch.randn(batch_size, self.config.model_config['latent_size'], device='cuda') + latent = torch.randn(batch_size, self.config.model_config['latent_size'], device=self.device) latent = latent.clone() if input_is_latent: latent_w = latent else: if isinstance(self.model, torch.nn.DataParallel): - latent_w = self.model.module.style(latent.cuda()) + latent_w = self.model.module.style(latent.to(self.device)) else: - latent_w = self.model.style(latent.cuda()) + latent_w = self.model.style(latent.to(self.device)) for group_key in kwargs.keys(): if self.check_if_group_has_control(group_key): if group_key == 'expression' and kwargs[group_key].shape[1] == 8: - group_w_latent = self.fc_controls['expression_q'](kwargs[group_key].cuda().float()) + group_w_latent = self.fc_controls['expression_q'](kwargs[group_key].to(self.device).float()) else: - group_w_latent = self.fc_controls[group_key](kwargs[group_key].cuda().float()) + group_w_latent = self.fc_controls[group_key](kwargs[group_key].to(self.device).float()) latent_w = self.insert_group_w_latent(latent_w, group_w_latent, group_key) injection_noise = None if static_noise: @@ -101,12 +101,12 @@ class Controller(Inference): ckpt_path = ckpt_list[-1] ckpt_iter = ckpt_path.split('.')[0] config = read_json(config_path, return_obj=True) - ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path)) + ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path), map_location=self.device) group_chunk = self.batch_utils.place_in_latent_dict[sub_group_name if sub_group_name is not 'expression_q' else 'expression'] group_latent_size = group_chunk[1] - group_chunk[0] _log.info('Init %s Controller...' % sub_group_name) - controller = FcStack(config.model_config['lr_mlp'], config.model_config['n_mlp'], config.model_config['in_dim'], config.model_config['mid_dim'], group_latent_size).cuda() + controller = FcStack(config.model_config['lr_mlp'], config.model_config['n_mlp'], config.model_config['in_dim'], config.model_config['mid_dim'], group_latent_size).to(self.device) controller.print() _log.info('Loading Controller: %s, ckpt iter %s' % (controller_dir_path, ckpt_iter)) diff --git a/src/gan_control/inference/inference.py b/src/gan_control/inference/inference.py index e6ccedb..4393bb7 100644 --- a/src/gan_control/inference/inference.py +++ b/src/gan_control/inference/inference.py @@ -15,10 +15,11 @@ _log = get_logger(__name__) class Inference(): - def __init__(self, model_dir): + def __init__(self, model_dir, device): _log.info('Init inference class...') self.model_dir = model_dir - self.model, self.batch_utils, self.config, self.ckpt_iter = self.retrieve_model(model_dir) + self.device = device + self.model, self.batch_utils, self.config, self.ckpt_iter = self.retrieve_model(model_dir, device) self.noise = None self.reset_noise() self.mean_w_latent = None @@ -28,7 +29,7 @@ class Inference(): _log.info('Calc mean_w_latents...') mean_latent_w_list = [] for i in range(100): - latent_z = torch.randn(1000, self.config.model_config['latent_size'], device='cuda') + latent_z = torch.randn(1000, self.config.model_config['latent_size'], device=self.device) if isinstance(self.model, torch.nn.DataParallel): latent_w = self.model.module.style(latent_z).cpu() else: @@ -41,9 +42,9 @@ class Inference(): def reset_noise(self): if isinstance(self.model, torch.nn.DataParallel): - self.noise = self.model.module.make_noise(device='cuda') + self.noise = self.model.module.make_noise(device=self.device) else: - self.noise = self.model.make_noise(device='cuda') + self.noise = self.model.make_noise(device=self.device) @staticmethod def expend_noise(noise, batch_size): @@ -56,14 +57,14 @@ class Inference(): self.calc_mean_w_latents() injection_noise = None if latent is None: - latent = torch.randn(batch_size, self.config.model_config['latent_size'], device='cuda') + latent = torch.randn(batch_size, self.config.model_config['latent_size'], device=self.device) elif input_is_latent: - latent = latent.cuda() + latent = latent.to(self.device) for group_key in kwargs.keys(): if group_key not in self.batch_utils.sub_group_names: raise ValueError('group_key: %s not in sub_group_names %s' % (group_key, str(self.batch_utils.sub_group_names))) if isinstance(kwargs[group_key], str) and kwargs[group_key] == 'random': - group_latent_w = self.model.style(torch.randn(latent.shape[0], self.config.model_config['latent_size'], device='cuda')) + group_latent_w = self.model.style(torch.randn(latent.shape[0], self.config.model_config['latent_size'], device=self.device)) group_latent_w = group_latent_w[:, self.batch_utils.place_in_latent_dict[group_key][0], self.batch_utils.place_in_latent_dict[group_key][0]] latent[:, self.batch_utils.place_in_latent_dict[group_key][0], self.batch_utils.place_in_latent_dict[group_key][0]] = group_latent_w if static_noise: @@ -82,11 +83,11 @@ class Inference(): latent[:, place_in_latent[0]: place_in_latent[1]] = \ truncation * (latent[:, place_in_latent[0]: place_in_latent[1]] - torch.cat( [self.mean_w_latents[key].clone().unsqueeze(0) for _ in range(latent.shape[0])], dim=0 - ).cuda()) + torch.cat( + ).to(self.device)) + torch.cat( [self.mean_w_latents[key].clone().unsqueeze(0) for _ in range(latent.shape[0])], dim=0 - ).cuda() + ).to(self.device) - tensor, latent_w = self.model([latent.cuda()], return_latents=True, input_is_latent=input_is_latent, noise=injection_noise) + tensor, latent_w = self.model([latent.to(self.device)], return_latents=True, input_is_latent=input_is_latent, noise=injection_noise) if normalize: tensor = tensor.mul(0.5).add(0.5).clamp(min=0., max=1.).cpu() return tensor, latent, latent_w @@ -107,7 +108,7 @@ class Inference(): return grid_image @staticmethod - def retrieve_model(model_dir): + def retrieve_model(model_dir, device): config_path = os.path.join(model_dir, 'args.json') _log.info('Retrieve config from %s' % config_path) @@ -117,7 +118,7 @@ class Inference(): ckpt_path = ckpt_list[-1] ckpt_iter = ckpt_path.split('.')[0] config = read_json(config_path, return_obj=True) - ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path)) + ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path), map_location=device) batch_utils = None if not config.model_config['vanilla']: @@ -140,7 +141,7 @@ class Inference(): fc_config=None if config.model_config['vanilla'] else batch_utils.get_fc_config(), conv_transpose=config.model_config['conv_transpose'], noise_mode=config.model_config['g_noise_mode'] - ).cuda() + ).to(device) _log.info('Loading Model: %s, ckpt iter %s' % (model_dir, ckpt_iter)) model.load_state_dict(ckpt['g_ema']) model = torch.nn.DataParallel(model)