from langchain.chains import create_sql_query_chain from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline, LlamaTokenizer, LlamaForCausalLM from langchain_huggingface import HuggingFacePipeline from langchain_openai import ChatOpenAI import os from langchain_community.utilities.sql_database import SQLDatabase from operator import itemgetter from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate from langchain_community.vectorstores import Chroma from langchain_core.example_selectors import SemanticSimilarityExampleSelector from langchain_openai import OpenAIEmbeddings from operator import itemgetter from langchain.chains.openai_tools import create_extraction_chain_pydantic from langchain_core.pydantic_v1 import BaseModel, Field from typing import List import pandas as pd from argparse import ArgumentParser import json from langchain.memory import ChatMessageHistory from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool import subprocess import sys from transformers import pipeline import librosa import soundfile import datasets import sounddevice as sd import numpy as np import io model_id = "avnishkanungo/whisper-small-dv" # update with your model id pipe = pipeline("automatic-speech-recognition", model=model_id) def select_table(desc_path): def get_table_details(): # Read the CSV file into a DataFrame table_description = pd.read_csv(desc_path) ##"/teamspace/studios/this_studio/database_table_descriptions.csv" table_docs = [] # Iterate over the DataFrame rows to create Document objects table_details = "" for index, row in table_description.iterrows(): table_details = table_details + "Table Name:" + row['Table'] + "\n" + "Table Description:" + row['Description'] + "\n\n" return table_details class Table(BaseModel): """Table in SQL database.""" name: str = Field(description="Name of table in SQL database.") table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \ The tables are: {get_table_details()} Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.""" table_chain = create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) def get_tables(tables: List[Table]) -> List[str]: tables = [table.name for table in tables] return tables select_table = {"input": itemgetter("question")} | create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) | get_tables return select_table def prompt_creation(example_path): with open(example_path, 'r') as file: ##'/teamspace/studios/this_studio/few_shot_samples.json' data = json.load(file) examples = data["examples"] example_prompt = ChatPromptTemplate.from_messages( [ ("human", "{input}\nSQLQuery:"), ("ai", "{query}"), ] ) vectorstore = Chroma() vectorstore.delete_collection() example_selector = SemanticSimilarityExampleSelector.from_examples( examples, OpenAIEmbeddings(), vectorstore, k=2, input_keys=["input"], ) few_shot_prompt = FewShotChatMessagePromptTemplate( example_prompt=example_prompt, example_selector=example_selector, input_variables=["input","top_k"], ) final_prompt = ChatPromptTemplate.from_messages( [ ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."), few_shot_prompt, MessagesPlaceholder(variable_name="messages"), ("human", "{input}"), ] ) print(few_shot_prompt.format(input="How many products are there?")) return final_prompt def rephrase_answer(): answer_prompt = PromptTemplate.from_template( """Given the following user question, corresponding SQL query, and SQL result, answer the user question. Question: {question} SQL Query: {query} SQL Result: {result} Answer: """ ) rephrase_answer = answer_prompt | llm | StrOutputParser() return rephrase_answer def is_ffmpeg_installed(): try: # Run `ffmpeg -version` to check if ffmpeg is installed subprocess.run(['ffmpeg', '-version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return True except (subprocess.CalledProcessError, FileNotFoundError): return False def install_ffmpeg(): try: if sys.platform.startswith('linux'): subprocess.run(['sudo', 'apt-get', 'update'], check=True) subprocess.run(['sudo', 'apt-get', 'install', '-y', 'ffmpeg'], check=True) elif sys.platform == 'darwin': # macOS subprocess.run(['/bin/bash', '-c', 'brew install ffmpeg'], check=True) elif sys.platform == 'win32': print("Please download ffmpeg from https://ffmpeg.org/download.html and install it manually.") return False else: print("Unsupported OS. Please install ffmpeg manually.") return False except subprocess.CalledProcessError as e: print(f"Failed to install ffmpeg: {e}") return False return True def transcribe_speech(filepath): output = pipe( filepath, max_new_tokens=256, generate_kwargs={ "task": "transcribe", "language": "english", }, # update with the language you've fine-tuned on chunk_length_s=30, batch_size=8, ) return output["text"] def record_command(): sample_rate = 16000 # Sample rate in Hz duration = 8 # Duration in seconds print("Recording...") # Record audio audio = sd.rec(int(sample_rate * duration), samplerate=sample_rate, channels=1, dtype='float32') sd.wait() # Wait until recording is finished print("Recording finished") # Convert the audio to a binary stream and save it to a variable audio_buffer = io.BytesIO() soundfile.write(audio_buffer, audio, sample_rate, format='WAV') audio_buffer.seek(0) # Reset buffer position to the beginning # The audio file is now saved in audio_buffer # You can read it again using soundfile or any other audio library audio_data, sample_rate = soundfile.read(audio_buffer) # Optional: Save the audio to a file for verification # with open('recorded_audio.wav', 'wb') as f: # f.write(audio_buffer.getbuffer()) print("Audio saved to variable") return audio_data def check_libportaudio_installed(): try: # Run `ffmpeg -version` to check if ffmpeg is installed subprocess.run(['libportaudio2', '-version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return True except (subprocess.CalledProcessError, FileNotFoundError): return False def install_libportaudio(): try: if sys.platform.startswith('linux'): subprocess.run(['sudo', 'apt-get', 'update'], check=True) subprocess.run(['sudo', 'apt-get', 'install', '-y', 'libportaudio2'], check=True) elif sys.platform == 'darwin': # macOS subprocess.run(['/bin/bash', '-c', 'brew install portaudio'], check=True) elif sys.platform == 'win32': print("Please download ffmpeg from https://ffmpeg.org/download.html and install it manually.") return False else: print("Unsupported OS. Please install ffmpeg manually.") return False except subprocess.CalledProcessError as e: print(f"Failed to install ffmpeg: {e}") return False return True if __name__ == '__main__': # Please configure your DB credentials and paths of the files for few shot learning and fine tuning parser = ArgumentParser() parser.add_argument('--example_path', type=str, default=os.getcwd()+"/few_shot_samples.json") parser.add_argument('--desc_path', type=str, default=os.getcwd()+"/database_table_descriptions.csv") parser.add_argument('--db_user', type=str, default="root") parser.add_argument('--db_password', type=str, default="") parser.add_argument('--db_host', type=str, default="localhost") parser.add_argument('--db_name', type=str, default="classicmodels") parser.add_argument('--open_ai_key', type=str) args = parser.parse_args() db_user = args.db_user db_password = args.db_password db_host = args.db_host db_name = args.db_name db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}") # print(db.dialect) # print(db.get_usable_table_names()) # print(db.table_info) os.environ["OPENAI_API_KEY"] = args.open_ai_key history = ChatMessageHistory() llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) final_prompt = prompt_creation(args.example_path) generate_query = create_sql_query_chain(llm, db, final_prompt) execute_query = QuerySQLDataBaseTool(db=db) chain = ( RunnablePassthrough.assign(table_names_to_use=select_table(args.desc_path)) | RunnablePassthrough.assign(query=generate_query).assign( result=itemgetter("query") | execute_query ) | rephrase_answer() ) if is_ffmpeg_installed(): print("ffmpeg is already installed.") else: print("ffmpeg is not installed. Installing ffmpeg...") if install_ffmpeg(): print("ffmpeg installation successful.") else: print("ffmpeg installation failed. Please install it manually.") if check_libportaudio_installed(): print("libportaudio is already installed.") else: print("libportaudio is not installed. Installing ffmpeg...") if install_libportaudio(): print("libportaudio installation successful.") else: print("libportaudio installation failed. Please install it manually.") valid_interface_type = ["audio", "text", "quit"] while True: interface_type = input("Please enter 'audio', 'text', or 'quit': ").strip().lower() if interface_type in valid_interface_type: if interface_type == "quit": print("Exiting the loop.") break elif interface_type == "text" : print(f"You selected '{interface_type}'.") while True: user_input = input("Enter a question for the DB (or type 'quit' to exit): ") if user_input.lower() == 'quit': break output = chain.invoke({"question": user_input, "messages":history.messages}) history.add_user_message(user_input) history.add_ai_message(output) print(output) elif interface_type == "audio": print(f"You selected '{interface_type}'.") command = record_command() sql_query = transcribe_speech(command) print(sql_query) output = chain.invoke({"question": sql_query, "messages":history.messages}) history.add_user_message(sql_query) history.add_ai_message(output) print(output) else: print("Invalid input. Please try again.") # while True: # user_input = input("Enter a question for the DB (or type 'quit' to exit): ") # if user_input.lower() == 'quit': # break # output = chain.invoke({"question": user_input, "messages":history.messages}) # history.add_user_message(user_input) # history.add_ai_message(output) # print(output)