File size: 6,646 Bytes
b04763d
 
 
e348197
 
b04763d
 
 
 
 
 
e348197
 
 
 
b04763d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e348197
 
 
b04763d
 
 
 
 
 
 
 
e348197
b04763d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from typing import List, Union

import torch
import streamlit as st
import numpy as np
from numpy import ndarray
from transformers import (AlbertModel, AlbertTokenizer, BertModel,
                          BertTokenizer, DistilBertModel, DistilBertTokenizer,
                          PreTrainedModel, PreTrainedTokenizer, XLMModel,
                          XLMTokenizer, XLNetModel, XLNetTokenizer)

@st.cache()
def load_hf_model(base_model, model_name, device):
    model = base_model.from_pretrained(model_name, output_hidden_states=True).to(device)
    return model

class BertParent(object):
    """
    Base handler for BERT models.
    """

    MODELS = {
        'bert-base-uncased': (BertModel, BertTokenizer),
        'bert-large-uncased': (BertModel, BertTokenizer),
        'xlnet-base-cased': (XLNetModel, XLNetTokenizer),
        'xlm-mlm-enfr-1024': (XLMModel, XLMTokenizer),
        'distilbert-base-uncased': (DistilBertModel, DistilBertTokenizer),
        'albert-base-v1': (AlbertModel, AlbertTokenizer),
        'albert-large-v1': (AlbertModel, AlbertTokenizer)
    }

    def __init__(
        self,
        model: str,
        custom_model: PreTrainedModel = None,
        custom_tokenizer: PreTrainedTokenizer = None,
        gpu_id: int = 0,
    ):
        """
        :param model: Model is the string path for the bert weights. If given a keyword, the s3 path will be used.
        :param custom_model: This is optional if a custom bert model is used.
        :param custom_tokenizer: Place to use custom tokenizer.
        """
        base_model, base_tokenizer = self.MODELS.get(model, (None, None))

        self.device = torch.device("cpu")
        if torch.cuda.is_available():
            assert (
                isinstance(gpu_id, int) and (0 <= gpu_id and gpu_id < torch.cuda.device_count())
            ), f"`gpu_id` must be an integer between 0 to {torch.cuda.device_count() - 1}. But got: {gpu_id}"

            self.device = torch.device(f"cuda:{gpu_id}")

        if custom_model:
            self.model = custom_model.to(self.device)
        else:
            # self.model = base_model.from_pretrained(
                # model, output_hidden_states=True).to(self.device)
            self.model = load_hf_model(base_model, model, self.device)

        if custom_tokenizer:
            self.tokenizer = custom_tokenizer
        else:
            self.tokenizer = base_tokenizer.from_pretrained(model)

        self.model.eval()


    def tokenize_input(self, text: str) -> torch.tensor:
        """
        Tokenizes the text input.
        :param text: Text to tokenize.
        :return: Returns a torch tensor.
        """
        tokenized_text = self.tokenizer.tokenize(text)
        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        return torch.tensor([indexed_tokens]).to(self.device)

    def _pooled_handler(self, hidden: torch.Tensor,
                        reduce_option: str) -> torch.Tensor:
        """
        Handles torch tensor.
        :param hidden: The hidden torch tensor to process.
        :param reduce_option: The reduce option to use, such as mean, etc.
        :return: Returns a torch tensor.
        """

        if reduce_option == 'max':
            return hidden.max(dim=1)[0].squeeze()

        elif reduce_option == 'median':
            return hidden.median(dim=1)[0].squeeze()

        return hidden.mean(dim=1).squeeze()

    def extract_embeddings(
        self,
        text: str,
        hidden: Union[List[int], int] = -2,
        reduce_option: str = 'mean',
        hidden_concat: bool = False,
    ) -> torch.Tensor:
        """
        Extracts the embeddings for the given text.
        :param text: The text to extract embeddings for.
        :param hidden: The hidden layer(s) to use for a readout handler.
        :param squeeze: If we should squeeze the outputs (required for some layers).
        :param reduce_option: How we should reduce the items.
        :param hidden_concat: Whether or not to concat multiple hidden layers.
        :return: A torch vector.
        """
        tokens_tensor = self.tokenize_input(text)
        pooled, hidden_states = self.model(tokens_tensor)[-2:]

        # deprecated temporary keyword functions.
        if reduce_option == 'concat_last_4':
            last_4 = [hidden_states[i] for i in (-1, -2, -3, -4)]
            cat_hidden_states = torch.cat(tuple(last_4), dim=-1)
            return torch.mean(cat_hidden_states, dim=1).squeeze()

        elif reduce_option == 'reduce_last_4':
            last_4 = [hidden_states[i] for i in (-1, -2, -3, -4)]
            return torch.cat(tuple(last_4), dim=1).mean(axis=1).squeeze()

        elif type(hidden) == int:
            hidden_s = hidden_states[hidden]
            return self._pooled_handler(hidden_s, reduce_option)

        elif hidden_concat:
            last_states = [hidden_states[i] for i in hidden]
            cat_hidden_states = torch.cat(tuple(last_states), dim=-1)
            return torch.mean(cat_hidden_states, dim=1).squeeze()

        last_states = [hidden_states[i] for i in hidden]
        hidden_s = torch.cat(tuple(last_states), dim=1)

        return self._pooled_handler(hidden_s, reduce_option)

    def create_matrix(
        self,
        content: List[str],
        hidden: Union[List[int], int] = -2,
        reduce_option: str = 'mean',
        hidden_concat: bool = False,
    ) -> ndarray:
        """
        Create matrix from the embeddings.
        :param content: The list of sentences.
        :param hidden: Which hidden layer to use.
        :param reduce_option: The reduce option to run.
        :param hidden_concat: Whether or not to concat multiple hidden layers.
        :return: A numpy array matrix of the given content.
        """

        return np.asarray([
            np.squeeze(self.extract_embeddings(
                t, hidden=hidden, reduce_option=reduce_option, hidden_concat=hidden_concat
            ).data.cpu().numpy()) for t in content
        ])

    def __call__(
        self,
        content: List[str],
        hidden: int = -2,
        reduce_option: str = 'mean',
        hidden_concat: bool = False,
    ) -> ndarray:
        """
        Create matrix from the embeddings.
        :param content: The list of sentences.
        :param hidden: Which hidden layer to use.
        :param reduce_option: The reduce option to run.
        :param hidden_concat: Whether or not to concat multiple hidden layers.
        :return: A numpy array matrix of the given content.
        """
        return self.create_matrix(content, hidden, reduce_option, hidden_concat)