| | |
| | import io, base64, wave |
| | import numpy as np |
| | import torch |
| | from transformers import AutoProcessor, CsmForConditionalGeneration |
| |
|
| | SAMPLING_RATE = 24000 |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | |
| | self.processor = AutoProcessor.from_pretrained(path) |
| | self.model = CsmForConditionalGeneration.from_pretrained( |
| | path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
| | ).to(self.device) |
| | self.model.eval() |
| |
|
| | def _wav_bytes(self, audio_f32, sr=SAMPLING_RATE): |
| | |
| | audio_i16 = np.clip(audio_f32, -1.0, 1.0) |
| | audio_i16 = (audio_i16 * 32767.0).astype(np.int16) |
| | buf = io.BytesIO() |
| | with wave.open(buf, "wb") as wf: |
| | wf.setnchannels(1) |
| | wf.setsampwidth(2) |
| | wf.setframerate(sr) |
| | wf.writeframes(audio_i16.tobytes()) |
| | return buf.getvalue() |
| |
|
| | def __call__(self, data): |
| | """ |
| | Accepts either: |
| | { "inputs": "Hello there", "parameters": {"speaker": 0, "max_length": 250} } |
| | or: |
| | { |
| | "conversation": [ |
| | {"role":"0","content":[{"type":"text","text":"Say hi!"}]} |
| | ], |
| | "parameters": {"speaker": 0} |
| | } |
| | """ |
| | params = data.get("parameters") or {} |
| | speaker = int(params.get("speaker", 0)) |
| | max_length = int(params.get("max_length", 250)) |
| |
|
| | if "conversation" in data: |
| | |
| | conversation = data["conversation"] |
| | |
| | for msg in conversation: |
| | if "role" in msg: |
| | msg["role"] = str(speaker) |
| | else: |
| | text = data.get("inputs") or "" |
| | conversation = [{"role": str(speaker), |
| | "content": [{"type": "text", "text": text}]}] |
| |
|
| | inputs = self.processor.apply_chat_template( |
| | conversation, |
| | tokenize=True, |
| | return_tensors="pt", |
| | return_dict=True |
| | ).to(self.device) |
| |
|
| | with torch.no_grad(): |
| | out = self.model.generate( |
| | **inputs, |
| | max_length=max_length, |
| | output_audio=True, |
| | do_sample=True, |
| | temperature=0.8, |
| | top_p=0.9, |
| | ) |
| |
|
| | |
| | |
| | if isinstance(out, np.ndarray): |
| | audio = out |
| | elif isinstance(out, list): |
| | |
| | first_item = out[0] if len(out) > 0 else out |
| | if hasattr(first_item, 'cpu'): |
| | audio = first_item.cpu().numpy() |
| | else: |
| | audio = np.array(first_item) |
| | elif hasattr(out, 'cpu'): |
| | |
| | audio = out.detach().cpu().numpy() |
| | else: |
| | |
| | audio = np.array(out) |
| |
|
| | wav_b = self._wav_bytes(audio, SAMPLING_RATE) |
| | return { |
| | "audio_base64": base64.b64encode(wav_b).decode("ascii"), |
| | "sampling_rate": SAMPLING_RATE, |
| | "format": "wav" |
| | } |
| |
|