music-d3xy / handler.py
samsonsbike's picture
Update handler.py
99557a7 verified
raw
history blame contribute delete
No virus
1.66 kB
import torch
import logging
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path=""):
logging.basicConfig(level=logging.INFO)
try:
# load model and processor from path
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
except Exception as e:
logging.error(f"Error loading model or processor: {e}")
raise
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
try:
# validate and process input
inputs = data.get("inputs")
if not inputs:
raise ValueError("No inputs provided")
parameters = data.get("parameters", {})
# preprocess
processed_inputs = self.processor(
text=[inputs],
padding=True,
return_tensors="pt"
).to("cuda")
# generate outputs
with torch.autocast("cuda"):
outputs = self.model.generate(**processed_inputs, **parameters)
# postprocess the prediction
prediction = outputs[0].cpu().numpy().tolist()
return [{"generated_audio": prediction}]
except Exception as e:
logging.error(f"Error during model inference: {e}")
return {"error": str(e)}
# Example usage:
# handler = EndpointHandler(path="your_model_path")
# result = handler({"inputs": "your input text"})