File size: 5,516 Bytes
ab01e4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
# Change the numbers when you want to test with specific gpus
# os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'
import torch
from STTNet import STTNet
import torch.nn.functional as F
from Utils.Datasets import get_data_loader
from Utils.Utils import make_numpy_img, inv_normalize_img, encode_onehot_to_mask, get_metrics, Logger
import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict

if __name__ == '__main__':
    model_infos = {
        # vgg16_bn, resnet50, resnet18
        'backbone': 'resnet50',
        'pretrained': True,
        'out_keys': ['block4'],
        'in_channel': 3,
        'n_classes': 2,
        'top_k_s': 64,
        'top_k_c': 16,
        'encoder_pos': True,
        'decoder_pos': True,
        'model_pattern': ['X', 'A', 'S', 'C'],

        'log_path': 'Results',
        'NUM_WORKERS': 0,
        # if you need the validation process.
        'IS_VAL': True,
        'VAL_BATCH_SIZE': 4,
        'VAL_DATASET': 'Tools/generate_dep_info/val_data.csv',
        # if you need the test process.
        'IS_TEST': True,
        'TEST_DATASET': 'Tools/generate_dep_info/test_data.csv',
        'IMG_SIZE': [512, 512],
        'PHASE': 'seg',

        # INRIA Dataset
        'PRIOR_MEAN': [0.40672500537632994, 0.42829032416229895, 0.39331840468605667],
        'PRIOR_STD': [0.029498464618176873, 0.027740088491668233, 0.028246722411879095],
        # # # WHU Dataset
        # 'PRIOR_MEAN': [0.4352682576428411, 0.44523221318154493, 0.41307610541534784],
        # 'PRIOR_STD': [0.026973196780331585, 0.026424642808887323, 0.02791246590291434],

        # load state dict path
        'load_checkpoint_path': r'E:\BuildingExtractionDataset\INRIA_ckpt_latest.pt',
    }
    if model_infos['IS_VAL']:
        os.makedirs(model_infos['log_path']+'/val', exist_ok=True)
    if model_infos['IS_TEST']:
        os.makedirs(model_infos['log_path']+'/test', exist_ok=True)
    logger = Logger(model_infos['log_path'] + '/log.log')

    data_loaders = get_data_loader(model_infos, test_mode=True)
    loss_weight = 0.1
    model = STTNet(**model_infos)

    logger.write(f'load checkpoint from {model_infos["load_checkpoint_path"]}\n')
    state_dict = torch.load(model_infos['load_checkpoint_path'], map_location='cpu')
    model_dict = state_dict['model_state_dict']
    try:
        model_dict = OrderedDict({k.replace('module.', ''): v for k, v in model_dict.items()})
        model.load_state_dict(model_dict)
    except Exception as e:
        model.load_state_dict(model_dict)
    model = model.cuda()
    device_ids = range(torch.cuda.device_count())
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids)
        logger.write(f'Use GPUs: {device_ids}\n')
    else:
        logger.write(f'Use GPUs: 1\n')

    patterns = ['val', 'test']
    for pattern_id, is_pattern in enumerate([model_infos['IS_VAL'], model_infos['IS_TEST']]):
        if is_pattern:
            # pred: logits, tensor, nBatch * nClass * W * H
            # target: labels, tensor, nBatch * nClass * W * H
            # output, batch['label']
            collect_result = {'pred': [], 'target': []}
            pattern = patterns[pattern_id]
            model.eval()
            for batch_id, batch in enumerate(data_loaders[pattern]):
                # Get data
                img_batch = batch['img'].cuda()
                label_batch = batch['label'].cuda()
                img_names = batch['img_name']
                collect_result['target'].append(label_batch.data.cpu())

                # inference
                with torch.no_grad():
                    logits, att_branch_output = model(img_batch)

                collect_result['pred'].append(logits.data.cpu())
                # get segmentation result, when the phase is test.
                pred_label = torch.argmax(logits, 1)
                pred_label *= 255

                # output the segmentation result
                if pattern == 'test' or batch_id % 5 == 1:
                    batch_size = pred_label.size(0)
                    # k = np.clip(int(0.3 * batch_size), a_min=1, a_max=batch_size)
                    # ids = np.random.choice(range(batch_size), k, replace=False)
                    ids = range(batch_size)
                    for img_id in ids:
                        img = img_batch[img_id].detach().cpu()
                        target = label_batch[img_id].detach().cpu()
                        pred = pred_label[img_id].detach().cpu()
                        img_name = img_names[img_id]

                        img = make_numpy_img(
                            inv_normalize_img(img, model_infos['PRIOR_MEAN'], model_infos['PRIOR_STD']))
                        target = make_numpy_img(encode_onehot_to_mask(target)) * 255
                        pred = make_numpy_img(pred)

                        vis = np.concatenate([img / 255., target / 255., pred / 255.], axis=0)
                        vis = np.clip(vis, a_min=0, a_max=1)
                        file_name = os.path.join(model_infos['log_path'], pattern, f'{img_name.split(".")[0]}.png')
                        plt.imsave(file_name, vis)

            collect_result['pred'] = torch.cat(collect_result['pred'], dim=0)
            collect_result['target'] = torch.cat(collect_result['target'], dim=0)
            IoU, OA, F1_score = get_metrics('seg', **collect_result)
            logger.write(f'{pattern}: Iou:{IoU[-1]:.4f} OA:{OA[-1]:.4f} F1:{F1_score[-1]:.4f}\n')