Spaces:
Running
Running
diff --git a/src/gan_control/inference/controller.py b/src/gan_control/inference/controller.py | |
index ee464ba..d1907dd 100644 | |
--- a/src/gan_control/inference/controller.py | |
+++ b/src/gan_control/inference/controller.py | |
class Controller(Inference): | |
- def __init__(self, controller_dir): | |
+ def __init__(self, controller_dir, device): | |
_log.info('Init Controller class...') | |
- super(Controller, self).__init__(os.path.join(controller_dir, 'generator')) | |
+ super(Controller, self).__init__(os.path.join(controller_dir, 'generator'), device) | |
self.fc_controls = {} | |
self.config_controls = {} | |
for sub_group_name in self.batch_utils.sub_group_names: | |
def gen_batch_by_controls(self, batch_size=1, latent=None, normalize=True, input_is_latent=False, static_noise=True, **kwargs): | |
if latent is None: | |
- latent = torch.randn(batch_size, self.config.model_config['latent_size'], device='cuda') | |
+ latent = torch.randn(batch_size, self.config.model_config['latent_size'], device=self.device) | |
latent = latent.clone() | |
if input_is_latent: | |
latent_w = latent | |
else: | |
if isinstance(self.model, torch.nn.DataParallel): | |
- latent_w = self.model.module.style(latent.cuda()) | |
+ latent_w = self.model.module.style(latent.to(self.device)) | |
else: | |
- latent_w = self.model.style(latent.cuda()) | |
+ latent_w = self.model.style(latent.to(self.device)) | |
for group_key in kwargs.keys(): | |
if self.check_if_group_has_control(group_key): | |
if group_key == 'expression' and kwargs[group_key].shape[1] == 8: | |
- group_w_latent = self.fc_controls['expression_q'](kwargs[group_key].cuda().float()) | |
+ group_w_latent = self.fc_controls['expression_q'](kwargs[group_key].to(self.device).float()) | |
else: | |
- group_w_latent = self.fc_controls[group_key](kwargs[group_key].cuda().float()) | |
+ group_w_latent = self.fc_controls[group_key](kwargs[group_key].to(self.device).float()) | |
latent_w = self.insert_group_w_latent(latent_w, group_w_latent, group_key) | |
injection_noise = None | |
if static_noise: | |
ckpt_path = ckpt_list[-1] | |
ckpt_iter = ckpt_path.split('.')[0] | |
config = read_json(config_path, return_obj=True) | |
- ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path)) | |
+ ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path), map_location=self.device) | |
group_chunk = self.batch_utils.place_in_latent_dict[sub_group_name if sub_group_name is not 'expression_q' else 'expression'] | |
group_latent_size = group_chunk[1] - group_chunk[0] | |
_log.info('Init %s Controller...' % sub_group_name) | |
- controller = FcStack(config.model_config['lr_mlp'], config.model_config['n_mlp'], config.model_config['in_dim'], config.model_config['mid_dim'], group_latent_size).cuda() | |
+ controller = FcStack(config.model_config['lr_mlp'], config.model_config['n_mlp'], config.model_config['in_dim'], config.model_config['mid_dim'], group_latent_size).to(self.device) | |
controller.print() | |
_log.info('Loading Controller: %s, ckpt iter %s' % (controller_dir_path, ckpt_iter)) | |
diff --git a/src/gan_control/inference/inference.py b/src/gan_control/inference/inference.py | |
index e6ccedb..4393bb7 100644 | |
--- a/src/gan_control/inference/inference.py | |
+++ b/src/gan_control/inference/inference.py | |
class Inference(): | |
- def __init__(self, model_dir): | |
+ def __init__(self, model_dir, device): | |
_log.info('Init inference class...') | |
self.model_dir = model_dir | |
- self.model, self.batch_utils, self.config, self.ckpt_iter = self.retrieve_model(model_dir) | |
+ self.device = device | |
+ self.model, self.batch_utils, self.config, self.ckpt_iter = self.retrieve_model(model_dir, device) | |
self.noise = None | |
self.reset_noise() | |
self.mean_w_latent = None | |
_log.info('Calc mean_w_latents...') | |
mean_latent_w_list = [] | |
for i in range(100): | |
- latent_z = torch.randn(1000, self.config.model_config['latent_size'], device='cuda') | |
+ latent_z = torch.randn(1000, self.config.model_config['latent_size'], device=self.device) | |
if isinstance(self.model, torch.nn.DataParallel): | |
latent_w = self.model.module.style(latent_z).cpu() | |
else: | |
def reset_noise(self): | |
if isinstance(self.model, torch.nn.DataParallel): | |
- self.noise = self.model.module.make_noise(device='cuda') | |
+ self.noise = self.model.module.make_noise(device=self.device) | |
else: | |
- self.noise = self.model.make_noise(device='cuda') | |
+ self.noise = self.model.make_noise(device=self.device) | |
def expend_noise(noise, batch_size): | |
self.calc_mean_w_latents() | |
injection_noise = None | |
if latent is None: | |
- latent = torch.randn(batch_size, self.config.model_config['latent_size'], device='cuda') | |
+ latent = torch.randn(batch_size, self.config.model_config['latent_size'], device=self.device) | |
elif input_is_latent: | |
- latent = latent.cuda() | |
+ latent = latent.to(self.device) | |
for group_key in kwargs.keys(): | |
if group_key not in self.batch_utils.sub_group_names: | |
raise ValueError('group_key: %s not in sub_group_names %s' % (group_key, str(self.batch_utils.sub_group_names))) | |
if isinstance(kwargs[group_key], str) and kwargs[group_key] == 'random': | |
- group_latent_w = self.model.style(torch.randn(latent.shape[0], self.config.model_config['latent_size'], device='cuda')) | |
+ group_latent_w = self.model.style(torch.randn(latent.shape[0], self.config.model_config['latent_size'], device=self.device)) | |
group_latent_w = group_latent_w[:, self.batch_utils.place_in_latent_dict[group_key][0], self.batch_utils.place_in_latent_dict[group_key][0]] | |
latent[:, self.batch_utils.place_in_latent_dict[group_key][0], self.batch_utils.place_in_latent_dict[group_key][0]] = group_latent_w | |
if static_noise: | |
latent[:, place_in_latent[0]: place_in_latent[1]] = \ | |
truncation * (latent[:, place_in_latent[0]: place_in_latent[1]] - torch.cat( | |
[self.mean_w_latents[key].clone().unsqueeze(0) for _ in range(latent.shape[0])], dim=0 | |
- ).cuda()) + torch.cat( | |
+ ).to(self.device)) + torch.cat( | |
[self.mean_w_latents[key].clone().unsqueeze(0) for _ in range(latent.shape[0])], dim=0 | |
- ).cuda() | |
+ ).to(self.device) | |
- tensor, latent_w = self.model([latent.cuda()], return_latents=True, input_is_latent=input_is_latent, noise=injection_noise) | |
+ tensor, latent_w = self.model([latent.to(self.device)], return_latents=True, input_is_latent=input_is_latent, noise=injection_noise) | |
if normalize: | |
tensor = tensor.mul(0.5).add(0.5).clamp(min=0., max=1.).cpu() | |
return tensor, latent, latent_w | |
return grid_image | |
- def retrieve_model(model_dir): | |
+ def retrieve_model(model_dir, device): | |
config_path = os.path.join(model_dir, 'args.json') | |
_log.info('Retrieve config from %s' % config_path) | |
ckpt_path = ckpt_list[-1] | |
ckpt_iter = ckpt_path.split('.')[0] | |
config = read_json(config_path, return_obj=True) | |
- ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path)) | |
+ ckpt = torch.load(os.path.join(checkpoints_path, ckpt_path), map_location=device) | |
batch_utils = None | |
if not config.model_config['vanilla']: | |
fc_config=None if config.model_config['vanilla'] else batch_utils.get_fc_config(), | |
conv_transpose=config.model_config['conv_transpose'], | |
noise_mode=config.model_config['g_noise_mode'] | |
- ).cuda() | |
+ ).to(device) | |
_log.info('Loading Model: %s, ckpt iter %s' % (model_dir, ckpt_iter)) | |
model.load_state_dict(ckpt['g_ema']) | |
model = torch.nn.DataParallel(model) | |