jgyasu commited on
Commit
4506e19
1 Parent(s): 7efd422

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. masking_methods_trial.py +188 -0
  2. requirements.txt +2 -1
masking_methods_trial.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
3
+ from transformers import pipeline
4
+ import random
5
+ from nltk.corpus import stopwords
6
+ import nltk
7
+ nltk.download('stopwords')
8
+ import math
9
+ from vocabulary_split import split_vocabulary, filter_logits
10
+ import abc
11
+ from typing import List
12
+
13
+ # Load tokenizer and model for masked language model
14
+ tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
15
+ model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
16
+ fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
17
+
18
+ # Get permissible vocabulary
19
+ permissible, _ = split_vocabulary(seed=42)
20
+ permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])
21
+
22
+ def get_logits_for_mask(model, tokenizer, sentence):
23
+ inputs = tokenizer(sentence, return_tensors="pt")
24
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
25
+
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+
29
+ logits = outputs.logits
30
+ mask_token_logits = logits[0, mask_token_index, :]
31
+ return mask_token_logits.squeeze()
32
+
33
+ # Abstract Masking Strategy
34
+ class MaskingStrategy(abc.ABC):
35
+ @abc.abstractmethod
36
+ def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
37
+ """
38
+ Given a list of words, return the indices of words to mask.
39
+ """
40
+ pass
41
+
42
+ # Specific Masking Strategies
43
+ class RandomNonStopwordMasking(MaskingStrategy):
44
+ def __init__(self, num_masks: int = 1):
45
+ self.num_masks = num_masks
46
+ self.stop_words = set(stopwords.words('english'))
47
+
48
+ def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
49
+ non_stop_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
50
+ if not non_stop_indices:
51
+ return []
52
+ num_masks = min(self.num_masks, len(non_stop_indices))
53
+ return random.sample(non_stop_indices, num_masks)
54
+
55
+ class HighEntropyMasking(MaskingStrategy):
56
+ def __init__(self, num_masks: int = 1):
57
+ self.num_masks = num_masks
58
+
59
+ def select_words_to_mask(self, words: List[str], sentence: str, model, tokenizer, permissible_indices) -> List[int]:
60
+ candidate_indices = [i for i, word in enumerate(words) if word.lower() not in set(stopwords.words('english'))]
61
+ if not candidate_indices:
62
+ return []
63
+
64
+ entropy_scores = {}
65
+ for idx in candidate_indices:
66
+ masked_sentence = ' '.join(words[:idx] + ['[MASK]'] + words[idx+1:])
67
+ logits = get_logits_for_mask(model, tokenizer, masked_sentence)
68
+ filtered_logits = filter_logits(logits, permissible_indices)
69
+ probs = torch.softmax(filtered_logits, dim=-1)
70
+ top_5_probs = probs.topk(5).values
71
+ entropy = -torch.sum(top_5_probs * torch.log(top_5_probs + 1e-10)).item()
72
+ entropy_scores[idx] = entropy
73
+
74
+ # Select top N indices with highest entropy
75
+ sorted_indices = sorted(entropy_scores, key=entropy_scores.get, reverse=True)
76
+ return sorted_indices[:self.num_masks]
77
+
78
+ class PseudoRandomNonStopwordMasking(MaskingStrategy):
79
+ def __init__(self, num_masks: int = 1, seed: int = 10):
80
+ self.num_masks = num_masks
81
+ self.seed = seed
82
+ self.stop_words = set(stopwords.words('english'))
83
+
84
+ def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
85
+ non_stop_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
86
+ if not non_stop_indices:
87
+ return []
88
+ random.seed(self.seed)
89
+ num_masks = min(self.num_masks, len(non_stop_indices))
90
+ return random.sample(non_stop_indices, num_masks)
91
+
92
+ class CompositeMaskingStrategy(MaskingStrategy):
93
+ def __init__(self, strategies: List[MaskingStrategy]):
94
+ self.strategies = strategies
95
+
96
+ def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
97
+ selected_indices = []
98
+ for strategy in self.strategies:
99
+ if isinstance(strategy, HighEntropyMasking):
100
+ selected = strategy.select_words_to_mask(words, **kwargs)
101
+ else:
102
+ selected = strategy.select_words_to_mask(words)
103
+ selected_indices.extend(selected)
104
+ return list(set(selected_indices)) # Remove duplicates
105
+
106
+ # Refactored mask_between_lcs function
107
+ def mask_between_lcs(sentence, lcs_points, masking_strategy: MaskingStrategy, model, tokenizer, permissible_indices):
108
+ words = sentence.split()
109
+ masked_indices = []
110
+
111
+ segments = []
112
+
113
+ # Define segments based on LCS points
114
+ previous = 0
115
+ for point in lcs_points:
116
+ if point > previous:
117
+ segments.append((previous, point))
118
+ previous = point + 1
119
+ if previous < len(words):
120
+ segments.append((previous, len(words)))
121
+
122
+ # Collect all indices to mask from each segment
123
+ for start, end in segments:
124
+ segment_words = words[start:end]
125
+ if isinstance(masking_strategy, HighEntropyMasking):
126
+ selected = masking_strategy.select_words_to_mask(segment_words, sentence, model, tokenizer, permissible_indices)
127
+ else:
128
+ selected = masking_strategy.select_words_to_mask(segment_words)
129
+
130
+ # Adjust indices relative to the whole sentence
131
+ for idx in selected:
132
+ masked_idx = start + idx
133
+ if masked_idx not in masked_indices:
134
+ masked_indices.append(masked_idx)
135
+
136
+ # Apply masking
137
+ for idx in masked_indices:
138
+ words[idx] = '[MASK]'
139
+
140
+ masked_sentence = ' '.join(words)
141
+ logits = get_logits_for_mask(model, tokenizer, masked_sentence)
142
+
143
+ # Process each masked token
144
+ top_words_list = []
145
+ logits_list = []
146
+ for i, idx in enumerate(masked_indices):
147
+ logits_i = logits[i]
148
+ if logits_i.dim() > 1:
149
+ logits_i = logits_i.squeeze()
150
+ filtered_logits_i = filter_logits(logits_i, permissible_indices)
151
+ logits_list.append(filtered_logits_i.tolist())
152
+ top_5_indices = filtered_logits_i.topk(5).indices.tolist()
153
+ top_words = [tokenizer.decode([i]) for i in top_5_indices]
154
+ top_words_list.append(top_words)
155
+
156
+ return masked_sentence, logits_list, top_words_list
157
+
158
+ # Example Usage
159
+ if __name__ == "__main__":
160
+ # Example sentence and LCS points
161
+ sentence = "This is a sample sentence with some LCS points"
162
+ lcs_points = [2, 5, 8] # Indices of LCS points
163
+
164
+ # Initialize masking strategies
165
+ random_non_stopword_strategy = RandomNonStopwordMasking(num_masks=1)
166
+ high_entropy_strategy = HighEntropyMasking(num_masks=1)
167
+ pseudo_random_strategy = PseudoRandomNonStopwordMasking(num_masks=1, seed=10)
168
+ composite_strategy = CompositeMaskingStrategy([
169
+ RandomNonStopwordMasking(num_masks=1),
170
+ HighEntropyMasking(num_masks=1)
171
+ ])
172
+
173
+ # Choose a strategy
174
+ chosen_strategy = composite_strategy # You can choose any initialized strategy
175
+
176
+ # Apply masking
177
+ masked_sentence, logits_list, top_words_list = mask_between_lcs(
178
+ sentence,
179
+ lcs_points,
180
+ masking_strategy=chosen_strategy,
181
+ model=model,
182
+ tokenizer=tokenizer,
183
+ permissible_indices=permissible_indices
184
+ )
185
+
186
+ print("Masked Sentence:", masked_sentence)
187
+ for idx, top_words in enumerate(top_words_list):
188
+ print(f"Top words for mask {idx+1}:", top_words)
requirements.txt CHANGED
@@ -17,4 +17,5 @@ graphviz==0.20.3
17
  gradio==4.29.0
18
  openai
19
  python-dotenv
20
- scikit-learn
 
 
17
  gradio==4.29.0
18
  openai
19
  python-dotenv
20
+ scikit-learn
21
+ sentence-transformers