pbotsaris commited on
Commit
acb1db0
·
1 Parent(s): ceb1cc0

added tests and changed handler to respond with an wav file

Browse files
Files changed (4) hide show
  1. __pycache__/handler.cpython-310.pyc +0 -0
  2. handler.py +13 -2
  3. requirements.txt +1 -0
  4. test.py +17 -0
__pycache__/handler.cpython-310.pyc ADDED
Binary file (2.13 kB). View file
 
handler.py CHANGED
@@ -1,6 +1,8 @@
1
  from typing import Dict, List, Any
 
2
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
  import torch
 
4
 
5
  def create_params(params, fr):
6
  # default
@@ -35,11 +37,13 @@ class EndpointHandler:
35
  self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
36
  self.model.to('cuda')
37
 
38
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
39
  """
40
  Args:
41
  data (:dict:):
42
  The payload with the text prompt and generation parameters.
 
 
43
  """
44
 
45
  inputs = data.pop("inputs", data)
@@ -65,7 +69,14 @@ class EndpointHandler:
65
  except:
66
  sr = 32000
67
 
68
- return [{"audio": pred, "sr":sr}]
 
 
 
 
 
 
 
69
 
70
 
71
  if __name__ == "__main__":
 
1
  from typing import Dict, List, Any
2
+ from scipy.io import wavfile
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import torch
5
+ import io
6
 
7
  def create_params(params, fr):
8
  # default
 
37
  self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
38
  self.model.to('cuda')
39
 
40
+ def __call__(self, data: Dict[str, Any]) -> bytes:
41
  """
42
  Args:
43
  data (:dict:):
44
  The payload with the text prompt and generation parameters.
45
+
46
+ Returns: wav file in bytes
47
  """
48
 
49
  inputs = data.pop("inputs", data)
 
69
  except:
70
  sr = 32000
71
 
72
+ # Convert the audio data to WAV format
73
+ wav_buffer = io.BytesIO()
74
+ wavfile.write(wav_buffer, sr, pred)
75
+
76
+ # Convert BytesIO to bytes
77
+ wav_data = wav_buffer.getvalue()
78
+
79
+ return wav_data
80
 
81
 
82
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  transformers==4.31.0
2
  accelerate>=0.20.3
 
 
1
  transformers==4.31.0
2
  accelerate>=0.20.3
3
+ scipy==1.11.1
test.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ # init handler
4
+ my_handler = EndpointHandler(path="pbotsaris/musicgen-small")
5
+
6
+ non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "duration": 2}
7
+ # prepare sample payload
8
+
9
+ # test the handler
10
+ non_holiday_pred=my_handler(non_holiday_payload)
11
+
12
+ # show results
13
+ print("done")
14
+ print(non_holiday_pred)
15
+
16
+ # non_holiday_pred [{'label': 'joy', 'score': 0.9985942244529724}]
17
+ # holiday_payload [{'label': 'happy', 'score': 1}]