File size: 4,952 Bytes
7629b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

# Modified from:
#   https://github.com/anibali/pytorch-stacked-hourglass 
#   https://github.com/bearpaw/pytorch-pose

import torch
from stacked_hourglass.utils.evaluation import final_preds_untransformed
from stacked_hourglass.utils.imfit import fit, calculate_fit_contain_output_area
from stacked_hourglass.utils.transforms import color_normalize, fliplr, flip_back


def _check_batched(images):
    if isinstance(images, (tuple, list)):
        return True
    if images.ndimension() == 4:
        return True
    return False


class HumanPosePredictor:
    def __init__(self, model, device=None, data_info=None, input_shape=None):
        """Helper class for predicting 2D human pose joint locations.

        Args:
            model: The model for generating joint heatmaps.
            device: The computational device to use for inference.
            data_info: Specifications of the data (defaults to ``Mpii.DATA_INFO``).
            input_shape: The input dimensions of the model (height, width).
        """
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        device = torch.device(device)
        model.to(device)
        self.model = model
        self.device = device

        if data_info is None:
            raise ValueError
            # self.data_info = Mpii.DATA_INFO
        else:
            self.data_info = data_info

        # Input shape ordering: H, W
        if input_shape is None:
            self.input_shape = (256, 256)
        elif isinstance(input_shape, int):
            self.input_shape = (input_shape, input_shape)
        else:
            self.input_shape = input_shape

    def do_forward(self, input_tensor):
        self.model.eval()
        with torch.no_grad():
            output = self.model(input_tensor)
        return output

    def prepare_image(self, image):
        was_fixed_point = not image.is_floating_point()
        image = torch.empty_like(image, dtype=torch.float32).copy_(image)
        if was_fixed_point:
            image /= 255.0
        if image.shape[-2:] != self.input_shape:
            image = fit(image, self.input_shape, fit_mode='contain')
        image = color_normalize(image, self.data_info.rgb_mean, self.data_info.rgb_stddev)
        return image

    def estimate_heatmaps(self, images, flip=False):
        is_batched = _check_batched(images)
        raw_images = images if is_batched else images.unsqueeze(0)
        input_tensor = torch.empty((len(raw_images), 3, *self.input_shape),
                                   device=self.device, dtype=torch.float32)
        for i, raw_image in enumerate(raw_images):
            input_tensor[i] = self.prepare_image(raw_image)
        heatmaps = self.do_forward(input_tensor)[-1].cpu()
        if flip:
            flip_input = fliplr(input_tensor)
            flip_heatmaps = self.do_forward(flip_input)[-1].cpu()
            heatmaps += flip_back(flip_heatmaps, self.data_info.hflip_indices)
            heatmaps /= 2
        if is_batched:
            return heatmaps
        else:
            return heatmaps[0]

    def estimate_joints(self, images, flip=False):
        """Estimate human joint locations from input images.

        Images are expected to be centred on a human subject and scaled reasonably.

        Args:
            images: The images to estimate joint locations for. Can be a single image or a list
                    of images.
            flip (bool): If set to true, evaluates on flipped versions of the images as well and
                         averages the results.

        Returns:
            The predicted human joint locations in image pixel space.
        """
        is_batched = _check_batched(images)
        raw_images = images if is_batched else images.unsqueeze(0)
        heatmaps = self.estimate_heatmaps(raw_images, flip=flip).cpu()
        # final_preds_untransformed compares the first component of shape with x and second with y
        # This relates to the image Width, Height (Heatmap has shape Height, Width)
        coords = final_preds_untransformed(heatmaps, heatmaps.shape[-2:][::-1])
        # Rescale coords to pixel space of specified images.
        for i, image in enumerate(raw_images):
            # When returning to original image space we need to compensate for the fact that we are
            # used fit_mode='contain' when preparing the images for inference.
            y_off, x_off, height, width = calculate_fit_contain_output_area(*image.shape[-2:], *self.input_shape)
            coords[i, :, 1] *= self.input_shape[-2] / heatmaps.shape[-2]
            coords[i, :, 1] -= y_off
            coords[i, :, 1] *= image.shape[-2] / height
            coords[i, :, 0] *= self.input_shape[-1] / heatmaps.shape[-1]
            coords[i, :, 0] -= x_off
            coords[i, :, 0] *= image.shape[-1] / width
        if is_batched:
            return coords
        else:
            return coords[0]