originstory commited on
Commit
a47fd1a
1 Parent(s): ff9b931

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +32 -28
handler.py CHANGED
@@ -2,37 +2,41 @@ from transformers import AutoProcessor, MusicgenForConditionalGeneration
2
  import torch
3
 
4
  class EndpointHandler:
5
- def __init__(self, path=""):
6
- # Load model and processor
7
- self.processor = AutoProcessor.from_pretrained(path)
8
- self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
9
 
10
  def __call__(self, data: dict) -> dict:
11
  """
12
  Args:
13
  data (dict): Contains the text prompt, vibe, style, and public domain song reference.
14
  """
15
- # Extract user inputs
16
- text_prompt = data.get("text_prompt")
17
- vibe = data.get("vibe")
18
- style = data.get("style")
19
- song_reference = data.get("song_reference")
20
-
21
- # Combine user inputs to form the complete prompt
22
- combined_prompt = f"{vibe} {style} version of {song_reference}: {text_prompt}"
23
-
24
- # Process the prompt
25
- inputs = self.processor(text=[combined_prompt], padding=True, return_tensors="pt").to("cuda")
26
-
27
- # Generate music
28
- with torch.autocast("cuda"):
29
- audio_output = self.model.generate(**inputs)
30
-
31
- # Convert to suitable format
32
- audio_data = audio_output[0].cpu().numpy().tolist()
33
-
34
- # Return generated music
35
- return {"generated_audio": audio_data}
36
-
37
- # Replace with the actual path or model identifier
38
- handler = EndpointHandler(path="path-to-your-model")
 
 
 
 
 
2
  import torch
3
 
4
  class EndpointHandler:
5
+ def __init__(self, model_path="originstory/holisleigh", use_auth_token=None):
6
+ # Load model and processor with consistent path
7
+ self.processor = AutoProcessor.from_pretrained(model_path, use_auth_token=None)
8
+ self.model = MusicgenForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16, use_auth_token=use_auth_token).to("cuda")
9
 
10
  def __call__(self, data: dict) -> dict:
11
  """
12
  Args:
13
  data (dict): Contains the text prompt, vibe, style, and public domain song reference.
14
  """
15
+ try:
16
+ # Extract user inputs
17
+ text_prompt = data.get("text_prompt")
18
+ vibe = data.get("vibe")
19
+ style = data.get("style")
20
+ song_reference = data.get("song_reference")
21
+
22
+ # Combine user inputs to form the complete prompt
23
+ combined_prompt = f"{vibe} {style} version of {song_reference}: {text_prompt}"
24
+
25
+ # Process the prompt
26
+ inputs = self.processor(text=[combined_prompt], padding=True, return_tensors="pt").to("cuda")
27
+
28
+ # Generate music
29
+ with torch.autocast("cuda"):
30
+ audio_output = self.model.generate(**inputs)
31
+
32
+ # Convert to suitable format
33
+ audio_data = audio_output[0].cpu().numpy().tolist()
34
+
35
+ # Return generated music
36
+ return {"generated_audio": audio_data}
37
+ except Exception as e:
38
+ # Handle errors
39
+ return {"error": str(e)}
40
+
41
+ # Example usage
42
+ handler = EndpointHandler()