Spaces:
Sleeping
Sleeping
| import random | |
| import torch | |
| from transformers import BertTokenizer, BertForMaskedLM | |
| from nltk.corpus import stopwords | |
| import nltk | |
| from transformers import RobertaTokenizer, RobertaForMaskedLM | |
| # Ensure stopwords are downloaded | |
| try: | |
| nltk.data.find('corpora/stopwords') | |
| except LookupError: | |
| nltk.download('stopwords') | |
| class MaskingProcessor: | |
| # def __init__(self, tokenizer, model): | |
| def __init__(self): | |
| # self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| # self.model = BertForMaskedLM.from_pretrained("bert-base-uncased") | |
| # self.tokenizer = tokenizer | |
| # self.model = model | |
| self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking") | |
| self.model = BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking") | |
| # self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base") | |
| # self.model = RobertaForMaskedLM.from_pretrained("roberta-base") | |
| self.stop_words = set(stopwords.words('english')) | |
| def remove_stopwords(self, words): | |
| """ | |
| Remove stopwords from the given list of words. | |
| Args: | |
| words (list): List of words. | |
| Returns: | |
| list: List of non-stop words. | |
| """ | |
| return [word for word in words if word.lower() not in self.stop_words] | |
| def adjust_ngram_indices(self, original_words, common_ngrams): | |
| """ | |
| Adjust indices of common n-grams after removing stopwords. | |
| Args: | |
| original_words (list): Original list of words. | |
| common_ngrams (dict): Common n-grams and their indices. | |
| Returns: | |
| dict: Adjusted common n-grams with updated indices. | |
| """ | |
| non_stop_words = self.remove_stopwords(original_words) | |
| original_to_non_stop = [] | |
| non_stop_idx = 0 | |
| for original_idx, word in enumerate(original_words): | |
| if word.lower() not in self.stop_words: | |
| original_to_non_stop.append((original_idx, non_stop_idx)) | |
| non_stop_idx += 1 | |
| adjusted_ngrams = {} | |
| for ngram, positions in common_ngrams.items(): | |
| adjusted_positions = [] | |
| for start, end in positions: | |
| try: | |
| new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start) | |
| new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end) | |
| adjusted_positions.append((new_start, new_end)) | |
| except StopIteration: | |
| continue # Skip if indices cannot be mapped | |
| adjusted_ngrams[ngram] = adjusted_positions | |
| return adjusted_ngrams | |
| def mask_sentence_random(self, sentence, common_ngrams): | |
| """ | |
| Mask words in the sentence based on the specified rules after removing stopwords. | |
| """ | |
| # Split sentence into words | |
| original_words = sentence.split() | |
| # Handle punctuation at the end | |
| has_punctuation = False | |
| punctuation = None | |
| if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']): | |
| has_punctuation = True | |
| punctuation = original_words[-1][-1] | |
| original_words = original_words[:-1] | |
| print(f' ---- original_words : {original_words} ----- ') | |
| # Process words without punctuation | |
| non_stop_words = self.remove_stopwords(original_words) | |
| adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) | |
| # Rest of the existing function code... | |
| mask_indices = [] | |
| ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] | |
| if ngram_positions: | |
| first_ngram_start = ngram_positions[0][0] | |
| if first_ngram_start > 0: | |
| mask_index_before_ngram = random.randint(0, first_ngram_start-1) | |
| mask_indices.append(mask_index_before_ngram) | |
| # Mask words between common n-grams | |
| for i in range(len(ngram_positions) - 1): | |
| end_prev = ngram_positions[i][1] | |
| start_next = ngram_positions[i + 1][0] | |
| if start_next > end_prev + 1: | |
| mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) | |
| mask_indices.append(mask_index_between_ngrams) | |
| # Mask a word after the last common n-gram | |
| last_ngram_end = ngram_positions[-1][1] | |
| if last_ngram_end < len(non_stop_words) - 1: | |
| mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) | |
| mask_indices.append(mask_index_after_ngram) | |
| # Create mapping from non-stop words to original indices | |
| non_stop_to_original = {} | |
| non_stop_idx = 0 | |
| for orig_idx, word in enumerate(original_words): | |
| if word.lower() not in self.stop_words: | |
| non_stop_to_original[non_stop_idx] = orig_idx | |
| non_stop_idx += 1 | |
| # Map mask indices and apply masks | |
| original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] | |
| masked_words = original_words.copy() | |
| for idx in original_mask_indices: | |
| masked_words[idx] = self.tokenizer.mask_token | |
| # masked_words[idx] = '<mask>' # for roberta | |
| # Add back punctuation if it existed | |
| if has_punctuation: | |
| masked_words.append(punctuation) | |
| print(f' ***** masked_words at end : {masked_words} ***** ') | |
| print(f' ***** original_mask_indices : {original_mask_indices} ***** ') | |
| print(f' ***** TESTING : {" ".join(masked_words)} ***** ') | |
| return " ".join(masked_words), original_mask_indices | |
| def mask_sentence_pseudorandom(self, sentence, common_ngrams): | |
| """ | |
| Mask words in the sentence based on the specified rules after removing stopwords. | |
| """ | |
| # Split sentence into words | |
| random.seed(3) | |
| original_words = sentence.split() | |
| # Handle punctuation at the end | |
| has_punctuation = False | |
| punctuation = None | |
| if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']): | |
| has_punctuation = True | |
| punctuation = original_words[-1][-1] | |
| original_words = original_words[:-1] | |
| print(f' ---- original_words : {original_words} ----- ') | |
| # Process words without punctuation | |
| non_stop_words = self.remove_stopwords(original_words) | |
| adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) | |
| # Rest of the existing function code... | |
| mask_indices = [] | |
| ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] | |
| if ngram_positions: | |
| first_ngram_start = ngram_positions[0][0] | |
| if first_ngram_start > 0: | |
| mask_index_before_ngram = random.randint(0, first_ngram_start-1) | |
| mask_indices.append(mask_index_before_ngram) | |
| # Mask words between common n-grams | |
| for i in range(len(ngram_positions) - 1): | |
| end_prev = ngram_positions[i][1] | |
| start_next = ngram_positions[i + 1][0] | |
| if start_next > end_prev + 1: | |
| mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) | |
| mask_indices.append(mask_index_between_ngrams) | |
| # Mask a word after the last common n-gram | |
| last_ngram_end = ngram_positions[-1][1] | |
| if last_ngram_end < len(non_stop_words) - 1: | |
| mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) | |
| mask_indices.append(mask_index_after_ngram) | |
| # Create mapping from non-stop words to original indices | |
| non_stop_to_original = {} | |
| non_stop_idx = 0 | |
| for orig_idx, word in enumerate(original_words): | |
| if word.lower() not in self.stop_words: | |
| non_stop_to_original[non_stop_idx] = orig_idx | |
| non_stop_idx += 1 | |
| # Map mask indices and apply masks | |
| original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] | |
| masked_words = original_words.copy() | |
| for idx in original_mask_indices: | |
| masked_words[idx] = self.tokenizer.mask_token | |
| # masked_words[idx] = '<mask>' # for roberta | |
| # Add back punctuation if it existed | |
| if has_punctuation: | |
| masked_words.append(punctuation) | |
| print(f' ***** masked_words at end : {masked_words} ***** ') | |
| print(f' ***** original_mask_indices : {original_mask_indices} ***** ') | |
| print(f' ***** TESTING : {" ".join(masked_words)} ***** ') | |
| return " ".join(masked_words), original_mask_indices | |
| def calculate_word_entropy(self, sentence, word_position): | |
| """ | |
| Calculate entropy for a specific word position in the sentence. | |
| Args: | |
| sentence (str): The input sentence | |
| word_position (int): Position of the word to calculate entropy for | |
| Returns: | |
| float: Entropy value for the word | |
| """ | |
| words = sentence.split() | |
| masked_words = words.copy() | |
| masked_words[word_position] = self.tokenizer.mask_token | |
| masked_sentence = " ".join(masked_words) | |
| input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] | |
| mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids) | |
| logits = outputs.logits | |
| # Get probabilities for the masked position | |
| probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1) | |
| # Calculate entropy: -sum(p * log(p)) | |
| entropy = -torch.sum(probs * torch.log(probs + 1e-9)) | |
| return entropy.item() | |
| def mask_sentence_entropy(self, sentence, common_ngrams): | |
| """ | |
| Mask words in the sentence based on entropy, following n-gram positioning rules. | |
| Args: | |
| sentence (str): Original sentence | |
| common_ngrams (dict): Common n-grams and their indices | |
| Returns: | |
| str: Masked sentence | |
| """ | |
| # Split sentence into words | |
| original_words = sentence.split() | |
| # Handle punctuation at the end | |
| has_punctuation = False | |
| punctuation = None | |
| if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']): | |
| has_punctuation = True | |
| punctuation = original_words[-1][-1] | |
| original_words = original_words[:-1] | |
| # Process words without punctuation | |
| non_stop_words = self.remove_stopwords(original_words) | |
| adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) | |
| # Create mapping from non-stop words to original indices | |
| non_stop_to_original = {} | |
| original_to_non_stop = {} | |
| non_stop_idx = 0 | |
| for orig_idx, word in enumerate(original_words): | |
| if word.lower() not in self.stop_words: | |
| non_stop_to_original[non_stop_idx] = orig_idx | |
| original_to_non_stop[orig_idx] = non_stop_idx | |
| non_stop_idx += 1 | |
| ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] | |
| mask_indices = [] | |
| if ngram_positions: | |
| # Handle words before first n-gram | |
| first_ngram_start = ngram_positions[0][0] | |
| if first_ngram_start > 0: | |
| candidate_positions = range(0, first_ngram_start) | |
| entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) | |
| for pos in candidate_positions] | |
| mask_indices.append(max(entropies, key=lambda x: x[1])[0]) | |
| # Handle words between n-grams | |
| for i in range(len(ngram_positions) - 1): | |
| end_prev = ngram_positions[i][1] | |
| start_next = ngram_positions[i + 1][0] | |
| if start_next > end_prev + 1: | |
| candidate_positions = range(end_prev + 1, start_next) | |
| entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) | |
| for pos in candidate_positions] | |
| mask_indices.append(max(entropies, key=lambda x: x[1])[0]) | |
| # Handle words after last n-gram | |
| last_ngram_end = ngram_positions[-1][1] | |
| if last_ngram_end < len(non_stop_words) - 1: | |
| candidate_positions = range(last_ngram_end + 1, len(non_stop_words)) | |
| entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) | |
| for pos in candidate_positions] | |
| mask_indices.append(max(entropies, key=lambda x: x[1])[0]) | |
| # Map mask indices to original sentence positions and apply masks | |
| original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] | |
| masked_words = original_words.copy() | |
| for idx in original_mask_indices: | |
| masked_words[idx] = self.tokenizer.mask_token | |
| # Add back punctuation if it existed | |
| if has_punctuation: | |
| masked_words.append(punctuation) | |
| return " ".join(masked_words), original_mask_indices | |
| def calculate_mask_logits(self, original_sentence, original_mask_indices): | |
| """ | |
| Calculate logits for masked tokens in the sentence using BERT. | |
| Args: | |
| original_sentence (str): Original sentence without masks | |
| original_mask_indices (list): List of indices to mask | |
| Returns: | |
| dict: Masked token indices and their logits | |
| """ | |
| print('==========================================================================================================') | |
| words = original_sentence.split() | |
| print(f' ##### calculate_mask_logits >> words : {words} ##### ') | |
| mask_logits = {} | |
| for idx in original_mask_indices: | |
| # Create a copy of words and mask the current position | |
| print(f' ---- idx : {idx} ----- ') | |
| masked_words = words.copy() | |
| masked_words[idx] = '[MASK]' | |
| # masked_words[idx] = '<mask>' # for roberta | |
| masked_sentence = " ".join(masked_words) | |
| print(f' ---- masked_sentence : {masked_sentence} ----- ') | |
| # Calculate logits for the current mask | |
| input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] | |
| mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids) | |
| logits = outputs.logits | |
| # Extract logits for the masked position | |
| mask_logits_tensor = logits[0, mask_token_index, :] | |
| # Get top logits and corresponding tokens | |
| top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1) # Get more candidates | |
| # Convert token IDs to words and filter out subword tokens | |
| top_tokens = [] | |
| top_logits = [] | |
| seen_words = set() # To keep track of unique words | |
| for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]): | |
| token = self.tokenizer.convert_ids_to_tokens(token_id.item()) | |
| # Skip if it's a subword token (starts with ##) | |
| if token.startswith('##'): | |
| continue | |
| # Convert token to proper word | |
| word = self.tokenizer.convert_tokens_to_string([token]).strip() | |
| # Only add if it's a new word and not empty | |
| if word and word not in seen_words: | |
| seen_words.add(word) | |
| top_tokens.append(word) | |
| top_logits.append(logit.item()) | |
| # Break if we have 50 unique complete words | |
| if len(top_tokens) == 50: | |
| break | |
| # print(f' ---- top_tokens : {top_tokens} ----- ') | |
| # Store results | |
| mask_logits[idx] = { | |
| "tokens": top_tokens, | |
| "logits": top_logits | |
| } | |
| return mask_logits | |
| # def calculate_mask_logits(self, original_sentence, original_mask_indices): | |
| # """ | |
| # Calculate logits for masked tokens in the sentence using BERT. | |
| # Args: | |
| # original_sentence (str): Original sentence without masks | |
| # original_mask_indices (list): List of indices to mask | |
| # Returns: | |
| # dict: Masked token indices and their logits | |
| # """ | |
| # words = original_sentence.split() | |
| # print(f' ##### calculate_mask_logits >> words : {words} ##### ') | |
| # mask_logits = {} | |
| # for idx in original_mask_indices: | |
| # # Create a copy of words and mask the current position | |
| # print(f' ---- idx : {idx} ----- ') | |
| # masked_words = words.copy() | |
| # print(f' ---- words : {masked_words} ----- ') | |
| # # masked_words[idx] = self.tokenizer.mask_token | |
| # masked_words[idx] = '[MASK]' | |
| # print(f' ---- masked_words : {masked_words} ----- ') | |
| # masked_sentence = " ".join(masked_words) | |
| # print(f' ---- masked_sentence : {masked_sentence} ----- ') | |
| # # Calculate logits for the current mask | |
| # input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] | |
| # mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] | |
| # with torch.no_grad(): | |
| # outputs = self.model(input_ids) | |
| # logits = outputs.logits | |
| # # Extract logits for the masked position | |
| # mask_logits_tensor = logits[0, mask_token_index, :] | |
| # # Get top 50 logits and corresponding tokens | |
| # top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 50, dim=-1) | |
| # # Convert token IDs to words | |
| # top_tokens = [self.tokenizer.convert_ids_to_tokens(token_id.item()) for token_id in top_mask_indices[0]] | |
| # print(f' ---- top_tokens : {top_tokens} ----- ') | |
| # # Store results | |
| # mask_logits[idx] = { | |
| # "tokens": top_tokens, | |
| # "logits": top_mask_logits.tolist() | |
| # } | |
| # return mask_logits | |
| def process_sentences(self, sentences, result_dict, method="random"): | |
| """ | |
| Process sentences and calculate logits for masked tokens. | |
| """ | |
| results = {} | |
| for sentence, ngrams in result_dict.items(): | |
| # Split punctuation from the last word before processing | |
| words = sentence.split() | |
| last_word = words[-1] | |
| if any(last_word.endswith(p) for p in ['.', ',', '!', '?', ';', ':']): | |
| # Split the last word and punctuation | |
| words[-1] = last_word[:-1] | |
| punctuation = last_word[-1] | |
| # Rejoin with space before punctuation to treat it as separate token | |
| processed_sentence = " ".join(words) + " " + punctuation | |
| else: | |
| processed_sentence = sentence | |
| if method == "random": | |
| masked_sentence, original_mask_indices = self.mask_sentence_random(processed_sentence, ngrams) | |
| elif method == "pseudorandom": | |
| masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(processed_sentence, ngrams) | |
| else: # entropy | |
| masked_sentence, original_mask_indices = self.mask_sentence_entropy(processed_sentence, ngrams) | |
| logits = self.calculate_mask_logits(processed_sentence, original_mask_indices) | |
| results[sentence] = { | |
| "masked_sentence": masked_sentence, | |
| "mask_logits": logits | |
| } | |
| return results | |
| if __name__ == "__main__": | |
| # !!! Working both the cases regardless if the stopword is removed or not | |
| sentences = [ | |
| "The quick brown fox jumps over small cat the lazy dog everyday again and again .", | |
| # "A speedy brown fox jumps over a lazy dog.", | |
| # "A swift brown fox leaps over the lethargic dog." | |
| ] | |
| result_dict ={ | |
| 'The quick brown fox jumps over small cat the lazy dog everyday again and again .': {'brown fox': [(2, 3)],'cat': [(7, 7)], 'dog': [(10, 10)]}, | |
| # 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, | |
| # 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]} | |
| } | |
| processor = MaskingProcessor() | |
| # results_random = processor.process_sentences(sentences, result_dict) | |
| results_entropy = processor.process_sentences(sentences, result_dict, method="random") | |
| ''' | |
| results structure : | |
| results = { | |
| "The quick brown fox jumps over the lazy dog everyday.": | |
| { # Original sentence as key | |
| "masked_sentence": str, # The sentence with [MASK] tokens | |
| "mask_logits": | |
| { # Dictionary of mask positions and their predictions | |
| 1: | |
| { # Position of mask in sentence | |
| "tokens" (words) : list, # List of top 50 predicted tokens | |
| "logits" (probabilities) : list # Corresponding logits for those tokens | |
| }, | |
| 7: | |
| { | |
| "tokens" (words) : list, | |
| "logits" (probabilities) : list | |
| }, | |
| 10: | |
| { | |
| "tokens (words)": list, | |
| "logits (probabilities)": list | |
| } | |
| } | |
| } | |
| } | |
| ''' | |
| # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False) | |
| for sentence, output in results_entropy.items(): | |
| print(f"Original Sentence (Random): {sentence}") | |
| print(f"Masked Sentence (Random): {output['masked_sentence']}") | |
| # print(f"Mask Logits (Random): {output['mask_logits']}") | |
| # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}') | |
| # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}') | |
| # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}') | |
| # print('--------------------------------') | |
| # for mask_idx, logits in output["mask_logits"].items(): | |
| # print(f"Logits for [MASK] at position {mask_idx}:") | |
| # print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens | |
| # print(f' len(logits) : {len(logits)}') | |
| # ------------------------------------------------------------------------------------------------ | |
| # def mask_sentence_random(self, sentence, common_ngrams): | |
| # """ | |
| # Mask words in the sentence based on the specified rules after removing stopwords. | |
| # """ | |
| # original_words = sentence.split() | |
| # # print(f' ---- original_words : {original_words} ----- ') | |
| # non_stop_words = self.remove_stopwords(original_words) | |
| # # print(f' ---- non_stop_words : {non_stop_words} ----- ') | |
| # adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) | |
| # # print(f' ---- common_ngrams : {common_ngrams} ----- ') | |
| # # print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ') | |
| # mask_indices = [] | |
| # # Extract n-gram positions in non-stop words | |
| # ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] | |
| # # Mask a word before the first common n-gram | |
| # if ngram_positions: | |
| # # print(f' ---- ngram_positions : {ngram_positions} ----- ') | |
| # first_ngram_start = ngram_positions[0][0] | |
| # # print(f' ---- first_ngram_start : {first_ngram_start} ----- ') | |
| # if first_ngram_start > 0: | |
| # mask_index_before_ngram = random.randint(0, first_ngram_start-1) | |
| # # print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ') | |
| # mask_indices.append(mask_index_before_ngram) | |
| # # Mask words between common n-grams | |
| # for i in range(len(ngram_positions) - 1): | |
| # end_prev = ngram_positions[i][1] | |
| # # print(f' ---- end_prev : {end_prev} ----- ') | |
| # start_next = ngram_positions[i + 1][0] | |
| # # print(f' ---- start_next : {start_next} ----- ') | |
| # if start_next > end_prev + 1: | |
| # mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) | |
| # # print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ') | |
| # mask_indices.append(mask_index_between_ngrams) | |
| # # Mask a word after the last common n-gram | |
| # last_ngram_end = ngram_positions[-1][1] | |
| # if last_ngram_end < len(non_stop_words) - 1: | |
| # # print(f' ---- last_ngram_end : {last_ngram_end} ----- ') | |
| # mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) | |
| # # print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ') | |
| # mask_indices.append(mask_index_after_ngram) | |
| # # Create mapping from non-stop words to original indices | |
| # non_stop_to_original = {} | |
| # non_stop_idx = 0 | |
| # for orig_idx, word in enumerate(original_words): | |
| # if word.lower() not in self.stop_words: | |
| # non_stop_to_original[non_stop_idx] = orig_idx | |
| # non_stop_idx += 1 | |
| # # Map mask indices from non-stop word positions to original positions | |
| # # print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ') | |
| # original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] | |
| # # print(f' ---- original_mask_indices : {original_mask_indices} ----- ') | |
| # # Apply masks to the original sentence | |
| # masked_words = original_words.copy() | |
| # for idx in original_mask_indices: | |
| # masked_words[idx] = self.tokenizer.mask_token | |
| # return " ".join(masked_words), original_mask_indices | |