|
|
"""Example script for generating audio using HiggsAudio.""" |
|
|
|
|
|
import click |
|
|
import soundfile as sf |
|
|
import langid |
|
|
import jieba |
|
|
import os |
|
|
import re |
|
|
import copy |
|
|
import torchaudio |
|
|
import tqdm |
|
|
import yaml |
|
|
|
|
|
from loguru import logger |
|
|
from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse |
|
|
from boson_multimodal.data_types import Message, ChatMLSample, AudioContent, TextContent |
|
|
|
|
|
from boson_multimodal.model.higgs_audio import HiggsAudioConfig, HiggsAudioModel |
|
|
from boson_multimodal.data_collator.higgs_audio_collator import HiggsAudioSampleCollator |
|
|
from boson_multimodal.audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer |
|
|
from boson_multimodal.dataset.chatml_dataset import ( |
|
|
ChatMLDatasetSample, |
|
|
prepare_chatml_sample, |
|
|
) |
|
|
from boson_multimodal.model.higgs_audio.utils import revert_delay_pattern |
|
|
from typing import List |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
from transformers.cache_utils import StaticCache |
|
|
from typing import Optional |
|
|
from dataclasses import asdict |
|
|
import torch |
|
|
|
|
|
CURR_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
|
|
|
AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>" |
|
|
|
|
|
|
|
|
MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech. |
|
|
If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice. |
|
|
If no speaker tag is present, select a suitable voice on your own.""" |
|
|
|
|
|
|
|
|
def normalize_chinese_punctuation(text): |
|
|
""" |
|
|
Convert Chinese (full-width) punctuation marks to English (half-width) equivalents. |
|
|
""" |
|
|
|
|
|
chinese_to_english_punct = { |
|
|
",": ", ", |
|
|
"。": ".", |
|
|
":": ":", |
|
|
";": ";", |
|
|
"?": "?", |
|
|
"!": "!", |
|
|
"(": "(", |
|
|
")": ")", |
|
|
"【": "[", |
|
|
"】": "]", |
|
|
"《": "<", |
|
|
"》": ">", |
|
|
"“": '"', |
|
|
"”": '"', |
|
|
"‘": "'", |
|
|
"’": "'", |
|
|
"、": ",", |
|
|
"—": "-", |
|
|
"…": "...", |
|
|
"·": ".", |
|
|
"「": '"', |
|
|
"」": '"', |
|
|
"『": '"', |
|
|
"』": '"', |
|
|
} |
|
|
|
|
|
|
|
|
for zh_punct, en_punct in chinese_to_english_punct.items(): |
|
|
text = text.replace(zh_punct, en_punct) |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
def prepare_chunk_text( |
|
|
text, chunk_method: Optional[str] = None, chunk_max_word_num: int = 100, chunk_max_num_turns: int = 1 |
|
|
): |
|
|
"""Chunk the text into smaller pieces. We will later feed the chunks one by one to the model. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
text : str |
|
|
The text to be chunked. |
|
|
chunk_method : str, optional |
|
|
The method to use for chunking. Options are "speaker", "word", or None. By default, we won't use any chunking and |
|
|
will feed the whole text to the model. |
|
|
replace_speaker_tag_with_special_tags : bool, optional |
|
|
Whether to replace speaker tags with special tokens, by default False |
|
|
If the flag is set to True, we will replace [SPEAKER0] with <|speaker_id_start|>SPEAKER0<|speaker_id_end|> |
|
|
chunk_max_word_num : int, optional |
|
|
The maximum number of words for each chunk when "word" chunking method is used, by default 100 |
|
|
chunk_max_num_turns : int, optional |
|
|
The maximum number of turns for each chunk when "speaker" chunking method is used, |
|
|
|
|
|
Returns |
|
|
------- |
|
|
List[str] |
|
|
The list of text chunks. |
|
|
|
|
|
""" |
|
|
if chunk_method is None: |
|
|
return [text] |
|
|
elif chunk_method == "speaker": |
|
|
lines = text.split("\n") |
|
|
speaker_chunks = [] |
|
|
speaker_utterance = "" |
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if line.startswith("[SPEAKER") or line.startswith("<|speaker_id_start|>"): |
|
|
if speaker_utterance: |
|
|
speaker_chunks.append(speaker_utterance.strip()) |
|
|
speaker_utterance = line |
|
|
else: |
|
|
if speaker_utterance: |
|
|
speaker_utterance += "\n" + line |
|
|
else: |
|
|
speaker_utterance = line |
|
|
if speaker_utterance: |
|
|
speaker_chunks.append(speaker_utterance.strip()) |
|
|
if chunk_max_num_turns > 1: |
|
|
merged_chunks = [] |
|
|
for i in range(0, len(speaker_chunks), chunk_max_num_turns): |
|
|
merged_chunk = "\n".join(speaker_chunks[i : i + chunk_max_num_turns]) |
|
|
merged_chunks.append(merged_chunk) |
|
|
return merged_chunks |
|
|
return speaker_chunks |
|
|
elif chunk_method == "word": |
|
|
|
|
|
|
|
|
|
|
|
language = langid.classify(text)[0] |
|
|
paragraphs = text.split("\n\n") |
|
|
chunks = [] |
|
|
for idx, paragraph in enumerate(paragraphs): |
|
|
if language == "zh": |
|
|
|
|
|
words = list(jieba.cut(paragraph, cut_all=False)) |
|
|
for i in range(0, len(words), chunk_max_word_num): |
|
|
chunk = "".join(words[i : i + chunk_max_word_num]) |
|
|
chunks.append(chunk) |
|
|
else: |
|
|
words = paragraph.split(" ") |
|
|
for i in range(0, len(words), chunk_max_word_num): |
|
|
chunk = " ".join(words[i : i + chunk_max_word_num]) |
|
|
chunks.append(chunk) |
|
|
chunks[-1] += "\n\n" |
|
|
return chunks |
|
|
else: |
|
|
raise ValueError(f"Unknown chunk method: {chunk_method}") |
|
|
|
|
|
|
|
|
def _build_system_message_with_audio_prompt(system_message): |
|
|
contents = [] |
|
|
|
|
|
while AUDIO_PLACEHOLDER_TOKEN in system_message: |
|
|
loc = system_message.find(AUDIO_PLACEHOLDER_TOKEN) |
|
|
contents.append(TextContent(system_message[:loc])) |
|
|
contents.append(AudioContent(audio_url="")) |
|
|
system_message = system_message[loc + len(AUDIO_PLACEHOLDER_TOKEN) :] |
|
|
|
|
|
if len(system_message) > 0: |
|
|
contents.append(TextContent(system_message)) |
|
|
ret = Message( |
|
|
role="system", |
|
|
content=contents, |
|
|
) |
|
|
return ret |
|
|
|
|
|
|
|
|
class HiggsAudioModelClient: |
|
|
def __init__( |
|
|
self, |
|
|
model_path, |
|
|
audio_tokenizer, |
|
|
device_id=None, |
|
|
max_new_tokens=2048, |
|
|
kv_cache_lengths: List[int] = [1024, 4096, 8192], |
|
|
use_static_kv_cache=False, |
|
|
): |
|
|
if device_id is None: |
|
|
self._device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
else: |
|
|
self._device = f"cuda:{device_id}" |
|
|
self._audio_tokenizer = ( |
|
|
load_higgs_audio_tokenizer(audio_tokenizer, device=self._device) |
|
|
if isinstance(audio_tokenizer, str) |
|
|
else audio_tokenizer |
|
|
) |
|
|
self._model = HiggsAudioModel.from_pretrained( |
|
|
model_path, |
|
|
device_map=self._device, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
self._model.eval() |
|
|
self._kv_cache_lengths = kv_cache_lengths |
|
|
self._use_static_kv_cache = use_static_kv_cache |
|
|
|
|
|
self._tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
self._config = AutoConfig.from_pretrained(model_path) |
|
|
self._max_new_tokens = max_new_tokens |
|
|
self._collator = HiggsAudioSampleCollator( |
|
|
whisper_processor=None, |
|
|
audio_in_token_id=self._config.audio_in_token_idx, |
|
|
audio_out_token_id=self._config.audio_out_token_idx, |
|
|
audio_stream_bos_id=self._config.audio_stream_bos_id, |
|
|
audio_stream_eos_id=self._config.audio_stream_eos_id, |
|
|
encode_whisper_embed=self._config.encode_whisper_embed, |
|
|
pad_token_id=self._config.pad_token_id, |
|
|
return_audio_in_tokens=self._config.encode_audio_in_tokens, |
|
|
use_delay_pattern=self._config.use_delay_pattern, |
|
|
round_to=1, |
|
|
audio_num_codebooks=self._config.audio_num_codebooks, |
|
|
) |
|
|
self.kv_caches = None |
|
|
if use_static_kv_cache: |
|
|
self._init_static_kv_cache() |
|
|
|
|
|
def _init_static_kv_cache(self): |
|
|
cache_config = copy.deepcopy(self._model.config.text_config) |
|
|
cache_config.num_hidden_layers = self._model.config.text_config.num_hidden_layers |
|
|
if self._model.config.audio_dual_ffn_layers: |
|
|
cache_config.num_hidden_layers += len(self._model.config.audio_dual_ffn_layers) |
|
|
|
|
|
self.kv_caches = { |
|
|
length: StaticCache( |
|
|
config=cache_config, |
|
|
max_batch_size=1, |
|
|
max_cache_len=length, |
|
|
device=self._model.device, |
|
|
dtype=self._model.dtype, |
|
|
) |
|
|
for length in sorted(self._kv_cache_lengths) |
|
|
} |
|
|
|
|
|
if "cuda" in self._device: |
|
|
logger.info(f"Capturing CUDA graphs for each KV cache length") |
|
|
self._model.capture_model(self.kv_caches.values()) |
|
|
|
|
|
def _prepare_kv_caches(self): |
|
|
for kv_cache in self.kv_caches.values(): |
|
|
kv_cache.reset() |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate( |
|
|
self, |
|
|
messages, |
|
|
audio_ids, |
|
|
chunked_text, |
|
|
generation_chunk_buffer_size, |
|
|
temperature=1.0, |
|
|
top_k=50, |
|
|
top_p=0.95, |
|
|
ras_win_len=7, |
|
|
ras_win_max_num_repeat=2, |
|
|
seed=123, |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
if ras_win_len is not None and ras_win_len <= 0: |
|
|
ras_win_len = None |
|
|
sr = 24000 |
|
|
audio_out_ids_l = [] |
|
|
generated_audio_ids = [] |
|
|
generation_messages = [] |
|
|
for idx, chunk_text in tqdm.tqdm( |
|
|
enumerate(chunked_text), desc="Generating audio chunks", total=len(chunked_text) |
|
|
): |
|
|
generation_messages.append( |
|
|
Message( |
|
|
role="user", |
|
|
content=chunk_text, |
|
|
) |
|
|
) |
|
|
chatml_sample = ChatMLSample(messages=messages + generation_messages) |
|
|
input_tokens, _, _, _ = prepare_chatml_sample(chatml_sample, self._tokenizer) |
|
|
postfix = self._tokenizer.encode( |
|
|
"<|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False |
|
|
) |
|
|
input_tokens.extend(postfix) |
|
|
|
|
|
logger.info(f"========= Chunk {idx} Input =========") |
|
|
logger.info(self._tokenizer.decode(input_tokens)) |
|
|
context_audio_ids = audio_ids + generated_audio_ids |
|
|
|
|
|
curr_sample = ChatMLDatasetSample( |
|
|
input_ids=torch.LongTensor(input_tokens), |
|
|
label_ids=None, |
|
|
audio_ids_concat=torch.concat([ele.cpu() for ele in context_audio_ids], dim=1) |
|
|
if context_audio_ids |
|
|
else None, |
|
|
audio_ids_start=torch.cumsum( |
|
|
torch.tensor([0] + [ele.shape[1] for ele in context_audio_ids], dtype=torch.long), dim=0 |
|
|
) |
|
|
if context_audio_ids |
|
|
else None, |
|
|
audio_waveforms_concat=None, |
|
|
audio_waveforms_start=None, |
|
|
audio_sample_rate=None, |
|
|
audio_speaker_indices=None, |
|
|
) |
|
|
|
|
|
batch_data = self._collator([curr_sample]) |
|
|
batch = asdict(batch_data) |
|
|
for k, v in batch.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
batch[k] = v.contiguous().to(self._device) |
|
|
|
|
|
if self._use_static_kv_cache: |
|
|
self._prepare_kv_caches() |
|
|
|
|
|
|
|
|
outputs = self._model.generate( |
|
|
**batch, |
|
|
max_new_tokens=self._max_new_tokens, |
|
|
use_cache=True, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
past_key_values_buckets=self.kv_caches, |
|
|
ras_win_len=ras_win_len, |
|
|
ras_win_max_num_repeat=ras_win_max_num_repeat, |
|
|
stop_strings=["<|end_of_text|>", "<|eot_id|>"], |
|
|
tokenizer=self._tokenizer, |
|
|
seed=seed, |
|
|
) |
|
|
|
|
|
step_audio_out_ids_l = [] |
|
|
for ele in outputs[1]: |
|
|
audio_out_ids = ele |
|
|
if self._config.use_delay_pattern: |
|
|
audio_out_ids = revert_delay_pattern(audio_out_ids) |
|
|
step_audio_out_ids_l.append(audio_out_ids.clip(0, self._audio_tokenizer.codebook_size - 1)[:, 1:-1]) |
|
|
audio_out_ids = torch.concat(step_audio_out_ids_l, dim=1) |
|
|
audio_out_ids_l.append(audio_out_ids) |
|
|
generated_audio_ids.append(audio_out_ids) |
|
|
|
|
|
generation_messages.append( |
|
|
Message( |
|
|
role="assistant", |
|
|
content=AudioContent(audio_url=""), |
|
|
) |
|
|
) |
|
|
if generation_chunk_buffer_size is not None and len(generated_audio_ids) > generation_chunk_buffer_size: |
|
|
generated_audio_ids = generated_audio_ids[-generation_chunk_buffer_size:] |
|
|
generation_messages = generation_messages[(-2 * generation_chunk_buffer_size) :] |
|
|
|
|
|
logger.info(f"========= Final Text output =========") |
|
|
logger.info(self._tokenizer.decode(outputs[0][0])) |
|
|
concat_audio_out_ids = torch.concat(audio_out_ids_l, dim=1) |
|
|
concat_wv = self._audio_tokenizer.decode(concat_audio_out_ids.unsqueeze(0))[0, 0] |
|
|
text_result = self._tokenizer.decode(outputs[0][0]) |
|
|
return concat_wv, sr, text_result |
|
|
|
|
|
|
|
|
def prepare_generation_context(scene_prompt, ref_audio, ref_audio_in_system_message, audio_tokenizer, speaker_tags): |
|
|
"""Prepare the context for generation. |
|
|
|
|
|
The context contains the system message, user message, assistant message, and audio prompt if any. |
|
|
""" |
|
|
system_message = None |
|
|
messages = [] |
|
|
audio_ids = [] |
|
|
if ref_audio is not None: |
|
|
num_speakers = len(ref_audio.split(",")) |
|
|
speaker_info_l = ref_audio.split(",") |
|
|
voice_profile = None |
|
|
if any([speaker_info.startswith("profile:") for speaker_info in ref_audio.split(",")]): |
|
|
ref_audio_in_system_message = True |
|
|
if ref_audio_in_system_message: |
|
|
speaker_desc = [] |
|
|
for spk_id, character_name in enumerate(speaker_info_l): |
|
|
if character_name.startswith("profile:"): |
|
|
if voice_profile is None: |
|
|
with open(f"{CURR_DIR}/voice_prompts/profile.yaml", "r", encoding="utf-8") as f: |
|
|
voice_profile = yaml.safe_load(f) |
|
|
character_desc = voice_profile["profiles"][character_name[len("profile:") :].strip()] |
|
|
speaker_desc.append(f"SPEAKER{spk_id}: {character_desc}") |
|
|
else: |
|
|
speaker_desc.append(f"SPEAKER{spk_id}: {AUDIO_PLACEHOLDER_TOKEN}") |
|
|
if scene_prompt: |
|
|
system_message = ( |
|
|
"Generate audio following instruction." |
|
|
"\n\n" |
|
|
f"<|scene_desc_start|>\n{scene_prompt}\n\n" + "\n".join(speaker_desc) + "\n<|scene_desc_end|>" |
|
|
) |
|
|
else: |
|
|
system_message = ( |
|
|
"Generate audio following instruction.\n\n" |
|
|
+ f"<|scene_desc_start|>\n" |
|
|
+ "\n".join(speaker_desc) |
|
|
+ "\n<|scene_desc_end|>" |
|
|
) |
|
|
system_message = _build_system_message_with_audio_prompt(system_message) |
|
|
else: |
|
|
if scene_prompt: |
|
|
system_message = Message( |
|
|
role="system", |
|
|
content=f"Generate audio following instruction.\n\n<|scene_desc_start|>\n{scene_prompt}\n<|scene_desc_end|>", |
|
|
) |
|
|
voice_profile = None |
|
|
for spk_id, character_name in enumerate(ref_audio.split(",")): |
|
|
if not character_name.startswith("profile:"): |
|
|
prompt_audio_path = os.path.join(f"{CURR_DIR}/voice_prompts", f"{character_name}.wav") |
|
|
prompt_text_path = os.path.join(f"{CURR_DIR}/voice_prompts", f"{character_name}.txt") |
|
|
assert os.path.exists(prompt_audio_path), ( |
|
|
f"Voice prompt audio file {prompt_audio_path} does not exist." |
|
|
) |
|
|
assert os.path.exists(prompt_text_path), f"Voice prompt text file {prompt_text_path} does not exist." |
|
|
with open(prompt_text_path, "r", encoding="utf-8") as f: |
|
|
prompt_text = f.read().strip() |
|
|
audio_tokens = audio_tokenizer.encode(prompt_audio_path) |
|
|
audio_ids.append(audio_tokens) |
|
|
|
|
|
if not ref_audio_in_system_message: |
|
|
messages.append( |
|
|
Message( |
|
|
role="user", |
|
|
content=f"[SPEAKER{spk_id}] {prompt_text}" if num_speakers > 1 else prompt_text, |
|
|
) |
|
|
) |
|
|
messages.append( |
|
|
Message( |
|
|
role="assistant", |
|
|
content=AudioContent( |
|
|
audio_url=prompt_audio_path, |
|
|
), |
|
|
) |
|
|
) |
|
|
else: |
|
|
if len(speaker_tags) > 1: |
|
|
|
|
|
speaker_desc_l = [] |
|
|
|
|
|
for idx, tag in enumerate(speaker_tags): |
|
|
if idx % 2 == 0: |
|
|
speaker_desc = f"feminine" |
|
|
else: |
|
|
speaker_desc = f"masculine" |
|
|
speaker_desc_l.append(f"{tag}: {speaker_desc}") |
|
|
|
|
|
speaker_desc = "\n".join(speaker_desc_l) |
|
|
scene_desc_l = [] |
|
|
if scene_prompt: |
|
|
scene_desc_l.append(scene_prompt) |
|
|
scene_desc_l.append(speaker_desc) |
|
|
scene_desc = "\n\n".join(scene_desc_l) |
|
|
|
|
|
system_message = Message( |
|
|
role="system", |
|
|
content=f"{MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE}\n\n<|scene_desc_start|>\n{scene_desc}\n<|scene_desc_end|>", |
|
|
) |
|
|
else: |
|
|
system_message_l = ["Generate audio following instruction."] |
|
|
if scene_prompt: |
|
|
system_message_l.append(f"<|scene_desc_start|>\n{scene_prompt}\n<|scene_desc_end|>") |
|
|
system_message = Message( |
|
|
role="system", |
|
|
content="\n\n".join(system_message_l), |
|
|
) |
|
|
if system_message: |
|
|
messages.insert(0, system_message) |
|
|
return messages, audio_ids |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.option( |
|
|
"--model_path", |
|
|
type=str, |
|
|
default="bosonai/higgs-audio-v2-generation-3B-base", |
|
|
help="Output wav file path.", |
|
|
) |
|
|
@click.option( |
|
|
"--audio_tokenizer", |
|
|
type=str, |
|
|
default="bosonai/higgs-audio-v2-tokenizer", |
|
|
help="Audio tokenizer path, if not set, use the default one.", |
|
|
) |
|
|
@click.option( |
|
|
"--max_new_tokens", |
|
|
type=int, |
|
|
default=2048, |
|
|
help="The maximum number of new tokens to generate.", |
|
|
) |
|
|
@click.option( |
|
|
"--transcript", |
|
|
type=str, |
|
|
default="transcript/single_speaker/en_dl.txt", |
|
|
help="The prompt to use for generation. If not set, we will use a default prompt.", |
|
|
) |
|
|
@click.option( |
|
|
"--scene_prompt", |
|
|
type=str, |
|
|
default=f"{CURR_DIR}/scene_prompts/quiet_indoor.txt", |
|
|
help="The scene description prompt to use for generation. If not set, or set to `empty`, we will leave it to empty.", |
|
|
) |
|
|
@click.option( |
|
|
"--temperature", |
|
|
type=float, |
|
|
default=1.0, |
|
|
help="The value used to module the next token probabilities.", |
|
|
) |
|
|
@click.option( |
|
|
"--top_k", |
|
|
type=int, |
|
|
default=50, |
|
|
help="The number of highest probability vocabulary tokens to keep for top-k-filtering.", |
|
|
) |
|
|
@click.option( |
|
|
"--top_p", |
|
|
type=float, |
|
|
default=0.95, |
|
|
help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.", |
|
|
) |
|
|
@click.option( |
|
|
"--ras_win_len", |
|
|
type=int, |
|
|
default=7, |
|
|
help="The window length for RAS sampling. If set to 0 or a negative value, we won't use RAS sampling.", |
|
|
) |
|
|
@click.option( |
|
|
"--ras_win_max_num_repeat", |
|
|
type=int, |
|
|
default=2, |
|
|
help="The maximum number of times to repeat the RAS window. Only used when --ras_win_len is set.", |
|
|
) |
|
|
@click.option( |
|
|
"--ref_audio", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The voice prompt to use for generation. If not set, we will let the model randomly pick a voice. " |
|
|
"For multi-speaker generation, you can specify the prompts as `belinda,chadwick` and we will use the voice of belinda as SPEAKER0 and the voice of chadwick as SPEAKER1.", |
|
|
) |
|
|
@click.option( |
|
|
"--ref_audio_in_system_message", |
|
|
is_flag=True, |
|
|
default=False, |
|
|
help="Whether to include the voice prompt description in the system message.", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--chunk_method", |
|
|
default=None, |
|
|
type=click.Choice([None, "speaker", "word"]), |
|
|
help="The method to use for chunking the prompt text. Options are 'speaker', 'word', or None. By default, we won't use any chunking and will feed the whole text to the model.", |
|
|
) |
|
|
@click.option( |
|
|
"--chunk_max_word_num", |
|
|
default=200, |
|
|
type=int, |
|
|
help="The maximum number of words for each chunk when 'word' chunking method is used. Only used when --chunk_method is set to 'word'.", |
|
|
) |
|
|
@click.option( |
|
|
"--chunk_max_num_turns", |
|
|
default=1, |
|
|
type=int, |
|
|
help="The maximum number of turns for each chunk when 'speaker' chunking method is used. Only used when --chunk_method is set to 'speaker'.", |
|
|
) |
|
|
@click.option( |
|
|
"--generation_chunk_buffer_size", |
|
|
default=None, |
|
|
type=int, |
|
|
help="The maximal number of chunks to keep in the buffer. We will always keep the reference audios, and keep `max_chunk_buffer` chunks of generated audio.", |
|
|
) |
|
|
@click.option( |
|
|
"--seed", |
|
|
default=None, |
|
|
type=int, |
|
|
help="Random seed for generation.", |
|
|
) |
|
|
@click.option( |
|
|
"--device_id", |
|
|
type=int, |
|
|
default=None, |
|
|
help="The device to run the model on.", |
|
|
) |
|
|
@click.option( |
|
|
"--out_path", |
|
|
type=str, |
|
|
default="generation.wav", |
|
|
) |
|
|
@click.option( |
|
|
"--use_static_kv_cache", |
|
|
type=int, |
|
|
default=1, |
|
|
help="Whether to use static KV cache for faster generation. Only works when using GPU.", |
|
|
) |
|
|
def main( |
|
|
model_path, |
|
|
audio_tokenizer, |
|
|
max_new_tokens, |
|
|
transcript, |
|
|
scene_prompt, |
|
|
temperature, |
|
|
top_k, |
|
|
top_p, |
|
|
ras_win_len, |
|
|
ras_win_max_num_repeat, |
|
|
ref_audio, |
|
|
ref_audio_in_system_message, |
|
|
chunk_method, |
|
|
chunk_max_word_num, |
|
|
chunk_max_num_turns, |
|
|
generation_chunk_buffer_size, |
|
|
seed, |
|
|
device_id, |
|
|
out_path, |
|
|
use_static_kv_cache, |
|
|
): |
|
|
if device_id is None: |
|
|
if torch.cuda.is_available(): |
|
|
device_id = 0 |
|
|
device = "cuda:0" |
|
|
else: |
|
|
device_id = None |
|
|
device = "cpu" |
|
|
else: |
|
|
device = f"cuda:{device_id}" |
|
|
audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer, device=device) |
|
|
|
|
|
model_client = HiggsAudioModelClient( |
|
|
model_path=model_path, |
|
|
audio_tokenizer=audio_tokenizer, |
|
|
device_id=device_id, |
|
|
max_new_tokens=max_new_tokens, |
|
|
use_static_kv_cache=use_static_kv_cache, |
|
|
) |
|
|
pattern = re.compile(r"\[(SPEAKER\d+)\]") |
|
|
|
|
|
if os.path.exists(transcript): |
|
|
logger.info(f"Loading transcript from {transcript}") |
|
|
with open(transcript, "r", encoding="utf-8") as f: |
|
|
transcript = f.read().strip() |
|
|
|
|
|
if scene_prompt is not None and scene_prompt != "empty" and os.path.exists(scene_prompt): |
|
|
with open(scene_prompt, "r", encoding="utf-8") as f: |
|
|
scene_prompt = f.read().strip() |
|
|
else: |
|
|
scene_prompt = None |
|
|
|
|
|
speaker_tags = sorted(set(pattern.findall(transcript))) |
|
|
|
|
|
|
|
|
transcript = normalize_chinese_punctuation(transcript) |
|
|
|
|
|
transcript = transcript.replace("(", " ") |
|
|
transcript = transcript.replace(")", " ") |
|
|
transcript = transcript.replace("°F", " degrees Fahrenheit") |
|
|
transcript = transcript.replace("°C", " degrees Celsius") |
|
|
|
|
|
for tag, replacement in [ |
|
|
("[laugh]", "<SE>[Laughter]</SE>"), |
|
|
("[humming start]", "<SE>[Humming]</SE>"), |
|
|
("[humming end]", "<SE_e>[Humming]</SE_e>"), |
|
|
("[music start]", "<SE_s>[Music]</SE_s>"), |
|
|
("[music end]", "<SE_e>[Music]</SE_e>"), |
|
|
("[music]", "<SE>[Music]</SE>"), |
|
|
("[sing start]", "<SE_s>[Singing]</SE_s>"), |
|
|
("[sing end]", "<SE_e>[Singing]</SE_e>"), |
|
|
("[applause]", "<SE>[Applause]</SE>"), |
|
|
("[cheering]", "<SE>[Cheering]</SE>"), |
|
|
("[cough]", "<SE>[Cough]</SE>"), |
|
|
]: |
|
|
transcript = transcript.replace(tag, replacement) |
|
|
lines = transcript.split("\n") |
|
|
transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()]) |
|
|
transcript = transcript.strip() |
|
|
|
|
|
if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]): |
|
|
transcript += "." |
|
|
|
|
|
messages, audio_ids = prepare_generation_context( |
|
|
scene_prompt=scene_prompt, |
|
|
ref_audio=ref_audio, |
|
|
ref_audio_in_system_message=ref_audio_in_system_message, |
|
|
audio_tokenizer=audio_tokenizer, |
|
|
speaker_tags=speaker_tags, |
|
|
) |
|
|
chunked_text = prepare_chunk_text( |
|
|
transcript, |
|
|
chunk_method=chunk_method, |
|
|
chunk_max_word_num=chunk_max_word_num, |
|
|
chunk_max_num_turns=chunk_max_num_turns, |
|
|
) |
|
|
|
|
|
logger.info("Chunks used for generation:") |
|
|
for idx, chunk_text in enumerate(chunked_text): |
|
|
logger.info(f"Chunk {idx}:") |
|
|
logger.info(chunk_text) |
|
|
logger.info("-----") |
|
|
|
|
|
concat_wv, sr, text_output = model_client.generate( |
|
|
messages=messages, |
|
|
audio_ids=audio_ids, |
|
|
chunked_text=chunked_text, |
|
|
generation_chunk_buffer_size=generation_chunk_buffer_size, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
ras_win_len=ras_win_len, |
|
|
ras_win_max_num_repeat=ras_win_max_num_repeat, |
|
|
seed=seed, |
|
|
) |
|
|
|
|
|
sf.write(out_path, concat_wv, sr) |
|
|
logger.info(f"Wav file is saved to '{out_path}' with sample rate {sr}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|