Spaces:
Running
on
Zero
Running
on
Zero
File size: 26,320 Bytes
ac6279d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 |
import librosa
import torch
import torch.nn.functional as F
import math
import numpy as np
from typing import List, Tuple, Dict
from dataclasses import dataclass
from typing import List, Optional
from transformers.models.whisper.processing_whisper import WhisperProcessor
from ..dataset.chatml_dataset import ChatMLDatasetSample, RankedChatMLDatasetSampleTuple
from ..model.utils import build_delay_pattern_mask
def _ceil_to_nearest(n, round_to):
return (n + round_to - 1) // round_to * round_to
@dataclass
class HiggsAudioBatchInput:
input_ids: torch.LongTensor # shape (bsz, seq_len).
attention_mask: torch.Tensor # shape (bsz, seq_len).
audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len).
audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len).
audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,)
# The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment
# Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information
# For example,
# audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples.
# This is a batch of 3 samples, then we will have the group location as:
# audio_out_ids_start_group_loc = [0, 0, 1, 2]
audio_out_ids_start_group_loc: Optional[
torch.LongTensor
] # shape (num_audio_out,), specify which a sample's group location in the batch
audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length)
audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,)
label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len)
label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
reward: Optional[float] = None
class HiggsAudioSampleCollator:
"""Sample collator for Higgs-Audio model.
Args:
whisper_processor (WhisperProcessor): The whisper processor.
audio_in_token_id (int): The token id for audio-in.
audio_out_token_id (int): The token id for audio-out.
pad_token_id (int): The token id for padding.
audio_stream_bos_id (int): The token id for audio-stream beginning of sentence.
audio_stream_eos_id (int): The token id for audio-stream end of sentence.
round_to (int): The round-to value.
pad_left (bool): Whether to pad left.
return_audio_in_tokens (bool): Whether to return audio-in tokens.
use_delay_pattern (bool): Whether to use delay pattern.
disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes.
chunk_size_seconds (int): The chunk size in seconds.
add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks.
mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>.
"""
def __init__(
self,
whisper_processor: WhisperProcessor,
audio_in_token_id,
audio_out_token_id,
pad_token_id,
audio_stream_bos_id,
audio_stream_eos_id,
round_to=8,
pad_left=False,
encode_whisper_embed=True,
return_audio_in_tokens=True,
audio_num_codebooks=None,
use_delay_pattern=False,
disable_audio_codes_transform=False,
chunk_size_seconds=30, # Maximum duration for each chunk
add_new_bos_eos_for_long_chunk=True,
mask_audio_out_token_label=True,
):
self.whisper_processor = whisper_processor
self.round_to = round_to
self.pad_left = pad_left
self.audio_in_token_id = audio_in_token_id
self.audio_out_token_id = audio_out_token_id
self.audio_stream_bos_id = audio_stream_bos_id
self.audio_stream_eos_id = audio_stream_eos_id
self.pad_token_id = pad_token_id
self.encode_whisper_embed = encode_whisper_embed
self.return_audio_in_tokens = return_audio_in_tokens
self.audio_num_codebooks = audio_num_codebooks
self.use_delay_pattern = use_delay_pattern
if encode_whisper_embed:
self.chunk_size_seconds = chunk_size_seconds
self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate)
else:
self.chunk_size_seconds = None
self.chunk_size_samples = None
self.disable_audio_codes_transform = disable_audio_codes_transform
self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk
self.mask_audio_out_token_label = mask_audio_out_token_label
def _process_and_duplicate_audio_tokens(
self,
input_ids: torch.Tensor,
audio_idx: int,
wv: torch.Tensor,
sr: int,
labels: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""Process long audio and duplicate corresponding audio tokens.
Args:
input_ids: Input token ids
audio_idx: Index of the audio token in the sequence
wv: Audio waveform
sr: Sample rate
labels: Optional label ids to be duplicated alongside input ids
Returns:
Tuple of:
- New input ids with duplicated audio tokens
- New label ids (if labels were provided) or None
- Number of chunks created
"""
# Calculate number of chunks needed
total_samples = len(wv)
num_chunks = math.ceil(total_samples / self.chunk_size_samples)
if num_chunks <= 1:
return input_ids, labels, 1
# Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|>
audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2]
# Duplicate sequence for each chunk
duplicated_sequence = audio_token_seq.repeat(num_chunks)
# Create new input_ids with duplicated tokens
new_input_ids = torch.cat(
[
input_ids[: audio_idx - 1],
duplicated_sequence,
input_ids[audio_idx + 2 :],
]
)
# If labels are provided, duplicate them as well
new_labels = None
if labels is not None:
label_seq = labels[audio_idx - 1 : audio_idx + 2]
duplicated_labels = label_seq.repeat(num_chunks)
new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]])
return new_input_ids, new_labels, num_chunks
def __call__(self, batch: List[ChatMLDatasetSample]):
"""Collate the input data with support for long audio processing."""
label_ids = None
label_audio_ids = None
if all([ele.label_ids is None for ele in batch]):
return_labels = False
else:
return_labels = True
if self.encode_whisper_embed:
# Process each sample in the batch to handle long audio
# TODO(?) The implementation here can be optimized.
processed_batch = []
for i in range(len(batch)):
sample = batch[i]
audio_in_mask = sample.input_ids == self.audio_in_token_id
audio_in_indices = torch.where(audio_in_mask)[0]
audio_out_mask = sample.input_ids == self.audio_out_token_id
# Process each audio token and duplicate if needed
modified_input_ids = sample.input_ids
modified_labels = sample.label_ids if return_labels else None
modified_waveforms_concat = []
modified_waveforms_start = []
modified_sample_rate = []
offset = 0 # Track position changes from duplicating tokens
curr_wv_offset = 0
# Process input audio tokens
for idx, audio_idx in enumerate(audio_in_indices):
# Get the audio for this token
wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index
if sr != self.whisper_processor.feature_extractor.sampling_rate:
resampled_wv = librosa.resample(
wv.cpu().numpy(),
orig_sr=sr,
target_sr=self.whisper_processor.feature_extractor.sampling_rate,
)
else:
resampled_wv = wv.cpu().numpy()
wv = torch.tensor(resampled_wv, device=wv.device)
sr = self.whisper_processor.feature_extractor.sampling_rate
# Process and duplicate tokens if necessary
token_pos = audio_idx + offset
modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens(
modified_input_ids, token_pos, wv, sr, modified_labels
)
# Update audio data
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * self.chunk_size_samples
chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv))
chunk_wv = wv[chunk_start:chunk_end]
modified_waveforms_concat.append(chunk_wv)
modified_waveforms_start.append(curr_wv_offset)
curr_wv_offset += len(chunk_wv)
modified_sample_rate.append(sr)
# Update offset for next iteration
offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens
# Create new sample with modified tokens and audio data
processed_sample = ChatMLDatasetSample(
input_ids=modified_input_ids,
label_ids=modified_labels if return_labels else sample.label_ids,
audio_ids_concat=sample.audio_ids_concat,
audio_ids_start=sample.audio_ids_start,
audio_waveforms_concat=torch.cat(modified_waveforms_concat)
if modified_waveforms_concat
else sample.audio_waveforms_concat,
audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long)
if modified_waveforms_start
else sample.audio_waveforms_start,
audio_sample_rate=torch.tensor(modified_sample_rate)
if modified_sample_rate
else sample.audio_sample_rate,
audio_speaker_indices=torch.tensor([]),
# FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat.
audio_label_ids_concat=sample.audio_label_ids_concat,
)
# audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0])
# assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}"
processed_batch.append(processed_sample)
else:
processed_batch = batch
# Get the max sequence length based on processed batch
max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to)
# Get the ids for audio-in and audio-out for each batch
audio_in_wv_l = []
audio_in_ids_l = []
audio_out_ids_l = []
audio_out_ids_group_loc_l = []
audio_in_label_ids_l = None
audio_out_label_ids_l = None
reward_l = []
if return_labels:
audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not.
# Process the audio inputs and outputs
for i in range(len(processed_batch)):
audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id
audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id
audio_ids = torch.ones_like(processed_batch[i].input_ids)
audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1
audio_in_ids = audio_ids[audio_in_mask]
audio_out_ids = audio_ids[audio_out_mask]
if return_labels:
audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0)
if self.mask_audio_out_token_label:
processed_batch[i].label_ids[audio_out_mask] = -100
# Process audio inputs
if self.return_audio_in_tokens:
audio_in_ids_l.extend(
[processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids]
)
if processed_batch[i].audio_label_ids_concat is not None:
if audio_in_label_ids_l is None:
audio_in_label_ids_l = []
audio_in_label_ids_l.extend(
[
processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
for idx in audio_in_ids
]
)
audio_out_ids_l.extend(
[processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids]
)
audio_out_ids_group_loc_l.append(i)
if processed_batch[i].reward is not None:
reward_l.append(processed_batch[i].reward)
if processed_batch[i].audio_label_ids_concat is not None:
if audio_out_label_ids_l is None:
audio_out_label_ids_l = []
audio_out_label_ids_l.extend(
[
processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
for idx in audio_out_ids
]
)
if self.encode_whisper_embed:
for idx in audio_in_ids:
wv, sr = processed_batch[i].get_wv(idx)
resampled_wv = wv.cpu().numpy()
# Split long audio into chunks
total_samples = len(resampled_wv)
for chunk_start in range(0, total_samples, self.chunk_size_samples):
chunk_end = min(chunk_start + self.chunk_size_samples, total_samples)
chunk = resampled_wv[chunk_start:chunk_end]
audio_in_wv_l.append(chunk)
# assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \
# f"Assertion failed: Mismatch in number of audios. " \
# f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}."
if return_labels:
audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
# Process all audio features
if len(audio_in_wv_l) > 0:
feature_ret = self.whisper_processor.feature_extractor(
audio_in_wv_l,
sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
return_attention_mask=True,
padding="max_length",
)
audio_features = torch.from_numpy(feature_ret["input_features"])
audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"])
else:
if self.encode_whisper_embed:
audio_features = torch.zeros(
(
0,
self.whisper_processor.feature_extractor.feature_size,
self.whisper_processor.feature_extractor.nb_max_frames,
),
dtype=torch.float32,
)
audio_feature_attention_mask = torch.zeros(
(0, self.whisper_processor.feature_extractor.nb_max_frames),
dtype=torch.int32,
)
else:
audio_features = None
audio_feature_attention_mask = None
# Process audio input tokens
if len(audio_in_ids_l) > 0:
# Append audio-stream-bos and eos tokens
new_audio_in_ids_l = []
for ele in audio_in_ids_l:
if self.disable_audio_codes_transform:
# Do not add audio-stream-bos or eos tokens.
# This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
audio_codes = ele
else:
audio_codes = torch.cat(
[
torch.full(
(ele.shape[0], 1),
self.audio_stream_bos_id,
dtype=torch.long,
),
ele,
torch.full(
(ele.shape[0], 1),
self.audio_stream_eos_id,
dtype=torch.long,
),
],
dim=1,
)
if self.use_delay_pattern:
audio_codes = build_delay_pattern_mask(
audio_codes.unsqueeze(0),
bos_token_id=self.audio_stream_bos_id,
pad_token_id=self.audio_stream_eos_id,
)[0].squeeze(0)
new_audio_in_ids_l.append(audio_codes)
audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long()
audio_in_ids_start = torch.cumsum(
torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]),
dim=0,
)
else:
audio_in_ids = torch.zeros((0, 0), dtype=torch.long)
audio_in_ids_start = torch.zeros(0, dtype=torch.long)
# Process audio output tokens
audio_out_ids_start_group_loc = None
if len(audio_out_ids_l) > 0:
new_audio_out_ids_l = []
label_audio_ids_l = []
for idx, ele in enumerate(audio_out_ids_l):
if self.disable_audio_codes_transform:
# Do not add audio-stream-bos or eos tokens.
# This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
audio_codes = ele
if return_labels:
label_audio_ids = audio_out_label_ids_l[idx]
else:
audio_codes = torch.cat(
[
torch.full(
(ele.shape[0], 1),
self.audio_stream_bos_id,
dtype=torch.long,
),
ele,
torch.full(
(ele.shape[0], 1),
self.audio_stream_eos_id,
dtype=torch.long,
),
],
dim=1,
)
if return_labels:
label_audio_ids = torch.cat(
[
torch.full((ele.shape[0], 1), -100, dtype=torch.long),
ele,
torch.full(
(ele.shape[0], 1),
self.audio_stream_eos_id,
dtype=torch.long,
),
],
dim=1,
)
if self.use_delay_pattern:
audio_codes = build_delay_pattern_mask(
audio_codes.unsqueeze(0),
bos_token_id=self.audio_stream_bos_id,
pad_token_id=self.audio_stream_eos_id,
)[0].squeeze(0)
if return_labels:
label_audio_ids = build_delay_pattern_mask(
label_audio_ids.unsqueeze(0),
bos_token_id=-100,
pad_token_id=-100,
)[0].squeeze(0)
new_audio_out_ids_l.append(audio_codes)
if return_labels:
if audio_out_no_train_flag[idx]:
label_audio_ids[:] = -100
label_audio_ids_l.append(label_audio_ids)
audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long()
if return_labels:
label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long()
audio_out_ids_start = torch.cumsum(
torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]),
dim=0,
)
audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long)
else:
audio_out_ids = torch.zeros((0, 0), dtype=torch.long)
audio_out_ids_start = torch.zeros(0, dtype=torch.long)
if return_labels:
label_audio_ids = torch.zeros((0, 0), dtype=torch.long)
reward = torch.tensor(reward_l, dtype=torch.float32)
# Handle padding for input ids and attention mask
if self.pad_left:
input_ids = torch.stack(
[
F.pad(
ele.input_ids,
(max_seq_length - len(ele.input_ids), 0),
value=self.pad_token_id,
)
for ele in processed_batch
]
)
if return_labels:
label_ids = torch.stack(
[
F.pad(
ele.label_ids,
(max_seq_length - len(ele.label_ids), 0),
value=-100,
)
for ele in processed_batch
]
)
attention_mask = torch.stack(
[
F.pad(
torch.ones_like(ele.input_ids),
(max_seq_length - len(ele.input_ids), 0),
value=0,
)
for ele in processed_batch
]
)
else:
input_ids = torch.stack(
[
F.pad(
ele.input_ids,
(0, max_seq_length - len(ele.input_ids)),
value=self.pad_token_id,
)
for ele in processed_batch
]
)
if return_labels:
label_ids = torch.stack(
[
F.pad(
ele.label_ids,
(0, max_seq_length - len(ele.label_ids)),
value=-100,
)
for ele in processed_batch
]
)
attention_mask = torch.stack(
[
F.pad(
torch.ones_like(ele.input_ids),
(0, max_seq_length - len(ele.input_ids)),
value=0,
)
for ele in processed_batch
]
)
if not self.return_audio_in_tokens:
audio_in_ids = None
audio_in_ids_start = None
# Apply audio_num_codebooks limit if specified
if self.audio_num_codebooks is not None:
if audio_in_ids is not None:
audio_in_ids = audio_in_ids[: self.audio_num_codebooks]
if audio_out_ids is not None:
audio_out_ids = audio_out_ids[: self.audio_num_codebooks]
if label_audio_ids is not None:
label_audio_ids = label_audio_ids[: self.audio_num_codebooks]
return HiggsAudioBatchInput(
input_ids=input_ids,
attention_mask=attention_mask,
audio_features=audio_features,
audio_feature_attention_mask=audio_feature_attention_mask,
audio_out_ids=audio_out_ids,
audio_out_ids_start=audio_out_ids_start,
audio_out_ids_start_group_loc=audio_out_ids_start_group_loc,
audio_in_ids=audio_in_ids,
audio_in_ids_start=audio_in_ids_start,
label_ids=label_ids,
label_audio_ids=label_audio_ids,
reward=reward,
)
class HiggsAudioDPOSamplesCollator(HiggsAudioSampleCollator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(self, batch: List[RankedChatMLDatasetSampleTuple]) -> HiggsAudioBatchInput:
# flatten ranked chatml samples
chosen = []
rejected = []
for sample in batch:
chosen.append(sample.max_score_sample())
rejected.append(sample.min_score_sample())
merged = chosen
merged.extend(rejected)
return super().__call__(batch=merged)
|