Spaces:
Running
on
L4
Running
on
L4
| import io | |
| import json | |
| import os | |
| import queue | |
| import re | |
| import time | |
| import traceback | |
| import wave | |
| from argparse import ArgumentParser | |
| from http import HTTPStatus | |
| from pathlib import Path | |
| from typing import Annotated, Any | |
| import librosa | |
| import numpy as np | |
| import ormsgpack | |
| import pyrootutils | |
| import soundfile as sf | |
| import torch | |
| import torchaudio | |
| from baize.datastructures import ContentType | |
| from kui.asgi import ( | |
| Body, | |
| FactoryClass, | |
| HTTPException, | |
| HttpRequest, | |
| HttpView, | |
| JSONResponse, | |
| Kui, | |
| OpenAPI, | |
| StreamResponse, | |
| request, | |
| ) | |
| from kui.asgi.routing import MultimethodRoutes | |
| from loguru import logger | |
| pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
| import struct | |
| from threading import Lock | |
| import httpx | |
| from cachetools import LRUCache, cached | |
| from funasr import AutoModel | |
| from silero_vad import get_speech_timestamps, load_silero_vad | |
| from fish_speech.models.text2semantic.llama import BaseModelArgs | |
| # from fish_speech.models.vqgan.lit_module import VQGAN | |
| from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture | |
| from fish_speech.text.chn_text_norm.text import Text as ChnNormedText | |
| # from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN | |
| from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer | |
| from fish_speech.utils import autocast_exclude_mps, set_seed | |
| from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text | |
| from tools.llama.generate import ( | |
| GenerateRequest, | |
| GenerateResponse, | |
| WrappedGenerateResponse, | |
| launch_thread_safe_queue, | |
| launch_thread_safe_queue_agent, | |
| ) | |
| from tools.schema import ( | |
| GLOBAL_NUM_SAMPLES, | |
| ASRPackRequest, | |
| ServeASRRequest, | |
| ServeASRResponse, | |
| ServeASRSegment, | |
| ServeAudioPart, | |
| ServeForwardMessage, | |
| ServeMessage, | |
| ServeRequest, | |
| ServeResponse, | |
| ServeStreamDelta, | |
| ServeStreamResponse, | |
| ServeTextPart, | |
| ServeTimedASRResponse, | |
| ServeTTSRequest, | |
| ServeVQGANDecodeRequest, | |
| ServeVQGANDecodeResponse, | |
| ServeVQGANEncodeRequest, | |
| ServeVQGANEncodeResponse, | |
| ServeVQPart, | |
| ) | |
| from tools.vqgan.inference import load_model as load_decoder_model | |
| global_lock = Lock() | |
| # Whether to disable keepalive (which is helpful if the server is in the same cluster) | |
| DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true" | |
| async_client = httpx.AsyncClient( | |
| timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None) | |
| ) | |
| backends = torchaudio.list_audio_backends() | |
| if "ffmpeg" in backends: | |
| backend = "ffmpeg" | |
| else: | |
| backend = "soundfile" | |
| def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): | |
| buffer = io.BytesIO() | |
| with wave.open(buffer, "wb") as wav_file: | |
| wav_file.setnchannels(channels) | |
| wav_file.setsampwidth(bit_depth // 8) | |
| wav_file.setframerate(sample_rate) | |
| wav_header_bytes = buffer.getvalue() | |
| buffer.close() | |
| return wav_header_bytes | |
| # Define utils for web server | |
| async def http_execption_handler(exc: HTTPException): | |
| return JSONResponse( | |
| dict( | |
| statusCode=exc.status_code, | |
| message=exc.content, | |
| error=HTTPStatus(exc.status_code).phrase, | |
| ), | |
| exc.status_code, | |
| exc.headers, | |
| ) | |
| async def other_exception_handler(exc: "Exception"): | |
| traceback.print_exc() | |
| status = HTTPStatus.INTERNAL_SERVER_ERROR | |
| return JSONResponse( | |
| dict(statusCode=status, message=str(exc), error=status.phrase), | |
| status, | |
| ) | |
| def load_audio(reference_audio, sr): | |
| if len(reference_audio) > 255 or not Path(reference_audio).exists(): | |
| audio_data = reference_audio | |
| reference_audio = io.BytesIO(audio_data) | |
| waveform, original_sr = torchaudio.load(reference_audio, backend=backend) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| if original_sr != sr: | |
| resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr) | |
| waveform = resampler(waveform) | |
| audio = waveform.squeeze().numpy() | |
| return audio | |
| def encode_reference(*, decoder_model, reference_audio, enable_reference_audio): | |
| if enable_reference_audio and reference_audio is not None: | |
| # Load audios, and prepare basic info here | |
| reference_audio_content = load_audio( | |
| reference_audio, decoder_model.spec_transform.sample_rate | |
| ) | |
| audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[ | |
| None, None, : | |
| ] | |
| audio_lengths = torch.tensor( | |
| [audios.shape[2]], device=decoder_model.device, dtype=torch.long | |
| ) | |
| logger.info( | |
| f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds" | |
| ) | |
| # VQ Encoder | |
| if isinstance(decoder_model, FireflyArchitecture): | |
| prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0] | |
| logger.info(f"Encoded prompt: {prompt_tokens.shape}") | |
| else: | |
| prompt_tokens = None | |
| logger.info("No reference audio provided") | |
| return prompt_tokens | |
| def decode_vq_tokens( | |
| *, | |
| decoder_model, | |
| codes, | |
| ): | |
| feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device) | |
| logger.info(f"VQ features: {codes.shape}") | |
| if isinstance(decoder_model, FireflyArchitecture): | |
| # VQGAN Inference | |
| return decoder_model.decode( | |
| indices=codes[None], | |
| feature_lengths=feature_lengths, | |
| )[0].squeeze() | |
| raise ValueError(f"Unknown model type: {type(decoder_model)}") | |
| routes = MultimethodRoutes(base_class=HttpView) | |
| def get_content_type(audio_format): | |
| if audio_format == "wav": | |
| return "audio/wav" | |
| elif audio_format == "flac": | |
| return "audio/flac" | |
| elif audio_format == "mp3": | |
| return "audio/mpeg" | |
| else: | |
| return "application/octet-stream" | |
| def batch_encode(model, audios: list[bytes | torch.Tensor]): | |
| audios = [ | |
| ( | |
| torch.from_numpy( | |
| librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0] | |
| )[None] | |
| if isinstance(audio, bytes) | |
| else audio | |
| ) | |
| for audio in audios | |
| ] | |
| # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios): | |
| # raise ValueError("Single audio length is too long (>120s)") | |
| max_length = max(audio.shape[-1] for audio in audios) | |
| print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s") | |
| lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device) | |
| max_length = lengths.max().item() | |
| padded = torch.stack( | |
| [ | |
| torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1])) | |
| for audio in audios | |
| ] | |
| ).to(model.device) | |
| features, feature_lengths = model.encode(padded, audio_lengths=lengths) | |
| features, feature_lengths = features.cpu(), feature_lengths.cpu() | |
| return [feature[..., :length] for feature, length in zip(features, feature_lengths)] | |
| def cached_vqgan_batch_encode(model, audios: list[bytes]): | |
| return batch_encode(model, audios) | |
| def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]): | |
| start_time = time.time() | |
| tokens = cached_vqgan_batch_encode(decoder_model, payload.audios) | |
| logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms") | |
| return ormsgpack.packb( | |
| ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]), | |
| option=ormsgpack.OPT_SERIALIZE_PYDANTIC, | |
| ) | |
| def vqgan_decode(model, features): | |
| lengths = torch.tensor( | |
| [feature.shape[-1] for feature in features], device=model.device | |
| ) | |
| max_length = lengths.max().item() | |
| padded = torch.stack( | |
| [ | |
| torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1])) | |
| for feature in features | |
| ] | |
| ).to(model.device) | |
| # If bs too large, we do micro batch decode | |
| audios, audio_lengths = [], [] | |
| for i in range(0, padded.shape[0], 8): | |
| audio, audio_length = model.decode( | |
| padded[i : i + 8], feature_lengths=lengths[i : i + 8] | |
| ) | |
| audios.append(audio) | |
| audio_lengths.append(audio_length) | |
| audios = torch.cat(audios, dim=0) | |
| audio_lengths = torch.cat(audio_lengths, dim=0) | |
| audios, audio_lengths = audios.cpu(), audio_lengths.cpu() | |
| return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)] | |
| def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]): | |
| tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens] | |
| start_time = time.time() | |
| audios = vqgan_decode(decoder_model, tokens) | |
| logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms") | |
| audios = [audio.astype(np.float16).tobytes() for audio in audios] | |
| return ormsgpack.packb( | |
| ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC | |
| ) | |
| def batch_asr(model, audios, sr, language="auto"): | |
| resampled_audios = [] | |
| for audio in audios: | |
| audio = torchaudio.functional.resample(audio, sr, 16000) | |
| assert audio.ndim == 1 | |
| resampled_audios.append(audio) | |
| with global_lock: | |
| res = model.generate( | |
| input=resampled_audios, | |
| batch_size=len(resampled_audios), | |
| language=language, | |
| use_itn=True, | |
| ) | |
| results = [] | |
| for r, audio in zip(res, audios): | |
| text = r["text"] | |
| text = re.sub(r"<\|.*?\|>", "", text) | |
| duration = len(audio) / sr * 1000 | |
| huge_gap = False | |
| if "timestamp" in r and len(r["timestamp"]) > 2: | |
| for timestamp_a, timestamp_b in zip( | |
| r["timestamp"][:-1], r["timestamp"][1:] | |
| ): | |
| # If there is a gap of more than 5 seconds, we consider it as a huge gap | |
| if timestamp_b[0] - timestamp_a[1] > 5000: | |
| huge_gap = True | |
| break | |
| # Doesn't make sense to have a huge gap at the end | |
| if duration - r["timestamp"][-1][1] > 3000: | |
| huge_gap = True | |
| results.append( | |
| { | |
| "text": text, | |
| "duration": duration, | |
| "huge_gap": huge_gap, | |
| } | |
| ) | |
| return results | |
| def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]): | |
| start_time = time.time() | |
| audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios] | |
| audios = [torch.from_numpy(audio).float() for audio in audios] | |
| if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios): | |
| raise HTTPException(status_code=400, detail="Audio length is too long") | |
| transcriptions = batch_asr( | |
| asr_model, audios=audios, sr=payload.sample_rate, language=payload.language | |
| ) | |
| logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms") | |
| return ormsgpack.packb( | |
| ServeASRResponse(transcriptions=transcriptions), | |
| option=ormsgpack.OPT_SERIALIZE_PYDANTIC, | |
| ) | |
| from fish_speech.conversation import Conversation, Message | |
| def execute_request( | |
| input_queue: queue.Queue, | |
| tokenizer: FishTokenizer, | |
| config: BaseModelArgs, | |
| request: ServeRequest, | |
| device: str = "cuda:0", | |
| ): | |
| im_end_id = tokenizer.get_token_id(IM_END_TOKEN) | |
| messages = [] | |
| for message in request.messages: | |
| messages.append(message.to_conversation_message()) | |
| assert len(messages) >= 1, "At least one message is required" | |
| # assert messages[-1].role == "user", "The last message must be from the user" | |
| if messages[-1].role == "user": | |
| messages.append( | |
| Message(role="assistant", parts=[], add_im_end=False, modality="voice") | |
| ) | |
| elif messages[-1].role == "raw": | |
| messages[-1].add_im_start = False | |
| messages[-1].add_im_end = False | |
| messages[-1].modality = "voice" | |
| else: | |
| assert ( | |
| messages[-1].role == "assistant" | |
| ), "The last message must be from the assistant" | |
| messages[-1].add_im_end = False | |
| conv = Conversation(messages=messages) | |
| # conv.visualize(tokenizer) | |
| prompt = conv.encode_for_inference( | |
| tokenizer=tokenizer, num_codebooks=config.num_codebooks | |
| ).to(device) | |
| if request.streaming: | |
| for i in range(request.num_samples): | |
| yield ServeStreamResponse( | |
| sample_id=i, | |
| delta=ServeStreamDelta( | |
| role="assistant", | |
| ), | |
| ) | |
| req = { | |
| "prompt": prompt, | |
| "max_new_tokens": request.max_new_tokens, | |
| "im_end_id": im_end_id, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| "repetition_penalty": request.repetition_penalty, | |
| "num_samples": request.num_samples, | |
| "early_stop_threshold": request.early_stop_threshold, | |
| } | |
| start = time.time() | |
| response_queue = queue.Queue() | |
| input_queue.put(GenerateRequest(req, response_queue)) | |
| # Decoding | |
| decode_buffer = [[] for _ in range(request.num_samples)] | |
| parts = [[] for _ in range(request.num_samples)] | |
| def send_reset_buffer(sample_id): | |
| nonlocal decode_buffer | |
| if len(decode_buffer[sample_id]) == 0: | |
| return | |
| decoded = tokenizer.decode(decode_buffer[sample_id]) | |
| part = ServeTextPart(text=decoded) | |
| if request.streaming: | |
| yield ServeStreamResponse(delta=ServeStreamDelta(part=part)) | |
| else: | |
| parts[sample_id].append(part) | |
| decode_buffer[sample_id] = [] | |
| # Decode process | |
| finished = [False for _ in range(request.num_samples)] | |
| stats = {} | |
| idx = 0 | |
| while True: | |
| response = response_queue.get() | |
| if response in ["stop", "error"]: | |
| break | |
| for sample_id, tokens in enumerate(response): | |
| if finished[sample_id]: | |
| continue | |
| if tokens[0] == im_end_id: | |
| finished[sample_id] = True | |
| if request.streaming: | |
| yield from send_reset_buffer(sample_id) | |
| yield ServeStreamResponse( | |
| sample_id=sample_id, | |
| finish_reason="stop", | |
| stats=stats, | |
| ) | |
| continue | |
| is_semantic = ( | |
| tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id | |
| ) | |
| if is_semantic and request.streaming: | |
| yield from send_reset_buffer(sample_id) | |
| # Streaming vq | |
| _tokens = tokens[1:].clone() | |
| if config.share_codebook_embeddings is False: | |
| for i in range(len(_tokens)): | |
| _tokens[i] -= config.codebook_size * i | |
| yield ServeStreamResponse( | |
| sample_id=sample_id, | |
| delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())), | |
| ) | |
| continue | |
| # Not streaming vq | |
| if is_semantic: | |
| yield from send_reset_buffer(sample_id) | |
| # None streaming vq | |
| if len(parts[sample_id]) == 0 or not isinstance( | |
| parts[sample_id][-1], ServeVQPart | |
| ): | |
| _tokens = tokens[1:].clone() | |
| if config.share_codebook_embeddings is False: | |
| for i in range(len(_tokens)): | |
| _tokens[i] -= config.codebook_size * i | |
| parts[sample_id].append(ServeVQPart(codes=_tokens.tolist())) | |
| else: | |
| for codebook_id, value in enumerate(tokens[1:, :]): | |
| val = value.item() | |
| if config.share_codebook_embeddings is False: | |
| val -= config.codebook_size * codebook_id | |
| parts[sample_id][-1].codes[codebook_id].append(val) | |
| continue | |
| if not is_semantic: | |
| # Stream text decode is not supported now | |
| decode_buffer[sample_id].append(tokens[0, 0]) | |
| if idx == 0: | |
| stats["time_to_first_token"] = (time.time() - start) * 1000 | |
| idx += 1 | |
| for sample_id in range(request.num_samples): | |
| yield from send_reset_buffer(sample_id) | |
| stats["total_time"] = (time.time() - start) * 1000 | |
| stats["total_tokens"] = idx | |
| if request.streaming: | |
| for sample_id in range(request.num_samples): | |
| if finished[sample_id]: | |
| continue | |
| yield ServeStreamResponse( | |
| finish_reason=response, stats=stats, sample_id=sample_id | |
| ) | |
| return | |
| yield ServeResponse( | |
| messages=[ | |
| ServeMessage(role="assistant", parts=parts[i]) | |
| for i in range(request.num_samples) | |
| ], | |
| finish_reason=response, | |
| stats=stats, | |
| ) | |
| def api_invoke_chat( | |
| req: Annotated[ServeRequest, Body(exclusive=True)], | |
| ): | |
| """ | |
| Invoke model and generate audio | |
| """ | |
| # This makes torch compile happy | |
| assert ( | |
| req.num_samples == GLOBAL_NUM_SAMPLES | |
| ), f"num_samples must be {GLOBAL_NUM_SAMPLES}" | |
| content_type = request.headers.get("Content-Type", "application/json") | |
| json_mode = "application/json" in content_type | |
| async def wrapped_generator(): | |
| generator = execute_request(llama_queue, tokenizer, config, req, args.device) | |
| for i in generator: | |
| if json_mode: | |
| body = i.model_dump_json().encode("utf-8") | |
| yield b"data: " + body + b"\n\n" | |
| else: | |
| body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) | |
| yield struct.pack("I", len(body)) + body | |
| # Naive mode | |
| if req.streaming is False: | |
| result = next(execute_request(llama_queue, tokenizer, config, req, args.device)) | |
| if json_mode: | |
| return JSONResponse(result.model_dump()) | |
| else: | |
| return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) | |
| return StreamResponse( | |
| iterable=wrapped_generator(), content_type="text/event-stream" | |
| ) | |
| def inference(req: ServeTTSRequest): | |
| idstr: str | None = req.reference_id | |
| if idstr is not None: | |
| ref_folder = Path("references") / idstr | |
| ref_folder.mkdir(parents=True, exist_ok=True) | |
| ref_audios = list_files( | |
| ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False | |
| ) | |
| prompt_tokens = [ | |
| encode_reference( | |
| decoder_model=decoder_model, | |
| reference_audio=audio_to_bytes(str(ref_audio)), | |
| enable_reference_audio=True, | |
| ) | |
| for ref_audio in ref_audios | |
| ] | |
| prompt_texts = [ | |
| read_ref_text(str(ref_audio.with_suffix(".lab"))) | |
| for ref_audio in ref_audios | |
| ] | |
| else: | |
| # Parse reference audio aka prompt | |
| refs = req.references | |
| prompt_tokens = [ | |
| encode_reference( | |
| decoder_model=decoder_model, | |
| reference_audio=ref.audio, | |
| enable_reference_audio=True, | |
| ) | |
| for ref in refs | |
| ] | |
| prompt_texts = [ref.text for ref in refs] | |
| if req.seed is not None: | |
| set_seed(req.seed) | |
| logger.warning(f"set seed: {req.seed}") | |
| # LLAMA Inference | |
| request = dict( | |
| device=decoder_model.device, | |
| max_new_tokens=req.max_new_tokens, | |
| text=( | |
| req.text | |
| if not req.normalize | |
| else ChnNormedText(raw_text=req.text).normalize() | |
| ), | |
| top_p=req.top_p, | |
| repetition_penalty=req.repetition_penalty, | |
| temperature=req.temperature, | |
| compile=args.compile, | |
| iterative_prompt=req.chunk_length > 0, | |
| chunk_length=req.chunk_length, | |
| max_length=4096, | |
| prompt_tokens=prompt_tokens, | |
| prompt_text=prompt_texts, | |
| ) | |
| response_queue = queue.Queue() | |
| llama_queue.put( | |
| GenerateRequest( | |
| request=request, | |
| response_queue=response_queue, | |
| ) | |
| ) | |
| if req.streaming: | |
| yield wav_chunk_header() | |
| segments = [] | |
| while True: | |
| result: WrappedGenerateResponse = response_queue.get() | |
| if result.status == "error": | |
| raise result.response | |
| break | |
| result: GenerateResponse = result.response | |
| if result.action == "next": | |
| break | |
| with autocast_exclude_mps( | |
| device_type=decoder_model.device.type, dtype=args.precision | |
| ): | |
| fake_audios = decode_vq_tokens( | |
| decoder_model=decoder_model, | |
| codes=result.codes, | |
| ) | |
| fake_audios = fake_audios.float().cpu().numpy() | |
| if req.streaming: | |
| yield (fake_audios * 32768).astype(np.int16).tobytes() | |
| else: | |
| segments.append(fake_audios) | |
| if req.streaming: | |
| return | |
| if len(segments) == 0: | |
| raise HTTPException( | |
| HTTPStatus.INTERNAL_SERVER_ERROR, | |
| content="No audio generated, please check the input text.", | |
| ) | |
| fake_audios = np.concatenate(segments, axis=0) | |
| yield fake_audios | |
| async def inference_async(req: ServeTTSRequest): | |
| for chunk in inference(req): | |
| yield chunk | |
| async def buffer_to_async_generator(buffer): | |
| yield buffer | |
| async def api_invoke_model( | |
| req: Annotated[ServeTTSRequest, Body(exclusive=True)], | |
| ): | |
| """ | |
| Invoke model and generate audio | |
| """ | |
| if args.max_text_length > 0 and len(req.text) > args.max_text_length: | |
| raise HTTPException( | |
| HTTPStatus.BAD_REQUEST, | |
| content=f"Text is too long, max length is {args.max_text_length}", | |
| ) | |
| if req.streaming and req.format != "wav": | |
| raise HTTPException( | |
| HTTPStatus.BAD_REQUEST, | |
| content="Streaming only supports WAV format", | |
| ) | |
| if req.streaming: | |
| return StreamResponse( | |
| iterable=inference_async(req), | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=audio.{req.format}", | |
| }, | |
| content_type=get_content_type(req.format), | |
| ) | |
| else: | |
| fake_audios = next(inference(req)) | |
| buffer = io.BytesIO() | |
| sf.write( | |
| buffer, | |
| fake_audios, | |
| decoder_model.spec_transform.sample_rate, | |
| format=req.format, | |
| ) | |
| return StreamResponse( | |
| iterable=buffer_to_async_generator(buffer.getvalue()), | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=audio.{req.format}", | |
| }, | |
| content_type=get_content_type(req.format), | |
| ) | |
| async def api_health(): | |
| """ | |
| Health check | |
| """ | |
| return JSONResponse({"status": "ok"}) | |
| def parse_args(): | |
| parser = ArgumentParser() | |
| parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts") | |
| parser.add_argument("--load-asr-model", action="store_true") | |
| parser.add_argument( | |
| "--llama-checkpoint-path", | |
| type=str, | |
| default="checkpoints/fish-speech-1.4", | |
| ) | |
| parser.add_argument( | |
| "--decoder-checkpoint-path", | |
| type=str, | |
| default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", | |
| ) | |
| parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") | |
| parser.add_argument("--device", type=str, default="cuda") | |
| parser.add_argument("--half", action="store_true") | |
| parser.add_argument("--compile", action="store_true") | |
| parser.add_argument("--max-text-length", type=int, default=0) | |
| parser.add_argument("--listen", type=str, default="127.0.0.1:8080") | |
| parser.add_argument("--workers", type=int, default=1) | |
| return parser.parse_args() | |
| # Define Kui app | |
| openapi = OpenAPI( | |
| { | |
| "title": "Fish Speech API", | |
| "version": "1.4.2", | |
| }, | |
| ).routes | |
| class MsgPackRequest(HttpRequest): | |
| async def data( | |
| self, | |
| ) -> Annotated[ | |
| Any, ContentType("application/msgpack"), ContentType("application/json") | |
| ]: | |
| if self.content_type == "application/msgpack": | |
| return ormsgpack.unpackb(await self.body) | |
| elif self.content_type == "application/json": | |
| return await self.json | |
| raise HTTPException( | |
| HTTPStatus.UNSUPPORTED_MEDIA_TYPE, | |
| headers={"Accept": "application/msgpack, application/json"}, | |
| ) | |
| app = Kui( | |
| routes=routes + openapi[1:], # Remove the default route | |
| exception_handlers={ | |
| HTTPException: http_execption_handler, | |
| Exception: other_exception_handler, | |
| }, | |
| factory_class=FactoryClass(http=MsgPackRequest), | |
| cors_config={}, | |
| ) | |
| def load_asr_model(*, device="cuda", hub="ms"): | |
| return AutoModel( | |
| model="iic/SenseVoiceSmall", | |
| device=device, | |
| disable_pbar=True, | |
| hub=hub, | |
| ) | |
| # Each worker process created by Uvicorn has its own memory space, | |
| # meaning that models and variables are not shared between processes. | |
| # Therefore, any global variables (like `llama_queue` or `decoder_model`) | |
| # will not be shared across workers. | |
| # Multi-threading for deep learning can cause issues, such as inconsistent | |
| # outputs if multiple threads access the same buffers simultaneously. | |
| # Instead, it's better to use multiprocessing or independent models per thread. | |
| def initialize_app(app: Kui): | |
| global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts | |
| prompt_tokens, prompt_texts = [], [] | |
| args = parse_args() # args same as ones in other processes | |
| args.precision = torch.half if args.half else torch.bfloat16 | |
| if args.load_asr_model: | |
| logger.info(f"Loading ASR model...") | |
| asr_model = load_asr_model(device=args.device) | |
| logger.info("Loading Llama model...") | |
| if args.mode == "tts": | |
| llama_queue = launch_thread_safe_queue( | |
| checkpoint_path=args.llama_checkpoint_path, | |
| device=args.device, | |
| precision=args.precision, | |
| compile=args.compile, | |
| ) | |
| else: | |
| llama_queue, tokenizer, config = launch_thread_safe_queue_agent( | |
| checkpoint_path=args.llama_checkpoint_path, | |
| device=args.device, | |
| precision=args.precision, | |
| compile=args.compile, | |
| ) | |
| logger.info("Llama model loaded, loading VQ-GAN model...") | |
| decoder_model = load_decoder_model( | |
| config_name=args.decoder_config_name, | |
| checkpoint_path=args.decoder_checkpoint_path, | |
| device=args.device, | |
| ) | |
| logger.info("VQ-GAN model loaded, warming up...") | |
| vad_model = load_silero_vad() | |
| logger.info("VAD model loaded, warming up...") | |
| if args.mode == "tts": | |
| # Dry run to ensure models work and avoid first-time latency | |
| list( | |
| inference( | |
| ServeTTSRequest( | |
| text="Hello world.", | |
| references=[], | |
| reference_id=None, | |
| max_new_tokens=0, | |
| chunk_length=200, | |
| top_p=0.7, | |
| repetition_penalty=1.5, | |
| temperature=0.7, | |
| emotion=None, | |
| format="wav", | |
| ) | |
| ) | |
| ) | |
| logger.info(f"Warming up done, starting server at http://{args.listen}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| args = parse_args() | |
| host, port = args.listen.split(":") | |
| uvicorn.run( | |
| "tools.api:app", | |
| host=host, | |
| port=int(port), | |
| workers=args.workers, | |
| log_level="info", | |
| ) | |