import onnxruntime as ort import numpy as np import torch import time import argparse from typing import Set, Optional from .model import ByteTokenizer sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"] class DRYLogitsProcessor: """ Don't Repeat Yourself (DRY) Logits Processor that penalizes repetitive sequences. """ def __init__( self, multiplier: float = 0.5, base: float = 2.0, allowed_length: int = 1, sequence_breakers: Optional[Set[int]] = None, range: int = 512, ): """ Args: multiplier: Base penalty multiplier base: Exponential base for penalty calculation allowed_length: Length of sequence that's allowed to repeat without penalty sequence_breakers: Set of token IDs that should break sequence matching range: Number of previous tokens to consider for repetition checking """ self.multiplier = multiplier self.base = base self.allowed_length = allowed_length self.sequence_breakers = sequence_breakers or set() self.range = range def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray: """ Apply DRY penalty to logits. Args: input_ids: Array of shape (batch_size, seq_len) scores: Array of shape (vocab_size,) with logits Returns: Modified scores with penalties applied """ if self.range > 0: input_ids = input_ids[:, -self.range :] # Convert to torch tensors for easier manipulation input_tensor = torch.from_numpy(input_ids) scores_tensor = torch.from_numpy(scores) for input_ids_row in input_tensor: # Raw integer must be extracted here to check for set membership last_token = input_ids_row[-1].item() if last_token in self.sequence_breakers: continue # Exclude the last token as it always matches match_indices = (input_ids_row[:-1] == last_token).nonzero(as_tuple=False) # Stores the maximum matching sequence length for each next token match_lengths = {} for i in match_indices.squeeze(1): i = i.item() if i + 1 >= len(input_ids_row): continue next_token = input_ids_row[i + 1].item() if next_token in self.sequence_breakers: continue # We have already found that `last_token` matches at this index, # so the match is at least of length 1. match_length = 1 # Extend the match backwards as far as possible while True: j = i - match_length if j < 0: break # Start of input reached if match_length + 1 > len(input_ids_row): break # End of input reached previous_token = input_ids_row[-(match_length + 1)].item() if input_ids_row[j] != previous_token: break # Start of match reached if previous_token in self.sequence_breakers: break # Sequence-breaking token reached match_length += 1 # Update the maximum match length for this next token if match_length >= match_lengths.get(next_token, 0): match_lengths[next_token] = match_length # Apply penalties for token, match_length in match_lengths.items(): if match_length >= self.allowed_length: penalty = self.multiplier * ( self.base ** (match_length - self.allowed_length) ) scores_tensor[token] -= penalty return scores_tensor.numpy() def generate_text( session, tokenizer, prompt, max_new_tokens=100, temperature=0.8, top_k=25, # There are only 256 bytes total stop_sequences=None, dry_multiplier: float = 0.0, # Set to 0 to disable DRY by default dry_base: float = 2.0, dry_allowed_length: int = 20, # 20 since this is byte level. dry_sequence_breakers: Optional[Set[int]] = None, dry_range: int = 512, ): """Generate text using an ONNX model with DRY sampling and stop sequences.""" input_ids_list = tokenizer.encode(prompt.encode("utf-8"), add_special_tokens=False) input_ids = np.array([input_ids_list], dtype=np.int64) generated_token_ids = [] start_time = time.time() for _ in range(max_new_tokens): seq_len = input_ids.shape[1] # Create a causal mask for the current sequence length. causal_mask = np.triu(np.ones((1, seq_len, seq_len), dtype=np.bool_), k=1) attn_mask = np.zeros((1, seq_len, seq_len), dtype=np.float32) attn_mask[causal_mask] = -np.inf ort_inputs = {"input_ids": input_ids, "attn_mask": attn_mask} try: ort_outs = session.run(None, ort_inputs) except Exception as e: print(f"ONNX Runtime Error: {e}") # Potentially return or handle the error gracefully return "[ONNX Error]", 0 logits = ort_outs[0][0, -1, :] # Apply DRY penalty if enabled if dry_multiplier > 0: dry_processor = DRYLogitsProcessor( multiplier=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length, sequence_breakers=dry_sequence_breakers, range=dry_range, ) logits = dry_processor(input_ids, logits) # Apply temperature scaling logits = logits / temperature # Apply top-k filtering if top_k > 0: top_k = min(top_k, logits.shape[-1]) indices_to_remove = logits.argsort()[:-top_k] logits[indices_to_remove] = -float("inf") # Sample from the distribution probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy() next_token_id = np.random.choice(len(probs), p=probs) if next_token_id == tokenizer.im_end_id: break input_ids = np.append(input_ids, [[next_token_id]], axis=1) generated_token_ids.append(next_token_id) if stop_sequences: current_output = tokenizer.decode(np.array(generated_token_ids)) stop_generation = False for seq in stop_sequences: if current_output.endswith(seq): stop_generation = True # Remove the stop sequence from the generated text generated_token_ids = generated_token_ids[: -len(seq)] current_output = tokenizer.decode(np.array(generated_token_ids)) break if stop_generation: break final_text = tokenizer.decode(np.array(generated_token_ids)) tps = len(generated_token_ids) / (time.time() - start_time) return final_text, tps