File size: 10,914 Bytes
7e6946d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import torch
from torch import nn


class Sampler(nn.Module):
    """
    Optimized sampler implementation using vectorized operations instead of loops, significantly improving performance

    Performance optimizations:
    1. Using batch processing instead of sequence loops, reducing Python loop overhead
    2. Using PyTorch's vectorized operations (like torch.sort, torch.gather) for parallel computation
    3. Using mask operations to apply top-k filtering at once, avoiding per-sequence processing
    """
    def __init__(self):
        super().__init__()

    def forward(self, logits: torch.Tensor, temperatures: torch.Tensor, top_k: int = None):
        """
        Perform sampling operation using vectorized method for top-k filtering

        Args:
            logits: Logits tensor with shape [batch_size, vocab_size]
            temperatures: Temperature parameters with shape [batch_size]
            top_k: Top-k value for filtering (uniform across all sequences)

        Returns:
            Sampled token IDs
        """
        logits = logits.to(torch.float)
        greedy_tokens = logits.argmax(dim=-1)  # Greedy decoding result, used when temperature=0
        logits.div_(temperatures.unsqueeze(dim=1))  # Apply temperature scaling

        # Apply uniform top-k filtering if top_k is provided
        if top_k is not None and top_k > 0:
            vocab_size = logits.size(-1)

            # Create a mask to store which positions should be kept
            mask = torch.zeros_like(logits, dtype=torch.bool)

            # Batch sorting for all sequences at once
            sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)

            # Get threshold for each sequence (the k-th largest value)
            k_value = min(top_k, vocab_size)  # Ensure k doesn't exceed vocab size
            thresholds = sorted_logits[:, k_value-1:k_value]  # Shape [batch_size, 1]
            thresholds = thresholds.expand(-1, vocab_size)    # Expand to match logits shape

            # Create mask: only keep logits greater than or equal to threshold
            mask = logits >= thresholds

            # Apply mask: set logits not in top-k to negative infinity
            logits = torch.where(mask, logits, torch.tensor(float('-inf'), device=logits.device))

        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
        # logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
        sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
        return torch.where(temperatures == 0, greedy_tokens, sample_tokens)


class RasSampler(nn.Module):
    """
    Optimized Repetition Aware Sampling implementation

    Performance optimizations:
    1. Using vectorized nucleus sampling instead of loop implementation, improving sampling efficiency
    2. Using tensor operations to calculate repetition rate, reducing Python loop overhead
    3. Optimizing EOS handling logic, reducing unnecessary resampling
    4. Using PyTorch's vectorized operations for parallel computation
    5. Batch processing for all sequences, dramatically improving throughput
    6. Robust handling for sequences of any length, including empty sequences
    """
    def __init__(self):
        super().__init__()

    def forward(self, logits: torch.Tensor, decoded_tokens_list: list,
                win_size: int = 10, tau_r: float = 0.1,
                top_p: float = 0.8, top_k: int = 25,
                eos_token: int = 6561, min_tokens: list[int] = None):
        """
        Execute repetition-aware sampling using optimized vectorized operations with batch processing

        Args:
            logits: Input logits with shape [batch_size, vocab_size]
            decoded_tokens_list: List of decoded tokens, each element is a token list for a batch
            win_size: Window size for repetition detection (uniform across all batch items)
            tau_r: Repetition threshold (uniform across all batch items)
            top_p: Nucleus sampling probability threshold (uniform across all batch items)
            top_k: Nucleus sampling top-k threshold (uniform across all batch items)
            eos_token: End of sequence token ID (uniform across all batch items)
            min_tokens: List of minimum tokens to generate before allowing EOS, one per batch item
        Returns:
            Selected token IDs
        """
        batch_size = logits.size(0)
        device = logits.device
        result = torch.zeros(batch_size, dtype=torch.long, device=device)

        # Set default values if not provided
        if min_tokens is None:
            min_tokens = [2] * batch_size

        # Ensure min_tokens list has the correct length
        assert len(min_tokens) == batch_size, f"min_tokens length {len(min_tokens)} != batch_size {batch_size}"

        # Force continue decode first token
        for i in range(batch_size):
            if i < len(decoded_tokens_list) and len(decoded_tokens_list[i]) == 0:
                logits[i, eos_token] = -float('inf')

        # 1. First, perform nucleus sampling for all sequences
        probs = torch.softmax(logits, dim=-1)

        # Use vectorized nucleus sampling for all sequences
        # This can be done in batch since top_p and top_k are uniform
        sorted_probs, sorted_indices = probs.sort(dim=-1, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # Create masks for top-p and top-k filtering
        top_p_mask = cumulative_probs <= top_p

        # Create top-k mask (first top_k positions are True)
        top_k_mask = torch.zeros_like(top_p_mask)
        top_k_mask[:, :top_k] = True

        # Combine masks
        mask = top_p_mask & top_k_mask

        # Ensure at least one token is selected per sequence
        first_token_mask = torch.zeros_like(mask)
        first_token_mask[:, 0] = True
        mask = mask | first_token_mask

        # Sample from the filtered distribution
        sample_probs = torch.where(mask, sorted_probs, torch.zeros_like(sorted_probs))
        sample_probs = sample_probs / sample_probs.sum(dim=-1, keepdim=True)

        # Sample indices from the filtered distribution
        sampled_indices = torch.multinomial(sample_probs, 1).squeeze(-1)
        top_ids = torch.gather(sorted_indices, -1, sampled_indices.unsqueeze(-1)).squeeze(-1)

        # 2. Check for repetitions and apply random sampling if needed
        # Extract recent tokens for each sequence, handling empty or short sequences
        recent_tokens_list = []
        for i in range(batch_size):
            # Handle index out of range or empty tokens
            if i < len(decoded_tokens_list):
                tokens = decoded_tokens_list[i]
                if len(tokens) > 0:
                    start_idx = max(0, len(tokens) - win_size)
                    recent_tokens_list.append(tokens[start_idx:])
                else:
                    recent_tokens_list.append([])  # Empty list for empty tokens
            else:
                recent_tokens_list.append([])  # Empty list for missing batch items

        # Check if we have any tokens to process for repetition detection
        if any(len(tokens) > 0 for tokens in recent_tokens_list):
            # Convert to padded tensor for batch processing
            max_recent_len = max(len(tokens) for tokens in recent_tokens_list)
            if max_recent_len > 0:  # Only proceed if we have tokens
                recent_tokens_tensor = torch.zeros((batch_size, max_recent_len), dtype=torch.long, device=device) - 1
                for i, tokens in enumerate(recent_tokens_list):
                    if len(tokens) > 0:
                        recent_tokens_tensor[i, -len(tokens):] = torch.tensor(tokens, device=device)

                # Create a mask for valid positions and to avoid division by zero
                valid_positions_mask = torch.zeros_like(recent_tokens_tensor, dtype=torch.bool)
                for i, tokens in enumerate(recent_tokens_list):
                    if len(tokens) > 0:
                        valid_positions_mask[i, -len(tokens):] = True

                # Check repetition rates
                repetition_counts = torch.zeros(batch_size, device=device)
                for i in range(batch_size):
                    if len(recent_tokens_list[i]) > 0:
                        repetition_counts[i] = (recent_tokens_tensor[i] == top_ids[i]).sum()

                # Calculate repetition rates, avoiding division by zero
                recent_lengths = torch.tensor([max(1, len(tokens)) for tokens in recent_tokens_list], device=device)
                repetition_rates = repetition_counts / recent_lengths

                # Identify sequences needing random sampling
                need_random = repetition_rates >= tau_r

                # Apply random sampling where needed
                if need_random.any():
                    random_indices = torch.multinomial(probs[need_random], 1).squeeze(-1)
                    top_ids[need_random] = random_indices

        # 3. Handle EOS tokens
        # Create mask for sequences that should ignore EOS tokens
        ignore_eos_mask = torch.zeros(batch_size, dtype=torch.bool, device=device)
        for i in range(batch_size):
            if i < len(decoded_tokens_list):
                ignore_eos_mask[i] = len(decoded_tokens_list[i]) < min_tokens[i]
            else:
                ignore_eos_mask[i] = True  # Default to ignoring EOS for missing sequences

        is_eos_mask = top_ids == eos_token
        need_resample = ignore_eos_mask & is_eos_mask

        # Resample for sequences that need it
        if need_resample.any():
            max_trials = 100
            for attempt in range(max_trials):
                # Break if no more resampling needed
                if not need_resample.any():
                    break

                # Sample new tokens for sequences that need resampling
                new_samples = torch.multinomial(probs[need_resample], 1).squeeze(-1)

                # Update top_ids with new samples
                top_ids[need_resample] = new_samples

                # Update which sequences still need resampling
                is_eos_mask = top_ids == eos_token
                need_resample = ignore_eos_mask & is_eos_mask

            # If still have EOS tokens that should be ignored, force them to be non-EOS
            if need_resample.any():
                # Force to a non-EOS token (e.g., the second most likely token)
                for i in range(batch_size):
                    if need_resample[i]:
                        # Get second most likely token (or first if only one token)
                        second_best_idx = 1 if sorted_indices.size(1) > 1 else 0
                        top_ids[i] = sorted_indices[i, second_best_idx]

        result = top_ids

        return result