working handler
Browse files- __pycache__/handler.cpython-311.pyc +0 -0
- __pycache__/handler.cpython-312.pyc +0 -0
- call_endpoint.py +22 -0
- handler.py +45 -0
- test.py +14 -0
__pycache__/handler.cpython-311.pyc
ADDED
Binary file (3.29 kB). View file
|
|
__pycache__/handler.cpython-312.pyc
ADDED
Binary file (2.54 kB). View file
|
|
call_endpoint.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
|
3 |
+
# Define the URL of the FastAPI server
|
4 |
+
url = "http://localhost:8000/generate_audio"
|
5 |
+
|
6 |
+
# Define the text for which you want to generate audio
|
7 |
+
text = "lo-fi music with a soothing melody"
|
8 |
+
|
9 |
+
# Define the headers for the request
|
10 |
+
headers = {"Content-Type": "application/json"}
|
11 |
+
|
12 |
+
# Make a POST request to the endpoint with the text data in the request body and the specified header
|
13 |
+
response = requests.post(url, json={"text": text}, headers=headers)
|
14 |
+
|
15 |
+
# Check if the request was successful
|
16 |
+
if response.status_code == 200:
|
17 |
+
# Save the audio file
|
18 |
+
with open("generated_audio.wav", "wb") as f:
|
19 |
+
f.write(response.content)
|
20 |
+
print("Audio file saved successfully.")
|
21 |
+
else:
|
22 |
+
print("Error:", response.text)
|
handler.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
from transformers import pipeline
|
3 |
+
import soundfile as sf
|
4 |
+
import torch
|
5 |
+
import logging
|
6 |
+
import io
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class EndpointHandler():
|
11 |
+
def __init__(self, path=""):
|
12 |
+
# load the optimized model
|
13 |
+
# create inference pipeline
|
14 |
+
self.pipeline = pipeline("text-to-audio", "facebook/musicgen-stereo-large", device="mps", torch_dtype=torch.float16)
|
15 |
+
|
16 |
+
def generate_audio(self, text: str):
|
17 |
+
# Here you can implement your audio generation logic
|
18 |
+
# For demonstration purposes, let's use your existing code
|
19 |
+
logger.info("Generating audio for text: %s", text)
|
20 |
+
try:
|
21 |
+
music = self.pipeline(text, forward_params={"max_new_tokens": 256})
|
22 |
+
return music["audio"][0].T, music["sampling_rate"]
|
23 |
+
except Exception as e:
|
24 |
+
logger.error("Error generating audio for text: %s", text, exc_info=True)
|
25 |
+
raise e
|
26 |
+
|
27 |
+
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
28 |
+
"""
|
29 |
+
Args:
|
30 |
+
data (:obj:):
|
31 |
+
includes the input data and the parameters for the inference.
|
32 |
+
Return:
|
33 |
+
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
|
34 |
+
- "label": A string representing what the label/class is. There can be multiple labels.
|
35 |
+
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
|
36 |
+
"""
|
37 |
+
input = data.pop("input", data)
|
38 |
+
|
39 |
+
audio_data, sampling_rate = self.generate_audio(input)
|
40 |
+
|
41 |
+
|
42 |
+
with io.BytesIO() as buffer:
|
43 |
+
sf.write(buffer, audio_data, sampling_rate, format="WAV")
|
44 |
+
buffer.seek(0)
|
45 |
+
return buffer.getvalue()
|
test.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from handler import EndpointHandler
|
2 |
+
|
3 |
+
# init handler
|
4 |
+
my_handler = EndpointHandler(path=".")
|
5 |
+
|
6 |
+
# prepare sample payload
|
7 |
+
payload = {"input": "Lowfi hiphop with deep bass"}
|
8 |
+
|
9 |
+
# test the handler
|
10 |
+
pred=my_handler(payload)
|
11 |
+
|
12 |
+
|
13 |
+
with open("generated_audio.wav", "wb") as f:
|
14 |
+
f.write(pred)
|