import random from transformers import AutoTokenizer, AutoModelForMaskedLM import torch tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") def split_vocabulary(seed=42): # Initialize the tokenizer and model tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") # Get the full vocabulary vocab = list(tokenizer.get_vocab().items()) # Initialize the random number generator random.seed(seed) # Split the vocabulary into permissible and non-permissible buckets permissible = {} non_permissible = {} for word, index in vocab: if random.random() < 0.5: # 50% chance of being permissible permissible[word] = index else: non_permissible[word] = index return permissible, non_permissible def get_logits_for_mask(model, tokenizer, sentence): inputs = tokenizer(sentence, return_tensors="pt") mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits mask_token_logits = logits[0, mask_token_index, :] return mask_token_logits.squeeze() def filter_logits(logits, permissible_indices): filtered_logits = logits.clone() if filtered_logits.dim() > 1: filtered_logits = filtered_logits.squeeze() if filtered_logits.shape != permissible_indices.shape: permissible_indices = permissible_indices[:filtered_logits.shape[0]] filtered_logits[~permissible_indices] = float('-inf') return filtered_logits # Usage example permissible, non_permissible = split_vocabulary(seed=42) permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))]) # When sampling: sentence = "The [MASK] is bright today." logits = get_logits_for_mask(model, tokenizer, sentence) filtered_logits = filter_logits(logits, permissible_indices) # Use filtered_logits for sampling