|
import os |
|
os.environ['KMP_DUPLICATE_LIB_OK']='True' |
|
import streamlit as st |
|
|
|
|
|
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx |
|
from datetime import datetime |
|
import pandas as pd |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_experimental_ui(): |
|
|
|
with st.sidebar: |
|
tabs = st.sidebar.selectbox('SELECT TASK', [ |
|
"Question & Answer", |
|
"Question & Answer (Chat Mode)", |
|
"Transcript Intelligence", |
|
]) |
|
st.markdown('---') |
|
|
|
if tabs=='Question & Answer': |
|
selected_model = st.selectbox("Select Model:", options=[], index=0) |
|
selected_embeddings = st.selectbox("Select Embeddings:", options=[], index=0) |
|
|
|
strategy = '' |
|
elif tabs=='Question & Answer (Chat Mode)': |
|
selected_model = st.selectbox("Select Model:", options=[], index=0) |
|
selected_embeddings = st.selectbox("Select Embeddings:", options=[], index=0) |
|
elif tabs == 'Transcript Intelligence': |
|
selected_model = st.selectbox("Select Model:", options=[], index=0) |
|
|
|
|
|
|
|
st.error("Disclaimer: All data processed in this application will be sent to OpenAI API based in the United States.") |
|
|
|
st.markdown('## '+tabs) |
|
|
|
if tabs=='Question & Answer': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown('---') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = st.text_input('Input your prompt', disabled=False, key="text") |
|
|
|
questions_file = st.file_uploader('Upload a CSV file with questions', type=['csv'],accept_multiple_files=False) |
|
|
|
if questions_file: |
|
questions_df = pd.read_csv(questions_file)) |
|
|
|
|
|
|
|
return |
|
|
|
chat_content = {} |
|
chat_content['question'] = st.session_state.query |
|
|
|
if strategy=='Without Chain-of-Thought': |
|
instruction = st.text_area('Input your instruction (optional)', value=st.session_state['qa_instruction'], disabled=disable_query) |
|
with st.expander("Sample instruction"): |
|
sample_instruction = "Answer the question based on the context provided. Explain with reason in bullet points. Let's think step by step." |
|
button_sample_instruction = st.button(sample_instruction, key='instruction1', disabled=disable_query, on_click=set_qa_instruction, args=(sample_instruction,)) |
|
|
|
if st.session_state.query is None: |
|
disable = True |
|
else: |
|
disable = False |
|
|
|
button_query = st.button('Submit', disabled=disable) |
|
|
|
|
|
if button_query: |
|
|
|
print('---- run query ----') |
|
print(f'model: {selected_model} embeddings: {selected_embeddings}') |
|
if selected_embeddings!=st.session_state['selected_embeddings']: |
|
st.session_state['selected_embeddings'] = selected_embeddings |
|
texts = load_pdf_document(pdf_docs) |
|
st.session_state['retriever'] = get_retriever_from_text(texts, embeddings[selected_embeddings]) |
|
|
|
|
|
st.session_state['docs'] = st.session_state['retriever'].get_relevant_documents(st.session_state.query) |
|
context = '\n\n'.join([doc.page_content for doc in st.session_state['docs']]) |
|
st.session_state['context'] = context |
|
source_files = get_pdf_file_names(st.session_state['pdf_file']) |
|
|
|
|
|
if strategy=='Without Chain-of-Thought': |
|
user_token = model_configs[selected_model]['USER_TOKEN'] |
|
end_token = model_configs[selected_model]['END_TOKEN'] |
|
assistant_token = model_configs[selected_model]['ASSISTANT_TOKEN'] |
|
prompt_pattern, prompt = create_prompt(user_token, instruction, st.session_state.query, end_token, assistant_token, context) |
|
updated_context = truncate_context(prompt_pattern, context, |
|
max_token_len=model_configs[selected_model]['MAX_TOKENS'], |
|
max_new_token_length=model_configs[selected_model]['MAX_NEW_TOKEN_LENGTH']) |
|
updated_prompt = prompt_pattern.replace('{context}', updated_context) |
|
print(updated_prompt) |
|
|
|
|
|
with st.spinner(): |
|
answer = models[selected_model].generate([updated_prompt]).generations[0][0].text.strip() |
|
st.write(answer) |
|
chat_content['answer'] = answer |
|
chat_content['source'] = source_files |
|
chat_content['context']=st.session_state['context'] |
|
chat_content['time']=datetime.now().strftime("%d-%m-%Y %H:%M:%S") |
|
if st.session_state['chat_history']: |
|
st.session_state['chat_history'].append(chat_content) |
|
else: |
|
st.session_state['chat_history']=[chat_content] |
|
print('------chat history-----',st.session_state['chat_history']) |
|
if updated_prompt!=prompt: |
|
st.caption(f"Note: The context has been truncated to fit model max tokens of {model_configs[selected_model]['MAX_TOKENS']}. Original context contains {len(context.split())} words. Truncated context contains {len(updated_context.split())} words.") |
|
|
|
else: |
|
chain = PDSCoverageChain() |
|
with st.spinner(): |
|
answer = chain.generate(models[selected_model], model_configs[selected_model], st.session_state.query, context) |
|
st.write(answer) |
|
chat_content['answer'] = answer |
|
chat_content['source'] = source_files |
|
chat_content['context']=st.session_state['context'] |
|
chat_content['time']=datetime.now().strftime("%d-%m-%Y %H:%M:%S") |
|
if st.session_state['chat_history']: |
|
st.session_state['chat_history'].append(chat_content) |
|
else: |
|
st.session_state['chat_history']=[chat_content] |
|
print('------chat history-----',st.session_state['chat_history']) |
|
|
|
|
|
if st.session_state['docs']: |
|
|
|
docs = st.session_state['docs'] |
|
|
|
col3, col4, col5, col6 = st.columns([0.2,0.35, 0.65, 3.8]) |
|
if st.session_state.query is None: |
|
disable_query = True |
|
else: |
|
disable_query = False |
|
chat_history = st.session_state['chat_history'] |
|
with col3: |
|
st.button(":thumbsup:", on_click = get_feedback,disabled=disable_query, |
|
kwargs=dict(upvote=True, downvote=False, |
|
button='upvote')) |
|
with col4: |
|
st.button(":thumbsdown:", on_click = get_feedback,disabled=disable_query, |
|
kwargs=dict(upvote=False, downvote=True, |
|
button='downvote')) |
|
|
|
|
|
with st.expander("References"): |
|
for doc in docs: |
|
print('-------',doc) |
|
|
|
st.write(doc.page_content.replace('\n','\n\n').replace('$','\$').replace('**','')) |
|
|
|
|
|
|
|
st.button("End Chat", on_click = get_feedback, |
|
kwargs=dict(button='end-chat', |
|
chat_history=chat_history)) |
|
|
|
|
|
|
|
|
|
else: |
|
st.info("Under Development") |
|
|
|
build_experimental_ui() |