Spaces:
Sleeping
Sleeping
import os | |
import torch as th | |
import whisper | |
from whisper.audio import SAMPLE_RATE | |
from tenacity import retry, wait_random | |
import openai | |
import requests | |
# os.environ['OPENAI_API_KEY'] = "sk-<API KEY>" | |
class WHISPERModel: | |
def __init__(self, model_name='small', device='cuda',openai_flag=False): | |
self.device = device | |
self.openai_flag = openai_flag | |
self.model = whisper.load_model(model_name, device=self.device) | |
def get_info(self, audio_data, conv_duration=30): | |
clip_audio = whisper.pad_or_trim(audio_data, length=SAMPLE_RATE * conv_duration) | |
result = self.model.transcribe(clip_audio) | |
return result['language'] | |
def read_audio(self,audio_path,duration=None): | |
audio = None | |
try: | |
audio = whisper.load_audio(audio_path) | |
if duration : | |
audio = whisper.pad_or_trim(audio, length=SAMPLE_RATE * duration*60) | |
except IOError as err: | |
raise err | |
return audio | |
def speech_to_text(self, audio_path,duration=None): | |
text_data = dict() | |
audio_duration = 0 | |
conv_language = "" | |
if audio_path.startswith('http'): | |
r = requests.get(audio_path) | |
if r.status_code == 200: | |
audio = self.read_audio(audio_path) | |
else: | |
raise("Unable to reach for URL {}".format(audio_path)) | |
else: | |
audio = self.read_audio(audio_path) | |
conv_language = self.get_info(audio) | |
if conv_language !='en': | |
res = self.model.transcribe(audio,task='translate') | |
if self.openai_flag: | |
res['text'] = self.translate_text(res['text'], orginal_text=conv_language, convert_to='English') | |
else: | |
res = self.model.transcribe(audio) | |
audio_duration = audio.shape[0] / SAMPLE_RATE | |
text_data['text'] = res['text'] | |
text_data['duration'] = audio_duration | |
text_data['language'] = conv_language | |
return text_data | |
def translate_text(self, text, orginal_text='ar', convert_to='english'): | |
prompt = f'Translate the following {orginal_text} text to {convert_to}:\n\n{orginal_text}: ' + text + '\n{convert_to}:' | |
# Generate response using ChatGPT | |
response = openai.Completion.create( | |
engine='text-davinci-003', | |
prompt=prompt, | |
max_tokens=100, | |
n=1, | |
stop=None, | |
temperature=0.7 | |
) | |
# Extract the translated English text from the response | |
translation = response.choices[0].text.strip() | |
return translation | |
if __name__ == '__main__': | |
url = "https://prypto-api.aswat.co/surveillance/recordings/5f53c28b-3504-4b8b-9db5-0c8b69a96233.mp3" | |
audio2text = WHISPERModel() | |
text = audio2text.speech_to_text(url) | |