File size: 1,953 Bytes
fa4523a 180d507 fa4523a 7555d4c fa4523a 42c44d1 33b11ab fa4523a 180d507 3408f20 9edeb3b fa4523a 33b11ab 180d507 81df6e3 180d507 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
from typing import Dict, List, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import array
import base64
import io
import wave
import numpy as np
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# preprocess
inputs = self.processor(
text=[inputs],
padding=True,
return_tensors="pt",).to("cuda")
# pass inputs with all kwargs in data
with torch.autocast("cuda"):
audio_values = self.model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=400)
# postprocess the prediction
sampling_rate = self.model.config.audio_encoder.sampling_rate
audio_samples = audio_values[0].cpu().numpy()[0].tolist()
audio_samples = [int(min(max(sample * 32000, -32000), 32000)) for sample in audio_samples]
# Create BytesIO object to capture the audio in-memory
audio_io = io.BytesIO()
# Create WAV file
with wave.open(audio_io, 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2) # 2 bytes for 16-bit PCM
wf.setframerate(sampling_rate)
wf.writeframes(array.array('h', audio_samples).tobytes())
audio_base64 = base64.b64encode(audio_io.getvalue()).decode('utf-8')
return [{'sampling_rate': sampling_rate, 'audio': audio_base64}] |