cotxetj's picture
Update app.py
4166423
raw
history blame
No virus
7.73 kB
import torch
import json
import os
from transformers import pipeline, VitsModel, VitsTokenizer, SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
import numpy as np
os.system("pip install git+https://github.com/openai/whisper.git")
import gradio as gr
import whisper
import requests
MODEL = "gpt-3.5-turbo"
API_URL = os.getenv("API_URL")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
NUM_THREADS = 2
model = whisper.load_model("small")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
def parse_codeblock(text):
lines = text.split("\n")
for i, line in enumerate(lines):
if "```" in line:
if line != "```":
lines[i] = f'<pre><code class="{lines[i][3:]}">'
else:
lines[i] = '</code></pre>'
else:
if i > 0:
lines[i] = "<br/>" + line.replace("<", "&lt;").replace(">", "&gt;")
return "".join(lines)
def inference(audio):
audio = whisper.load_audio(audio)
print("loading finished")
audio = whisper.pad_or_trim(audio)
print("audio trimed")
mel = whisper.log_mel_spectrogram(audio).to(model.device)
print("spectro finished")
_, probs = model.detect_language(mel)
print("lang detected")
options = whisper.DecodingOptions(fp16 = False)
result = whisper.decode(model, mel, options)
print(result.text)
return result.text
#Load Whisper-small
pipe = pipeline("automatic-speech-recognition",
model="openai/whisper-small",
device=device
)
#pipe = pipeline(model="Sleepyp00/whisper-small-Swedish")
model2 = VitsModel.from_pretrained("facebook/mms-tts-eng")
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
# Define a function to translate an audio, in english here
def translate(audio):
# return inference(audio)
outputs = pipe(audio, max_new_tokens=256,
generate_kwargs={"task": "translate"})
return outputs["text"]
# Define function to generate the waveform output
def synthesise(text):
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
with torch.no_grad():
outputs = model2(input_ids)
return outputs.audio[0]
def gpt_predict(inputs, request:gr.Request=gr.State([]), top_p = 1, temperature = 1, chat_counter = 0,history =[]):
payload = {
"model": MODEL,
"messages": [{"role": "user", "content": f"{inputs}"}],
"temperature" : 1.0,
"top_p":1.0,
"n" : 1,
"stream": True,
"presence_penalty":0,
"frequency_penalty":0,
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}",
}
# print(f"chat_counter - {chat_counter}")
if chat_counter != 0 :
messages = []
for i, data in enumerate(history):
if i % 2 == 0:
role = 'user'
else:
role = 'assistant'
message = {}
message["role"] = role
message["content"] = data
messages.append(message)
message = {}
message["role"] = "user"
message["content"] = inputs
messages.append(message)
payload = {
"model": MODEL,
"messages": messages,
"temperature" : temperature,
"top_p": top_p,
"n" : 1,
"stream": True,
"presence_penalty":0,
"frequency_penalty":0,
}
chat_counter += 1
history.append(inputs)
token_counter = 0
partial_words = ""
counter = 0
try:
# make a POST request to the API endpoint using the requests.post method, passing in stream=True
print('in the try')
response = requests.post(API_URL, headers=headers, json=payload, stream=True)
response_code = f"{response}"
print(response_code)
#if response_code.strip() != "<Response [200]>":
# #print(f"response code - {response}")
# raise Exception(f"Sorry, hitting rate limit. Please try again later. {response}")
out = []
for chunk in response.iter_lines():
#Skipping first chunk
if counter == 0:
counter += 1
continue
#counter+=1
# check whether each line is non-empty
if chunk.decode() :
chunk = chunk.decode()
# decode each line as response data is in bytes
if len(chunk) > 12 and "content" in json.loads(chunk[6:])['choices'][0]['delta']:
partial_words = partial_words + json.loads(chunk[6:])['choices'][0]["delta"]["content"]
print(partial_words)
print(response)
if token_counter == 0:
history.append(" " + partial_words)
else:
history[-1] = partial_words
token_counter += 1
out.append([(parse_codeblock(history[i]), parse_codeblock(history[i + 1])) for i in range(0, len(history) - 1, 2) ], history, chat_counter, response, gr.update(interactive=False), gr.update(interactive=False) # resembles {chatbot: chat, state: history}
except Exception as e:
print (f'error found: {e}')
print(json.dumps({"chat_counter": chat_counter, "payload": payload, "partial_words": partial_words, "token_counter": token_counter, "counter": counter}))
return response["choices"]["message"]["content"]
return [(parse_codeblock(history[i]), parse_codeblock(history[i + 1])) for i in range(0, len(history) - 1, 2) ], history, chat_counter, response, gr.update(interactive=True), gr.update(interactive=True)
# Define the pipeline
def speech_to_speech_translation(audio):
translated_text = translate(audio)
synthesised_speech = synthesise(translated_text)
synthesised_speech = (
synthesised_speech.numpy() * 32767).astype(np.int16)
return [translated_text, (16000, synthesised_speech)]
def predict(transType, language, audio, audio_mic = None):
print("debug1:", audio,"debug2", audio_mic)
if not audio and audio_mic:
audio = audio_mic
st = gr.State([])
return gpt_predict("What time is it?",st)
if transType == "Text":
return translate(audio), None
if transType == "GPT answer":
req = translate(audio)
return gpt_predict(req)
if transType == "Audio":
return speech_to_speech_translation(audio)
# Define the title etc
title = "Swedish STSOT (Speech To Speech Or Text)"
description="Use Whisper pretrained model to convert swedish audio to english (text or audio)"
supportLangs = ["Swedish", "French (in training)"]
transTypes = ["Text", "Audio", "GPT answer"]
#examples = [
# ["Text", "Swedish", "./ex1.wav", None],
# ["Audio", "Swedish", "./ex2.wav", None]
#]
examples =[]
demo = gr.Interface(
fn=predict,
inputs=[
gr.Radio(label="Choose your output format", choices=transTypes),
gr.Radio(label="Choose a source language", choices=supportLangs, value="Swedish"),
gr.Audio(label="Import an audio", sources="upload", type="filepath"),
#gr.Audio(label="Import an audio", sources="upload", type="numpy"),
gr.Audio(label="Record an audio", sources="microphone", type="filepath"),
],
outputs=[
gr.Text(label="Text translation or gpt answer"),gr.Audio(label="Audio translation",type = "numpy")
],
title=title,
description=description,
article="",
examples=examples,
)
demo.launch()