import torch import logging from transformers import AutoProcessor, MusicgenForConditionalGeneration from typing import Dict, Any class EndpointHandler: def __init__(self, path=""): logging.basicConfig(level=logging.INFO) try: # load model and processor from path self.processor = AutoProcessor.from_pretrained(path) self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda") except Exception as e: logging.error(f"Error loading model or processor: {e}") raise def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: try: # validate and process input inputs = data.get("inputs") if not inputs: raise ValueError("No inputs provided") parameters = data.get("parameters", {}) # preprocess processed_inputs = self.processor( text=[inputs], padding=True, return_tensors="pt" ).to("cuda") # generate outputs with torch.autocast("cuda"): outputs = self.model.generate(**processed_inputs, **parameters) # postprocess the prediction prediction = outputs[0].cpu().numpy().tolist() return [{"generated_audio": prediction}] except Exception as e: logging.error(f"Error during model inference: {e}") return {"error": str(e)} # Example usage: # handler = EndpointHandler(path="your_model_path") # result = handler({"inputs": "your input text"})