Spaces:
Sleeping
Sleeping
import onnxruntime as ort | |
import numpy as np | |
import torch | |
import time | |
import argparse | |
from typing import Set, Optional | |
from .model import ByteTokenizer | |
sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"] | |
class DRYLogitsProcessor: | |
""" | |
Don't Repeat Yourself (DRY) Logits Processor that penalizes repetitive sequences. | |
""" | |
def __init__( | |
self, | |
multiplier: float = 0.5, | |
base: float = 2.0, | |
allowed_length: int = 1, | |
sequence_breakers: Optional[Set[int]] = None, | |
range: int = 512, | |
): | |
""" | |
Args: | |
multiplier: Base penalty multiplier | |
base: Exponential base for penalty calculation | |
allowed_length: Length of sequence that's allowed to repeat without penalty | |
sequence_breakers: Set of token IDs that should break sequence matching | |
range: Number of previous tokens to consider for repetition checking | |
""" | |
self.multiplier = multiplier | |
self.base = base | |
self.allowed_length = allowed_length | |
self.sequence_breakers = sequence_breakers or set() | |
self.range = range | |
def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray: | |
""" | |
Apply DRY penalty to logits. | |
Args: | |
input_ids: Array of shape (batch_size, seq_len) | |
scores: Array of shape (vocab_size,) with logits | |
Returns: | |
Modified scores with penalties applied | |
""" | |
if self.range > 0: | |
input_ids = input_ids[:, -self.range :] | |
# Convert to torch tensors for easier manipulation | |
input_tensor = torch.from_numpy(input_ids) | |
scores_tensor = torch.from_numpy(scores) | |
for input_ids_row in input_tensor: | |
# Raw integer must be extracted here to check for set membership | |
last_token = input_ids_row[-1].item() | |
if last_token in self.sequence_breakers: | |
continue | |
# Exclude the last token as it always matches | |
match_indices = (input_ids_row[:-1] == last_token).nonzero(as_tuple=False) | |
# Stores the maximum matching sequence length for each next token | |
match_lengths = {} | |
for i in match_indices.squeeze(1): | |
i = i.item() | |
if i + 1 >= len(input_ids_row): | |
continue | |
next_token = input_ids_row[i + 1].item() | |
if next_token in self.sequence_breakers: | |
continue | |
# We have already found that `last_token` matches at this index, | |
# so the match is at least of length 1. | |
match_length = 1 | |
# Extend the match backwards as far as possible | |
while True: | |
j = i - match_length | |
if j < 0: | |
break # Start of input reached | |
if match_length + 1 > len(input_ids_row): | |
break # End of input reached | |
previous_token = input_ids_row[-(match_length + 1)].item() | |
if input_ids_row[j] != previous_token: | |
break # Start of match reached | |
if previous_token in self.sequence_breakers: | |
break # Sequence-breaking token reached | |
match_length += 1 | |
# Update the maximum match length for this next token | |
if match_length >= match_lengths.get(next_token, 0): | |
match_lengths[next_token] = match_length | |
# Apply penalties | |
for token, match_length in match_lengths.items(): | |
if match_length >= self.allowed_length: | |
penalty = self.multiplier * ( | |
self.base ** (match_length - self.allowed_length) | |
) | |
scores_tensor[token] -= penalty | |
return scores_tensor.numpy() | |
def generate_text( | |
session, | |
tokenizer, | |
prompt, | |
max_new_tokens=100, | |
temperature=0.8, | |
top_k=25, # There are only 256 bytes total | |
stop_sequences=None, | |
dry_multiplier: float = 0.0, # Set to 0 to disable DRY by default | |
dry_base: float = 2.0, | |
dry_allowed_length: int = 20, # 20 since this is byte level. | |
dry_sequence_breakers: Optional[Set[int]] = None, | |
dry_range: int = 512, | |
): | |
"""Generate text using an ONNX model with DRY sampling and stop sequences.""" | |
input_ids_list = tokenizer.encode(prompt.encode("utf-8"), add_special_tokens=False) | |
input_ids = np.array([input_ids_list], dtype=np.int64) | |
generated_token_ids = [] | |
start_time = time.time() | |
for _ in range(max_new_tokens): | |
seq_len = input_ids.shape[1] | |
# Create a causal mask for the current sequence length. | |
causal_mask = np.triu(np.ones((1, seq_len, seq_len), dtype=np.bool_), k=1) | |
attn_mask = np.zeros((1, seq_len, seq_len), dtype=np.float32) | |
attn_mask[causal_mask] = -np.inf | |
ort_inputs = {"input_ids": input_ids, "attn_mask": attn_mask} | |
try: | |
ort_outs = session.run(None, ort_inputs) | |
except Exception as e: | |
print(f"ONNX Runtime Error: {e}") | |
# Potentially return or handle the error gracefully | |
return "[ONNX Error]", 0 | |
logits = ort_outs[0][0, -1, :] | |
# Apply DRY penalty if enabled | |
if dry_multiplier > 0: | |
dry_processor = DRYLogitsProcessor( | |
multiplier=dry_multiplier, | |
base=dry_base, | |
allowed_length=dry_allowed_length, | |
sequence_breakers=dry_sequence_breakers, | |
range=dry_range, | |
) | |
logits = dry_processor(input_ids, logits) | |
# Apply temperature scaling | |
logits = logits / temperature | |
# Apply top-k filtering | |
if top_k > 0: | |
top_k = min(top_k, logits.shape[-1]) | |
indices_to_remove = logits.argsort()[:-top_k] | |
logits[indices_to_remove] = -float("inf") | |
# Sample from the distribution | |
probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy() | |
next_token_id = np.random.choice(len(probs), p=probs) | |
if next_token_id == tokenizer.im_end_id: | |
break | |
input_ids = np.append(input_ids, [[next_token_id]], axis=1) | |
generated_token_ids.append(next_token_id) | |
if stop_sequences: | |
current_output = tokenizer.decode(np.array(generated_token_ids)) | |
stop_generation = False | |
for seq in stop_sequences: | |
if current_output.endswith(seq): | |
stop_generation = True | |
# Remove the stop sequence from the generated text | |
generated_token_ids = generated_token_ids[: -len(seq)] | |
current_output = tokenizer.decode(np.array(generated_token_ids)) | |
break | |
if stop_generation: | |
break | |
final_text = tokenizer.decode(np.array(generated_token_ids)) | |
tps = len(generated_token_ids) / (time.time() - start_time) | |
return final_text, tps | |