musicgen-large-csi / handler.py
reneepc's picture
Change model to run on cuda
987d5d0
raw
history blame contribute delete
No virus
1.03 kB
from typing import Dict, List, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
class EndpointHandler:
def __init__(self, path=""):
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
text_input = data.pop("inputs", data)
parameters = data.pop("parameters", None)
inputs = self.processor(
text = [text_input],
return_tensors="pt",
padding=True).to("cuda")
if parameters is not None:
with torch.autocast("cuda"):
outputs = self.model.generate(**inputs, **parameters)
else:
with torch.autocast("cuda"):
outputs = self.model.generate(**inputs)
prediction = outputs[0].cpu().numpy().tolist()
return [{"generated_audio": prediction}]