File size: 3,346 Bytes
8f09a1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pkg_resources
import json
import copy
import torch

import spiga.data.loaders.dl_config as dl_cfg
import spiga.data.loaders.dataloader as dl
import spiga.inference.pretreatment as pretreat
from spiga.inference.framework import SPIGAFramework
from spiga.inference.config import ModelConfig


def main():
    import argparse
    pars = argparse.ArgumentParser(description='Experiment results generator')
    pars.add_argument('database', type=str, help='Database name',
                      choices=['wflw', '300wpublic', '300wprivate', "merlrav", "cofw68"])
    pars.add_argument('-a','--anns', type=str, default='test', help='Annotations type: test, valid or train')
    pars.add_argument('--gpus', type=int, default=0, help='GPU Id')
    args = pars.parse_args()

    # Load model framework
    model_cfg = ModelConfig(args.database)
    model_framework = SPIGAFramework(model_cfg, gpus=[args.gpus])

    # Generate results
    tester = Tester(model_framework, args.database, anns_type=args.anns)
    with torch.no_grad():
        tester.generate_results()


class Tester:

    def __init__(self, model_framework, database, anns_type='test'):

        # Parameters
        self.anns_type = anns_type
        self.database = database

        # Model initialization
        self.model_framework = model_framework

        # Dataloader
        self.dl_eval = dl_cfg.AlignConfig(self.database, mode=self.anns_type)
        self.dl_eval.aug_names = []
        self.dl_eval.shuffle = False
        self.dl_eval.target_dist = self.model_framework.model_cfg.target_dist
        self.dl_eval.image_size = self.model_framework.model_cfg.image_size
        self.dl_eval.ftmap_size = self.model_framework.model_cfg.ftmap_size

        self.batch_size = 1
        self.test_data, _ = dl.get_dataloader(self.batch_size, self.dl_eval,
                                              pretreat=pretreat.NormalizeAndPermute(), debug=True)

        # Results
        self.data_struc = {'imgpath': str, 'bbox': None, 'headpose': None, 'ids': None, 'landmarks': None, 'visible': None}
        self.result_path = pkg_resources.resource_filename('spiga', 'eval/results')
        self.result_file = '/results_%s_%s.json' % (self.database, self.anns_type)
        self.file_out = self.result_path + self.result_file

    def generate_results(self):

        data = []
        for step, batch in enumerate(self.test_data):
            print('Step: ', step)
            inputs = self.model_framework.select_inputs(batch)
            outputs_raw = self.model_framework.net_forward(inputs)
            # Postprocessing
            outputs = self.model_framework.postreatment(outputs_raw, batch['bbox'], batch['bbox_raw'])

            # Data
            data_dict = copy.deepcopy(self.data_struc)
            data_dict['imgpath'] = batch['imgpath_local'][0]
            data_dict['bbox'] = batch['bbox_raw'][0].numpy().tolist()
            data_dict['visible'] = batch['visible'][0].numpy().tolist()
            data_dict['ids'] = self.dl_eval.database.ldm_ids
            data_dict['landmarks'] = outputs['landmarks'][0]
            data_dict['headpose'] = outputs['headpose'][0]
            data.append(data_dict)

        # Save outputs
        with open(self.file_out, 'w') as outfile:
            json.dump(data, outfile)


if __name__ == '__main__':
    main()