from typing import Dict, List, Any import json import numpy as np from transformers import AutoProcessor, MusicgenForConditionalGeneration import torch import logging class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.processor = AutoProcessor.from_pretrained(path) # Check if CUDA is available, and set the device accordingly self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the model to the device self.model = MusicgenForConditionalGeneration.from_pretrained(path) self.model.to(self.device) # Correcting this line def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # Set up logging logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) # process input logger.debug(f"Data: {data}") inputs = data.pop("inputs", data) logger.debug(f"Inputs: {inputs}") parameters = data.pop("parameters", None) logger.debug(f"Parameters: {parameters}") duration = parameters.pop("duration", None) logger.debug(f"Duration: {duration}") audio = parameters.pop("audio", None) logger.debug(f"Audio: {audio}") sampling_rate = parameters.pop("sampling_rate", None) logger.debug(f"Sampling Rate: {sampling_rate}") if not sampling_rate: sampling_rate = self.model.config.audio_encoder.sampling_rate if audio is not None: audio_array = np.array(audio) audio = audio_array[: len(audio_array) // 3] # sample["array"] = sample["array"][: len(sample["array"]) // 3] if duration is not None: # Calculate max new tokens based on duration, this is a placeholder, replace with actual logic max_new_tokens = int(duration * 50) else: max_new_tokens = 256 # Default value if duration is not provided # preprocess inputs = self.processor( text=[inputs], padding=True, return_tensors="pt", audio=audio, sampling_rate=sampling_rate).to(self.device) # If 'duration' is inside 'parameters', remove it if parameters is not None and 'duration' in parameters: parameters.pop('duration') if parameters is not None and 'audio' in parameters: parameters.pop('audio') if parameters is not None and 'sampling_rate' in parameters: parameters.pop('sampling_rate') # pass inputs with all kwargs in data if parameters is not None: outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens, **parameters) else: outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens) # postprocess the prediction prediction = outputs[0].cpu().numpy() return [{"generated_text": prediction, "sampling_rate" : sampling_rate}]