Spaces:
Runtime error
Runtime error
# for setting/extracting environment variables such as API keys | |
import os | |
### 1. For Web Scraping | |
# for querying Financial Modelling Prep API | |
from urllib.request import urlopen | |
import json | |
### 2. For Converting Scraped Text Into a Vector Store of Chunked Documents | |
# for tokenizing texts and splitting them into chunks of documents | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# for turning documents into embeddings before putting them in vector store | |
from langchain.embeddings import HuggingFaceEmbeddings | |
# for vector store for documents | |
from langchain.vectorstores import Chroma | |
### 3. For Querying LLM | |
# for loading HuggingFace LLM models from the hub | |
#from langchain.llms import HuggingFaceHub | |
from langchain_community.llms import HuggingFaceEndpoint | |
# for querying LLM conveniently using the context | |
from langchain.chains.question_answering import load_qa_chain | |
### 4. For Gradio App UI | |
import gradio as gr | |
fmp_api_key = os.environ['FMP_API_KEY'] | |
# initialize the default model for embedding the tokenized texts, the articles are stored in this embedded form in the vector database | |
hf_embeddings = HuggingFaceEmbeddings() | |
if os.path.exists("chromadb_earnings_transcripts_extracted"): | |
os.system("rm -r chromadb_earnings_transcripts_extracted") | |
if os.path.exists("earnings_transcripts_chromadb.zip"): | |
os.system("rm earnings_transcripts_chromadb.zip") | |
os.system("wget https://github.com/damianboh/test_earnings_calls/raw/main/earnings_transcripts_chromadb.zip") | |
os.system("unzip earnings_transcripts_chromadb.zip -d chromadb_earnings_transcripts_extracted") | |
chroma_db = Chroma(persist_directory='chromadb_earnings_transcripts_extracted/chromadb_earnings_transcripts',embedding_function=hf_embeddings) | |
# Load the huggingface inference endpoint of an LLM model | |
# Name of the LLM model we are using, feel free to try others! | |
model = "mistralai/Mistral-7B-Instruct-v0.1" | |
# This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface | |
# hf_llm = HuggingFaceHub(repo_id=model,model_kwargs={'temperature':0.5,"max_new_tokens":300}) | |
hf_llm = HuggingFaceEndpoint( | |
endpoint_url=model, | |
huggingfacehub_api_token=os.environ['HUGGINGFACEHUB_API_TOKEN'], | |
task="text-generation", | |
max_new_tokens=512 | |
) | |
def source_question_answer(query:str,vectorstore:Chroma=chroma_db,llm:HuggingFaceHub=hf_llm): | |
""" | |
Return answer to the query | |
""" | |
input_docs = vectorstore.similarity_search(query,k=4) | |
qa_chain = load_qa_chain(llm, chain_type="stuff") | |
query = f"[INST]According to the earnings calls transcripts earlier, {query}[INST]" | |
response = qa_chain.run(input_documents=input_docs, question=query) | |
source_docs_1 = input_docs[0].page_content | |
source_docs_2 = input_docs[1].page_content | |
source_docs_3 = input_docs[2].page_content | |
source_docs_4 = input_docs[3].page_content | |
source_title_1 = input_docs[0].metadata['title'] | |
source_title_2 = input_docs[1].metadata['title'] | |
source_title_3 = input_docs[2].metadata['title'] | |
source_title_4 = input_docs[3].metadata['title'] | |
return response,source_docs_1 ,source_docs_2,source_docs_3,source_docs_4, source_title_1, source_title_2, source_title_3, source_title_4 | |
with gr.Blocks() as app: | |
with gr.Row(): | |
gr.HTML("<h1>Chat with Tesla 2023 Earnings Calls Transcripts</h1>") | |
with gr.Row(): | |
query = gr.Textbox("How is Tesla planning to expand?", placeholder="Enter question here...", label="Enter question") | |
btn = gr.Button("Ask Question") | |
with gr.Row(): | |
gr.HTML("<h3>Answer</h3>") | |
with gr.Row(): | |
answer = gr.Textbox(label="Answer") | |
with gr.Row(): | |
gr.HTML("<h3>Sources Referenced from Tesla 2023 Earnings Calls Transcripts</h3>") | |
with gr.Row(): | |
with gr.Column(): | |
source_title_1 = gr.Markdown() | |
source1 = gr.Textbox(label="Source Text 1") | |
with gr.Column(): | |
source_title_2 = gr.Markdown() | |
source2 = gr.Textbox(label="Source Text 2") | |
with gr.Row(): | |
with gr.Column(): | |
source_title_3 = gr.Markdown() | |
source3 = gr.Textbox(label="Source Text 3") | |
with gr.Column(): | |
source_title_4 = gr.Markdown() | |
source4 = gr.Textbox(label="Source Text 4") | |
query.submit(fn=source_question_answer, inputs=[query], | |
outputs=[answer, source1, source2, source3, source4, source_title_1, source_title_2, source_title_3, source_title_4]) | |
btn.click(fn=source_question_answer, inputs=[query], | |
outputs=[answer, source1, source2, source3, source4, source_title_1, source_title_2, source_title_3, source_title_4]) | |
app.launch() |