onurio commited on
Commit
e657366
1 Parent(s): c53d19c

working handler

Browse files
__pycache__/handler.cpython-311.pyc ADDED
Binary file (3.29 kB). View file
 
__pycache__/handler.cpython-312.pyc ADDED
Binary file (2.54 kB). View file
 
call_endpoint.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ # Define the URL of the FastAPI server
4
+ url = "http://localhost:8000/generate_audio"
5
+
6
+ # Define the text for which you want to generate audio
7
+ text = "lo-fi music with a soothing melody"
8
+
9
+ # Define the headers for the request
10
+ headers = {"Content-Type": "application/json"}
11
+
12
+ # Make a POST request to the endpoint with the text data in the request body and the specified header
13
+ response = requests.post(url, json={"text": text}, headers=headers)
14
+
15
+ # Check if the request was successful
16
+ if response.status_code == 200:
17
+ # Save the audio file
18
+ with open("generated_audio.wav", "wb") as f:
19
+ f.write(response.content)
20
+ print("Audio file saved successfully.")
21
+ else:
22
+ print("Error:", response.text)
handler.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ import soundfile as sf
4
+ import torch
5
+ import logging
6
+ import io
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ # load the optimized model
13
+ # create inference pipeline
14
+ self.pipeline = pipeline("text-to-audio", "facebook/musicgen-stereo-large", device="mps", torch_dtype=torch.float16)
15
+
16
+ def generate_audio(self, text: str):
17
+ # Here you can implement your audio generation logic
18
+ # For demonstration purposes, let's use your existing code
19
+ logger.info("Generating audio for text: %s", text)
20
+ try:
21
+ music = self.pipeline(text, forward_params={"max_new_tokens": 256})
22
+ return music["audio"][0].T, music["sampling_rate"]
23
+ except Exception as e:
24
+ logger.error("Error generating audio for text: %s", text, exc_info=True)
25
+ raise e
26
+
27
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
28
+ """
29
+ Args:
30
+ data (:obj:):
31
+ includes the input data and the parameters for the inference.
32
+ Return:
33
+ A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
34
+ - "label": A string representing what the label/class is. There can be multiple labels.
35
+ - "score": A score between 0 and 1 describing how confident the model is for this label/class.
36
+ """
37
+ input = data.pop("input", data)
38
+
39
+ audio_data, sampling_rate = self.generate_audio(input)
40
+
41
+
42
+ with io.BytesIO() as buffer:
43
+ sf.write(buffer, audio_data, sampling_rate, format="WAV")
44
+ buffer.seek(0)
45
+ return buffer.getvalue()
test.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ # init handler
4
+ my_handler = EndpointHandler(path=".")
5
+
6
+ # prepare sample payload
7
+ payload = {"input": "Lowfi hiphop with deep bass"}
8
+
9
+ # test the handler
10
+ pred=my_handler(payload)
11
+
12
+
13
+ with open("generated_audio.wav", "wb") as f:
14
+ f.write(pred)