samsonsbike commited on
Commit
99557a7
1 Parent(s): 084a6a5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -27
handler.py CHANGED
@@ -1,38 +1,46 @@
1
- from typing import Dict, List, Any
2
- from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
  import torch
 
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
- # load model and processor from path
8
- self.processor = AutoProcessor.from_pretrained(path)
9
- self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
 
 
 
 
 
10
 
11
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
12
- """
13
- Args:
14
- data (:dict:):
15
- The payload with the text prompt and generation parameters.
16
- """
17
- # process input
18
- inputs = data.pop("inputs", data)
19
- parameters = data.pop("parameters", None)
20
 
21
- # preprocess
22
- inputs = self.processor(
23
- text=[inputs],
24
- padding=True,
25
- return_tensors="pt",).to("cuda")
 
26
 
27
- # pass inputs with all kwargs in data
28
- if parameters is not None:
29
- with torch.autocast("cuda"):
30
- outputs = self.model.generate(**inputs, **parameters)
31
- else:
32
  with torch.autocast("cuda"):
33
- outputs = self.model.generate(**inputs,)
34
 
35
- # postprocess the prediction
36
- prediction = outputs[0].cpu().numpy().tolist()
 
 
 
 
37
 
38
- return [{"generated_audio": prediction}]
 
 
 
 
 
1
  import torch
2
+ import logging
3
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
+ from typing import Dict, Any
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
+ logging.basicConfig(level=logging.INFO)
9
+ try:
10
+ # load model and processor from path
11
+ self.processor = AutoProcessor.from_pretrained(path)
12
+ self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
13
+ except Exception as e:
14
+ logging.error(f"Error loading model or processor: {e}")
15
+ raise
16
 
17
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
18
+ try:
19
+ # validate and process input
20
+ inputs = data.get("inputs")
21
+ if not inputs:
22
+ raise ValueError("No inputs provided")
23
+
24
+ parameters = data.get("parameters", {})
 
25
 
26
+ # preprocess
27
+ processed_inputs = self.processor(
28
+ text=[inputs],
29
+ padding=True,
30
+ return_tensors="pt"
31
+ ).to("cuda")
32
 
33
+ # generate outputs
 
 
 
 
34
  with torch.autocast("cuda"):
35
+ outputs = self.model.generate(**processed_inputs, **parameters)
36
 
37
+ # postprocess the prediction
38
+ prediction = outputs[0].cpu().numpy().tolist()
39
+ return [{"generated_audio": prediction}]
40
+ except Exception as e:
41
+ logging.error(f"Error during model inference: {e}")
42
+ return {"error": str(e)}
43
 
44
+ # Example usage:
45
+ # handler = EndpointHandler(path="your_model_path")
46
+ # result = handler({"inputs": "your input text"})