LLAMA-QA-AudioFiles / whisper_app.py
Samarth991's picture
whisper llama
6a3f9b4
import os
import torch as th
import whisper
from whisper.audio import SAMPLE_RATE
from tenacity import retry, wait_random
import openai
import requests
# os.environ['OPENAI_API_KEY'] = "sk-<API KEY>"
class WHISPERModel:
def __init__(self, model_name='small', device='cuda',openai_flag=False):
self.device = device
self.openai_flag = openai_flag
self.model = whisper.load_model(model_name, device=self.device)
def get_info(self, audio_data, conv_duration=30):
clip_audio = whisper.pad_or_trim(audio_data, length=SAMPLE_RATE * conv_duration)
result = self.model.transcribe(clip_audio)
return result['language']
def read_audio(self,audio_path,duration=None):
audio = None
try:
audio = whisper.load_audio(audio_path)
if duration :
audio = whisper.pad_or_trim(audio, length=SAMPLE_RATE * duration*60)
except IOError as err:
raise err
return audio
def speech_to_text(self, audio_path,duration=None):
text_data = dict()
audio_duration = 0
conv_language = ""
if audio_path.startswith('http'):
r = requests.get(audio_path)
if r.status_code == 200:
audio = self.read_audio(audio_path)
else:
raise("Unable to reach for URL {}".format(audio_path))
else:
audio = self.read_audio(audio_path)
conv_language = self.get_info(audio)
if conv_language !='en':
res = self.model.transcribe(audio,task='translate')
if self.openai_flag:
res['text'] = self.translate_text(res['text'], orginal_text=conv_language, convert_to='English')
else:
res = self.model.transcribe(audio)
audio_duration = audio.shape[0] / SAMPLE_RATE
text_data['text'] = res['text']
text_data['duration'] = audio_duration
text_data['language'] = conv_language
return text_data
@retry(wait=wait_random(min=5, max=10))
def translate_text(self, text, orginal_text='ar', convert_to='english'):
prompt = f'Translate the following {orginal_text} text to {convert_to}:\n\n{orginal_text}: ' + text + '\n{convert_to}:'
# Generate response using ChatGPT
response = openai.Completion.create(
engine='text-davinci-003',
prompt=prompt,
max_tokens=100,
n=1,
stop=None,
temperature=0.7
)
# Extract the translated English text from the response
translation = response.choices[0].text.strip()
return translation
if __name__ == '__main__':
url = "https://prypto-api.aswat.co/surveillance/recordings/5f53c28b-3504-4b8b-9db5-0c8b69a96233.mp3"
audio2text = WHISPERModel()
text = audio2text.speech_to_text(url)