Spaces:
Running
Running
| import torch | |
| from transformers import LogitsProcessor, AutoModelForCausalLM, AutoTokenizer | |
| # --- EXERCISE 1: La disparition (No 'e' or 'E) --- | |
| # --- Logits Processor to forbid specific tokens --- | |
| class ForbidTokensLogitsProcessor(LogitsProcessor): | |
| """Logits processor that sets forbidden token logits to -inf.""" | |
| def __init__(self, forbidden_token_ids): | |
| self.forbidden_token_ids = list(forbidden_token_ids) | |
| def __call__(self, input_ids, scores): | |
| scores[:, self.forbidden_token_ids] = float('-inf') | |
| return scores | |
| class LaDisparition: | |
| """Generate text without ever using the letter 'e' or 'E' using model.generate().""" | |
| def __init__(self, model, tokenizer, debug=False): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.debug = debug | |
| # Pre-calculate forbidden token IDs (tokens containing 'e', 'E', or non-ASCII) | |
| self.forbidden_token_ids = set() | |
| for token_id in range(len(tokenizer.get_vocab())): | |
| decoded = tokenizer.decode([token_id]) | |
| if 'e' in decoded.lower() or not all(ord(c) < 128 for c in decoded): | |
| self.forbidden_token_ids.add(token_id) | |
| self.processor = ForbidTokensLogitsProcessor(self.forbidden_token_ids) | |
| def __call__(self, prompt, max_tokens=30, beam_width=5): | |
| # Option 2: we use self.tokenizer.apply_chat_template to tokenize the prompt | |
| message = [{"role": "user", "content": prompt}] | |
| inputs = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(self.model.device) | |
| # Create an attention mask for the inputs | |
| attention_mask = torch.ones_like(inputs) | |
| prompt_length = inputs.shape[1] | |
| outputs = self.model.generate( | |
| inputs, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_tokens, | |
| num_beams=beam_width, | |
| logits_processor=[self.processor], | |
| do_sample=False | |
| ) | |
| # Return only the generated part | |
| generated_tokens = outputs[0][prompt_length:] | |
| return self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() | |
| # --- EXERCISE 2: The Toulouse Sequence --- | |
| class ForbidToulousePrefixLogitsProcessor(LogitsProcessor): | |
| """ | |
| When generating, we store the largest suffix since whitespace. | |
| We mask out all tokens that if added would lead to a prefix of "Toulouse" of length at least 4. | |
| """ | |
| def __init__(self, tokenizer): | |
| self.tokenizer = tokenizer | |
| self.forbidden_word = "toulouse" | |
| self.min_prefix_len = 4 | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
| current_sequence_ids = input_ids[0] | |
| # Decode the current sequence to find the last word | |
| decoded_sequence = self.tokenizer.decode(current_sequence_ids) | |
| # Find the start of the last word (suffix since the last non-alphabetical character) | |
| last_separator_idx = -1 | |
| for i in range(len(decoded_sequence) - 1, -1, -1): | |
| if not decoded_sequence[i].isalpha(): | |
| last_separator_idx = i | |
| break | |
| if last_separator_idx != -1: | |
| current_word_prefix = decoded_sequence[last_separator_idx + 1:] | |
| else: | |
| current_word_prefix = decoded_sequence | |
| # If the current word prefix is empty, we don't need to check anything yet | |
| if not current_word_prefix: | |
| return scores | |
| # print(f"Current word prefix: '{current_word_prefix}'") | |
| # Get the token IDs for the current word prefix to avoid re-tokenizing the whole sequence | |
| current_word_ids = self.tokenizer.encode(current_word_prefix, add_special_tokens=False) | |
| # Iterate over all possible next tokens | |
| for token_id in range(scores.shape[1]): | |
| # Create a hypothetical next word by adding the candidate token | |
| hypothetical_word_ids = current_word_ids + [token_id] | |
| hypothetical_word = self.tokenizer.decode(hypothetical_word_ids) | |
| # Check if the hypothetical word is a forbidden prefix | |
| # We check against the lowercase version for case-insensitivity | |
| if len(hypothetical_word) >= self.min_prefix_len and \ | |
| self.forbidden_word.startswith(hypothetical_word.lower()): | |
| scores[0, token_id] = float('-inf') | |
| # print(f"Forbidden prefix: '{hypothetical_word}'") | |
| return scores | |
| class ToulouseSequence: | |
| """Generate text without ever using the word 'Toulouse' using model.generate().""" | |
| def __init__(self, model, tokenizer, debug=False): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.debug = debug | |
| # Use the new processor for the "Toulouse" prefix strategy | |
| self.processor = ForbidToulousePrefixLogitsProcessor(self.tokenizer) | |
| def __call__(self, prompt, max_tokens=100): | |
| # Option 2: we use self.tokenizer.apply_chat_template to tokenize the prompt | |
| message = [{"role": "user", "content": prompt}] | |
| inputs = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(self.model.device) | |
| # Create an attention mask for the inputs | |
| attention_mask = torch.ones_like(inputs) | |
| prompt_length = inputs.shape[1] | |
| outputs = self.model.generate( | |
| inputs, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_tokens, | |
| logits_processor=[self.processor], | |
| do_sample=False | |
| ) | |
| # Return only the generated part | |
| generated_tokens = outputs[0][prompt_length:] | |
| return self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() | |
| if __name__ == "__main__": | |
| # NOTE: This block is for testing only. The evaluation server provides model and tokenizer. | |
| # SETUP | |
| MODEL_NAME = "HuggingFaceTB/SmolLM2-1.7B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.float32, device_map="auto") | |
| la_disparition_generator = LaDisparition(model, tokenizer) | |
| print("Ex 1 (No 'e'):", la_disparition_generator("Who are you?")) | |
| toulouse_sequence_generator = ToulouseSequence(model, tokenizer) | |
| print("Ex 2 (No 'Toulouse'):", toulouse_sequence_generator("Where is Toulouse?")) | |