File size: 1,322 Bytes
816b46d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from typing import Dict, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy
import io
class EndpointHandler:
def __init__(self, path=""):
# Explicitly load processor with local files
self.processor = AutoProcessor.from_pretrained(
path,
local_files_only=True,
trust_remote_code=True
)
self.model = MusicgenForConditionalGeneration.from_pretrained(
path,
local_files_only=True,
trust_remote_code=True
)
def __call__(self, data: Dict[str, Any]) -> bytes:
text = data.get("inputs", "")
duration = data.get("parameters", {}).get("duration", 5)
inputs = self.processor(
text=[text],
return_tensors="pt",
padding=True,
truncation=True
)
audio_values = self.model.generate(
**inputs,
max_new_tokens=int(duration * 50)
)
sampling_rate = self.model.config.audio_encoder.sampling_rate
with io.BytesIO() as wav_io:
scipy.io.wavfile.write(
wav_io,
rate=sampling_rate,
data=audio_values[0, 0].numpy()
)
return wav_io.getvalue() |