PTI / evaluation /experiment_setting_creator.py
ucalyptus's picture
simp
2d7efb8
import glob
import os
from configs import global_config, paths_config, hyperparameters
from scripts.latent_creators.sg2_plus_latent_creator import SG2PlusLatentCreator
from scripts.latent_creators.e4e_latent_creator import E4ELatentCreator
from scripts.run_pti import run_PTI
import pickle
import torch
from utils.models_utils import toogle_grad, load_old_G
class ExperimentRunner:
def __init__(self, run_id=''):
self.images_paths = glob.glob(f'{paths_config.input_data_path}/*')
self.target_paths = glob.glob(f'{paths_config.input_data_path}/*')
self.run_id = run_id
self.sampled_ws = None
self.old_G = load_old_G()
toogle_grad(self.old_G, False)
def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False):
if run_pt:
self.run_id = run_PTI(self.run_id, use_wandb=use_wandb, use_multi_id_training=use_multi_id_training)
if create_other_latents:
sg2_plus_latent_creator = SG2PlusLatentCreator(use_wandb=use_wandb)
sg2_plus_latent_creator.create_latents()
e4e_latent_creator = E4ELatentCreator(use_wandb=use_wandb)
e4e_latent_creator.create_latents()
torch.cuda.empty_cache()
return self.run_id
if __name__ == '__main__':
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices
runner = ExperimentRunner()
runner.run_experiment(True, False, False)