Spaces:
Sleeping
Sleeping
| from langchain_openai import ChatOpenAI | |
| from dotenv import load_dotenv | |
| import os | |
| from langchain.schema import AIMessage, HumanMessage | |
| import gradio as gr | |
| from transformers import pipeline, AutoTokenizer, TFAutoModelForSeq2SeqLM | |
| import subprocess | |
| import torch | |
| import tempfile | |
| from langdetect import detect | |
| from transformers import MarianMTModel, MarianTokenizer | |
| import boto3 | |
| # Additional imports for loading PDF documents and QA chain. | |
| from langchain_community.document_loaders import PyPDFLoader | |
| # Additional imports for loading Wikipedia content and QA chain | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langchain.chains.question_answering import load_qa_chain | |
| # Import RegEx for translate function, to split sentences in avoiding token limits | |
| import re | |
| #Get keys ######################################################################################### | |
| load_dotenv() | |
| # Set the model name for our LLMs. | |
| OPENAI_MODEL = "gpt-3.5-turbo" | |
| # Store the API key in a variable. | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| #Define variables for AWS Polly Access############################################################# | |
| aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID') | |
| aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY') | |
| aws_default_region = os.getenv('AWS_DEFAULT_REGION') | |
| #Define language variables ################################################################################### | |
| #Define voice map | |
| voice_map = { | |
| "ar": "Hala", | |
| "en": "Gregory", | |
| "es": "Mia", | |
| "fr": "Liam", | |
| "de": "Vicki", | |
| "it": "Bianca", | |
| "zh": "Hiujin", | |
| "hi": "Kajal", | |
| "jap": "Tomoko", | |
| "trk": "Burcu" | |
| } | |
| #Define language map from full names to ISO codes | |
| language_map = { | |
| "Arabic (Gulf)": "ar", | |
| "Chinese (Cantonese)": "zh", | |
| "English": "en", | |
| "French": "fr", | |
| "German": "de", | |
| "Hindi": "hi", | |
| "Italian": "it", | |
| "Japanese": "jap", | |
| "Spanish": "es", | |
| "Turkish": "trk" | |
| } | |
| # list of languages and their codes for dropdown | |
| languages = gr.Dropdown( | |
| label="Click in the middle of the dropdown bar to select translation language", | |
| choices=list(language_map.keys())) | |
| #Define default language | |
| default_language = "English" | |
| #Setting the Chatbot Model ################################################################################# | |
| #Instantiating the llm we'll use and the arguments to pass | |
| #This is done at a global level, and not within the definition of a function to improve | |
| #the speed and efficiency of the app. Thus, the model will not be instantiated every time | |
| #a new question is submitted. Similar setup is created for all of the models called. This | |
| #was part of our optimization process to help the app be more efficient and effective. | |
| llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model_name=OPENAI_MODEL, temperature=0.0) | |
| # Define the wikipedia topic as a string. | |
| wiki_topic = "diabetes" | |
| # Load the wikipedia results as documents, using a max of 2. | |
| #included error handling- unable to load documents | |
| try: | |
| documents = WikipediaLoader(query=wiki_topic, load_max_docs=2, load_all_available_meta=True).load() | |
| except Exception as e: | |
| print("Failed to load documents:", str(e)) | |
| documents = [] | |
| # Create the QA chain using the LLM. | |
| chain = load_qa_chain(llm) | |
| ############################################################################################################## | |
| #Define the function to call the OpenAI chat LLM | |
| def handle_query(user_query): | |
| if not documents: | |
| return "Source not loading info; please try again later." | |
| if user_query.lower() == 'quit': | |
| return "Goodbye!" | |
| try: | |
| # Pass the documents and the user's query to the chain, and return the result. | |
| result = chain.invoke({"input_documents": documents, "question": user_query}) | |
| return result["output_text"] if result["output_text"].strip() else "No answer found, try a different question." | |
| except Exception as e: | |
| return "An error occurred while searching for the answer: " + str(e) | |
| #Language models and functions ############################################################################ | |
| #Setup cache mechanism to initialize translation model at module level to improve app speed. | |
| #Define global variables for tokenizer and model | |
| helsinki_model_cache = {} | |
| def get_helsinki_model_and_tokenizer(src_lang, target_lang): | |
| helsinki_model_name =f"Helsinki-NLP/opus-mt-{src_lang}-{target_lang}" | |
| if helsinki_model_name not in helsinki_model_cache: | |
| tokenizer = MarianTokenizer.from_pretrained(helsinki_model_name) | |
| model = MarianMTModel.from_pretrained(helsinki_model_name) | |
| helsinki_model_cache[helsinki_model_name] = (tokenizer, model) | |
| return helsinki_model_cache[helsinki_model_name] | |
| #Define function to transcribe audio to text and then translate it into the specified language | |
| def translate(transcribed_text, target_lang="es"): | |
| try: | |
| #Define the model and tokenizer | |
| src_lang = detect(transcribed_text) | |
| tokenizer, model = get_helsinki_model_and_tokenizer(src_lang, target_lang) | |
| max_length = tokenizer.model_max_length | |
| # Split text based on sentence endings to better manage translation segments | |
| # This is done because in previous iterations of the app, some translations hit | |
| # the max number of tokens and the output was truncated. This is part of our | |
| # evaluation and optimization process | |
| sentences = re.split(r'(?<=[.!?]) +', transcribed_text) | |
| full_translation = "" | |
| # Process each sentence individually | |
| for sentence in sentences: | |
| tokens = tokenizer.encode(sentence, return_tensors="pt", truncation=True, max_length=max_length) | |
| if tokens.size(1) > max_length: | |
| continue # optionally handle long sentences longer than max # tokens for model | |
| translated_tokens = model.generate(tokens) | |
| segment_translation = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
| full_translation += segment_translation + " " | |
| return full_translation.strip() | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| return "Error in transcription or translation" | |
| #Initialize Whisper model at the module level to be used across different calls | |
| transcription_pipeline = None | |
| def initialize_transcription_model(): | |
| global transcription_pipeline | |
| if transcription_pipeline is None: | |
| transcription_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-large") | |
| #Define function to transcribes audio to text using Whisper in the original language it was spoken | |
| def transcribe_audio_original(audio_filepath): | |
| try: | |
| if transcription_pipeline is None: | |
| initialize_transcription_model() | |
| transcription_result = transcription_pipeline(audio_filepath) | |
| transcribed_text = transcription_result['text'] | |
| return transcribed_text | |
| except Exception as e: | |
| print(f"an error occured: {e}") | |
| return "Error in transcription" | |
| #Initialize Polly client at module level | |
| polly_client = boto3.client( | |
| 'polly', | |
| region_name=aws_default_region, | |
| aws_access_key_id=aws_access_key_id, | |
| aws_secret_access_key=aws_secret_access_key | |
| ) | |
| # Define text-to-speech function using Amazon Polly | |
| def polly_text_to_speech(text, lang_code): | |
| try: | |
| #get the appropriate voice ID from the mapping | |
| voice_id = voice_map[lang_code] | |
| #request speech synthesis | |
| response = polly_client.synthesize_speech( | |
| Engine = 'neural', | |
| Text=text, | |
| OutputFormat='mp3', | |
| VoiceId=voice_id | |
| ) | |
| # Save the audio to a temporary file and return its path | |
| if "AudioStream" in response: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as audio_file: | |
| audio_file.write(response['AudioStream'].read()) | |
| return audio_file.name | |
| except boto3.exceptions.Boto3Error as e: | |
| print(f"Error accessing Polly: {e}") | |
| return None # Return None if there was an error | |
| #Define function to submit query to Wikipedia to feed into Gradio app | |
| def submit_question (audio_filepath=None, typed_text=None, target_lang=default_language): | |
| #Determine source of text: audio transctiption or direct text input | |
| # if audio_filepath and typed_text: | |
| # return "Please use only one input method at a time", None | |
| if not audio_filepath and not typed_text: | |
| return "Please provide input by typing or speaking", None | |
| response_speech = None | |
| response_text = None | |
| if typed_text: | |
| #submit through handle_query function | |
| # query_text = typed_text | |
| detected_lang_code = detect(typed_text) | |
| response_text = handle_query(typed_text) | |
| response_speech = polly_text_to_speech(response_text, detected_lang_code) | |
| elif audio_filepath: | |
| #transcribe audio to text in background | |
| query_text = transcribe_audio_original(audio_filepath) | |
| detected_lang_code = detect(query_text) | |
| response_text = handle_query(query_text) | |
| response_speech = polly_text_to_speech(response_text, detected_lang_code) | |
| if not response_speech: | |
| response_speech = "No audio available" | |
| return response_text, response_speech | |
| #Define function to transcribe audio and provide output in text and speech | |
| def transcribe_and_speech(audio_filepath=None, typed_text=None, target_lang=default_language): | |
| #Determine source of text: audio transctiption or direct text input | |
| if audio_filepath and typed_text: | |
| return "Please use only one input method at a time", None | |
| query_text = None | |
| detected_lang_code = None | |
| original_speech = None | |
| if typed_text: | |
| #convert typed text to speech | |
| query_text = typed_text | |
| detected_lang_code = detect(query_text) | |
| original_speech = polly_text_to_speech(query_text, detected_lang_code) | |
| return None, original_speech | |
| elif audio_filepath: | |
| #transcribe audio to text | |
| query_text = transcribe_audio_original(audio_filepath) | |
| detected_lang_code = detect(query_text) | |
| original_speech = polly_text_to_speech(query_text, detected_lang_code) | |
| return query_text, original_speech | |
| if not query_text: | |
| return "Please provide input by typing or speaking.", None | |
| #Check if the language is specified. Default to English if not. | |
| target_lang_code = language_map.get(target_lang, "en") | |
| #Map detected language code to language name | |
| detected_lang = [key for key, value in language_map.items() if value == detected_lang_code][0] | |
| return query_text, original_speech | |
| #Define function to translate query into target language in text and audio | |
| def translate_and_speech(response_text=None, target_lang=default_language): | |
| #Detect language of input text | |
| detected_lang_code = detect(response_text) | |
| detected_lang = [key for key, value in language_map.items() if value == detected_lang_code][0] | |
| #Check if the language is specified. Default to English if not. | |
| target_lang_code = language_map.get(target_lang, "en") | |
| #Process text: translate | |
| #Check if the detected language and target language are the same | |
| if detected_lang == target_lang: | |
| translated_response = response_text | |
| else: | |
| translated_response = translate(response_text, target_lang_code) | |
| #convert to speech | |
| translated_speech = polly_text_to_speech(translated_response, target_lang_code) | |
| return translated_response, translated_speech | |
| # Function to clear out all inputs | |
| def clear_inputs(): | |
| return None, None, None, None, None, None |