Spaces:
Sleeping
Sleeping
File size: 7,200 Bytes
adf0368 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import onnxruntime as ort
import numpy as np
import torch
import time
import argparse
from typing import Set, Optional
from .model import ByteTokenizer
sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"]
class DRYLogitsProcessor:
"""
Don't Repeat Yourself (DRY) Logits Processor that penalizes repetitive sequences.
"""
def __init__(
self,
multiplier: float = 0.5,
base: float = 2.0,
allowed_length: int = 1,
sequence_breakers: Optional[Set[int]] = None,
range: int = 512,
):
"""
Args:
multiplier: Base penalty multiplier
base: Exponential base for penalty calculation
allowed_length: Length of sequence that's allowed to repeat without penalty
sequence_breakers: Set of token IDs that should break sequence matching
range: Number of previous tokens to consider for repetition checking
"""
self.multiplier = multiplier
self.base = base
self.allowed_length = allowed_length
self.sequence_breakers = sequence_breakers or set()
self.range = range
def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray:
"""
Apply DRY penalty to logits.
Args:
input_ids: Array of shape (batch_size, seq_len)
scores: Array of shape (vocab_size,) with logits
Returns:
Modified scores with penalties applied
"""
if self.range > 0:
input_ids = input_ids[:, -self.range :]
# Convert to torch tensors for easier manipulation
input_tensor = torch.from_numpy(input_ids)
scores_tensor = torch.from_numpy(scores)
for input_ids_row in input_tensor:
# Raw integer must be extracted here to check for set membership
last_token = input_ids_row[-1].item()
if last_token in self.sequence_breakers:
continue
# Exclude the last token as it always matches
match_indices = (input_ids_row[:-1] == last_token).nonzero(as_tuple=False)
# Stores the maximum matching sequence length for each next token
match_lengths = {}
for i in match_indices.squeeze(1):
i = i.item()
if i + 1 >= len(input_ids_row):
continue
next_token = input_ids_row[i + 1].item()
if next_token in self.sequence_breakers:
continue
# We have already found that `last_token` matches at this index,
# so the match is at least of length 1.
match_length = 1
# Extend the match backwards as far as possible
while True:
j = i - match_length
if j < 0:
break # Start of input reached
if match_length + 1 > len(input_ids_row):
break # End of input reached
previous_token = input_ids_row[-(match_length + 1)].item()
if input_ids_row[j] != previous_token:
break # Start of match reached
if previous_token in self.sequence_breakers:
break # Sequence-breaking token reached
match_length += 1
# Update the maximum match length for this next token
if match_length >= match_lengths.get(next_token, 0):
match_lengths[next_token] = match_length
# Apply penalties
for token, match_length in match_lengths.items():
if match_length >= self.allowed_length:
penalty = self.multiplier * (
self.base ** (match_length - self.allowed_length)
)
scores_tensor[token] -= penalty
return scores_tensor.numpy()
def generate_text(
session,
tokenizer,
prompt,
max_new_tokens=100,
temperature=0.8,
top_k=25, # There are only 256 bytes total
stop_sequences=None,
dry_multiplier: float = 0.0, # Set to 0 to disable DRY by default
dry_base: float = 2.0,
dry_allowed_length: int = 20, # 20 since this is byte level.
dry_sequence_breakers: Optional[Set[int]] = None,
dry_range: int = 512,
):
"""Generate text using an ONNX model with DRY sampling and stop sequences."""
input_ids_list = tokenizer.encode(prompt.encode("utf-8"), add_special_tokens=False)
input_ids = np.array([input_ids_list], dtype=np.int64)
generated_token_ids = []
start_time = time.time()
for _ in range(max_new_tokens):
seq_len = input_ids.shape[1]
# Create a causal mask for the current sequence length.
causal_mask = np.triu(np.ones((1, seq_len, seq_len), dtype=np.bool_), k=1)
attn_mask = np.zeros((1, seq_len, seq_len), dtype=np.float32)
attn_mask[causal_mask] = -np.inf
ort_inputs = {"input_ids": input_ids, "attn_mask": attn_mask}
try:
ort_outs = session.run(None, ort_inputs)
except Exception as e:
print(f"ONNX Runtime Error: {e}")
# Potentially return or handle the error gracefully
return "[ONNX Error]", 0
logits = ort_outs[0][0, -1, :]
# Apply DRY penalty if enabled
if dry_multiplier > 0:
dry_processor = DRYLogitsProcessor(
multiplier=dry_multiplier,
base=dry_base,
allowed_length=dry_allowed_length,
sequence_breakers=dry_sequence_breakers,
range=dry_range,
)
logits = dry_processor(input_ids, logits)
# Apply temperature scaling
logits = logits / temperature
# Apply top-k filtering
if top_k > 0:
top_k = min(top_k, logits.shape[-1])
indices_to_remove = logits.argsort()[:-top_k]
logits[indices_to_remove] = -float("inf")
# Sample from the distribution
probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy()
next_token_id = np.random.choice(len(probs), p=probs)
if next_token_id == tokenizer.im_end_id:
break
input_ids = np.append(input_ids, [[next_token_id]], axis=1)
generated_token_ids.append(next_token_id)
if stop_sequences:
current_output = tokenizer.decode(np.array(generated_token_ids))
stop_generation = False
for seq in stop_sequences:
if current_output.endswith(seq):
stop_generation = True
# Remove the stop sequence from the generated text
generated_token_ids = generated_token_ids[: -len(seq)]
current_output = tokenizer.decode(np.array(generated_token_ids))
break
if stop_generation:
break
final_text = tokenizer.decode(np.array(generated_token_ids))
tps = len(generated_token_ids) / (time.time() - start_time)
return final_text, tps
|