File size: 1,520 Bytes
2d7efb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import glob
import os
from configs import global_config, paths_config, hyperparameters
from scripts.latent_creators.sg2_plus_latent_creator import SG2PlusLatentCreator
from scripts.latent_creators.e4e_latent_creator import E4ELatentCreator
from scripts.run_pti import run_PTI
import pickle
import torch
from utils.models_utils import toogle_grad, load_old_G


class ExperimentRunner:

    def __init__(self, run_id=''):
        self.images_paths = glob.glob(f'{paths_config.input_data_path}/*')
        self.target_paths = glob.glob(f'{paths_config.input_data_path}/*')
        self.run_id = run_id
        self.sampled_ws = None

        self.old_G = load_old_G()

        toogle_grad(self.old_G, False)

    def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False):
        if run_pt:
            self.run_id = run_PTI(self.run_id, use_wandb=use_wandb, use_multi_id_training=use_multi_id_training)
        if create_other_latents:
            sg2_plus_latent_creator = SG2PlusLatentCreator(use_wandb=use_wandb)
            sg2_plus_latent_creator.create_latents()
            e4e_latent_creator = E4ELatentCreator(use_wandb=use_wandb)
            e4e_latent_creator.create_latents()

        torch.cuda.empty_cache()

        return self.run_id


if __name__ == '__main__':
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices

    runner = ExperimentRunner()
    runner.run_experiment(True, False, False)