File size: 7,469 Bytes
436c4c1
ea7f5b6
 
 
 
4b89d6b
436c4c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea7f5b6
 
 
 
 
 
ee305a4
ea7f5b6
 
436c4c1
 
 
 
ea7f5b6
4b89d6b
 
 
 
 
ee305a4
436c4c1
4b89d6b
 
436c4c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b89d6b
 
 
 
436c4c1
4b89d6b
 
 
436c4c1
4b89d6b
436c4c1
4b89d6b
ee305a4
436c4c1
4b89d6b
 
ee305a4
436c4c1
4b89d6b
 
436c4c1
 
4b89d6b
 
436c4c1
 
 
 
4b89d6b
 
 
436c4c1
 
 
 
 
ee305a4
436c4c1
 
4b89d6b
436c4c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea7f5b6
436c4c1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import pipeline
import random
from nltk.corpus import stopwords
import math
from vocabulary_split import split_vocabulary, filter_logits

# Load tokenizer and model for masked language model
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)

# Get permissible vocabulary
permissible, _ = split_vocabulary(seed=42)
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])

def get_logits_for_mask(model, tokenizer, sentence):
    inputs = tokenizer(sentence, return_tensors="pt")
    mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    mask_token_logits = logits[0, mask_token_index, :]
    return mask_token_logits.squeeze()

def mask_non_stopword(sentence):
    stop_words = set(stopwords.words('english'))
    words = sentence.split()
    non_stop_words = [word for word in words if word.lower() not in stop_words]
    if not non_stop_words:
        return sentence, None, None
    word_to_mask = random.choice(non_stop_words)
    masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
    logits = get_logits_for_mask(model, tokenizer, masked_sentence)
    filtered_logits = filter_logits(logits, permissible_indices)
    words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]]
    return masked_sentence, filtered_logits.tolist(), words

def mask_non_stopword_pseudorandom(sentence):
    stop_words = set(stopwords.words('english'))
    words = sentence.split()
    non_stop_words = [word for word in words if word.lower() not in stop_words]
    if not non_stop_words:
        return sentence, None, None
    random.seed(10)  # Fixed seed for pseudo-randomness
    word_to_mask = random.choice(non_stop_words)
    masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
    logits = get_logits_for_mask(model, tokenizer, masked_sentence)
    filtered_logits = filter_logits(logits, permissible_indices)
    words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]]
    return masked_sentence, filtered_logits.tolist(), words

# New function: mask words between LCS points
def mask_between_lcs(sentence, lcs_points):
    words = sentence.split()
    masked_indices = []

    # Mask between first word and first LCS point
    if lcs_points and lcs_points[0] > 0:
        idx = random.randint(0, lcs_points[0]-1)
        words[idx] = '[MASK]'
        masked_indices.append(idx)
    
    # Mask between LCS points
    for i in range(len(lcs_points) - 1):
        start, end = lcs_points[i], lcs_points[i+1]
        if end - start > 1:
            mask_index = random.randint(start + 1, end - 1)
            words[mask_index] = '[MASK]'
            masked_indices.append(mask_index)
    
    # Mask between last LCS point and last word
    if lcs_points and lcs_points[-1] < len(words) - 1:
        idx = random.randint(lcs_points[-1]+1, len(words)-1)
        words[idx] = '[MASK]'
        masked_indices.append(idx)
    
    masked_sentence = ' '.join(words)
    logits = get_logits_for_mask(model, tokenizer, masked_sentence)
    
    # Now process each masked token separately
    top_words_list = []
    logits_list = []
    for i in range(len(masked_indices)):
        logits_i = logits[i]
        if logits_i.dim() > 1:
            logits_i = logits_i.squeeze()
        filtered_logits_i = filter_logits(logits_i, permissible_indices)
        logits_list.append(filtered_logits_i.tolist())
        top_5_indices = filtered_logits_i.topk(5).indices.tolist()
        top_words = [tokenizer.decode([i]) for i in top_5_indices]
        top_words_list.append(top_words)
    
    return masked_sentence, logits_list, top_words_list


def high_entropy_words(sentence, non_melting_points):
    stop_words = set(stopwords.words('english'))
    words = sentence.split()

    non_melting_words = set()
    for _, point in non_melting_points:
        non_melting_words.update(point.lower().split())

    candidate_words = [word for word in words if word.lower() not in stop_words and word.lower() not in non_melting_words]

    if not candidate_words:
        return sentence, None, None

    max_entropy = -float('inf')
    max_entropy_word = None
    max_logits = None

    for word in candidate_words:
        masked_sentence = sentence.replace(word, '[MASK]', 1)
        logits = get_logits_for_mask(model, tokenizer, masked_sentence)
        filtered_logits = filter_logits(logits, permissible_indices)
        
        # Calculate entropy based on top 5 predictions
        probs = torch.softmax(filtered_logits, dim=-1)
        top_5_probs = probs.topk(5).values
        entropy = -torch.sum(top_5_probs * torch.log(top_5_probs))

        if entropy > max_entropy:
            max_entropy = entropy
            max_entropy_word = word
            max_logits = filtered_logits

    if max_entropy_word is None:
        return sentence, None, None

    masked_sentence = sentence.replace(max_entropy_word, '[MASK]', 1)
    words = [tokenizer.decode([i]) for i in max_logits.argsort()[-5:]]
    return masked_sentence, max_logits.tolist(), words

# New function: mask based on part of speech
def mask_by_pos(sentence, pos_to_mask=['NOUN', 'VERB', 'ADJ']):
    import nltk
    nltk.download('averaged_perceptron_tagger', quiet=True)
    
    words = nltk.word_tokenize(sentence)
    pos_tags = nltk.pos_tag(words)
    
    maskable_words = [word for word, pos in pos_tags if pos[:2] in pos_to_mask]
    
    if not maskable_words:
        return sentence, None, None
    
    word_to_mask = random.choice(maskable_words)
    masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
    
    logits = get_logits_for_mask(model, tokenizer, masked_sentence)
    filtered_logits = filter_logits(logits, permissible_indices)
    words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]]
    
    return masked_sentence, filtered_logits.tolist(), words

# New function: mask named entities
def mask_named_entity(sentence):
    import nltk
    nltk.download('maxent_ne_chunker', quiet=True)
    nltk.download('words', quiet=True)
    
    words = nltk.word_tokenize(sentence)
    pos_tags = nltk.pos_tag(words)
    named_entities = nltk.ne_chunk(pos_tags)
    
    maskable_words = [word for word, tag in named_entities.leaves() if isinstance(tag, nltk.Tree)]
    
    if not maskable_words:
        return sentence, None, None
    
    word_to_mask = random.choice(maskable_words)
    masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
    
    logits = get_logits_for_mask(model, tokenizer, masked_sentence)
    filtered_logits = filter_logits(logits, permissible_indices)
    words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]]
    
    return masked_sentence, filtered_logits.tolist(), words


# sentence = "This is a sample sentence with some LCS points"
# lcs_points = [2, 5, 8]  # Indices of LCS points
# masked_sentence, logits_list, top_words_list = mask_between_lcs(sentence, lcs_points)

# print("Masked Sentence:", masked_sentence)
# for idx, top_words in enumerate(top_words_list):
#     print(f"Top words for mask {idx+1}:", top_words)