File size: 1,990 Bytes
18694be
acb1db0
18694be
 
acb1db0
ee4f8f7
ceb1cc0
01d132b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18694be
fb92df5
18694be
224fe62
3ce0400
18694be
3ce0400
18694be
 
 
 
acb1db0
18694be
 
da50f54
 
18694be
da50f54
e3598b0
da50f54
 
3ce0400
18694be
3ce0400
ceb1cc0
e3598b0
18694be
da50f54
3ce0400
ceb1cc0
da50f54
 
 
18694be
da50f54
e3598b0
fb92df5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Dict, List, Any
from scipy.io import wavfile
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import io
import base64

def create_params(params, fr):

    # default
    out = { "do_sample": True,
          "guidance_scale": 3, 
          "max_new_tokens": 256
          }

    has_tokens = False

    if params is None:
       return out

    if 'duration' in params:
        out['max_new_tokens'] =  params['duration'] * fr
        has_tokens = True

    for k, p in params.items():
        if k in out: 
          if has_tokens and k == 'max_new_tokens':
            continue

          out[k] = p

    return out


class EndpointHandler:
    def __init__(self, path="pbotsaris/musicgen-small"):
        self.processor = AutoProcessor.from_pretrained(path)
        self.model = MusicgenForConditionalGeneration.from_pretrained(path)
        self.model.to('cuda:0') #type: ignore

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Args:
            data (:dict:):
                The payload with the text prompt and generation parameters.

        """

        inputs = data.pop("inputs", data)
        params = data.pop("parameters", None)

        inputs = self.processor(
            text=[inputs],
            padding=True,
            return_tensors="pt"
        )

        params = create_params(params, self.model.config.audio_encoder.frame_rate) #type: ignore

        outputs = self.model.generate(**inputs.to('cuda:0'), **params) #type: ignore

        pred = outputs[0, 0].cpu().numpy()
        sr = self.model.config.audio_encoder.sampling_rate #type: ignore

        wav_buffer = io.BytesIO()
        wavfile.write(wav_buffer, rate=sr, data=pred)
        wav_data = wav_buffer.getvalue()

        base64_encoded_wav = base64.b64encode(wav_data).decode('utf-8')
        return  [{"audio": base64_encoded_wav, "sr": sr}]

if __name__ == "__main__":
    handler = EndpointHandler()