Spaces:
Running
Running
import random | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
import torch | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") | |
def split_vocabulary(seed=42): | |
# Initialize the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") | |
# Get the full vocabulary | |
vocab = list(tokenizer.get_vocab().items()) | |
# Initialize the random number generator | |
random.seed(seed) | |
# Split the vocabulary into permissible and non-permissible buckets | |
permissible = {} | |
non_permissible = {} | |
for word, index in vocab: | |
if random.random() < 0.5: # 50% chance of being permissible | |
permissible[word] = index | |
else: | |
non_permissible[word] = index | |
return permissible, non_permissible | |
def get_logits_for_mask(model, tokenizer, sentence): | |
inputs = tokenizer(sentence, return_tensors="pt") | |
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
mask_token_logits = logits[0, mask_token_index, :] | |
return mask_token_logits.squeeze() | |
def filter_logits(logits, permissible_indices): | |
filtered_logits = logits.clone() | |
if filtered_logits.dim() > 1: | |
filtered_logits = filtered_logits.squeeze() | |
if filtered_logits.shape != permissible_indices.shape: | |
permissible_indices = permissible_indices[:filtered_logits.shape[0]] | |
filtered_logits[~permissible_indices] = float('-inf') | |
return filtered_logits | |
# Usage example | |
permissible, non_permissible = split_vocabulary(seed=42) | |
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))]) | |
# When sampling: | |
sentence = "The [MASK] is bright today." | |
logits = get_logits_for_mask(model, tokenizer, sentence) | |
filtered_logits = filter_logits(logits, permissible_indices) | |
# Use filtered_logits for sampling |