aiisc-watermarking-modelv3 / masking_methods_trial.py
jgyasu's picture
Upload folder using huggingface_hub
4506e19 verified
raw
history blame
7.48 kB
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import pipeline
import random
from nltk.corpus import stopwords
import nltk
nltk.download('stopwords')
import math
from vocabulary_split import split_vocabulary, filter_logits
import abc
from typing import List
# Load tokenizer and model for masked language model
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
# Get permissible vocabulary
permissible, _ = split_vocabulary(seed=42)
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])
def get_logits_for_mask(model, tokenizer, sentence):
inputs = tokenizer(sentence, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
mask_token_logits = logits[0, mask_token_index, :]
return mask_token_logits.squeeze()
# Abstract Masking Strategy
class MaskingStrategy(abc.ABC):
@abc.abstractmethod
def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
"""
Given a list of words, return the indices of words to mask.
"""
pass
# Specific Masking Strategies
class RandomNonStopwordMasking(MaskingStrategy):
def __init__(self, num_masks: int = 1):
self.num_masks = num_masks
self.stop_words = set(stopwords.words('english'))
def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
non_stop_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
if not non_stop_indices:
return []
num_masks = min(self.num_masks, len(non_stop_indices))
return random.sample(non_stop_indices, num_masks)
class HighEntropyMasking(MaskingStrategy):
def __init__(self, num_masks: int = 1):
self.num_masks = num_masks
def select_words_to_mask(self, words: List[str], sentence: str, model, tokenizer, permissible_indices) -> List[int]:
candidate_indices = [i for i, word in enumerate(words) if word.lower() not in set(stopwords.words('english'))]
if not candidate_indices:
return []
entropy_scores = {}
for idx in candidate_indices:
masked_sentence = ' '.join(words[:idx] + ['[MASK]'] + words[idx+1:])
logits = get_logits_for_mask(model, tokenizer, masked_sentence)
filtered_logits = filter_logits(logits, permissible_indices)
probs = torch.softmax(filtered_logits, dim=-1)
top_5_probs = probs.topk(5).values
entropy = -torch.sum(top_5_probs * torch.log(top_5_probs + 1e-10)).item()
entropy_scores[idx] = entropy
# Select top N indices with highest entropy
sorted_indices = sorted(entropy_scores, key=entropy_scores.get, reverse=True)
return sorted_indices[:self.num_masks]
class PseudoRandomNonStopwordMasking(MaskingStrategy):
def __init__(self, num_masks: int = 1, seed: int = 10):
self.num_masks = num_masks
self.seed = seed
self.stop_words = set(stopwords.words('english'))
def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
non_stop_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
if not non_stop_indices:
return []
random.seed(self.seed)
num_masks = min(self.num_masks, len(non_stop_indices))
return random.sample(non_stop_indices, num_masks)
class CompositeMaskingStrategy(MaskingStrategy):
def __init__(self, strategies: List[MaskingStrategy]):
self.strategies = strategies
def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
selected_indices = []
for strategy in self.strategies:
if isinstance(strategy, HighEntropyMasking):
selected = strategy.select_words_to_mask(words, **kwargs)
else:
selected = strategy.select_words_to_mask(words)
selected_indices.extend(selected)
return list(set(selected_indices)) # Remove duplicates
# Refactored mask_between_lcs function
def mask_between_lcs(sentence, lcs_points, masking_strategy: MaskingStrategy, model, tokenizer, permissible_indices):
words = sentence.split()
masked_indices = []
segments = []
# Define segments based on LCS points
previous = 0
for point in lcs_points:
if point > previous:
segments.append((previous, point))
previous = point + 1
if previous < len(words):
segments.append((previous, len(words)))
# Collect all indices to mask from each segment
for start, end in segments:
segment_words = words[start:end]
if isinstance(masking_strategy, HighEntropyMasking):
selected = masking_strategy.select_words_to_mask(segment_words, sentence, model, tokenizer, permissible_indices)
else:
selected = masking_strategy.select_words_to_mask(segment_words)
# Adjust indices relative to the whole sentence
for idx in selected:
masked_idx = start + idx
if masked_idx not in masked_indices:
masked_indices.append(masked_idx)
# Apply masking
for idx in masked_indices:
words[idx] = '[MASK]'
masked_sentence = ' '.join(words)
logits = get_logits_for_mask(model, tokenizer, masked_sentence)
# Process each masked token
top_words_list = []
logits_list = []
for i, idx in enumerate(masked_indices):
logits_i = logits[i]
if logits_i.dim() > 1:
logits_i = logits_i.squeeze()
filtered_logits_i = filter_logits(logits_i, permissible_indices)
logits_list.append(filtered_logits_i.tolist())
top_5_indices = filtered_logits_i.topk(5).indices.tolist()
top_words = [tokenizer.decode([i]) for i in top_5_indices]
top_words_list.append(top_words)
return masked_sentence, logits_list, top_words_list
# Example Usage
if __name__ == "__main__":
# Example sentence and LCS points
sentence = "This is a sample sentence with some LCS points"
lcs_points = [2, 5, 8] # Indices of LCS points
# Initialize masking strategies
random_non_stopword_strategy = RandomNonStopwordMasking(num_masks=1)
high_entropy_strategy = HighEntropyMasking(num_masks=1)
pseudo_random_strategy = PseudoRandomNonStopwordMasking(num_masks=1, seed=10)
composite_strategy = CompositeMaskingStrategy([
RandomNonStopwordMasking(num_masks=1),
HighEntropyMasking(num_masks=1)
])
# Choose a strategy
chosen_strategy = composite_strategy # You can choose any initialized strategy
# Apply masking
masked_sentence, logits_list, top_words_list = mask_between_lcs(
sentence,
lcs_points,
masking_strategy=chosen_strategy,
model=model,
tokenizer=tokenizer,
permissible_indices=permissible_indices
)
print("Masked Sentence:", masked_sentence)
for idx, top_words in enumerate(top_words_list):
print(f"Top words for mask {idx+1}:", top_words)