dthomas84 commited on
Commit
d6b621b
1 Parent(s): 0a2222a

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +69 -0
handler.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import json
3
+ import numpy as np
4
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
5
+ import torch
6
+
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path=""):
10
+ # load model and processor from path
11
+ self.processor = AutoProcessor.from_pretrained(path)
12
+
13
+ # Check if CUDA is available, and set the device accordingly
14
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # Load the model to the device
17
+ self.model = MusicgenForConditionalGeneration.from_pretrained(path)
18
+ self.model.to(self.device) # Correcting this line
19
+
20
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
21
+ """
22
+ Args:
23
+ data (:dict:):
24
+ The payload with the text prompt and generation parameters.
25
+ """
26
+ # process input
27
+ inputs = data.pop("inputs", data)
28
+ parameters = data.pop("parameters", None)
29
+ duration = parameters.pop("duration", None)
30
+ audio = parameters.pop("audio", None)
31
+ sampling_rate = parameters.pop("sampling_rate", None)
32
+
33
+ if audio is not None:
34
+ audio_list = json.loads(audio)
35
+ audio_array = np.array(audio_list)
36
+ audio = audio_array
37
+
38
+ if duration is not None:
39
+ # Calculate max new tokens based on duration, this is a placeholder, replace with actual logic
40
+ max_new_tokens = int(duration * 50)
41
+ else:
42
+ max_new_tokens = 256 # Default value if duration is not provided
43
+
44
+ # preprocess
45
+ inputs = self.processor(
46
+ text=[inputs],
47
+ padding=True,
48
+ return_tensors="pt",
49
+ audio=audio,
50
+ sampling_rate=sampling_rate).to(self.device)
51
+
52
+ # If 'duration' is inside 'parameters', remove it
53
+ if parameters is not None and 'duration' in parameters:
54
+ parameters.pop('duration')
55
+ if parameters is not None and 'audio' in parameters:
56
+ parameters.pop('audio')
57
+ if parameters is not None and 'sampling_rate' in parameters:
58
+ parameters.pop('sampling_rate')
59
+
60
+ # pass inputs with all kwargs in data
61
+ if parameters is not None:
62
+ outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens, **parameters)
63
+ else:
64
+ outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
65
+
66
+ # postprocess the prediction
67
+ prediction = outputs[0].cpu().numpy()
68
+
69
+ return [{"generated_text": prediction}]