QnA / app.py
Sasidhar's picture
Update app.py
73494f0
raw
history blame
No virus
9.35 kB
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import streamlit as st
#from streamlit_chat import message
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from datetime import datetime
import pandas as pd
# from services.qa_chat_mode import *
# from services.qna import *
# from services.utils import *
# from services.transcripts import *
# from langchain import HuggingFaceHub
# from services.smart_prompt import PDSCoverageChain
def build_experimental_ui():
with st.sidebar:
tabs = st.sidebar.selectbox('SELECT TASK', [
"Question & Answer",
"Question & Answer (Chat Mode)",
"Transcript Intelligence",
])
st.markdown('---')
# Filters
if tabs=='Question & Answer':
selected_model = st.selectbox("Select Model:", options=[], index=0)
selected_embeddings = st.selectbox("Select Embeddings:", options=[], index=0)
# strategy = st.selectbox("Select Strategy:", options=STRATEGY_OPTIONS, index=1, disabled=True)
strategy = ''
elif tabs=='Question & Answer (Chat Mode)':
selected_model = st.selectbox("Select Model:", options=[], index=0)
selected_embeddings = st.selectbox("Select Embeddings:", options=[], index=0)
elif tabs == 'Transcript Intelligence':
selected_model = st.selectbox("Select Model:", options=[], index=0)
# selected_embeddings = st.selectbox("Select Embeddings:", options=EMBEDDING_OPTIONS, index=0)
# claimnumber = st.selectbox("Select Claim:", options=CLAIM_OPTIONS, index=0, disabled=True)
st.error("Disclaimer: All data processed in this application will be sent to OpenAI API based in the United States.")
st.markdown('## '+tabs)
if tabs=='Question & Answer':
# pdf_docs = st.file_uploader('Upload a PDF file', type=['pdf'],accept_multiple_files=True)
# st.session_state['pdf_file'] = pdf_docs
# Process = st.button("Process", disabled=(pdf_docs==[]))
# if Process:
# if pdf_docs:
# if pdf_docs!=st.session_state['pdf_file']:
# st.session_state['pdf_file'] = pdf_docs
# with st.spinner('Creating embeddings...'):
# texts = load_pdf_document(pdf_docs=pdf_docs)
# retriever = get_retriever_from_text(texts, embeddings[selected_embeddings])
# st.session_state['retriever'] = retriever
st.markdown('---')
# # Question & Answer
# if st.session_state['retriever'] is None:
# disable_query = True
# else:
# disable_query = False
prompt = st.text_input('Input your prompt', disabled=False, key="text")
questions_file = st.file_uploader('Upload a CSV file with questions', type=['csv'],accept_multiple_files=False)
if questions_file:
questions_df = pd.read_csv(questions_file))
return
chat_content = {}
chat_content['question'] = st.session_state.query
if strategy=='Without Chain-of-Thought':
instruction = st.text_area('Input your instruction (optional)', value=st.session_state['qa_instruction'], disabled=disable_query)
with st.expander("Sample instruction"):
sample_instruction = "Answer the question based on the context provided. Explain with reason in bullet points. Let's think step by step."
button_sample_instruction = st.button(sample_instruction, key='instruction1', disabled=disable_query, on_click=set_qa_instruction, args=(sample_instruction,))
if st.session_state.query is None:
disable = True
else:
disable = False
button_query = st.button('Submit', disabled=disable)
if button_query:
print('---- run query ----')
print(f'model: {selected_model} embeddings: {selected_embeddings}')
if selected_embeddings!=st.session_state['selected_embeddings']:
st.session_state['selected_embeddings'] = selected_embeddings
texts = load_pdf_document(pdf_docs)
st.session_state['retriever'] = get_retriever_from_text(texts, embeddings[selected_embeddings])
# qa = RetrievalQA.from_chain_type(llm=models[selected_model], chain_type="stuff",
# retriever=st.session_state['retriever'], return_source_documents=True)
st.session_state['docs'] = st.session_state['retriever'].get_relevant_documents(st.session_state.query)
context = '\n\n'.join([doc.page_content for doc in st.session_state['docs']])
st.session_state['context'] = context
source_files = get_pdf_file_names(st.session_state['pdf_file'])
#st.session_state['conversation']= get_conversation_chain(st.session_state['retriever'])
if strategy=='Without Chain-of-Thought':
user_token = model_configs[selected_model]['USER_TOKEN']
end_token = model_configs[selected_model]['END_TOKEN']
assistant_token = model_configs[selected_model]['ASSISTANT_TOKEN']
prompt_pattern, prompt = create_prompt(user_token, instruction, st.session_state.query, end_token, assistant_token, context)
updated_context = truncate_context(prompt_pattern, context,
max_token_len=model_configs[selected_model]['MAX_TOKENS'],
max_new_token_length=model_configs[selected_model]['MAX_NEW_TOKEN_LENGTH'])
updated_prompt = prompt_pattern.replace('{context}', updated_context)
print(updated_prompt)
with st.spinner():
answer = models[selected_model].generate([updated_prompt]).generations[0][0].text.strip()
st.write(answer)
chat_content['answer'] = answer
chat_content['source'] = source_files
chat_content['context']=st.session_state['context']
chat_content['time']=datetime.now().strftime("%d-%m-%Y %H:%M:%S")
if st.session_state['chat_history']:
st.session_state['chat_history'].append(chat_content)
else:
st.session_state['chat_history']=[chat_content]
print('------chat history-----',st.session_state['chat_history'])
if updated_prompt!=prompt:
st.caption(f"Note: The context has been truncated to fit model max tokens of {model_configs[selected_model]['MAX_TOKENS']}. Original context contains {len(context.split())} words. Truncated context contains {len(updated_context.split())} words.")
else:
chain = PDSCoverageChain()
with st.spinner():
answer = chain.generate(models[selected_model], model_configs[selected_model], st.session_state.query, context)
st.write(answer)
chat_content['answer'] = answer
chat_content['source'] = source_files
chat_content['context']=st.session_state['context']
chat_content['time']=datetime.now().strftime("%d-%m-%Y %H:%M:%S")
if st.session_state['chat_history']:
st.session_state['chat_history'].append(chat_content)
else:
st.session_state['chat_history']=[chat_content]
print('------chat history-----',st.session_state['chat_history'])
if st.session_state['docs']:
docs = st.session_state['docs']
col3, col4, col5, col6 = st.columns([0.2,0.35, 0.65, 3.8])
if st.session_state.query is None:
disable_query = True
else:
disable_query = False
chat_history = st.session_state['chat_history']
with col3:
st.button(":thumbsup:", on_click = get_feedback,disabled=disable_query,
kwargs=dict(upvote=True, downvote=False,
button='upvote'))
with col4:
st.button(":thumbsdown:", on_click = get_feedback,disabled=disable_query,
kwargs=dict(upvote=False, downvote=True,
button='downvote'))
with st.expander("References"):
for doc in docs:
print('-------',doc)
#st.markdown('###### Page {}'.format(doc.metadata['page']))
st.write(doc.page_content.replace('\n','\n\n').replace('$','\$').replace('**',''))
st.button("End Chat", on_click = get_feedback,
kwargs=dict(button='end-chat',
chat_history=chat_history))
else:
st.info("Under Development")
build_experimental_ui()