bark / handler.py
H-H-E's picture
Update handler.py
62f977c
raw
history blame
No virus
669 Bytes
from typing import Dict, List, Any
from transformers import pipeline
class EndpointHandler:
def __init__(self, path=""):
self.model = pipeline("text-to-speech", "suno/bark")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
text_prompt = data.pop("inputs", data)
# run normal prediction
speech_array = self.model(text_prompt,forward_params={"do_sample": True})
return speech_array