supermomo668 commited on
Commit
692312c
1 Parent(s): 1ccfe17
Files changed (1) hide show
  1. handler.py +110 -0
handler.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from datasets import load_dataset
3
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
+ import torch, numpy as np
5
+ import io
6
+ import soundfile as sf
7
+
8
+ from audiocraft.models import MusicGen
9
+
10
+ import yaml
11
+ import math
12
+ import torchaudio
13
+ import torch
14
+ from audiocraft.utils.notebook import display_audio
15
+
16
+ def get_bip_bip(
17
+ bip_duration=0.125, frequency=440, duration=0.5, sample_rate=32000, device="cuda"):
18
+ """Generates a series of bip bip at the given frequency."""
19
+ t = torch.arange(
20
+ int(duration * sample_rate), device="cuda", dtype=torch.float) / sample_rate
21
+ wav = torch.cos(2 * math.pi * 440 * t)[None]
22
+ tp = (t % (2 * bip_duration)) / (2 * bip_duration)
23
+ envelope = (tp >= 0.5).float()
24
+ return wav * envelope
25
+
26
+ def load_conf(conf):
27
+ with open(conf,'r') as f:
28
+ conf= yaml.safeload(f)
29
+ return conf
30
+
31
+ class generator:
32
+ def __init__(self, conf_file):
33
+ """
34
+ conf{
35
+ model
36
+ sampling_rate
37
+ }
38
+ """
39
+ self.conf = load_conf(conf_file)
40
+ self.processor = AutoProcessor.from_pretrained(self.conf['model'])
41
+ self.model = MusicGen.get_pretrained(self.conf['model'])
42
+ self.model.set_generation_params(
43
+ use_sampling=True,
44
+ top_k=250,
45
+ duration=self.conf['duration']
46
+ )
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ self.model.to(device)
49
+ self.sampling_rate = self.model.config.audio_encoder.sampling_rate
50
+
51
+ def preprocess(self, text, audio):
52
+ audio = audio[: int(len(audio) // self.conf['nth_slice_prompt'])]
53
+
54
+ def generate(self, text:list, audio: np.array, **kwargs):
55
+ """
56
+ text: ["modern melodic electronic dance music", "80s blues track with groovy saxophone"]
57
+ audio (np.array)
58
+ """
59
+ # inputs = self.processor(
60
+ # audio=audio,
61
+ # sampling_rate=self.conf["sampling_rate"],
62
+ # text=text,
63
+ # padding=True,
64
+ # return_tensors="pt",
65
+ # )
66
+ output = self.model.generate_with_chroma(
67
+ descriptions=[
68
+ text
69
+ ],
70
+ melody_wavs=audio,
71
+ melody_sample_rate=self.conf['sampling_rate'],
72
+ progress=True
73
+ )
74
+ return output
75
+
76
+
77
+ class EndpointHandler:
78
+ def __init__(self, path=""):
79
+ # load model and processor from path
80
+ self.processor = AutoProcessor.from_pretrained(path)
81
+ self.model = MusicgenForConditionalGeneration.from_pretrained(
82
+ path, torch_dtype=torch.float16).to("cuda")
83
+ self.generator = generator('conf.yaml')
84
+
85
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
86
+ """
87
+ Args:
88
+ data (:dict:):
89
+ The payload with the text prompt and generation parameters.
90
+ """
91
+ prompt_duration = 2
92
+ # process input
93
+ text = data.pop("text", data)
94
+ audio = data.pop("audio", data)
95
+ parameters = data.pop("parameters", None)
96
+ audio, sr = sf.read(io.BytesIO(audio))
97
+ output = self.generate(text, audio, sr)
98
+
99
+ # # pass inputs with all kwargs in data
100
+ # if parameters is not None:
101
+ # with torch.autocast("cuda"):
102
+ # outputs = self.model.generate(**inputs, **parameters)
103
+ # else:
104
+ # with torch.autocast("cuda"):
105
+ # outputs = self.model.generate(**inputs,)
106
+
107
+ # postprocess the prediction
108
+ prediction = output.squeeze().cpu().numpy().tolist()
109
+
110
+ return [{"generated_audio": prediction}]