Nadine Rueegg
initial commit for barc
7629b39
raw history blame
No virus
1.66 kB
# Modified from:
# https://github.com/anibali/pytorch-stacked-hourglass
# https://github.com/bearpaw/pytorch-pose
import torch
from torch.nn import Conv2d, ModuleList
def change_hg_outputs(model, indices):
"""Change the output classes of the model.
Args:
model: The model to modify.
indices: An array of indices describing the new model outputs. For example, [3, 4, None]
will modify the model to have 3 outputs, the first two of which have parameters
copied from the fourth and fifth outputs of the original model.
"""
with torch.no_grad():
new_n_outputs = len(indices)
new_score = ModuleList()
for conv in model.score:
new_conv = Conv2d(conv.in_channels, new_n_outputs, conv.kernel_size, conv.stride)
new_conv = new_conv.to(conv.weight.device, conv.weight.dtype)
for i, index in enumerate(indices):
if index is not None:
new_conv.weight[i] = conv.weight[index]
new_conv.bias[i] = conv.bias[index]
new_score.append(new_conv)
model.score = new_score
new_score_ = ModuleList()
for conv in model.score_:
new_conv = Conv2d(new_n_outputs, conv.out_channels, conv.kernel_size, conv.stride)
new_conv = new_conv.to(conv.weight.device, conv.weight.dtype)
for i, index in enumerate(indices):
if index is not None:
new_conv.weight[:, i] = conv.weight[:, index]
new_conv.bias = conv.bias
new_score_.append(new_conv)
model.score_ = new_score_