lipogram_private / forbidden_solution.py
nathanael-fijalkow's picture
Improved logprob-based scoring
4d8bbd9
import torch
from transformers import LogitsProcessor, AutoModelForCausalLM, AutoTokenizer
# --- EXERCISE 1: La disparition (No 'e' or 'E) ---
# --- Logits Processor to forbid specific tokens ---
class ForbidTokensLogitsProcessor(LogitsProcessor):
"""Logits processor that sets forbidden token logits to -inf."""
def __init__(self, forbidden_token_ids):
self.forbidden_token_ids = list(forbidden_token_ids)
def __call__(self, input_ids, scores):
scores[:, self.forbidden_token_ids] = float('-inf')
return scores
class LaDisparition:
"""Generate text without ever using the letter 'e' or 'E' using model.generate()."""
def __init__(self, model, tokenizer, debug=False):
self.model = model
self.tokenizer = tokenizer
self.debug = debug
# Pre-calculate forbidden token IDs (tokens containing 'e', 'E', or non-ASCII)
self.forbidden_token_ids = set()
for token_id in range(len(tokenizer.get_vocab())):
decoded = tokenizer.decode([token_id])
if 'e' in decoded.lower() or not all(ord(c) < 128 for c in decoded):
self.forbidden_token_ids.add(token_id)
self.processor = ForbidTokensLogitsProcessor(self.forbidden_token_ids)
def __call__(self, prompt, max_tokens=30, beam_width=5):
# Option 2: we use self.tokenizer.apply_chat_template to tokenize the prompt
message = [{"role": "user", "content": prompt}]
inputs = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(self.model.device)
# Create an attention mask for the inputs
attention_mask = torch.ones_like(inputs)
prompt_length = inputs.shape[1]
outputs = self.model.generate(
inputs,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
num_beams=beam_width,
logits_processor=[self.processor],
do_sample=False
)
# Return only the generated part
generated_tokens = outputs[0][prompt_length:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# --- EXERCISE 2: The Toulouse Sequence ---
class ForbidToulousePrefixLogitsProcessor(LogitsProcessor):
"""
When generating, we store the largest suffix since whitespace.
We mask out all tokens that if added would lead to a prefix of "Toulouse" of length at least 4.
"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.forbidden_word = "toulouse"
self.min_prefix_len = 4
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
current_sequence_ids = input_ids[0]
# Decode the current sequence to find the last word
decoded_sequence = self.tokenizer.decode(current_sequence_ids)
# Find the start of the last word (suffix since the last non-alphabetical character)
last_separator_idx = -1
for i in range(len(decoded_sequence) - 1, -1, -1):
if not decoded_sequence[i].isalpha():
last_separator_idx = i
break
if last_separator_idx != -1:
current_word_prefix = decoded_sequence[last_separator_idx + 1:]
else:
current_word_prefix = decoded_sequence
# If the current word prefix is empty, we don't need to check anything yet
if not current_word_prefix:
return scores
# print(f"Current word prefix: '{current_word_prefix}'")
# Get the token IDs for the current word prefix to avoid re-tokenizing the whole sequence
current_word_ids = self.tokenizer.encode(current_word_prefix, add_special_tokens=False)
# Iterate over all possible next tokens
for token_id in range(scores.shape[1]):
# Create a hypothetical next word by adding the candidate token
hypothetical_word_ids = current_word_ids + [token_id]
hypothetical_word = self.tokenizer.decode(hypothetical_word_ids)
# Check if the hypothetical word is a forbidden prefix
# We check against the lowercase version for case-insensitivity
if len(hypothetical_word) >= self.min_prefix_len and \
self.forbidden_word.startswith(hypothetical_word.lower()):
scores[0, token_id] = float('-inf')
# print(f"Forbidden prefix: '{hypothetical_word}'")
return scores
class ToulouseSequence:
"""Generate text without ever using the word 'Toulouse' using model.generate()."""
def __init__(self, model, tokenizer, debug=False):
self.model = model
self.tokenizer = tokenizer
self.debug = debug
# Use the new processor for the "Toulouse" prefix strategy
self.processor = ForbidToulousePrefixLogitsProcessor(self.tokenizer)
def __call__(self, prompt, max_tokens=100):
# Option 2: we use self.tokenizer.apply_chat_template to tokenize the prompt
message = [{"role": "user", "content": prompt}]
inputs = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(self.model.device)
# Create an attention mask for the inputs
attention_mask = torch.ones_like(inputs)
prompt_length = inputs.shape[1]
outputs = self.model.generate(
inputs,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
logits_processor=[self.processor],
do_sample=False
)
# Return only the generated part
generated_tokens = outputs[0][prompt_length:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
if __name__ == "__main__":
# NOTE: This block is for testing only. The evaluation server provides model and tokenizer.
# SETUP
MODEL_NAME = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.float32, device_map="auto")
la_disparition_generator = LaDisparition(model, tokenizer)
print("Ex 1 (No 'e'):", la_disparition_generator("Who are you?"))
toulouse_sequence_generator = ToulouseSequence(model, tokenizer)
print("Ex 2 (No 'Toulouse'):", toulouse_sequence_generator("Where is Toulouse?"))