File size: 7,200 Bytes
adf0368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import onnxruntime as ort
import numpy as np
import torch
import time
import argparse
from typing import Set, Optional
from .model import ByteTokenizer

sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"]


class DRYLogitsProcessor:
    """
    Don't Repeat Yourself (DRY) Logits Processor that penalizes repetitive sequences.
    """

    def __init__(
        self,
        multiplier: float = 0.5,
        base: float = 2.0,
        allowed_length: int = 1,
        sequence_breakers: Optional[Set[int]] = None,
        range: int = 512,
    ):
        """
        Args:
            multiplier: Base penalty multiplier
            base: Exponential base for penalty calculation
            allowed_length: Length of sequence that's allowed to repeat without penalty
            sequence_breakers: Set of token IDs that should break sequence matching
            range: Number of previous tokens to consider for repetition checking
        """
        self.multiplier = multiplier
        self.base = base
        self.allowed_length = allowed_length
        self.sequence_breakers = sequence_breakers or set()
        self.range = range

    def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray:
        """
        Apply DRY penalty to logits.

        Args:
            input_ids: Array of shape (batch_size, seq_len)
            scores: Array of shape (vocab_size,) with logits

        Returns:
            Modified scores with penalties applied
        """
        if self.range > 0:
            input_ids = input_ids[:, -self.range :]

        # Convert to torch tensors for easier manipulation
        input_tensor = torch.from_numpy(input_ids)
        scores_tensor = torch.from_numpy(scores)

        for input_ids_row in input_tensor:
            # Raw integer must be extracted here to check for set membership
            last_token = input_ids_row[-1].item()

            if last_token in self.sequence_breakers:
                continue

            # Exclude the last token as it always matches
            match_indices = (input_ids_row[:-1] == last_token).nonzero(as_tuple=False)

            # Stores the maximum matching sequence length for each next token
            match_lengths = {}

            for i in match_indices.squeeze(1):
                i = i.item()
                if i + 1 >= len(input_ids_row):
                    continue

                next_token = input_ids_row[i + 1].item()

                if next_token in self.sequence_breakers:
                    continue

                # We have already found that `last_token` matches at this index,
                # so the match is at least of length 1.
                match_length = 1

                # Extend the match backwards as far as possible
                while True:
                    j = i - match_length
                    if j < 0:
                        break  # Start of input reached

                    if match_length + 1 > len(input_ids_row):
                        break  # End of input reached

                    previous_token = input_ids_row[-(match_length + 1)].item()
                    if input_ids_row[j] != previous_token:
                        break  # Start of match reached

                    if previous_token in self.sequence_breakers:
                        break  # Sequence-breaking token reached

                    match_length += 1

                # Update the maximum match length for this next token
                if match_length >= match_lengths.get(next_token, 0):
                    match_lengths[next_token] = match_length

            # Apply penalties
            for token, match_length in match_lengths.items():
                if match_length >= self.allowed_length:
                    penalty = self.multiplier * (
                        self.base ** (match_length - self.allowed_length)
                    )
                    scores_tensor[token] -= penalty

        return scores_tensor.numpy()


def generate_text(
    session,
    tokenizer,
    prompt,
    max_new_tokens=100,
    temperature=0.8,
    top_k=25,  # There are only 256 bytes total
    stop_sequences=None,
    dry_multiplier: float = 0.0,  # Set to 0 to disable DRY by default
    dry_base: float = 2.0,
    dry_allowed_length: int = 20,  # 20 since this is byte level.
    dry_sequence_breakers: Optional[Set[int]] = None,
    dry_range: int = 512,
):
    """Generate text using an ONNX model with DRY sampling and stop sequences."""
    input_ids_list = tokenizer.encode(prompt.encode("utf-8"), add_special_tokens=False)
    input_ids = np.array([input_ids_list], dtype=np.int64)

    generated_token_ids = []
    start_time = time.time()

    for _ in range(max_new_tokens):
        seq_len = input_ids.shape[1]

        # Create a causal mask for the current sequence length.
        causal_mask = np.triu(np.ones((1, seq_len, seq_len), dtype=np.bool_), k=1)
        attn_mask = np.zeros((1, seq_len, seq_len), dtype=np.float32)
        attn_mask[causal_mask] = -np.inf

        ort_inputs = {"input_ids": input_ids, "attn_mask": attn_mask}

        try:
            ort_outs = session.run(None, ort_inputs)
        except Exception as e:
            print(f"ONNX Runtime Error: {e}")
            # Potentially return or handle the error gracefully
            return "[ONNX Error]", 0

        logits = ort_outs[0][0, -1, :]

        # Apply DRY penalty if enabled
        if dry_multiplier > 0:
            dry_processor = DRYLogitsProcessor(
                multiplier=dry_multiplier,
                base=dry_base,
                allowed_length=dry_allowed_length,
                sequence_breakers=dry_sequence_breakers,
                range=dry_range,
            )
            logits = dry_processor(input_ids, logits)

        # Apply temperature scaling
        logits = logits / temperature

        # Apply top-k filtering
        if top_k > 0:
            top_k = min(top_k, logits.shape[-1])
            indices_to_remove = logits.argsort()[:-top_k]
            logits[indices_to_remove] = -float("inf")

        # Sample from the distribution
        probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy()
        next_token_id = np.random.choice(len(probs), p=probs)

        if next_token_id == tokenizer.im_end_id:
            break

        input_ids = np.append(input_ids, [[next_token_id]], axis=1)
        generated_token_ids.append(next_token_id)

        if stop_sequences:
            current_output = tokenizer.decode(np.array(generated_token_ids))
            stop_generation = False
            for seq in stop_sequences:
                if current_output.endswith(seq):
                    stop_generation = True
                    # Remove the stop sequence from the generated text
                    generated_token_ids = generated_token_ids[: -len(seq)]
                    current_output = tokenizer.decode(np.array(generated_token_ids))
                    break
            if stop_generation:
                break

    final_text = tokenizer.decode(np.array(generated_token_ids))
    tps = len(generated_token_ids) / (time.time() - start_time)
    return final_text, tps