File size: 5,278 Bytes
4d0eb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Tuple

import torch
import torch.nn as nn
from mmengine.model import BaseModule, ModuleDict

from mmpretrain.registry import MODELS
from mmpretrain.structures import MultiTaskDataSample


def loss_convertor(loss_func, task_name):

    def wrapped(inputs, data_samples, **kwargs):
        mask = torch.empty(len(data_samples), dtype=torch.bool)
        task_data_samples = []
        for i, data_sample in enumerate(data_samples):
            assert isinstance(data_sample, MultiTaskDataSample)
            sample_mask = task_name in data_sample
            mask[i] = sample_mask
            if sample_mask:
                task_data_samples.append(data_sample.get(task_name))

        if len(task_data_samples) == 0:
            # This makes it possible to perform loss.backward when a
            # task does not have gt_labels within a batch.
            loss = (inputs[0] * 0).sum()
            return {'loss': loss, 'mask_size': torch.tensor(0.)}

        # Mask the inputs of the task
        def mask_inputs(inputs, mask):
            if isinstance(inputs, Sequence):
                return type(inputs)(
                    [mask_inputs(input, mask) for input in inputs])
            elif isinstance(inputs, torch.Tensor):
                return inputs[mask]

        masked_inputs = mask_inputs(inputs, mask)
        loss_output = loss_func(masked_inputs, task_data_samples, **kwargs)
        loss_output['mask_size'] = mask.sum().to(torch.float)
        return loss_output

    return wrapped


@MODELS.register_module()
class MultiTaskHead(BaseModule):
    """Multi task head.

    Args:
        task_heads (dict): Sub heads to use, the key will be use to rename the
            loss components.
        common_cfg (dict): The common settings for all heads. Defaults to an
            empty dict.
        init_cfg (dict, optional): The extra initialization settings.
            Defaults to None.
    """

    def __init__(self, task_heads, init_cfg=None, **kwargs):
        super(MultiTaskHead, self).__init__(init_cfg=init_cfg)

        assert isinstance(task_heads, dict), 'The `task_heads` argument' \
            "should be a dict, which's keys are task names and values are" \
            'configs of head for the task.'

        self.task_heads = ModuleDict()

        for task_name, sub_head in task_heads.items():
            if not isinstance(sub_head, nn.Module):
                sub_head = MODELS.build(sub_head, default_args=kwargs)
            sub_head.loss = loss_convertor(sub_head.loss, task_name)
            self.task_heads[task_name] = sub_head

    def forward(self, feats):
        """The forward process."""
        return {
            task_name: head(feats)
            for task_name, head in self.task_heads.items()
        }

    def loss(self, feats: Tuple[torch.Tensor],
             data_samples: List[MultiTaskDataSample], **kwargs) -> dict:
        """Calculate losses from the classification score.

        Args:
            feats (tuple[Tensor]): The features extracted from the backbone.
            data_samples (List[MultiTaskDataSample]): The annotation data of
                every samples.
            **kwargs: Other keyword arguments to forward the loss module.

        Returns:
            dict[str, Tensor]: a dictionary of loss components, each task loss
                key will be prefixed by the task_name like "task1_loss"
        """
        losses = dict()
        for task_name, head in self.task_heads.items():
            head_loss = head.loss(feats, data_samples, **kwargs)
            for k, v in head_loss.items():
                losses[f'{task_name}_{k}'] = v
        return losses

    def predict(
        self,
        feats: Tuple[torch.Tensor],
        data_samples: List[MultiTaskDataSample] = None
    ) -> List[MultiTaskDataSample]:
        """Inference without augmentation.

        Args:
            feats (tuple[Tensor]): The features extracted from the backbone.
            data_samples (List[MultiTaskDataSample], optional): The annotation
                data of every samples. If not None, set ``pred_label`` of
                the input data samples. Defaults to None.

        Returns:
            List[MultiTaskDataSample]: A list of data samples which contains
            the predicted results.
        """
        predictions_dict = dict()

        for task_name, head in self.task_heads.items():
            task_samples = head.predict(feats)
            batch_size = len(task_samples)
            predictions_dict[task_name] = task_samples

        if data_samples is None:
            data_samples = [MultiTaskDataSample() for _ in range(batch_size)]

        for task_name, task_samples in predictions_dict.items():
            for data_sample, task_sample in zip(data_samples, task_samples):
                task_sample.set_field(
                    task_name in data_sample.tasks,
                    'eval_mask',
                    field_type='metainfo')

                if task_name in data_sample.tasks:
                    data_sample.get(task_name).update(task_sample)
                else:
                    data_sample.set_field(task_sample, task_name)

        return data_samples