File size: 7,479 Bytes
ea6afa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import torch 
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import pipeline
import random
from nltk.corpus import stopwords
import nltk
nltk.download('stopwords')
import math
from vocabulary_split import split_vocabulary, filter_logits
import abc
from typing import List

# Load tokenizer and model for masked language model
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)

# Get permissible vocabulary
permissible, _ = split_vocabulary(seed=42)
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])

def get_logits_for_mask(model, tokenizer, sentence):
    inputs = tokenizer(sentence, return_tensors="pt")
    mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    mask_token_logits = logits[0, mask_token_index, :]
    return mask_token_logits.squeeze()

# Abstract Masking Strategy
class MaskingStrategy(abc.ABC):
    @abc.abstractmethod
    def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
        """
        Given a list of words, return the indices of words to mask.
        """
        pass

# Specific Masking Strategies
class RandomNonStopwordMasking(MaskingStrategy):
    def __init__(self, num_masks: int = 1):
        self.num_masks = num_masks
        self.stop_words = set(stopwords.words('english'))
    
    def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
        non_stop_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
        if not non_stop_indices:
            return []
        num_masks = min(self.num_masks, len(non_stop_indices))
        return random.sample(non_stop_indices, num_masks)

class HighEntropyMasking(MaskingStrategy):
    def __init__(self, num_masks: int = 1):
        self.num_masks = num_masks
    
    def select_words_to_mask(self, words: List[str], sentence: str, model, tokenizer, permissible_indices) -> List[int]:
        candidate_indices = [i for i, word in enumerate(words) if word.lower() not in set(stopwords.words('english'))]
        if not candidate_indices:
            return []
        
        entropy_scores = {}
        for idx in candidate_indices:
            masked_sentence = ' '.join(words[:idx] + ['[MASK]'] + words[idx+1:])
            logits = get_logits_for_mask(model, tokenizer, masked_sentence)
            filtered_logits = filter_logits(logits, permissible_indices)
            probs = torch.softmax(filtered_logits, dim=-1)
            top_5_probs = probs.topk(5).values
            entropy = -torch.sum(top_5_probs * torch.log(top_5_probs + 1e-10)).item()
            entropy_scores[idx] = entropy
        
        # Select top N indices with highest entropy
        sorted_indices = sorted(entropy_scores, key=entropy_scores.get, reverse=True)
        return sorted_indices[:self.num_masks]

class PseudoRandomNonStopwordMasking(MaskingStrategy):
    def __init__(self, num_masks: int = 1, seed: int = 10):
        self.num_masks = num_masks
        self.seed = seed
        self.stop_words = set(stopwords.words('english'))
    
    def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
        non_stop_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
        if not non_stop_indices:
            return []
        random.seed(self.seed)
        num_masks = min(self.num_masks, len(non_stop_indices))
        return random.sample(non_stop_indices, num_masks)

class CompositeMaskingStrategy(MaskingStrategy):
    def __init__(self, strategies: List[MaskingStrategy]):
        self.strategies = strategies
    
    def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
        selected_indices = []
        for strategy in self.strategies:
            if isinstance(strategy, HighEntropyMasking):
                selected = strategy.select_words_to_mask(words, **kwargs)
            else:
                selected = strategy.select_words_to_mask(words)
            selected_indices.extend(selected)
        return list(set(selected_indices))  # Remove duplicates

# Refactored mask_between_lcs function
def mask_between_lcs(sentence, lcs_points, masking_strategy: MaskingStrategy, model, tokenizer, permissible_indices):
    words = sentence.split()
    masked_indices = []
    
    segments = []
    
    # Define segments based on LCS points
    previous = 0
    for point in lcs_points:
        if point > previous:
            segments.append((previous, point))
        previous = point + 1
    if previous < len(words):
        segments.append((previous, len(words)))
    
    # Collect all indices to mask from each segment
    for start, end in segments:
        segment_words = words[start:end]
        if isinstance(masking_strategy, HighEntropyMasking):
            selected = masking_strategy.select_words_to_mask(segment_words, sentence, model, tokenizer, permissible_indices)
        else:
            selected = masking_strategy.select_words_to_mask(segment_words)
        
        # Adjust indices relative to the whole sentence
        for idx in selected:
            masked_idx = start + idx
            if masked_idx not in masked_indices:
                masked_indices.append(masked_idx)
    
    # Apply masking
    for idx in masked_indices:
        words[idx] = '[MASK]'
    
    masked_sentence = ' '.join(words)
    logits = get_logits_for_mask(model, tokenizer, masked_sentence)
    
    # Process each masked token
    top_words_list = []
    logits_list = []
    for i, idx in enumerate(masked_indices):
        logits_i = logits[i]
        if logits_i.dim() > 1:
            logits_i = logits_i.squeeze()
        filtered_logits_i = filter_logits(logits_i, permissible_indices)
        logits_list.append(filtered_logits_i.tolist())
        top_5_indices = filtered_logits_i.topk(5).indices.tolist()
        top_words = [tokenizer.decode([i]) for i in top_5_indices]
        top_words_list.append(top_words)
    
    return masked_sentence, logits_list, top_words_list

# Example Usage
if __name__ == "__main__":
    # Example sentence and LCS points
    sentence = "This is a sample sentence with some LCS points"
    lcs_points = [2, 5, 8]  # Indices of LCS points
    
    # Initialize masking strategies
    random_non_stopword_strategy = RandomNonStopwordMasking(num_masks=1)
    high_entropy_strategy = HighEntropyMasking(num_masks=1)
    pseudo_random_strategy = PseudoRandomNonStopwordMasking(num_masks=1, seed=10)
    composite_strategy = CompositeMaskingStrategy([
        RandomNonStopwordMasking(num_masks=1),
        HighEntropyMasking(num_masks=1)
    ])
    
    # Choose a strategy
    chosen_strategy = composite_strategy  # You can choose any initialized strategy
    
    # Apply masking
    masked_sentence, logits_list, top_words_list = mask_between_lcs(
        sentence, 
        lcs_points, 
        masking_strategy=chosen_strategy, 
        model=model, 
        tokenizer=tokenizer, 
        permissible_indices=permissible_indices
    )
    
    print("Masked Sentence:", masked_sentence)
    for idx, top_words in enumerate(top_words_list):
        print(f"Top words for mask {idx+1}:", top_words)