File size: 1,662 Bytes
7629b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Modified from:
#   https://github.com/anibali/pytorch-stacked-hourglass 
#   https://github.com/bearpaw/pytorch-pose

import torch
from torch.nn import Conv2d, ModuleList


def change_hg_outputs(model, indices):
    """Change the output classes of the model.

    Args:
        model: The model to modify.
        indices: An array of indices describing the new model outputs. For example, [3, 4, None]
                 will modify the model to have 3 outputs, the first two of which have parameters
                 copied from the fourth and fifth outputs of the original model.
    """
    with torch.no_grad():
        new_n_outputs = len(indices)
        new_score = ModuleList()
        for conv in model.score:
            new_conv = Conv2d(conv.in_channels, new_n_outputs, conv.kernel_size, conv.stride)
            new_conv = new_conv.to(conv.weight.device, conv.weight.dtype)
            for i, index in enumerate(indices):
                if index is not None:
                    new_conv.weight[i] = conv.weight[index]
                    new_conv.bias[i] = conv.bias[index]
            new_score.append(new_conv)
        model.score = new_score
        new_score_ = ModuleList()
        for conv in model.score_:
            new_conv = Conv2d(new_n_outputs, conv.out_channels, conv.kernel_size, conv.stride)
            new_conv = new_conv.to(conv.weight.device, conv.weight.dtype)
            for i, index in enumerate(indices):
                if index is not None:
                    new_conv.weight[:, i] = conv.weight[:, index]
            new_conv.bias = conv.bias
            new_score_.append(new_conv)
        model.score_ = new_score_