|
|
import torch |
|
|
from fish_speech.models.fish_speech import FishSpeech |
|
|
from fish_speech.inference import infer |
|
|
import io |
|
|
import base64 |
|
|
import soundfile as sf |
|
|
|
|
|
|
|
|
model = FishSpeech.from_pretrained('fishaudio/fish-speech-1.5') |
|
|
|
|
|
def predict(inputs: dict): |
|
|
text = inputs.get('inputs', 'Hello world') |
|
|
|
|
|
if "[singing]" in text.lower(): |
|
|
mode = "singing" |
|
|
text = text.replace("[singing]", "") |
|
|
else: |
|
|
mode = "speech" |
|
|
|
|
|
|
|
|
audio = infer(model, text, mode=mode) |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
sf.write(buffer, audio.cpu().numpy(), 24000, format='WAV') |
|
|
audio_b64 = base64.b64encode(buffer.getvalue()).decode() |
|
|
|
|
|
return {"audio": audio_b64} |
|
|
|
|
|
def query(payload): |
|
|
return predict(payload) |