Spaces:
Sleeping
Sleeping
import torch | |
import random | |
import logging | |
from utils.masking_methods import MaskingProcessor | |
from tqdm import tqdm | |
# Configure logging to suppress INFO-level messages on the console. | |
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
class SamplingProcessor: | |
def __init__(self, tokenizer): | |
""" | |
Initialize the SamplingProcessor. | |
Args: | |
tokenizer: BERT tokenizer instance | |
""" | |
self.tokenizer = tokenizer | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tqdm.write(f"[SamplingProcessor] Initialized on device: {self.device}") | |
def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0): | |
""" | |
Sample tokens for each mask in the sentence using the specified sampling technique. | |
Args: | |
mask_logits_dict (dict): Dictionary of mask positions and their logits/tokens | |
masked_sentence (str): Sentence with [MASK] tokens | |
sampling_technique (str): Sampling method to use | |
temperature (float): Temperature parameter for sampling | |
Returns: | |
str: Sentence with sampled tokens replacing masks | |
""" | |
tqdm.write(f"[SamplingProcessor] Sampling tokens for: {masked_sentence}") | |
print(f"[SamplingProcessor] Sampling tokens for: {masked_sentence}") | |
words = masked_sentence.split() | |
print(f"words: {words}") | |
# Convert positions and logits to sorted list to process masks in order | |
mask_positions = sorted(mask_logits_dict.keys()) | |
print(f"mask_positions: {mask_positions}") | |
for mask_pos in mask_positions: | |
mask_data = mask_logits_dict[mask_pos] | |
# Move logits tensor to GPU | |
mask_logits = torch.tensor(mask_data['logits']).to(self.device) | |
candidate_tokens = mask_data['tokens'] | |
try: | |
if sampling_technique == "inverse_transform": | |
probs = torch.softmax(mask_logits / temperature, dim=-1) | |
cumulative_probs = torch.cumsum(probs, dim=-1) | |
random_prob = random.random() | |
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() | |
elif sampling_technique == "exponential_minimum": | |
probs = torch.softmax(mask_logits / temperature, dim=-1) | |
exp_probs = torch.exp(-torch.log(probs)) | |
random_probs = torch.rand_like(exp_probs) | |
sampled_index = torch.argmax(random_probs * exp_probs).item() | |
elif sampling_technique == "temperature": | |
mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8) | |
probs = torch.softmax(mask_logits / temperature, dim=-1) | |
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): | |
raise ValueError("The computed probabilities contain NaN or inf values.") | |
probs = torch.max(probs, torch.tensor(1e-8).to(self.device)) | |
probs = probs / torch.sum(probs) | |
probs = probs.flatten() | |
if probs.size(0) > 1: | |
sampled_index = torch.multinomial(probs, 1).item() | |
else: | |
sampled_index = torch.argmax(probs).item() | |
elif sampling_technique == 'greedy': | |
sampled_index = torch.argmax(mask_logits).item() | |
elif sampling_technique == 'tournament': | |
# Apply temperature and get probabilities | |
probs = torch.softmax(mask_logits / temperature, dim=-1) | |
# Number of candidates to select for the tournament | |
num_candidates = min(5, len(candidate_tokens)) | |
# Sample candidates based on their probabilities | |
if probs.size(0) > num_candidates: | |
candidate_indices = torch.multinomial(probs, num_candidates, replacement=False) | |
else: | |
# If we have fewer tokens than the number of candidates, | |
# just use all available tokens | |
candidate_indices = torch.arange(probs.size(0)) | |
# Run tournament rounds | |
while candidate_indices.size(0) > 1: | |
next_round = [] | |
# Process pairs of candidates | |
for i in range(0, candidate_indices.size(0), 2): | |
# If we have an odd number of candidates, the last one gets a bye | |
if i + 1 >= candidate_indices.size(0): | |
next_round.append(candidate_indices[i].item()) | |
continue | |
candidate1, candidate2 = candidate_indices[i], candidate_indices[i+1] | |
prob1, prob2 = probs[candidate1], probs[candidate2] | |
# Winner determined by probability comparison | |
winner = candidate1 if prob1 > prob2 else candidate2 | |
next_round.append(winner.item()) | |
# Update candidates for next round | |
candidate_indices = torch.tensor(next_round).to(self.device) | |
# The remaining candidate is our winner | |
sampled_index = candidate_indices[0].item() | |
else: | |
raise ValueError(f"Unknown sampling technique: {sampling_technique}") | |
# Use the sampled index to get the corresponding token | |
sampled_token = candidate_tokens[sampled_index] | |
# Remove ## if it's a subword token | |
sampled_token = sampled_token.replace('##', '') | |
words[mask_pos] = sampled_token | |
logger.info(f"Sampled token '{sampled_token}' for mask position {mask_pos}.") | |
except Exception as e: | |
logger.error(f"Error sampling for position {mask_pos}: {str(e)}") | |
continue | |
sampled_sentence = " ".join(words) | |
tqdm.write(f"[SamplingProcessor] Sampled sentence: {sampled_sentence}") | |
return sampled_sentence | |
def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0): | |
""" | |
Process all masked sentences in the results dictionary. | |
Args: | |
results_dict (dict): Dictionary containing masked sentences and their logits | |
sampling_technique (str): Sampling method to use | |
temperature (float): Temperature parameter for sampling | |
Returns: | |
dict: Dictionary containing original, masked, and sampled sentences | |
""" | |
tqdm.write("[SamplingProcessor] Starting sampling for masked sentences.") | |
processed_results = {} | |
# Wrap the iteration over each original sentence with tqdm | |
for original_sentence, data in tqdm(results_dict.items(), desc="Sampling Masked Sentences"): | |
masked_sentence = data["masked_sentence"] | |
mask_logits = data["mask_logits"] | |
sampled_sentence = self.sample_tokens(mask_logits, | |
masked_sentence, | |
sampling_technique, | |
temperature) | |
processed_results[original_sentence] = { | |
"masked_sentence": masked_sentence, | |
"sampled_sentence": sampled_sentence | |
} | |
logger.info(f"Processed sampling for sentence: {original_sentence}") | |
tqdm.write("[SamplingProcessor] Completed sampling for all sentences.") | |
return processed_results | |
if __name__ == "__main__": | |
sentences = [ | |
"The quick brown fox jumps over the lazy dog everyday.", | |
"A speedy brown fox jumps over a lazy dog.", | |
"A swift brown fox leaps over the lethargic dog." | |
] | |
result_dict = { | |
'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, | |
'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, | |
'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]} | |
} | |
# First, mask the sentences | |
masking_processor = MaskingProcessor() | |
masking_results = masking_processor.process_sentences(sentences, result_dict) | |
# Then, sample replacements for the masks | |
sampling_processor = SamplingProcessor(masking_processor.tokenizer) | |
# Try different sampling techniques | |
sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum", "Tournament"] | |
for technique in sampling_techniques: | |
logger.info(f"Sampling using technique: {technique}") | |
sampled_results = sampling_processor.process_masked_sentences( | |
masking_results, | |
sampling_technique=technique, | |
temperature=1.0 | |
) | |
''' | |
{ | |
"original_sentence_1": | |
{ | |
"masked_sentence": "sentence with [MASK] tokens", | |
"sampling_method1": "sentence with sampled tokens", | |
}, | |
"original_sentence_2": | |
{ | |
"masked_sentence": "sentence with [MASK] tokens", | |
"sampling_method": "sentence with sampled tokens" | |
}, | |
# ... and so on for each input sentence | |
}, | |
''' | |
for original_sentence, result in sampled_results.items(): | |
logger.info(f"Original: {original_sentence}") | |
logger.info(f"Masked: {result['masked_sentence']}") | |
logger.info(f"Sampled: {result['sampled_sentence']}") | |
logger.info("---") |