File size: 26,320 Bytes
ac6279d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
import librosa
import torch
import torch.nn.functional as F
import math
import numpy as np
from typing import List, Tuple, Dict

from dataclasses import dataclass
from typing import List, Optional
from transformers.models.whisper.processing_whisper import WhisperProcessor

from ..dataset.chatml_dataset import ChatMLDatasetSample, RankedChatMLDatasetSampleTuple
from ..model.utils import build_delay_pattern_mask


def _ceil_to_nearest(n, round_to):
    return (n + round_to - 1) // round_to * round_to


@dataclass
class HiggsAudioBatchInput:
    input_ids: torch.LongTensor  # shape (bsz, seq_len).
    attention_mask: torch.Tensor  # shape (bsz, seq_len).
    audio_features: Optional[torch.Tensor]  # shape (num_audio_in, feature_dim, max_mel_seq_len).
    audio_feature_attention_mask: Optional[torch.Tensor]  # shape (num_audio_in, max_mel_seq_len).
    audio_out_ids: Optional[torch.LongTensor]  # shape (num_codebooks, audio_out_total_length)
    audio_out_ids_start: Optional[torch.LongTensor]  # shape (num_audio_out,)
    # The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment
    # Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information
    # For example,
    #  audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples.
    #  This is a batch of 3 samples, then we will have the group location as:
    #  audio_out_ids_start_group_loc = [0, 0, 1, 2]
    audio_out_ids_start_group_loc: Optional[
        torch.LongTensor
    ]  # shape (num_audio_out,), specify which a sample's group location in the batch
    audio_in_ids: Optional[torch.LongTensor]  # shape (num_codebooks, audio_in_total_length)
    audio_in_ids_start: Optional[torch.LongTensor]  # shape (num_audio_in,)
    label_ids: Optional[torch.LongTensor]  # shape (bsz, seq_len)
    label_audio_ids: Optional[torch.LongTensor]  # shape (num_codebooks, audio_out_total_length)
    reward: Optional[float] = None


class HiggsAudioSampleCollator:
    """Sample collator for Higgs-Audio model.

    Args:
        whisper_processor (WhisperProcessor): The whisper processor.
        audio_in_token_id (int): The token id for audio-in.
        audio_out_token_id (int): The token id for audio-out.
        pad_token_id (int): The token id for padding.
        audio_stream_bos_id (int): The token id for audio-stream beginning of sentence.
        audio_stream_eos_id (int): The token id for audio-stream end of sentence.
        round_to (int): The round-to value.
        pad_left (bool): Whether to pad left.
        return_audio_in_tokens (bool): Whether to return audio-in tokens.
        use_delay_pattern (bool): Whether to use delay pattern.
        disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes.
        chunk_size_seconds (int): The chunk size in seconds.
        add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks.
        mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>.

    """

    def __init__(
        self,
        whisper_processor: WhisperProcessor,
        audio_in_token_id,
        audio_out_token_id,
        pad_token_id,
        audio_stream_bos_id,
        audio_stream_eos_id,
        round_to=8,
        pad_left=False,
        encode_whisper_embed=True,
        return_audio_in_tokens=True,
        audio_num_codebooks=None,
        use_delay_pattern=False,
        disable_audio_codes_transform=False,
        chunk_size_seconds=30,  # Maximum duration for each chunk
        add_new_bos_eos_for_long_chunk=True,
        mask_audio_out_token_label=True,
    ):
        self.whisper_processor = whisper_processor
        self.round_to = round_to
        self.pad_left = pad_left
        self.audio_in_token_id = audio_in_token_id
        self.audio_out_token_id = audio_out_token_id
        self.audio_stream_bos_id = audio_stream_bos_id
        self.audio_stream_eos_id = audio_stream_eos_id
        self.pad_token_id = pad_token_id
        self.encode_whisper_embed = encode_whisper_embed
        self.return_audio_in_tokens = return_audio_in_tokens
        self.audio_num_codebooks = audio_num_codebooks
        self.use_delay_pattern = use_delay_pattern
        if encode_whisper_embed:
            self.chunk_size_seconds = chunk_size_seconds
            self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate)
        else:
            self.chunk_size_seconds = None
            self.chunk_size_samples = None
        self.disable_audio_codes_transform = disable_audio_codes_transform
        self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk
        self.mask_audio_out_token_label = mask_audio_out_token_label

    def _process_and_duplicate_audio_tokens(
        self,
        input_ids: torch.Tensor,
        audio_idx: int,
        wv: torch.Tensor,
        sr: int,
        labels: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """Process long audio and duplicate corresponding audio tokens.

        Args:
            input_ids: Input token ids
            audio_idx: Index of the audio token in the sequence
            wv: Audio waveform
            sr: Sample rate
            labels: Optional label ids to be duplicated alongside input ids

        Returns:
            Tuple of:
                - New input ids with duplicated audio tokens
                - New label ids (if labels were provided) or None
                - Number of chunks created
        """
        # Calculate number of chunks needed
        total_samples = len(wv)
        num_chunks = math.ceil(total_samples / self.chunk_size_samples)

        if num_chunks <= 1:
            return input_ids, labels, 1

        # Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|>
        audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2]
        # Duplicate sequence for each chunk
        duplicated_sequence = audio_token_seq.repeat(num_chunks)

        # Create new input_ids with duplicated tokens
        new_input_ids = torch.cat(
            [
                input_ids[: audio_idx - 1],
                duplicated_sequence,
                input_ids[audio_idx + 2 :],
            ]
        )

        # If labels are provided, duplicate them as well
        new_labels = None
        if labels is not None:
            label_seq = labels[audio_idx - 1 : audio_idx + 2]
            duplicated_labels = label_seq.repeat(num_chunks)
            new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]])

        return new_input_ids, new_labels, num_chunks

    def __call__(self, batch: List[ChatMLDatasetSample]):
        """Collate the input data with support for long audio processing."""

        label_ids = None
        label_audio_ids = None
        if all([ele.label_ids is None for ele in batch]):
            return_labels = False
        else:
            return_labels = True

        if self.encode_whisper_embed:
            # Process each sample in the batch to handle long audio
            # TODO(?) The implementation here can be optimized.
            processed_batch = []
            for i in range(len(batch)):
                sample = batch[i]
                audio_in_mask = sample.input_ids == self.audio_in_token_id
                audio_in_indices = torch.where(audio_in_mask)[0]
                audio_out_mask = sample.input_ids == self.audio_out_token_id

                # Process each audio token and duplicate if needed
                modified_input_ids = sample.input_ids
                modified_labels = sample.label_ids if return_labels else None
                modified_waveforms_concat = []
                modified_waveforms_start = []
                modified_sample_rate = []
                offset = 0  # Track position changes from duplicating tokens
                curr_wv_offset = 0

                # Process input audio tokens
                for idx, audio_idx in enumerate(audio_in_indices):
                    # Get the audio for this token
                    wv, sr = sample.get_wv(idx)  # Use idx since we want the original audio index
                    if sr != self.whisper_processor.feature_extractor.sampling_rate:
                        resampled_wv = librosa.resample(
                            wv.cpu().numpy(),
                            orig_sr=sr,
                            target_sr=self.whisper_processor.feature_extractor.sampling_rate,
                        )
                    else:
                        resampled_wv = wv.cpu().numpy()
                    wv = torch.tensor(resampled_wv, device=wv.device)
                    sr = self.whisper_processor.feature_extractor.sampling_rate

                    # Process and duplicate tokens if necessary
                    token_pos = audio_idx + offset
                    modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens(
                        modified_input_ids, token_pos, wv, sr, modified_labels
                    )

                    # Update audio data
                    for chunk_idx in range(num_chunks):
                        chunk_start = chunk_idx * self.chunk_size_samples
                        chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv))
                        chunk_wv = wv[chunk_start:chunk_end]
                        modified_waveforms_concat.append(chunk_wv)
                        modified_waveforms_start.append(curr_wv_offset)
                        curr_wv_offset += len(chunk_wv)
                        modified_sample_rate.append(sr)

                    # Update offset for next iteration
                    offset += (num_chunks - 1) * 3  # Each new chunk adds 3 more tokens

                # Create new sample with modified tokens and audio data
                processed_sample = ChatMLDatasetSample(
                    input_ids=modified_input_ids,
                    label_ids=modified_labels if return_labels else sample.label_ids,
                    audio_ids_concat=sample.audio_ids_concat,
                    audio_ids_start=sample.audio_ids_start,
                    audio_waveforms_concat=torch.cat(modified_waveforms_concat)
                    if modified_waveforms_concat
                    else sample.audio_waveforms_concat,
                    audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long)
                    if modified_waveforms_start
                    else sample.audio_waveforms_start,
                    audio_sample_rate=torch.tensor(modified_sample_rate)
                    if modified_sample_rate
                    else sample.audio_sample_rate,
                    audio_speaker_indices=torch.tensor([]),
                    # FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat.
                    audio_label_ids_concat=sample.audio_label_ids_concat,
                )
                # audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0])
                # assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}"
                processed_batch.append(processed_sample)
        else:
            processed_batch = batch

        # Get the max sequence length based on processed batch
        max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to)

        # Get the ids for audio-in and audio-out for each batch
        audio_in_wv_l = []
        audio_in_ids_l = []
        audio_out_ids_l = []
        audio_out_ids_group_loc_l = []
        audio_in_label_ids_l = None
        audio_out_label_ids_l = None
        reward_l = []

        if return_labels:
            audio_out_no_train_flag = []  # Whether the audio-out data should be trained on or not.

        # Process the audio inputs and outputs
        for i in range(len(processed_batch)):
            audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id
            audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id
            audio_ids = torch.ones_like(processed_batch[i].input_ids)
            audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1
            audio_in_ids = audio_ids[audio_in_mask]
            audio_out_ids = audio_ids[audio_out_mask]

            if return_labels:
                audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0)
                if self.mask_audio_out_token_label:
                    processed_batch[i].label_ids[audio_out_mask] = -100

            # Process audio inputs
            if self.return_audio_in_tokens:
                audio_in_ids_l.extend(
                    [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids]
                )
                if processed_batch[i].audio_label_ids_concat is not None:
                    if audio_in_label_ids_l is None:
                        audio_in_label_ids_l = []
                    audio_in_label_ids_l.extend(
                        [
                            processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
                            for idx in audio_in_ids
                        ]
                    )

            audio_out_ids_l.extend(
                [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids]
            )
            audio_out_ids_group_loc_l.append(i)
            if processed_batch[i].reward is not None:
                reward_l.append(processed_batch[i].reward)

            if processed_batch[i].audio_label_ids_concat is not None:
                if audio_out_label_ids_l is None:
                    audio_out_label_ids_l = []
                audio_out_label_ids_l.extend(
                    [
                        processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
                        for idx in audio_out_ids
                    ]
                )

            if self.encode_whisper_embed:
                for idx in audio_in_ids:
                    wv, sr = processed_batch[i].get_wv(idx)
                    resampled_wv = wv.cpu().numpy()
                    # Split long audio into chunks
                    total_samples = len(resampled_wv)
                    for chunk_start in range(0, total_samples, self.chunk_size_samples):
                        chunk_end = min(chunk_start + self.chunk_size_samples, total_samples)
                        chunk = resampled_wv[chunk_start:chunk_end]
                        audio_in_wv_l.append(chunk)
            # assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \
            #     f"Assertion failed: Mismatch in number of audios. " \
            #     f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}."

        if return_labels:
            audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)

        # Process all audio features
        if len(audio_in_wv_l) > 0:
            feature_ret = self.whisper_processor.feature_extractor(
                audio_in_wv_l,
                sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
                return_attention_mask=True,
                padding="max_length",
            )
            audio_features = torch.from_numpy(feature_ret["input_features"])
            audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"])
        else:
            if self.encode_whisper_embed:
                audio_features = torch.zeros(
                    (
                        0,
                        self.whisper_processor.feature_extractor.feature_size,
                        self.whisper_processor.feature_extractor.nb_max_frames,
                    ),
                    dtype=torch.float32,
                )
                audio_feature_attention_mask = torch.zeros(
                    (0, self.whisper_processor.feature_extractor.nb_max_frames),
                    dtype=torch.int32,
                )
            else:
                audio_features = None
                audio_feature_attention_mask = None

        # Process audio input tokens
        if len(audio_in_ids_l) > 0:
            # Append audio-stream-bos and eos tokens
            new_audio_in_ids_l = []
            for ele in audio_in_ids_l:
                if self.disable_audio_codes_transform:
                    # Do not add audio-stream-bos or eos tokens.
                    # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
                    audio_codes = ele
                else:
                    audio_codes = torch.cat(
                        [
                            torch.full(
                                (ele.shape[0], 1),
                                self.audio_stream_bos_id,
                                dtype=torch.long,
                            ),
                            ele,
                            torch.full(
                                (ele.shape[0], 1),
                                self.audio_stream_eos_id,
                                dtype=torch.long,
                            ),
                        ],
                        dim=1,
                    )
                    if self.use_delay_pattern:
                        audio_codes = build_delay_pattern_mask(
                            audio_codes.unsqueeze(0),
                            bos_token_id=self.audio_stream_bos_id,
                            pad_token_id=self.audio_stream_eos_id,
                        )[0].squeeze(0)
                new_audio_in_ids_l.append(audio_codes)
            audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long()
            audio_in_ids_start = torch.cumsum(
                torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]),
                dim=0,
            )
        else:
            audio_in_ids = torch.zeros((0, 0), dtype=torch.long)
            audio_in_ids_start = torch.zeros(0, dtype=torch.long)

        # Process audio output tokens
        audio_out_ids_start_group_loc = None
        if len(audio_out_ids_l) > 0:
            new_audio_out_ids_l = []
            label_audio_ids_l = []
            for idx, ele in enumerate(audio_out_ids_l):
                if self.disable_audio_codes_transform:
                    # Do not add audio-stream-bos or eos tokens.
                    # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
                    audio_codes = ele
                    if return_labels:
                        label_audio_ids = audio_out_label_ids_l[idx]
                else:
                    audio_codes = torch.cat(
                        [
                            torch.full(
                                (ele.shape[0], 1),
                                self.audio_stream_bos_id,
                                dtype=torch.long,
                            ),
                            ele,
                            torch.full(
                                (ele.shape[0], 1),
                                self.audio_stream_eos_id,
                                dtype=torch.long,
                            ),
                        ],
                        dim=1,
                    )
                    if return_labels:
                        label_audio_ids = torch.cat(
                            [
                                torch.full((ele.shape[0], 1), -100, dtype=torch.long),
                                ele,
                                torch.full(
                                    (ele.shape[0], 1),
                                    self.audio_stream_eos_id,
                                    dtype=torch.long,
                                ),
                            ],
                            dim=1,
                        )
                    if self.use_delay_pattern:
                        audio_codes = build_delay_pattern_mask(
                            audio_codes.unsqueeze(0),
                            bos_token_id=self.audio_stream_bos_id,
                            pad_token_id=self.audio_stream_eos_id,
                        )[0].squeeze(0)
                        if return_labels:
                            label_audio_ids = build_delay_pattern_mask(
                                label_audio_ids.unsqueeze(0),
                                bos_token_id=-100,
                                pad_token_id=-100,
                            )[0].squeeze(0)
                new_audio_out_ids_l.append(audio_codes)

                if return_labels:
                    if audio_out_no_train_flag[idx]:
                        label_audio_ids[:] = -100
                    label_audio_ids_l.append(label_audio_ids)

            audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long()
            if return_labels:
                label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long()
            audio_out_ids_start = torch.cumsum(
                torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]),
                dim=0,
            )
            audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long)
        else:
            audio_out_ids = torch.zeros((0, 0), dtype=torch.long)
            audio_out_ids_start = torch.zeros(0, dtype=torch.long)
            if return_labels:
                label_audio_ids = torch.zeros((0, 0), dtype=torch.long)

        reward = torch.tensor(reward_l, dtype=torch.float32)

        # Handle padding for input ids and attention mask
        if self.pad_left:
            input_ids = torch.stack(
                [
                    F.pad(
                        ele.input_ids,
                        (max_seq_length - len(ele.input_ids), 0),
                        value=self.pad_token_id,
                    )
                    for ele in processed_batch
                ]
            )
            if return_labels:
                label_ids = torch.stack(
                    [
                        F.pad(
                            ele.label_ids,
                            (max_seq_length - len(ele.label_ids), 0),
                            value=-100,
                        )
                        for ele in processed_batch
                    ]
                )
            attention_mask = torch.stack(
                [
                    F.pad(
                        torch.ones_like(ele.input_ids),
                        (max_seq_length - len(ele.input_ids), 0),
                        value=0,
                    )
                    for ele in processed_batch
                ]
            )
        else:
            input_ids = torch.stack(
                [
                    F.pad(
                        ele.input_ids,
                        (0, max_seq_length - len(ele.input_ids)),
                        value=self.pad_token_id,
                    )
                    for ele in processed_batch
                ]
            )
            if return_labels:
                label_ids = torch.stack(
                    [
                        F.pad(
                            ele.label_ids,
                            (0, max_seq_length - len(ele.label_ids)),
                            value=-100,
                        )
                        for ele in processed_batch
                    ]
                )
            attention_mask = torch.stack(
                [
                    F.pad(
                        torch.ones_like(ele.input_ids),
                        (0, max_seq_length - len(ele.input_ids)),
                        value=0,
                    )
                    for ele in processed_batch
                ]
            )

        if not self.return_audio_in_tokens:
            audio_in_ids = None
            audio_in_ids_start = None

        # Apply audio_num_codebooks limit if specified
        if self.audio_num_codebooks is not None:
            if audio_in_ids is not None:
                audio_in_ids = audio_in_ids[: self.audio_num_codebooks]
            if audio_out_ids is not None:
                audio_out_ids = audio_out_ids[: self.audio_num_codebooks]
            if label_audio_ids is not None:
                label_audio_ids = label_audio_ids[: self.audio_num_codebooks]

        return HiggsAudioBatchInput(
            input_ids=input_ids,
            attention_mask=attention_mask,
            audio_features=audio_features,
            audio_feature_attention_mask=audio_feature_attention_mask,
            audio_out_ids=audio_out_ids,
            audio_out_ids_start=audio_out_ids_start,
            audio_out_ids_start_group_loc=audio_out_ids_start_group_loc,
            audio_in_ids=audio_in_ids,
            audio_in_ids_start=audio_in_ids_start,
            label_ids=label_ids,
            label_audio_ids=label_audio_ids,
            reward=reward,
        )


class HiggsAudioDPOSamplesCollator(HiggsAudioSampleCollator):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, batch: List[RankedChatMLDatasetSampleTuple]) -> HiggsAudioBatchInput:
        # flatten ranked chatml samples
        chosen = []
        rejected = []

        for sample in batch:
            chosen.append(sample.max_score_sample())
            rejected.append(sample.min_score_sample())

        merged = chosen
        merged.extend(rejected)

        return super().__call__(batch=merged)