File size: 4,729 Bytes
995430d
897efbf
995430d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897efbf
995430d
 
 
897efbf
 
 
 
 
995430d
 
897efbf
995430d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897efbf
 
995430d
 
 
897efbf
 
 
995430d
 
897efbf
 
995430d
897efbf
995430d
 
 
897efbf
995430d
897efbf
 
995430d
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Optional, Tuple, Union
from abc import abstractmethod

import torch
from torch.nn import BCEWithLogitsLoss
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast


from .sensor_loc_reg import SENSOR_LOC_REGISTRY
from .sensor_loc_finder import SensorLocFinder

class MeasurementPredictorMixin(PreTrainedModel):
    
    def __init__(self, config):
        super().__init__(config)
        self.sensor_loc_type = config.sensor_loc_type
        self.sensor_token = config.sensor_token
        self.n_sensors = config.n_sensors
        self.sensor_probes = torch.nn.ModuleList([
            torch.nn.Linear(config.emb_dim, 1) for _ in range(config.n_sensors)
        ])
        self.aggregate_probe = torch.nn.Linear(config.emb_dim, 1)
        self.sensors_weight = config.sensors_weight
        self.aggregate_weight = config.aggregate_weight

        self.find_sensor_locs: SensorLocFinder = None 
    
    @abstractmethod
    def set_pad_token(self, tokenizer: PreTrainedTokenizerBase):
        pass
    
    def init_sensor_loc_finder(self, tokenizer: PreTrainedTokenizerBase):
        self.find_sensor_locs = SENSOR_LOC_REGISTRY[self.sensor_loc_type](
            tokenizer, sensor_token=self.sensor_token, n_sensors=self.n_sensors
        )

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        base_model_output: BaseModelOutputWithPast = self.base_model(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # get sensor embeddings (including aggregate)
        sensor_locs = self.find_sensor_locs(input_ids)
        sensor_embs = base_model_output.last_hidden_state.gather(
            1, sensor_locs.unsqueeze(-1).expand(-1, -1, self.config.emb_dim)
        )
        assert sensor_embs.shape == (input_ids.shape[0], self.n_sensors + 1, self.config.emb_dim), sensor_embs.shape
        
        # get sensor and aggregate logits
        sensor_logits = torch.concat([self.sensor_probes[i](sensor_embs[:, i, :]) 
                               for i in range(self.n_sensors)], dim=-1)
        aggregate_logits = self.aggregate_probe(sensor_embs[:, -1, :])
        logits = torch.concat([sensor_logits, aggregate_logits], dim=-1)

        # compute loss
        loss = None
        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            sensor_loss = loss_fct(sensor_logits[:, :self.n_sensors], labels[:, :self.n_sensors]) * self.sensors_weight
            loss = sensor_loss
            aggregate_loss = loss_fct(aggregate_logits, labels[:, -1:]) * self.aggregate_weight
            loss += aggregate_loss

        if not return_dict:
            output = (logits, ) + base_model_output[1:]
            return ((loss,) + output) if loss is not None else output 
        
        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=base_model_output.past_key_values,
            hidden_states=base_model_output.hidden_states,
            attentions=base_model_output.attentions,
        )