import streamlit as st from functions import * from langchain.chains import QAGenerationChain import itertools st.set_page_config(page_title="Earnings Question/Answering", page_icon="🔎") st.sidebar.header("Semantic Search") st.markdown("Earnings Semantic Search with LangChain, OpenAI & SBert") starter_message = "Ask me anything about the Earnings Call!" st.markdown( """ """, unsafe_allow_html=True, ) bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2", 'instructor-base': 'hkunlp/instructor-base', 'FlagEmbedding': 'BAAI/bge-base-en'} sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox') st.sidebar.markdown('Earnings QnA Generator') chunk_size = 1000 overlap_size = 50 try: if "sen_df" in st.session_state and "earnings_passages" in st.session_state: ## Save to a dataframe for ease of visualization sen_df = st.session_state['sen_df'] title = st.session_state['title'] print(f'Earnings Call title: {title}') earnings_text = st.session_state['earnings_passages'] st.session_state.eval_set = generate_eval( earnings_text, 10, 3000) # Display the question-answer pairs in the sidebar with smaller text for i, qa_pair in enumerate(st.session_state.eval_set): st.sidebar.markdown( f"""
Question {i + 1}

{qa_pair['question']}

{qa_pair['answer']}

""", unsafe_allow_html=True, ) embedding_model = bi_enc_dict[sbert_model_name] with st.spinner( text=f"Loading {embedding_model} embedding model and creating vectorstore..." ): docsearch = create_vectorstore(earnings_text,title, embedding_model) memory, agent_executor = create_memory_and_agent(docsearch) if "messages" not in st.session_state or st.sidebar.button("Clear message history"): st.session_state["messages"] = [AIMessage(content=starter_message)] for msg in st.session_state.messages: if isinstance(msg, AIMessage): st.chat_message("assistant").write(msg.content) elif isinstance(msg, HumanMessage): st.chat_message("user").write(msg.content) memory.chat_memory.add_message(msg) if user_question := st.chat_input(placeholder=starter_message): st.chat_message("user").write(user_question) with st.chat_message("assistant"): st_callback = StreamlitCallbackHandler(st.container()) response = agent_executor( {"input": user_question, "history": st.session_state.messages}, callbacks=[st_callback], include_run_info=True, ) answer = response["output"] st.session_state.messages.append(AIMessage(content=answer)) st.write(answer) memory.save_context({"input": user_question}, response) st.session_state["messages"] = memory.buffer run_id = response["__run"].run_id col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1]) with col_text: st.text("Feedback:") with col1: st.button("👍", on_click=send_feedback, args=(run_id, 1)) with col2: st.button("👎", on_click=send_feedback, args=(run_id, 0)) with st.expander(label='Query Result with Sentiment Tag', expanded=True): sentiment_label = gen_sentiment(answer) df = pd.DataFrame.from_dict({'Text':[answer],'Sentiment':[sentiment_label]}) text_annotations = gen_annotated_text(df)[0] annotated_text(text_annotations) else: st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file') except RuntimeError: st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')