File size: 2,820 Bytes
3060b7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53


from test_fatezero import *
from glob import glob
import copy

@click.command()
@click.option("--edit_config", type=str,    default="config/supp/style/0313_style_edit_warp_640.yaml")
@click.option("--dataset_config", type=str, default="data/supp_edit_dataset/dataset_prompt.yaml")
def run(edit_config, dataset_config):
    Omegadict_edit_config = OmegaConf.load(edit_config)
    Omegadict_dataset_config = OmegaConf.load(dataset_config)

    # Go trough all data sample
    data_sample_list = sorted(Omegadict_dataset_config.keys())
    print(f'Datasample to evaluate: {data_sample_list}')
    dataset_time_string = get_time_string()
    for data_sample in data_sample_list:
        print(f'Evaluate {data_sample}')

        for p2p_config_index, p2p_config in Omegadict_edit_config['validation_sample_logger_config']['p2p_config'].items():
            edit_config_now = copy.deepcopy(Omegadict_edit_config)
            edit_config_now['train_dataset'] = copy.deepcopy(Omegadict_dataset_config[data_sample])
            edit_config_now['train_dataset'].pop('target')
            if 'eq_params' in edit_config_now['train_dataset']:
                edit_config_now['train_dataset'].pop('eq_params')
            # edit_config_now['train_dataset']['prompt'] = Omegadict_dataset_config[data_sample]['source']
            
            edit_config_now['validation_sample_logger_config']['prompts'] \
                = copy.deepcopy( [Omegadict_dataset_config[data_sample]['prompt'],]+ OmegaConf.to_object(Omegadict_dataset_config[data_sample]['target']))
            p2p_config_now = dict()
            for i in range(len(edit_config_now['validation_sample_logger_config']['prompts'])):
                p2p_config_now[i] = p2p_config
                if 'eq_params' in Omegadict_dataset_config[data_sample]:
                    p2p_config_now[i]['eq_params'] = Omegadict_dataset_config[data_sample]['eq_params']
            
            edit_config_now['validation_sample_logger_config']['p2p_config'] = copy.deepcopy(p2p_config_now)
            edit_config_now['validation_sample_logger_config']['source_prompt'] = Omegadict_dataset_config[data_sample]['prompt']
            # edit_config_now['validation_sample_logger_config']['source_prompt'] = Omegadict_dataset_config[data_sample]['eq_params']
            
            
            # if 'logdir' not in edit_config_now:
            logdir = edit_config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')+f'_config_{p2p_config_index}'+f'_{os.path.basename(dataset_config)[:-5]}'+f'_{dataset_time_string}'
            logdir +=  f"/{data_sample}"            
            edit_config_now['logdir'] = logdir
            print(f'Saving at {logdir}')
            
            test(config=edit_config, **edit_config_now)


if __name__ == "__main__":
    run()