pbotsaris commited on
Commit
01d132b
1 Parent(s): e3598b0

added create_params func

Browse files
Files changed (1) hide show
  1. handler.py +27 -0
handler.py CHANGED
@@ -5,6 +5,33 @@ import torch
5
  import io
6
  import base64
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class EndpointHandler:
9
  def __init__(self, path="pbotsaris/musicgen-small"):
10
  self.processor = AutoProcessor.from_pretrained(path)
 
5
  import io
6
  import base64
7
 
8
+ def create_params(params, fr):
9
+
10
+ # default
11
+ out = { "do_sample": True,
12
+ "guidance_scale": 3,
13
+ "max_new_tokens": 256
14
+ }
15
+
16
+ has_tokens = False
17
+
18
+ if params is None:
19
+ return out
20
+
21
+ if 'duration' in params:
22
+ out['max_new_tokens'] = params['duration'] * fr
23
+ has_tokens = True
24
+
25
+ for k, p in params.items():
26
+ if k in out:
27
+ if has_tokens and k == 'max_new_tokens':
28
+ continue
29
+
30
+ out[k] = p
31
+
32
+ return out
33
+
34
+
35
  class EndpointHandler:
36
  def __init__(self, path="pbotsaris/musicgen-small"):
37
  self.processor = AutoProcessor.from_pretrained(path)