File size: 2,842 Bytes
6477265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
""" 
Utility functions.
"""
import os
import numpy as np

def generate_tr_val_te_subject_ids(subject_list, val_subject_id):
	val_subject = subject_list[val_subject_id]
	te_subject = subject_list[val_subject_id-1]
	subject_list.remove(val_subject)
	subject_list.remove(te_subject)
	tr_subjects = subject_list
	return tr_subjects, val_subject, te_subject

def generate_data_ids(data_dir, subject_list):
    in_ids, out_ids = [], []
    vendor_list = [vendor for vendor in os.listdir(data_dir) if '.' not in vendor]
    for vendor in vendor_list:
        vendor_dir = os.path.join(data_dir, vendor)
        view_list = [view for view in os.listdir(vendor_dir) if '.' not in view]
        for view in view_list:
            view_dir = os.path.join(vendor_dir, view) 
            subject_full_list = [subject for subject in os.listdir(view_dir) if '.' not in subject]
            for subject in subject_full_list:
                if subject in subject_list:
                    subject_dir = os.path.join(view_dir, subject)
                    org_data_dir = os.path.join(subject_dir, 'data_org')
                    org_data_id = os.path.join(org_data_dir, os.listdir(org_data_dir)[0])
                    clutter_list = [clutter for clutter in os.listdir(subject_dir)
                                    if clutter in ['data_NFClt', 'data_NFRvbClt', 'data_RvbClt']
                                    and '.' not in clutter]
                    for clutter in clutter_list:
                        clutter_dir = os.path.join(subject_dir, clutter)
                        clutter_ids = os.listdir(clutter_dir)
                        clutter_ids_dir = [os.path.join(clutter_dir, id_dir) for id_dir in clutter_ids if '.DS' not in id_dir]
                        in_ids += clutter_ids_dir
                        out_ids += [org_data_id]*len(os.listdir(clutter_dir))
    return in_ids, out_ids

def id_preparation(config):
    tr_subjects, val_subject, te_subject = generate_tr_val_te_subject_ids(
        subject_list=config["subject_list"], val_subject_id=config["CV"]["val_subject_id"])
    if config["tr_phase"]:
        in_ids_tr, out_ids_tr = generate_data_ids(config["paths"]["data_path"], tr_subjects)
        in_ids_val, out_ids_val = generate_data_ids(config["paths"]["data_path"], val_subject)
        return in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject
    else:
        in_ids_te, out_ids_te = generate_data_ids(config["paths"]["data_path"], te_subject)
        return in_ids_te, out_ids_te, te_subject, val_subject

def create_weight_dir(val_subject, te_subject, config):
    weight_dir = os.path.join(config["paths"]["save_path"],
                              "Weights", f"ValTeIDs_{val_subject}_{te_subject}")  
    if not os.path.exists(weight_dir):
        os.makedirs(weight_dir)
    return weight_dir