Spaces:
Runtime error
Runtime error
import logging | |
from langchain import PromptTemplate, LLMChain | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chat_models import ChatOpenAI | |
import gradio as gr | |
import json | |
from prompts import PROMPT_EXTRACT_DATE, PROMPT_FED_ANALYST | |
from filterminutes import search_with_filter | |
# --------------------------Load the sentence transformer and the vector store--------------------------# | |
model_name = 'sentence-transformers/all-mpnet-base-v2' | |
model_kwargs = {'device': 'cpu'} | |
encode_kwargs = {'normalize_embeddings': False} | |
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs) | |
vs = FAISS.load_local("MINUTES_FOMC_HISTORY", embeddings) | |
# --------------------------Import the prompts------------------# | |
PROMPT_DATE = PromptTemplate.from_template(PROMPT_EXTRACT_DATE) | |
PROMPT_ANALYST = PromptTemplate.from_template(PROMPT_FED_ANALYST) | |
# --------------------------define the qa chain for answering queries--------------------------# | |
def load_chains(open_ai_key): | |
date_extractor = LLMChain(llm=ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo', openai_api_key=open_ai_key), | |
prompt=PROMPT_DATE) | |
fed_chain = load_qa_chain(llm=ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0, openai_api_key=open_ai_key), | |
chain_type='stuff', prompt=PROMPT_ANALYST) | |
return date_extractor, fed_chain | |
def get_chain(query, api_key): | |
""" | |
Detects the date, computes similarity, and answers the query using | |
only documents corresponding to the date requested. | |
The query is first passed to the date extractor to extract the date | |
and then to the qa chain to answer the query. | |
Parameters | |
---------- | |
query : str | |
Query to be answered. | |
api_key : str | |
OpenAI API key. | |
Returns | |
Answer to the query. | |
""" | |
date_extractor, fed_chain = load_chains(api_key) | |
logging.info('Extracting the date in numeric format..') | |
date_response = date_extractor.run(query) | |
if date_response != 'False': | |
filter_date = json.loads(date_response) | |
logging.info(f'Date parameters retrieved: {filter_date}') | |
logging.info('Running the qa with filtered context..') | |
filtered_context = search_with_filter(vs, query, init_k=200, step=300, target_k=7, filter_dict=filter_date) | |
logging.info(20 * '-' + 'Metadata for the documents to be used' + 20 * '-') | |
for doc in filtered_context: | |
logging.info(doc.metadata) | |
else: | |
logging.info('No date elements found. Running the qa without filtering can output incorrect results.') | |
filtered_context = vs.similarity_search(query, k=7) | |
return fed_chain({'input_documents': filtered_context[:7], 'question': query})['output_text'] | |
if __name__ == '__main__': | |
app = gr.Interface(fn=get_chain, | |
inputs=[gr.Textbox(lines=2, placeholder="Enter your query", label='Your query'), | |
gr.Textbox(lines=1, placeholder="Your OpenAI API key here", label='OpenAI Key')], | |
description='Query the public database in FRED from 1936-2023', | |
outputs=gr.Textbox(lines=1, label='Answer'), | |
title='Chat with the FOMC meeting minutes', | |
examples=[['What was the economic outlook from the staff presented in the meeting ' | |
'of April 2009 with respect to labour market developments and industrial production?'], | |
['Who were the voting members present in the meeting on March 2010?'], | |
['How important was the pandemic of Covid-19 in the discussions during 2020?'], | |
['What was the impact of the oil crisis for the economic outlook during 1973?']], | |
) | |
app.launch() | |