musicgen-large / handler.py
slavocado's picture
Override EndpointHandler class of Inference Endpoints
0a16498
raw
history blame contribute delete
No virus
1.32 kB
from typing import Dict, List, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# preprocess
inputs = self.processor(
text=[inputs],
padding=True,
return_tensors="pt",).to("cuda")
# pass inputs with all kwargs in data
if parameters is not None:
with torch.autocast("cuda"):
outputs = self.model.generate(**inputs, **parameters)
else:
with torch.autocast("cuda"):
outputs = self.model.generate(**inputs,)
# postprocess the prediction
prediction = outputs[0].cpu().numpy().tolist()
return [{"generated_audio": prediction}]