diabetes_assistant / functions.py
mmccanse
update file names
a50564f
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
import os
from langchain.schema import AIMessage, HumanMessage
import gradio as gr
from transformers import pipeline, AutoTokenizer, TFAutoModelForSeq2SeqLM
import subprocess
import torch
import tempfile
from langdetect import detect
from transformers import MarianMTModel, MarianTokenizer
import boto3
# Additional imports for loading PDF documents and QA chain.
from langchain_community.document_loaders import PyPDFLoader
# Additional imports for loading Wikipedia content and QA chain
from langchain_community.document_loaders import WikipediaLoader
from langchain.chains.question_answering import load_qa_chain
# Import RegEx for translate function, to split sentences in avoiding token limits
import re
#Get keys #########################################################################################
load_dotenv()
# Set the model name for our LLMs.
OPENAI_MODEL = "gpt-3.5-turbo"
# Store the API key in a variable.
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
#Define variables for AWS Polly Access#############################################################
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
aws_default_region = os.getenv('AWS_DEFAULT_REGION')
#Define language variables ###################################################################################
#Define voice map
voice_map = {
"ar": "Hala",
"en": "Gregory",
"es": "Mia",
"fr": "Liam",
"de": "Vicki",
"it": "Bianca",
"zh": "Hiujin",
"hi": "Kajal",
"jap": "Tomoko",
"trk": "Burcu"
}
#Define language map from full names to ISO codes
language_map = {
"Arabic (Gulf)": "ar",
"Chinese (Cantonese)": "zh",
"English": "en",
"French": "fr",
"German": "de",
"Hindi": "hi",
"Italian": "it",
"Japanese": "jap",
"Spanish": "es",
"Turkish": "trk"
}
# list of languages and their codes for dropdown
languages = gr.Dropdown(
label="Click in the middle of the dropdown bar to select translation language",
choices=list(language_map.keys()))
#Define default language
default_language = "English"
#Setting the Chatbot Model #################################################################################
#Instantiating the llm we'll use and the arguments to pass
#This is done at a global level, and not within the definition of a function to improve
#the speed and efficiency of the app. Thus, the model will not be instantiated every time
#a new question is submitted. Similar setup is created for all of the models called. This
#was part of our optimization process to help the app be more efficient and effective.
llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model_name=OPENAI_MODEL, temperature=0.0)
# Define the wikipedia topic as a string.
wiki_topic = "diabetes"
# Load the wikipedia results as documents, using a max of 2.
#included error handling- unable to load documents
try:
documents = WikipediaLoader(query=wiki_topic, load_max_docs=2, load_all_available_meta=True).load()
except Exception as e:
print("Failed to load documents:", str(e))
documents = []
# Create the QA chain using the LLM.
chain = load_qa_chain(llm)
##############################################################################################################
#Define the function to call the OpenAI chat LLM
def handle_query(user_query):
if not documents:
return "Source not loading info; please try again later."
if user_query.lower() == 'quit':
return "Goodbye!"
try:
# Pass the documents and the user's query to the chain, and return the result.
result = chain.invoke({"input_documents": documents, "question": user_query})
return result["output_text"] if result["output_text"].strip() else "No answer found, try a different question."
except Exception as e:
return "An error occurred while searching for the answer: " + str(e)
#Language models and functions ############################################################################
#Setup cache mechanism to initialize translation model at module level to improve app speed.
#Define global variables for tokenizer and model
helsinki_model_cache = {}
def get_helsinki_model_and_tokenizer(src_lang, target_lang):
helsinki_model_name =f"Helsinki-NLP/opus-mt-{src_lang}-{target_lang}"
if helsinki_model_name not in helsinki_model_cache:
tokenizer = MarianTokenizer.from_pretrained(helsinki_model_name)
model = MarianMTModel.from_pretrained(helsinki_model_name)
helsinki_model_cache[helsinki_model_name] = (tokenizer, model)
return helsinki_model_cache[helsinki_model_name]
#Define function to transcribe audio to text and then translate it into the specified language
def translate(transcribed_text, target_lang="es"):
try:
#Define the model and tokenizer
src_lang = detect(transcribed_text)
tokenizer, model = get_helsinki_model_and_tokenizer(src_lang, target_lang)
max_length = tokenizer.model_max_length
# Split text based on sentence endings to better manage translation segments
# This is done because in previous iterations of the app, some translations hit
# the max number of tokens and the output was truncated. This is part of our
# evaluation and optimization process
sentences = re.split(r'(?<=[.!?]) +', transcribed_text)
full_translation = ""
# Process each sentence individually
for sentence in sentences:
tokens = tokenizer.encode(sentence, return_tensors="pt", truncation=True, max_length=max_length)
if tokens.size(1) > max_length:
continue # optionally handle long sentences longer than max # tokens for model
translated_tokens = model.generate(tokens)
segment_translation = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
full_translation += segment_translation + " "
return full_translation.strip()
except Exception as e:
print(f"An error occurred: {e}")
return "Error in transcription or translation"
#Initialize Whisper model at the module level to be used across different calls
transcription_pipeline = None
def initialize_transcription_model():
global transcription_pipeline
if transcription_pipeline is None:
transcription_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-large")
#Define function to transcribes audio to text using Whisper in the original language it was spoken
def transcribe_audio_original(audio_filepath):
try:
if transcription_pipeline is None:
initialize_transcription_model()
transcription_result = transcription_pipeline(audio_filepath)
transcribed_text = transcription_result['text']
return transcribed_text
except Exception as e:
print(f"an error occured: {e}")
return "Error in transcription"
#Initialize Polly client at module level
polly_client = boto3.client(
'polly',
region_name=aws_default_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key
)
# Define text-to-speech function using Amazon Polly
def polly_text_to_speech(text, lang_code):
try:
#get the appropriate voice ID from the mapping
voice_id = voice_map[lang_code]
#request speech synthesis
response = polly_client.synthesize_speech(
Engine = 'neural',
Text=text,
OutputFormat='mp3',
VoiceId=voice_id
)
# Save the audio to a temporary file and return its path
if "AudioStream" in response:
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as audio_file:
audio_file.write(response['AudioStream'].read())
return audio_file.name
except boto3.exceptions.Boto3Error as e:
print(f"Error accessing Polly: {e}")
return None # Return None if there was an error
#Define function to submit query to Wikipedia to feed into Gradio app
def submit_question (audio_filepath=None, typed_text=None, target_lang=default_language):
#Determine source of text: audio transctiption or direct text input
# if audio_filepath and typed_text:
# return "Please use only one input method at a time", None
if not audio_filepath and not typed_text:
return "Please provide input by typing or speaking", None
response_speech = None
response_text = None
if typed_text:
#submit through handle_query function
# query_text = typed_text
detected_lang_code = detect(typed_text)
response_text = handle_query(typed_text)
response_speech = polly_text_to_speech(response_text, detected_lang_code)
elif audio_filepath:
#transcribe audio to text in background
query_text = transcribe_audio_original(audio_filepath)
detected_lang_code = detect(query_text)
response_text = handle_query(query_text)
response_speech = polly_text_to_speech(response_text, detected_lang_code)
if not response_speech:
response_speech = "No audio available"
return response_text, response_speech
#Define function to transcribe audio and provide output in text and speech
def transcribe_and_speech(audio_filepath=None, typed_text=None, target_lang=default_language):
#Determine source of text: audio transctiption or direct text input
if audio_filepath and typed_text:
return "Please use only one input method at a time", None
query_text = None
detected_lang_code = None
original_speech = None
if typed_text:
#convert typed text to speech
query_text = typed_text
detected_lang_code = detect(query_text)
original_speech = polly_text_to_speech(query_text, detected_lang_code)
return None, original_speech
elif audio_filepath:
#transcribe audio to text
query_text = transcribe_audio_original(audio_filepath)
detected_lang_code = detect(query_text)
original_speech = polly_text_to_speech(query_text, detected_lang_code)
return query_text, original_speech
if not query_text:
return "Please provide input by typing or speaking.", None
#Check if the language is specified. Default to English if not.
target_lang_code = language_map.get(target_lang, "en")
#Map detected language code to language name
detected_lang = [key for key, value in language_map.items() if value == detected_lang_code][0]
return query_text, original_speech
#Define function to translate query into target language in text and audio
def translate_and_speech(response_text=None, target_lang=default_language):
#Detect language of input text
detected_lang_code = detect(response_text)
detected_lang = [key for key, value in language_map.items() if value == detected_lang_code][0]
#Check if the language is specified. Default to English if not.
target_lang_code = language_map.get(target_lang, "en")
#Process text: translate
#Check if the detected language and target language are the same
if detected_lang == target_lang:
translated_response = response_text
else:
translated_response = translate(response_text, target_lang_code)
#convert to speech
translated_speech = polly_text_to_speech(translated_response, target_lang_code)
return translated_response, translated_speech
# Function to clear out all inputs
def clear_inputs():
return None, None, None, None, None, None