zsolt-klang commited on
Commit
004c940
1 Parent(s): ddf8395

Use audiocraft based implementation

Browse files
Files changed (1) hide show
  1. handler.py +7 -11
handler.py CHANGED
@@ -1,13 +1,15 @@
 
 
 
 
 
1
  from typing import Dict, List, Any
2
- from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
  import torch
4
  import logging
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
- # load model and processor from path
9
- self.processor = AutoProcessor.from_pretrained(path)
10
- self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
11
 
12
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
13
  """
@@ -21,19 +23,13 @@ class EndpointHandler:
21
  parameters = data.pop("parameters", None)
22
  self.model.set_generation_params(**parameters)
23
 
24
- # preprocess
25
- inputs = self.processor(
26
- text=[inputs],
27
- padding=True,
28
- return_tensors="pt",).to("cuda")
29
-
30
  # pass inputs with all kwargs in data
31
  if parameters is not None:
32
  with torch.autocast("cuda"):
33
  outputs = self.model.generate(**inputs)
34
  else:
35
  with torch.autocast("cuda"):
36
- outputs = self.model.generate(**inputs,)
37
 
38
  # postprocess the prediction
39
  prediction = outputs[0].cpu().numpy().tolist()
 
1
+ from audiocraft.data.audio_utils import convert_audio
2
+ from audiocraft.data.audio import audio_write
3
+ from audiocraft.models.encodec import InterleaveStereoCompressionModel
4
+ from audiocraft.models import MusicGen, MultiBandDiffusion
5
+
6
  from typing import Dict, List, Any
 
7
  import torch
8
  import logging
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
+ self.model = MusicGen.get_pretrained("musicgen-medium")
 
 
13
 
14
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
15
  """
 
23
  parameters = data.pop("parameters", None)
24
  self.model.set_generation_params(**parameters)
25
 
 
 
 
 
 
 
26
  # pass inputs with all kwargs in data
27
  if parameters is not None:
28
  with torch.autocast("cuda"):
29
  outputs = self.model.generate(**inputs)
30
  else:
31
  with torch.autocast("cuda"):
32
+ outputs = self.model.generate(**inputs)
33
 
34
  # postprocess the prediction
35
  prediction = outputs[0].cpu().numpy().tolist()