Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline | |
import random | |
from nltk.corpus import stopwords | |
import nltk | |
from vocabulary_split import split_vocabulary, filter_logits | |
# Load tokenizer and model for masked language model | |
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking") | |
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking") | |
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer) | |
# Get permissible vocabulary | |
permissible, _ = split_vocabulary(seed=42) | |
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))]) | |
# Initialize stop words and ensure NLTK resources are downloaded | |
stop_words = set(stopwords.words('english')) | |
nltk.download('averaged_perceptron_tagger', quiet=True) | |
nltk.download('maxent_ne_chunker', quiet=True) | |
nltk.download('words', quiet=True) | |
def get_logits_for_mask(sentence): | |
inputs = tokenizer(sentence, return_tensors="pt") | |
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
return logits[0, mask_token_index, :].squeeze() | |
def mask_word(sentence, word): | |
masked_sentence = sentence.replace(word, '[MASK]', 1) | |
logits = get_logits_for_mask(masked_sentence) | |
filtered_logits = filter_logits(logits, permissible_indices) | |
words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]] | |
return masked_sentence, filtered_logits.tolist(), words | |
def mask_non_stopword(sentence, pseudo_random=False): | |
non_stop_words = [word for word in sentence.split() if word.lower() not in stop_words] | |
if not non_stop_words: | |
return sentence, None, None | |
if pseudo_random: | |
random.seed(10) # Fixed seed for pseudo-randomness | |
word_to_mask = random.choice(non_stop_words) | |
return mask_word(sentence, word_to_mask) | |
def mask_between_lcs(sentence, lcs_points): | |
words = sentence.split() | |
masked_indices = [] | |
# Mask first word before the first LCS point | |
if lcs_points and lcs_points[0] > 0: | |
idx = random.randint(0, lcs_points[0] - 1) | |
words[idx] = '[MASK]' | |
masked_indices.append(idx) | |
# Mask between LCS points | |
for i in range(len(lcs_points) - 1): | |
start, end = lcs_points[i], lcs_points[i + 1] | |
if end - start > 1: | |
mask_index = random.randint(start + 1, end - 1) | |
words[mask_index] = '[MASK]' | |
masked_indices.append(mask_index) | |
# Mask last word after the last LCS point | |
if lcs_points and lcs_points[-1] < len(words) - 1: | |
idx = random.randint(lcs_points[-1] + 1, len(words) - 1) | |
words[idx] = '[MASK]' | |
masked_indices.append(idx) | |
masked_sentence = ' '.join(words) | |
logits = get_logits_for_mask(masked_sentence) | |
logits_list, top_words_list = [], [] | |
for idx in masked_indices: | |
filtered_logits = filter_logits(logits[idx], permissible_indices) | |
logits_list.append(filtered_logits.tolist()) | |
top_words = [tokenizer.decode([i]) for i in filtered_logits.topk(5).indices.tolist()] | |
top_words_list.append(top_words) | |
return masked_sentence, logits_list, top_words_list | |
def high_entropy_words(sentence, non_melting_points): | |
non_melting_words = {word.lower() for _, point in non_melting_points for word in point.split()} | |
candidate_words = [word for word in sentence.split() if word.lower() not in stop_words and word.lower() not in non_melting_words] | |
if not candidate_words: | |
return sentence, None, None | |
max_entropy, max_entropy_word, max_logits = -float('inf'), None, None | |
for word in candidate_words: | |
masked_sentence = sentence.replace(word, '[MASK]', 1) | |
logits = get_logits_for_mask(masked_sentence) | |
filtered_logits = filter_logits(logits, permissible_indices) | |
# Calculate entropy | |
probs = torch.softmax(filtered_logits, dim=-1) | |
top_5_probs = probs.topk(5).values | |
entropy = -torch.sum(top_5_probs * torch.log(top_5_probs + 1e-10)) # Avoid log(0) | |
if entropy > max_entropy: | |
max_entropy, max_entropy_word, max_logits = entropy, word, filtered_logits | |
if max_entropy_word is None: | |
return sentence, None, None | |
masked_sentence = sentence.replace(max_entropy_word, '[MASK]', 1) | |
words = [tokenizer.decode([i]) for i in max_logits.argsort()[-5:]] | |
return masked_sentence, max_logits.tolist(), words | |
def mask_by_pos(sentence, pos_to_mask=['NOUN', 'VERB', 'ADJ']): | |
words = nltk.word_tokenize(sentence) | |
pos_tags = nltk.pos_tag(words) | |
maskable_words = [word for word, pos in pos_tags if pos[:2] in pos_to_mask] | |
if not maskable_words: | |
return sentence, None, None | |
word_to_mask = random.choice(maskable_words) | |
return mask_word(sentence, word_to_mask) | |
def mask_named_entity(sentence): | |
words = nltk.word_tokenize(sentence) | |
pos_tags = nltk.pos_tag(words) | |
named_entities = nltk.ne_chunk(pos_tags) | |
maskable_words = [word for word, tag in named_entities.leaves() if isinstance(tag, nltk.Tree)] | |
if not maskable_words: | |
return sentence, None, None | |
word_to_mask = random.choice(maskable_words) | |
return mask_word(sentence, word_to_mask) | |