pbotsaris commited on
Commit
ee4f8f7
1 Parent(s): acb1db0

updated handler to return audio base64 strings

Browse files
Files changed (3) hide show
  1. __pycache__/handler.cpython-310.pyc +0 -0
  2. handler.py +8 -3
  3. test.py +6 -11
__pycache__/handler.cpython-310.pyc CHANGED
Binary files a/__pycache__/handler.cpython-310.pyc and b/__pycache__/handler.cpython-310.pyc differ
 
handler.py CHANGED
@@ -3,6 +3,7 @@ from scipy.io import wavfile
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import torch
5
  import io
 
6
 
7
  def create_params(params, fr):
8
  # default
@@ -37,7 +38,7 @@ class EndpointHandler:
37
  self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
38
  self.model.to('cuda')
39
 
40
- def __call__(self, data: Dict[str, Any]) -> bytes:
41
  """
42
  Args:
43
  data (:dict:):
@@ -60,7 +61,7 @@ class EndpointHandler:
60
  with torch.cuda.amp.autocast():
61
  outputs = self.model.generate(**inputs, **params)
62
 
63
- pred = outputs[0].cpu().numpy().tolist()
64
  sr = 32000
65
 
66
  try:
@@ -76,7 +77,11 @@ class EndpointHandler:
76
  # Convert BytesIO to bytes
77
  wav_data = wav_buffer.getvalue()
78
 
79
- return wav_data
 
 
 
 
80
 
81
 
82
  if __name__ == "__main__":
 
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import torch
5
  import io
6
+ import base64
7
 
8
  def create_params(params, fr):
9
  # default
 
38
  self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
39
  self.model.to('cuda')
40
 
41
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
42
  """
43
  Args:
44
  data (:dict:):
 
61
  with torch.cuda.amp.autocast():
62
  outputs = self.model.generate(**inputs, **params)
63
 
64
+ pred = outputs[0].cpu().numpy()
65
  sr = 32000
66
 
67
  try:
 
77
  # Convert BytesIO to bytes
78
  wav_data = wav_buffer.getvalue()
79
 
80
+
81
+ # Convert the WAV binary data to Base64
82
+ base64_encoded_wav = base64.b64encode(wav_data).decode('utf-8')
83
+
84
+ return [{"audio": base64_encoded_wav, "sr": sr}]
85
 
86
 
87
  if __name__ == "__main__":
test.py CHANGED
@@ -1,17 +1,12 @@
1
  from handler import EndpointHandler
2
 
3
  # init handler
4
- my_handler = EndpointHandler(path="pbotsaris/musicgen-small")
 
 
5
 
6
- non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "duration": 2}
7
- # prepare sample payload
8
 
9
- # test the handler
10
- non_holiday_pred=my_handler(non_holiday_payload)
11
-
12
- # show results
13
  print("done")
14
- print(non_holiday_pred)
15
-
16
- # non_holiday_pred [{'label': 'joy', 'score': 0.9985942244529724}]
17
- # holiday_payload [{'label': 'happy', 'score': 1}]
 
1
  from handler import EndpointHandler
2
 
3
  # init handler
4
+ print('init handler')
5
+ my_handler = EndpointHandler(path=".")
6
+ p = {"inputs": "I am quite excited how this will turn out", "duration": 2}
7
 
8
+ print('calling handler')
9
+ pred=my_handler(p)
10
 
 
 
 
 
11
  print("done")
12
+ print(pred)