File size: 39,649 Bytes
5bbc9a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 |
import asyncio
import functools
import logging
import random
import time
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, List, Tuple, Union, AsyncGenerator, Dict, Any
from concurrent.futures import ThreadPoolExecutor
import librosa
import torch
import numpy as np
import torchaudio
import sounddevice as sd
import io
from torch import nn
from IPython.display import Audio, display
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt, RequestOutput
from vllm.multimodal import MultiModalDataDict
from vllm.utils import Counter
from TTS.TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder # noqa
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler # noqa
from .xtts2_config import XTTSConfig, XTTSGPTConfig
from .tokenizer import XTTSTokenizerFast
from ..xtts2_gpt.xtts2_gpt_modeling import LearnedPositionEmbeddings
def wav_to_mel_cloning(
wav,
mel_norms_file="../experiments/clips_mel_norms.pth",
mel_norms=None,
device=torch.device("cpu"),
n_fft=4096,
hop_length=1024,
win_length=4096,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
):
mel_stft = torchaudio.transforms.MelSpectrogram(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
power=power,
normalized=normalized,
sample_rate=sample_rate,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
norm="slaney",
).to(device)
wav = wav.to(device)
mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device)
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel
def load_audio(audiopath, sampling_rate):
audio, lsr = torchaudio.load(audiopath)
# Stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)
if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
# Clip audio invalid values
audio.clip_(-1, 1)
return audio
@dataclass
class XTTSRequest:
"""Container for XTTS inference request data"""
request_id: str
text: Union[AsyncGenerator[str, None], str]
language: str
speaker_file: str # Path to the speaker audio file
generate_every_n_chars: Optional[int] = None
temperature: float = 0.75
top_p: float = 0.85
top_k: int = 50
repetition_penalty: float = 5.0
length_penalty: float = 1.0
do_sample: bool = True
max_ref_length: int = 60
gpt_cond_len: int = 30
gpt_cond_chunk_len: int = 4
import threading
class HiddenStatesCollector:
def __init__(self):
self.outputs = {}
self.lock = threading.Lock()
def __call__(self, outputs: Optional[torch.Tensor], request_id: str):
"""Save outputs for a specific request"""
with self.lock:
if request_id not in self.outputs:
self.outputs[request_id] = []
self.outputs[request_id].append(outputs)
def get_hidden_states(self, request_id) -> Optional[torch.Tensor]:
with self.lock:
outputs = self.outputs.pop(request_id, None)
if outputs is not None:
outputs = torch.cat(outputs, dim=0)
return outputs
def bind_to_request(self, request_id: str):
def bound_collector(outputs: Optional[torch.Tensor], _request_id: str = None):
self(outputs, request_id)
return bound_collector
class ExtendedSamplingParams(SamplingParams, kw_only=True):
"""Extended sampling parameters that allows additional fields while maintaining compatibility with SamplingParams.
This class inherits from SamplingParams and allows adding new required fields
without conflicting with the base class's optional fields ordering.
"""
hidden_state_collector: HiddenStatesCollector # New required field
class LogitsRepetitionPenalizer:
"""A logits processor that applies repetition penalty to prevent repetitive text generation."""
def __init__(self, repetition_penalty: float):
if repetition_penalty < 0:
raise ValueError("Repetition penalty must be non-negative")
self.repetition_penalty = repetition_penalty
def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
"""Apply repetition penalty to the logits based on previous tokens."""
# If no repetition penalty or no tokens to check, return original logits
if self.repetition_penalty == 1.0 or not token_ids:
return logits
# Create a mask for the repeated tokens
repeated_tokens = torch.tensor(token_ids,
device=logits.device,
dtype=torch.long)
# Get logits of repeated tokens
repeated_logits = logits[repeated_tokens]
# Apply penalty: divide positive logits by penalty, multiply negative logits by penalty
repeated_logits = torch.where(
repeated_logits > 0,
repeated_logits / self.repetition_penalty,
repeated_logits * self.repetition_penalty
)
# Update only the logits for repeated tokens
logits[repeated_tokens] = repeated_logits
return logits
@dataclass
class XTTSOutput:
"""Container for XTTS inference output with integrated audio utilities"""
request_id: str
wav: np.ndarray
sample_rate: int = 24000
def to_tensor(self) -> torch.Tensor:
"""Convert numpy array to torch tensor"""
if isinstance(self.wav, np.ndarray):
return torch.from_numpy(self.wav)
return self.wav
def to_bytes(self, format: str = 'wav', sample_width: int = 2) -> bytes:
"""Convert audio to bytes format.
Args:
format: Output format ('wav' or 'raw')
sample_width: Bit depth (1, 2, or 4 bytes per sample)
Returns:
Audio data as bytes
"""
# Convert to tensor if needed
wav_tensor = self.to_tensor()
# Ensure correct shape (1, N) for torchaudio
if wav_tensor.dim() == 1:
wav_tensor = wav_tensor.unsqueeze(0)
# Normalize to [-1, 1]
wav_tensor = torch.clamp(wav_tensor, -1.0, 1.0)
if format == 'wav':
buffer = io.BytesIO()
torchaudio.save(
buffer,
wav_tensor,
self.sample_rate,
format="wav",
encoding="PCM_S" if sample_width == 2 else "PCM_F",
bits_per_sample=sample_width * 8
)
return buffer.getvalue()
elif format == 'raw':
# Scale to appropriate range based on sample width
if sample_width == 2: # 16-bit
wav_tensor = (wav_tensor * 32767).to(torch.int16)
elif sample_width == 4: # 32-bit
wav_tensor = (wav_tensor * 2147483647).to(torch.int32)
else: # 8-bit
wav_tensor = (wav_tensor * 127).to(torch.int8)
return wav_tensor.cpu().numpy().tobytes()
else:
raise ValueError(f"Unsupported format: {format}")
def save(self,
filename: Union[str, Path],
sample_rate: Optional[int] = None,
format: Optional[str] = None) -> None:
"""Save audio to file.
Args:
filename: Output filename
sample_rate: Optional new sample rate for resampling
format: Optional format override (default: inferred from extension)
"""
wav_tensor = self.to_tensor()
if wav_tensor.dim() == 1:
wav_tensor = wav_tensor.unsqueeze(0)
# Resample if needed
if sample_rate and sample_rate != self.sample_rate:
wav_tensor = torchaudio.functional.resample(
wav_tensor,
orig_freq=self.sample_rate,
new_freq=sample_rate
)
else:
sample_rate = self.sample_rate
torchaudio.save(
filename,
wav_tensor,
sample_rate,
format=format
)
def resample(self, new_sample_rate: int) -> 'XTTSOutput':
"""Create new XTTSOutput with resampled audio.
Args:
new_sample_rate: Target sample rate
Returns:
New XTTSOutput instance with resampled audio
"""
wav_tensor = self.to_tensor()
if wav_tensor.dim() == 1:
wav_tensor = wav_tensor.unsqueeze(0)
resampled = torchaudio.functional.resample(
wav_tensor,
orig_freq=self.sample_rate,
new_freq=new_sample_rate
)
return XTTSOutput(
request_id=self.request_id,
wav=resampled.squeeze().numpy(),
sample_rate=new_sample_rate
)
def get_info(self) -> Tuple[int, int, float]:
"""Get audio information.
Returns:
Tuple of (number of samples, sample rate, duration in seconds)
"""
n_samples = len(self.wav)
duration = n_samples / self.sample_rate
return n_samples, self.sample_rate, duration
@classmethod
def from_tensor(cls, request_id: str, tensor: torch.Tensor, sample_rate: int = 24000) -> 'XTTSOutput':
"""Create XTTSOutput from torch tensor.
Args:
request_id: Request identifier
tensor: Audio tensor
sample_rate: Sample rate of the audio
Returns:
New XTTSOutput instance
"""
return cls(
request_id=request_id,
wav=tensor.squeeze().cpu().numpy(),
sample_rate=sample_rate
)
@classmethod
def from_file(cls, request_id: str, filename: Union[str, Path]) -> 'XTTSOutput':
"""Create XTTSOutput from audio file.
Args:
request_id: Request identifier
filename: Path to audio file
Returns:
New XTTSOutput instance
"""
wav_tensor, sample_rate = torchaudio.load(filename)
return cls.from_tensor(request_id, wav_tensor, sample_rate)
def play(self) -> None:
"""Play the audio through the default sound device.
For use in regular Python scripts/applications."""
# Ensure the audio is in the correct format
if isinstance(self.wav, torch.Tensor):
audio_data = self.wav.cpu().numpy()
else:
audio_data = self.wav
# Ensure float32 and normalize
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
audio_data = np.clip(audio_data, -1.0, 1.0)
# Play the audio
sd.play(audio_data, self.sample_rate)
sd.wait() # Wait until the audio is finished playing
def display(self) -> Optional[Audio]:
"""Display audio player in Jupyter notebook.
Returns Audio widget if in notebook, None otherwise."""
try:
# Convert to bytes
audio_bytes = self.to_bytes(format='wav')
# Create and display audio widget
audio_widget = Audio(audio_bytes, rate=self.sample_rate, autoplay=False)
display(audio_widget)
return audio_widget
except Exception as e:
print(f"Could not display audio widget: {str(e)}")
print("Try using .play() method instead")
return None
def preview(self) -> None:
"""Smart play method that chooses appropriate playback method."""
try:
# Try notebook display first
if self.display() is None:
# Fall back to sounddevice if not in notebook
self.play()
except Exception as e:
print(f"Error playing audio: {str(e)}")
class Xtts(nn.Module):
"""Async XTTS model implementation using VLLM's AsyncEngine."""
def __init__(self, hifi_config: XTTSConfig, gpt_config: XTTSGPTConfig, tensor_parallel_size: int = 1, **kwargs):
super().__init__()
self.hifi_config = hifi_config
self.gpt_config = gpt_config
self.mel_bos_token_id = gpt_config.start_audio_token
self.mel_eos_token_id = gpt_config.stop_audio_token
self.tp = tensor_parallel_size
self.tokenizer = XTTSTokenizerFast.from_pretrained("AstraMindAI/xtts2-gpt")
self.request_counter = Counter()
self.executor = ThreadPoolExecutor(max_workers=4) # For CPU-bound tasks
self.hidden_states_collector = HiddenStatesCollector()
# Register buffer before creating modules
self.register_buffer("mel_stats", torch.ones(80))
# Initialize all nn.Module components
self.conditioning_encoder = ConditioningEncoder(
gpt_config.audio_config.mel_channels,
gpt_config.hidden_size,
num_attn_heads=gpt_config.num_attention_heads
)
self.text_embedding = nn.Embedding(
gpt_config.number_text_tokens,
gpt_config.hidden_size
)
self.text_pos_embedding = (
LearnedPositionEmbeddings(
gpt_config.max_text_tokens + 2,
gpt_config.hidden_size,
supports_pp=False
)
if gpt_config.max_audio_tokens != -1
else functools.partial(gpt_config.null_position_embeddings, dim=gpt_config.hidden_size)
)
if gpt_config.use_perceiver_resampler:
self.conditioning_perceiver = PerceiverResampler(
dim=gpt_config.hidden_size,
depth=2,
dim_context=gpt_config.hidden_size,
num_latents=32,
dim_head=64,
heads=8,
ff_mult=4,
use_flash_attn=False,
)
# Initialize HiFi-GAN decoder
self.hifigan_decoder = HifiDecoder(
input_sample_rate=self.hifi_config.input_sample_rate,
output_sample_rate=self.hifi_config.output_sample_rate,
output_hop_length=self.hifi_config.output_hop_length,
ar_mel_length_compression=self.hifi_config.gpt_code_stride_len,
decoder_input_dim=self.hifi_config.decoder_input_dim,
d_vector_dim=self.hifi_config.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.hifi_config.cond_d_vector_in_each_upsampling_layer,
)
# Kept for model loading purposes
self.text_head = nn.Linear(gpt_config.hidden_size, gpt_config.number_text_tokens, bias=True)
self.final_norm = nn.LayerNorm(gpt_config.hidden_size, eps=1e-5, bias=True)
# Initialize VLLM engine at the end
self.init_vllm_engine()
# Semaphore for concurrency control
self.max_concurrency = 10
self.semaphore = asyncio.BoundedSemaphore(self.max_concurrency)
def half(self):
# We cannot permit downcasting since it will throw an error while padding
return
def to(self, *args, **kwargs):
# Block downcasting
dtype = kwargs.get('dtype', None)
if dtype == torch.float16 or dtype == torch.bfloat16:
kwargs['dtype'] = torch.float32
elif len(args) > 0 and (args[0] == torch.float16 or args[0] == torch.bfloat16):
args = list(args)
args[0] = torch.float32
args = tuple(args)
return super().to(*args, **kwargs)
@property
def device(self):
"""Get the current device of the model."""
return next(self.parameters()).device
@property
def dtype(self):
"""Get the current dtype of the model."""
return next(self.parameters()).dtype
@staticmethod
def get_memory_percentage(memory: int) -> float:
"""Get memory percentage."""
total_memory = torch.cuda.get_device_properties(0).total_memory
reserved_memory = torch.cuda.memory_reserved(0)
allocated_memory = torch.cuda.memory_allocated(0)
available_memory = total_memory - reserved_memory - allocated_memory
return memory / available_memory
def init_vllm_engine(self):
"""Initialize models with AsyncVLLMEngine."""
engine_args = AsyncEngineArgs(
model="AstraMindAI/xtts2-gpt",
tensor_parallel_size=self.tp,
dtype="auto",
disable_log_stats=True,
max_model_len=self.gpt_config.max_text_tokens + self.gpt_config.max_audio_tokens,
gpu_memory_utilization=self.get_memory_percentage(3 * 1024 ** 3),
trust_remote_code=True,
enforce_eager=True,
limit_mm_per_prompt={"audio": 1},
max_num_batched_tokens=7296,
)
self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: torch.dtype = torch.float32,
device_map: Optional[str] = "auto",
tensor_parallel_size: int = 1,
**kwargs,
) -> "Xtts":
"""Load pretrained XTTS model from HuggingFace Hub."""
from huggingface_hub import hf_hub_download
import json
import os
# Download and load configs
if not os.path.exists(pretrained_model_name_or_path):
config_file = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="config.json"
)
with open(config_file, 'r') as f:
config = json.load(f)
else:
# Load from local path
with open(os.path.join(pretrained_model_name_or_path, "config.json"), 'r') as f:
config = json.load(f)
# Initialize configs
gpt_config = XTTSGPTConfig(**config['gpt_config'])
hifi_config = XTTSConfig(**config)
# Initialize model
model = cls(
hifi_config=hifi_config,
gpt_config=gpt_config,
tensor_parallel_size=tensor_parallel_size,
**kwargs
)
# Load model weights
if not os.path.exists(pretrained_model_name_or_path):
hifigan_weights = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="xtts-v2.safetensors"
)
else:
hifigan_weights = os.path.join(pretrained_model_name_or_path, "xtts-v2.safetensors")
import safetensors.torch
# Load HiFi-GAN weights
hifigan_state = safetensors.torch.load_file(hifigan_weights)
model.load_state_dict(hifigan_state)
# Set model properties
model.config = config
# Cast model to specified dtype
model = model.to(torch_dtype)
model = model.to('cuda')
return model
@staticmethod
def load_audio(audio_path: Union[str, Path], sampling_rate: int = 22050) -> torch.Tensor:
audio, lsr = torchaudio.load(audio_path)
# Stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)
if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
# Clip audio invalid values
audio.clip_(-1, 1)
return audio
@torch.inference_mode()
def get_speaker_embedding(self, audio, sr):
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
return (
self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True)
.unsqueeze(-1)
.to(self.device)
)
@torch.inference_mode()
def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6):
"""Compute the conditioning latents for the GPT model from the given audio."""
if sr != 22050:
audio = torchaudio.functional.resample(audio, sr, 22050)
if length > 0:
audio = audio[:, : 22050 * length]
if self.gpt_config.use_perceiver_resampler:
style_embs = []
for i in range(0, audio.shape[1], 22050 * chunk_length):
audio_chunk = audio[:, i: i + 22050 * chunk_length]
# if the chunk is too short ignore it
if audio_chunk.size(-1) < 22050 * 0.33:
continue
mel_chunk = wav_to_mel_cloning(
audio_chunk,
mel_norms=self.mel_stats.cpu(),
n_fft=2048,
hop_length=256,
win_length=1024,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
style_emb = self.get_style_emb(mel_chunk.to(self.device), None)
style_embs.append(style_emb)
# mean style embedding
cond_latent = torch.stack(style_embs).mean(dim=0)
else:
mel = wav_to_mel_cloning(
audio,
mel_norms=self.mel_stats.cpu(),
n_fft=4096,
hop_length=1024,
win_length=4096,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
cond_latent = self.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2)
@torch.inference_mode()
def get_conditioning_latents(
self,
audio_path,
max_ref_length=30,
gpt_cond_len=6,
gpt_cond_chunk_len=6,
librosa_trim_db=None,
sound_norm_refs=False,
load_sr=22050,
):
"""Get the conditioning latents for the GPT model from the given audio."""
# Deal with multiple references
assert isinstance(audio_path, str) or isinstance(audio_path, list), "audio_path must be a string or a list."
if not isinstance(audio_path, list):
audio_paths = [audio_path]
else:
audio_paths = audio_path
speaker_embeddings = []
audios = []
for file_path in audio_paths:
audio = load_audio(file_path, load_sr)
audio = audio[:, : load_sr * max_ref_length].to(self.device).to(self.dtype)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
# Compute latents for the decoder
speaker_embedding = self.get_speaker_embedding(audio, load_sr)
speaker_embeddings.append(speaker_embedding)
audios.append(audio)
# Merge all the audios and compute the latents for the GPT
full_audio = torch.cat(audios, dim=-1)
gpt_cond_latents = self.get_gpt_cond_latents(
full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len
) # [1, 1024, T]
speaker_embedding = torch.stack(speaker_embeddings)
speaker_embedding = speaker_embedding.mean(dim=0)
return gpt_cond_latents, speaker_embedding
def get_style_emb(self, cond_input: torch.Tensor, return_latent: bool = False) -> torch.Tensor:
"""Get conditioning embeddings from mel spectrograms."""
if not return_latent:
if cond_input.ndim == 4:
cond_input = cond_input.squeeze(1)
conds = self.conditioning_encoder(cond_input)
if hasattr(self, 'conditioning_perceiver'):
conds = self.conditioning_perceiver(
conds.permute(0, 2, 1)
).transpose(1, 2)
else:
conds = cond_input.unsqueeze(1)
return conds
async def prepare_text_tokens_async(self, text: str, language: str, split_text=False) \
-> Tuple[List[Union[int, List[int]]], List[torch.Tensor]]:
"""Prepare text tokens for the given text and language."""
async def elaborate_tokens(text_tokens: List[int]) -> torch.Tensor:
text_tokens.insert(0, self.tokenizer.bos_token_id)
text_tokens.append(self.tokenizer.eos_token_id)
return torch.tensor(text_tokens).unsqueeze(0).to(self.text_embedding.weight.device)
async def embed_tokens(text_tokens: Union[torch.Tensor, List[torch.Tensor]]) -> List[torch.Tensor]:
embeds = []
if isinstance(text_tokens, list):
for list_element in text_tokens:
embeds.append(self.text_embedding(list_element) + self.text_pos_embedding(list_element))
else:
embeds.append(self.text_embedding(text_tokens) + self.text_pos_embedding(text_tokens))
return embeds
fake_tokens_for_audio_generation = []
if split_text:
text_tokens = self.tokenizer.batch_encode_with_split(text, lang=[language])
for idx, text_token in enumerate(text_tokens):
text_tokens[idx] = await elaborate_tokens(text_token)
fake_tokens_for_audio_generation.append([1] * len(text_token))
else:
text_tokens = self.tokenizer.batch_encode(text, lang=[language])
text_tokens = await elaborate_tokens(text_tokens)
fake_tokens_for_audio_generation = [1] * len(text_tokens)
return fake_tokens_for_audio_generation, await embed_tokens(text_tokens)
async def prepare_inputs_async(self, text: str, language: str, speaker_file: Union[str, Path],
max_ref_length: int, gpt_cond_len: int, gpt_cond_chunk_len: int, split_text: bool) \
-> Tuple[List[List[int]], List[torch.Tensor], torch.Tensor]:
"""Prepare input text with conditioning tokens. Return combined conditioning latents"""
# Tokenize text based on the language
text_tokens, text_embeddings = await self.prepare_text_tokens_async(text, language, split_text)
# Load the speaker file and convert it to a tensor
gpt_cond_latent, speaker_embeddings = await self.get_conditioning_latents_async(
speaker_file,
max_ref_length,
gpt_cond_len,
gpt_cond_chunk_len
)
cond_latents = []
for text_embedding in text_embeddings:
# Concatenate along sequence dimension
cond_latents.append((torch.cat([gpt_cond_latent, text_embedding], dim=1).squeeze(0)
.to(self.llm_engine.engine.model_config.dtype)))
return text_tokens, cond_latents, speaker_embeddings
async def get_conditioning_latents_async(
self,
audio_path,
max_ref_length=30,
gpt_cond_len=6,
gpt_cond_chunk_len=6,
librosa_trim_db=None,
sound_norm_refs=False,
load_sr=22050,
):
"""Async version of get_conditioning_latents with concurrency control."""
async with self.semaphore:
# Run the original get_conditioning_latents in executor
result = await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(self.get_conditioning_latents,
audio_path,
max_ref_length,
gpt_cond_len,
gpt_cond_chunk_len,
librosa_trim_db,
sound_norm_refs,
load_sr)
)
return result
async def get_model_logits(self, token_ids: List[int], conditioning: MultiModalDataDict) -> torch.Tensor:
"""Get model logits for a specific request"""
request_id = uuid.uuid4().hex
# Add start and end tokens
token_ids = [self.mel_bos_token_id] + token_ids + [self.mel_eos_token_id] * 5
engine_inputs = TokensPrompt(prompt_token_ids=token_ids)
engine_inputs["multi_modal_data"] = conditioning
# Bind the collector to this request
bound_collector = self.hidden_states_collector.bind_to_request(request_id)
# Set up sampling parameters with the bound collector
sampling_params = ExtendedSamplingParams(
detokenize=False,
max_tokens=1,
hidden_state_collector=bound_collector,
)
# Generate with unique request ID
generator = self.llm_engine.generate(
prompt=engine_inputs,
sampling_params=sampling_params,
request_id=request_id
)
# Consume the generator with a timeout
try:
async def consume_generator():
async for _ in generator:
pass
await asyncio.wait_for(consume_generator(), timeout=300)
except asyncio.TimeoutError:
raise RuntimeError("Timeout while generating logits")
# Get the collected hidden states
hidden_states = self.hidden_states_collector.get_hidden_states(request_id)
if hidden_states is None:
raise RuntimeError(f"No hidden states collected for request {request_id}")
return hidden_states[-len(token_ids):, ...].unsqueeze(0).to(self.device).to(self.dtype)
async def process_tokens_to_speech(
self,
generators: List[AsyncGenerator[RequestOutput, None]],
speaker_embeddings: torch.Tensor,
multimodal_data: List[torch.Tensor],
chunk_size: int = 20,
) -> AsyncGenerator[XTTSOutput, None]:
"""
Process multiple token generators concurrently and emit results sequentially.
Uses a queue-based approach to handle multiple generators reliably.
"""
# Create a queue for each generator to store its results
queues = [asyncio.Queue() for _ in generators]
# Create tasks for processing each generator
tasks = []
for i, generator in enumerate(generators):
task = asyncio.create_task(
self._process_single_generator(
generator,
queues[i],
speaker_embeddings,
multimodal_data[i],
chunk_size
)
)
tasks.append(task)
try:
# Process queues in sequence
for i, queue in enumerate(queues):
while True:
result = await queue.get()
if result is None:
# This generator has finished
break
else:
yield result
finally:
# Ensure all tasks are properly cleaned up
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
async def _process_single_generator(
self,
generator: AsyncGenerator[RequestOutput, None],
queue: asyncio.Queue,
speaker_embeddings: torch.Tensor,
gpt_embed_input: torch.Tensor,
chunk_size: int
) -> None:
"""Process a single generator and put results in its queue."""
try:
last_decoded_token = 0
accumulated_tokens = []
async for output in generator:
# Get new tokens
new_tokens = output.outputs[0].token_ids[last_decoded_token:]
accumulated_tokens.extend(new_tokens)
last_decoded_token = len(accumulated_tokens)
# Process tokens when we have enough or it's the final output
if output.finished:# or len(accumulated_tokens) >= chunk_size: se lascio con acculated token mi ripete gli stesis toke, why??
# Process the accumulated tokens
hidden_states = await self.get_model_logits(
accumulated_tokens,
{
"audio": {
'embeds': gpt_embed_input,
"is_logits_only_mode": True
}
}
)
# Generate audio segment
wav = await asyncio.get_event_loop().run_in_executor(
self.executor,
lambda: self.hifigan_decoder.inference(
hidden_states,
g=speaker_embeddings
).cpu().numpy().squeeze()
)
# Put result in queue
await queue.put(XTTSOutput(
request_id=output.request_id,
wav=wav
))
# Reset accumulated tokens
accumulated_tokens = []
if output.finished:
break
except Exception as e:
logging.error(f"Error in generator processing: {e}")
finally:
# Signal completion
await queue.put(None)
async def generate_speech_async_from_streaming_source(self, request: XTTSRequest) -> AsyncGenerator[XTTSOutput, None]:
"""Generate speech for streaming source of text, making a streaming source of audio tokens and then decoding
and returning a streaming audio response."""
assert isinstance(request.text, AsyncGenerator), "Text must be an AsyncGenerator for streaming source."
# Prepare input with conditioning
gpt_cond_latent, speaker_embeddings = await self.get_conditioning_latents_async(
request.speaker_file,
request.max_ref_length,
request.gpt_cond_len,
request.gpt_cond_chunk_len
)
sampling_params = SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
detokenize=False,
top_k=request.top_k,
logits_processors=[LogitsRepetitionPenalizer(request.repetition_penalty)],
repetition_penalty=1.0, # Since we're handling repetition penalty manually
max_tokens=self.gpt_config.gpt_max_audio_tokens,
ignore_eos=True, # Ignore the tokenizer eos token since it is for textual generation
stop_token_ids=[self.mel_eos_token_id],
)
accumulated_text = ""
async for text in request.text:
text = text.strip()
accumulated_text += text
if len(accumulated_text) > request.generate_every_n_chars:
tokens, embeddings = await self.prepare_text_tokens_async(accumulated_text, request.language)
gpt_embed_input = [torch.cat([gpt_cond_latent, embeddings[0]], dim=0)]
engine_inputs = TokensPrompt(prompt_token_ids=tokens)
if gpt_embed_input is not None:
engine_inputs["multi_modal_data"] = {"audio": {"embeds": gpt_embed_input, "is_logits_only_mode": False}}
token_generator = [self.llm_engine.generate(
prompt=engine_inputs,
sampling_params=sampling_params,
request_id=request.request_id,
)]
# Process tokens to speech
async for output in self.process_tokens_to_speech(
token_generator,
speaker_embeddings,
gpt_embed_input,
chunk_size=50
):
yield output
accumulated_text = ""
async def generate_speech_from_text_async(self, request: XTTSRequest) -> AsyncGenerator[XTTSOutput, None]:
"""Generate speech for a single request asynchronously."""
# Prepare input with conditioning
tokens_list, gpt_embed_inputs, speaker_embeddings = await self.prepare_inputs_async(
request.text,
request.language,
request.speaker_file,
request.max_ref_length,
request.gpt_cond_len,
request.gpt_cond_chunk_len,
split_text=True # Split text to avoid OOM on big texts
)
# Start all requests in parallel
generators = []
for seq_index, sequence in enumerate(tokens_list):
sampling_params = SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
detokenize=False,
top_k=request.top_k,
logits_processors=[LogitsRepetitionPenalizer(request.repetition_penalty)],
repetition_penalty=1.0, # Since we're handling repetition penalty manually
max_tokens=self.gpt_config.gpt_max_audio_tokens,
ignore_eos=True, # Ignore the tokenizer eos token since it is for textual generation
stop_token_ids=[self.mel_eos_token_id],
)
engine_inputs = TokensPrompt(prompt_token_ids=sequence)
if gpt_embed_inputs is not None:
engine_inputs["multi_modal_data"] = {"audio": {"embeds": gpt_embed_inputs[seq_index], "is_logits_only_mode": False}}
# Get audio token generator from VLLM
token_generator = self.llm_engine.generate(
prompt=engine_inputs,
sampling_params=sampling_params,
request_id=f"{request.request_id}_{seq_index}",
)
generators.append(token_generator)
# Process tokens to speech
async for output in self.process_tokens_to_speech(
generators,
speaker_embeddings,
gpt_embed_inputs,
chunk_size=50
):
yield output
def generate_speech_from_text(self, request: XTTSRequest) -> List[XTTSOutput]:
"""
Synchronous wrapper for generate_speech_from_text_async.
Args:
request: XTTSRequest object containing generation parameters
Returns:
List of XTTSOutput containing the generated speech segments
"""
async def _collect_outputs():
outputs = []
async for output in self.generate_speech_from_text_async(request):
outputs.append(output)
return outputs
# Run the async code in an event loop
import asyncio
# Get or create an event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if loop.is_running():
# Create a new loop if the current one is running
new_loop = asyncio.new_event_loop()
results = new_loop.run_until_complete(_collect_outputs())
new_loop.close()
else:
results = loop.run_until_complete(_collect_outputs())
return results
|