File size: 2,350 Bytes
b762e56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86e64e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from dataloader import get_loader
from models.model_main import ModelMain
from models.transformers import denumericalize
from options import get_parser_main_model
from data_utils.svg_utils import render
from models.util_funcs import svg2img, cal_iou

# Testing (Only accuracy)

def test_main_model(opts):
    test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')

    model_main = ModelMain(opts)
    path_ckpt = os.path.join(f"{opts.model_path}")
    model_main.load_state_dict(torch.load(path_ckpt)['model'])
    model_main.cuda()
    model_main.eval() # Testing mode

    with torch.no_grad():
        loss_val = {'img':{'l1':0.0, 'vggpt':0.0}, 'svg':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0},
                                'svg_para':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}}
        
        for val_idx, val_data in enumerate(test_loader):
            for key in val_data: val_data[key] = val_data[key].cuda()
            ret_dict_val, loss_dict_val = model_main(val_data, mode='val')
            for loss_cat in ['img', 'svg']:
                for key, _ in loss_val[loss_cat].items():
                    loss_val[loss_cat][key] += loss_dict_val[loss_cat][key]

        for loss_cat in ['img', 'svg']:
            for key, _ in loss_val[loss_cat].items():
                loss_val[loss_cat][key] /= len(test_loader) 

        val_msg = (
            f"Val loss img l1: {loss_val['img']['l1']: .6f}, "
            f"Val loss img pt: {loss_val['img']['vggpt']: .6f}, "
            f"Val loss total: {loss_val['svg']['total']: .6f}, "
            f"Val loss cmd: {loss_val['svg']['cmd']: .6f}, "
            f"Val loss args: {loss_val['svg']['args']: .6f}, "
        )

        print(val_msg)
        print(f"l1: {loss_val['img']['l1']: .6f}, pt: {loss_val['img']['vggpt']: .6f}")

def main():
    
    opts = get_parser_main_model().parse_args()
    opts.name_exp = opts.name_exp + '_' + opts.model_name
    experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp)
    print(f"Testing on experiment {opts.name_exp}...")
    # Dump options
    test_main_model(opts)

if __name__ == "__main__":
    main()