kumararvindibs commited on
Commit
5c6e3ec
1 Parent(s): ee3fbca

Update handlerForAudio.py

Browse files
Files changed (1) hide show
  1. handlerForAudio.py +18 -24
handlerForAudio.py CHANGED
@@ -1,11 +1,9 @@
1
- import requests
2
  from typing import Dict, Any
3
- from dotenv import load_dotenv, find_dotenv
4
- import os
5
- import streamlit as st
6
- import json
7
  from textToStoryGeneration import *
8
  import logging
 
 
 
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.DEBUG)
@@ -14,33 +12,29 @@ logging.basicConfig(level=logging.ERROR)
14
  # Configure logging
15
  logging.basicConfig(level=logging.WARNING)
16
 
17
- load_dotenv(find_dotenv())
18
- HUGGINFACE_API = os.getenv("HUGNINGFACEHUB_API_TOKEN")
19
 
20
  class CustomHandler:
21
  def __init__(self):
22
- self.model_name = "espnet/kan-bayashi_ljspeech_vits"
23
- self.endpoint = f"https://api-inference.huggingface.co/models/{self.model_name}"
24
-
25
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
26
  # Prepare the payload with input data
27
  logging.warning(f"------input_data-- {str(data)}")
28
- payload = {"inputs": data}
29
- print("payload----", payload)
30
  # Set headers with API token
31
- headers = {"Authorization": f"Bearer {HUGGINFACE_API}"}
 
 
 
 
 
 
 
32
 
33
- # Send POST request to the Hugging Face model endpoint
34
- response = requests.post(self.endpoint, json=payload, headers=headers)
35
- with open('StoryAudio.mp3', 'wb') as file:
36
- file.write(response.content)
37
- return 'StoryAudio.mp3'
38
  # Check if the request was successful
39
 
40
 
41
- # Example usage
42
- # if __name__ == "__main__":
43
- # handler = CustomHandler()
44
- # input_data = "Today I have tried with many model but I didnt find the any model which gives us better result and can be deployed on the endpoints. I think we need to Create custom Inference Handler and then it can be deployed on the interfernce end poitn.As I have deployed on model on interfernce endpoint i,e. text-to-story generation. I have also compared the result created with this endpoint and my local server as well that is not same. The endpoint is generating the different stroy."
45
- # result = handler(input_data)
46
- # print(result)dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddv 4
 
 
1
  from typing import Dict, Any
 
 
 
 
2
  from textToStoryGeneration import *
3
  import logging
4
+ import torch
5
+ import soundfile as sf
6
+ from transformers import AutoTokenizer, AutoModelForTextToWaveform
7
 
8
  # Configure logging
9
  logging.basicConfig(level=logging.DEBUG)
 
12
  # Configure logging
13
  logging.basicConfig(level=logging.WARNING)
14
 
15
+
 
16
 
17
  class CustomHandler:
18
  def __init__(self):
19
+
20
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
21
+ self.model= AutoModelForTextToWaveform.from_pretrained("facebook/mms-tts-eng")
22
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
23
  # Prepare the payload with input data
24
  logging.warning(f"------input_data-- {str(data)}")
25
+ payload = str(data)
26
+ logging.warning(f"payload----{str(payload)}")
27
  # Set headers with API token
28
+ inputs = self.tokenizer(payload, return_tensors="pt")
29
+
30
+ # Generate the waveform from the input text
31
+ with torch.no_grad():
32
+ outputs = self.model(**inputs)
33
+
34
+ # Save the audio to a file
35
+ sf.write("StoryAudio.wav", outputs["waveform"][0].numpy(), self.model.config.sampling_rate)
36
 
37
+ return 'StoryAudio.wav'
 
 
 
 
38
  # Check if the request was successful
39
 
40