DAT-Byte-Demo / inference /onnx_inference.py
hudsongouge's picture
Update space
adf0368
import onnxruntime as ort
import numpy as np
import torch
import time
import argparse
from typing import Set, Optional
from .model import ByteTokenizer
sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"]
class DRYLogitsProcessor:
"""
Don't Repeat Yourself (DRY) Logits Processor that penalizes repetitive sequences.
"""
def __init__(
self,
multiplier: float = 0.5,
base: float = 2.0,
allowed_length: int = 1,
sequence_breakers: Optional[Set[int]] = None,
range: int = 512,
):
"""
Args:
multiplier: Base penalty multiplier
base: Exponential base for penalty calculation
allowed_length: Length of sequence that's allowed to repeat without penalty
sequence_breakers: Set of token IDs that should break sequence matching
range: Number of previous tokens to consider for repetition checking
"""
self.multiplier = multiplier
self.base = base
self.allowed_length = allowed_length
self.sequence_breakers = sequence_breakers or set()
self.range = range
def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray:
"""
Apply DRY penalty to logits.
Args:
input_ids: Array of shape (batch_size, seq_len)
scores: Array of shape (vocab_size,) with logits
Returns:
Modified scores with penalties applied
"""
if self.range > 0:
input_ids = input_ids[:, -self.range :]
# Convert to torch tensors for easier manipulation
input_tensor = torch.from_numpy(input_ids)
scores_tensor = torch.from_numpy(scores)
for input_ids_row in input_tensor:
# Raw integer must be extracted here to check for set membership
last_token = input_ids_row[-1].item()
if last_token in self.sequence_breakers:
continue
# Exclude the last token as it always matches
match_indices = (input_ids_row[:-1] == last_token).nonzero(as_tuple=False)
# Stores the maximum matching sequence length for each next token
match_lengths = {}
for i in match_indices.squeeze(1):
i = i.item()
if i + 1 >= len(input_ids_row):
continue
next_token = input_ids_row[i + 1].item()
if next_token in self.sequence_breakers:
continue
# We have already found that `last_token` matches at this index,
# so the match is at least of length 1.
match_length = 1
# Extend the match backwards as far as possible
while True:
j = i - match_length
if j < 0:
break # Start of input reached
if match_length + 1 > len(input_ids_row):
break # End of input reached
previous_token = input_ids_row[-(match_length + 1)].item()
if input_ids_row[j] != previous_token:
break # Start of match reached
if previous_token in self.sequence_breakers:
break # Sequence-breaking token reached
match_length += 1
# Update the maximum match length for this next token
if match_length >= match_lengths.get(next_token, 0):
match_lengths[next_token] = match_length
# Apply penalties
for token, match_length in match_lengths.items():
if match_length >= self.allowed_length:
penalty = self.multiplier * (
self.base ** (match_length - self.allowed_length)
)
scores_tensor[token] -= penalty
return scores_tensor.numpy()
def generate_text(
session,
tokenizer,
prompt,
max_new_tokens=100,
temperature=0.8,
top_k=25, # There are only 256 bytes total
stop_sequences=None,
dry_multiplier: float = 0.0, # Set to 0 to disable DRY by default
dry_base: float = 2.0,
dry_allowed_length: int = 20, # 20 since this is byte level.
dry_sequence_breakers: Optional[Set[int]] = None,
dry_range: int = 512,
):
"""Generate text using an ONNX model with DRY sampling and stop sequences."""
input_ids_list = tokenizer.encode(prompt.encode("utf-8"), add_special_tokens=False)
input_ids = np.array([input_ids_list], dtype=np.int64)
generated_token_ids = []
start_time = time.time()
for _ in range(max_new_tokens):
seq_len = input_ids.shape[1]
# Create a causal mask for the current sequence length.
causal_mask = np.triu(np.ones((1, seq_len, seq_len), dtype=np.bool_), k=1)
attn_mask = np.zeros((1, seq_len, seq_len), dtype=np.float32)
attn_mask[causal_mask] = -np.inf
ort_inputs = {"input_ids": input_ids, "attn_mask": attn_mask}
try:
ort_outs = session.run(None, ort_inputs)
except Exception as e:
print(f"ONNX Runtime Error: {e}")
# Potentially return or handle the error gracefully
return "[ONNX Error]", 0
logits = ort_outs[0][0, -1, :]
# Apply DRY penalty if enabled
if dry_multiplier > 0:
dry_processor = DRYLogitsProcessor(
multiplier=dry_multiplier,
base=dry_base,
allowed_length=dry_allowed_length,
sequence_breakers=dry_sequence_breakers,
range=dry_range,
)
logits = dry_processor(input_ids, logits)
# Apply temperature scaling
logits = logits / temperature
# Apply top-k filtering
if top_k > 0:
top_k = min(top_k, logits.shape[-1])
indices_to_remove = logits.argsort()[:-top_k]
logits[indices_to_remove] = -float("inf")
# Sample from the distribution
probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy()
next_token_id = np.random.choice(len(probs), p=probs)
if next_token_id == tokenizer.im_end_id:
break
input_ids = np.append(input_ids, [[next_token_id]], axis=1)
generated_token_ids.append(next_token_id)
if stop_sequences:
current_output = tokenizer.decode(np.array(generated_token_ids))
stop_generation = False
for seq in stop_sequences:
if current_output.endswith(seq):
stop_generation = True
# Remove the stop sequence from the generated text
generated_token_ids = generated_token_ids[: -len(seq)]
current_output = tokenizer.decode(np.array(generated_token_ids))
break
if stop_generation:
break
final_text = tokenizer.decode(np.array(generated_token_ids))
tps = len(generated_token_ids) / (time.time() - start_time)
return final_text, tps