File size: 6,052 Bytes
e8aad19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a101a53
e8aad19
a101a53
e8aad19
a101a53
 
 
 
e8aad19
 
 
 
 
 
 
a101a53
 
e8aad19
 
 
 
 
 
 
 
 
 
 
 
 
a101a53
e8aad19
 
 
a101a53
e8aad19
 
 
 
a101a53
e8aad19
 
 
 
a101a53
e8aad19
 
 
 
a101a53
e8aad19
 
 
 
 
 
 
 
 
 
 
 
 
 
a101a53
 
e8aad19
a101a53
e8aad19
 
a101a53
e8aad19
 
 
a101a53
e8aad19
a101a53
 
 
 
e8aad19
 
 
 
 
 
a101a53
 
 
e8aad19
 
a101a53
 
 
e8aad19
 
 
 
a101a53
e8aad19
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from modules.module_customPllLabel import CustomPllLabel
from modules.module_pllScore import PllScore
from typing import List, Dict
import torch


class RankSents:
    def __init__(
        self, 
        language_model, # LanguageModel class instance
        lang: str,
        errorManager    # ErrorManager class instance
    ) -> None:
        
        self.tokenizer = language_model.initTokenizer()
        self.model = language_model.initModel()
        _ = self.model.eval()

        self.Label = CustomPllLabel()
        self.pllScore = PllScore(
            language_model=language_model
        )
        self.softmax = torch.nn.Softmax(dim=-1)

        if lang == "es":
            self.articles = [
                'un','una','unos','unas','el','los','la','las','lo'
            ]
            self.prepositions = [
                'a','ante','bajo','cabe','con','contra','de','desde','en','entre','hacia','hasta','para','por','según','sin','so','sobre','tras','durante','mediante','vía','versus'
            ]
            self.conjunctions = [
                'y','o','ni','que','pero','si'
            ]

        elif lang == "en":
            self.articles = [
                'a','an', 'the'
            ]
            self.prepositions = [
                'above', 'across', 'against', 'along', 'among', 'around', 'at', 'before', 'behind', 'below', 'beneath', 'beside', 'between', 'by', 'down', 'from', 'in', 'into', 'near', 'of', 'off', 'on', 'to', 'toward', 'under', 'upon', 'with', 'within'
            ]
            self.conjunctions = [
                'and', 'or', 'but', 'that', 'if', 'whether'
            ]

        self.errorManager = errorManager

    def errorChecking(
        self, 
        sent: str
    ) -> str:

        out_msj = ""
        if not sent:
            out_msj = ['RANKSENTS_NO_SENTENCE_PROVIDED']
        elif sent.count("*") > 1:
            out_msj = ['RANKSENTS_TOO_MANY_MASKS_IN_SENTENCE']
        elif sent.count("*") == 0:
            out_msj = ['RANKSENTS_NO_MASK_IN_SENTENCE']
        else:
            sent_len = len(self.tokenizer.encode(sent.replace("*", self.tokenizer.mask_token)))
            max_len = self.tokenizer.max_len_single_sentence
            if sent_len > max_len:
                out_msj = ['RANKSENTS_TOKENIZER_MAX_TOKENS_REACHED', max_len]
        
        return self.errorManager.process(out_msj)

    def getTopPredictions(
        self, 
        n: int,
        sent: str,
        banned_word_list: List[str], 
        exclude_articles: bool,
        exclude_prepositions: bool,
        exclude_conjunctions: bool,
    ) -> List[str]:
                                
        sent_masked = sent.replace("*", self.tokenizer.mask_token)
        inputs = self.tokenizer.encode_plus( 
            sent_masked,
            add_special_tokens=True,
            return_tensors='pt',
            return_attention_mask=True, 
            truncation=True
        )

        tk_position_mask = torch.where(inputs['input_ids'][0] == self.tokenizer.mask_token_id)[0].item()

        with torch.no_grad():
            out = self.model(**inputs)
            logits = out.logits
            outputs = self.softmax(logits)
            outputs = torch.squeeze(outputs, dim=0)
        
        probabilities = outputs[tk_position_mask]
        first_tk_id = torch.argsort(probabilities, descending=True)
        
        top_tks_pred = []
        for tk_id in first_tk_id:
            tk_string = self.tokenizer.decode([tk_id])
            
            tk_is_banned = tk_string in banned_word_list
            tk_is_punctuation = not tk_string.isalnum()
            tk_is_substring = tk_string.startswith("##")
            tk_is_special = (tk_string in self.tokenizer.all_special_tokens)

            if exclude_articles:
                tk_is_article = tk_string in self.articles
            else:
                tk_is_article = False
            
            if exclude_prepositions:
                tk_is_prepositions = tk_string in self.prepositions
            else:
                tk_is_prepositions = False
            
            if exclude_conjunctions:
                tk_is_conjunctions = tk_string in self.conjunctions
            else:
                tk_is_conjunctions = False
            
            predictions_is_dessire = not any([  
                                    tk_is_banned,
                                    tk_is_punctuation,
                                    tk_is_substring, 
                                    tk_is_special, 
                                    tk_is_article, 
                                    tk_is_prepositions,
                                    tk_is_conjunctions
            ])

            if predictions_is_dessire and len(top_tks_pred) < n:
                top_tks_pred.append(tk_string)

            elif len(top_tks_pred) >= n:
                break

        return top_tks_pred

    def rank(self, 
        sent: str, 
        interest_word_list: List[str]=[], 
        banned_word_list: List[str]=[], 
        exclude_articles: bool=False, 
        exclude_prepositions: bool=False, 
        exclude_conjunctions: bool=False,
        n_predictions: int=5
    ) -> Dict[str, float]:
        
        err = self.errorChecking(sent)
        if err:
            raise Exception(err)

        if not interest_word_list:
            interest_word_list = self.getTopPredictions(
                n_predictions,
                sent,
                banned_word_list,
                exclude_articles,
                exclude_prepositions,
                exclude_conjunctions
            )

        sent_list = []
        sent_list2print = []
        for word in interest_word_list:
            sent_list.append(sent.replace("*", "<"+word+">"))
            sent_list2print.append(sent.replace("*", "<"+word+">"))
            
        all_plls_scores = {}
        for sent, sent2print in zip(sent_list, sent_list2print):
            all_plls_scores[sent2print] = self.pllScore.compute(sent)

        return all_plls_scores