# -*- coding: utf-8 -*- """ To run: - activate the virtual environment - streamlit run path\to\streamlit_app.py """ import logging import os import re import sys import time import warnings import shutil from langchain.chat_models import ChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings import openai import pandas as pd import streamlit as st from st_aggrid import GridOptionsBuilder, AgGrid, GridUpdateMode, ColumnsAutoSizeMode from streamlit_chat import message from streamlit_langchain_chat.constants import * from streamlit_langchain_chat.customized_langchain.llms import OpenAI, AzureOpenAI, AzureOpenAIChat from streamlit_langchain_chat.dataset import Dataset # Configure logger logging.basicConfig(format="\n%(asctime)s\n%(message)s", level=logging.INFO, force=True) logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) warnings.filterwarnings('ignore') if 'generated' not in st.session_state: st.session_state['generated'] = [] if 'past' not in st.session_state: st.session_state['past'] = [] if 'costs' not in st.session_state: st.session_state['costs'] = [] if 'contexts' not in st.session_state: st.session_state['contexts'] = [] if 'chunks' not in st.session_state: st.session_state['chunks'] = [] if 'user_input' not in st.session_state: st.session_state['user_input'] = "" if 'dataset' not in st.session_state: st.session_state['dataset'] = None def check_api_keys() -> bool: source_id = app.params['source_id'] index_id = app.params['index_id'] open_api_key = os.getenv('OPENAI_API_KEY', '') openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0 pinecone_api_key = os.getenv('PINECONE_API_KEY', '') pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True is_ready = True if openapi_api_key_ready and pinecone_api_key_ready else False return is_ready def check_combination_point() -> bool: type_id = app.params['type_id'] open_api_key = os.getenv('OPENAI_API_KEY', '') openapi_api_key_ready = type(open_api_key) is str and len(open_api_key) > 0 api_base = app.params['api_base'] if type_id == 1: deployment_id = app.params['deployment_id'] return True if openapi_api_key_ready and api_base and deployment_id else False elif type_id == 2: return True if openapi_api_key_ready and api_base else False else: return False def check_index() -> bool: dataset = st.session_state['dataset'] index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False without_source = app.params['source_id'] == 4 is_ready = True if index_built or without_source else False return is_ready def check_index_point() -> bool: index_id = app.params['index_id'] pinecone_api_key = os.getenv('PINECONE_API_KEY', '') pinecone_api_key_ready = type(pinecone_api_key) is str and len(pinecone_api_key) > 0 if index_id == 2 else True pinecone_environment = os.getenv('PINECONE_ENVIRONMENT', False) if index_id == 2 else True is_ready = True if index_id and pinecone_api_key_ready and pinecone_environment else False return is_ready def check_params_point() -> bool: max_sources = app.params['max_sources'] temperature = app.params['temperature'] is_ready = True if max_sources and isinstance(temperature, float) else False return is_ready def check_source_point() -> bool: return True def clear_chat_history(): if st.session_state['past'] or st.session_state['generated'] or st.session_state['contexts'] or st.session_state['chunks'] or st.session_state['costs']: st.session_state['past'] = [] st.session_state['generated'] = [] st.session_state['contexts'] = [] st.session_state['chunks'] = [] st.session_state['costs'] = [] def clear_index(): if dataset := st.session_state['dataset']: # delete directory (with files) index_path = dataset.index_path if index_path.exists(): shutil.rmtree(str(index_path)) # update variable st.session_state['dataset'] = None elif (TEMP_DIR / "default").exists(): shutil.rmtree(str(TEMP_DIR / "default")) def check_sources() -> bool: uploaded_files_rows = app.params['uploaded_files_rows'] urls_df = app.params['urls_df'] source_id = app.params['source_id'] some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False some_urls = bool([True for url, citation in urls_df.to_numpy() if url]) only_local_files = some_files and not some_urls only_urls = not some_files and some_urls is_ready = only_local_files or only_urls or (source_id == 4) return is_ready def collect_dataset_and_built_index(): start = time.time() uploaded_files_rows = app.params['uploaded_files_rows'] urls_df = app.params['urls_df'] type_id = app.params['type_id'] temperature = app.params['temperature'] index_id = app.params['index_id'] api_base = app.params['api_base'] deployment_id = app.params['deployment_id'] some_files = True if uploaded_files_rows and uploaded_files_rows[-1].get('filepath') != "" else False some_urls = bool([True for url, citation in urls_df.to_numpy() if url]) openai.api_type = "azure" if type_id == 1 else "open_ai" openai.api_base = api_base openai.api_version = "2023-03-15-preview" if type_id == 1 else None if deployment_id != "text-davinci-003": dataset = Dataset( llm=ChatOpenAI( temperature=temperature, max_tokens=512, deployment_id=deployment_id, ) ) else: dataset = Dataset( llm=OpenAI( temperature=temperature, max_tokens=512, deployment_id=COMBINATIONS_OPTIONS.get(combination_id).get('deployment_name'), ) ) # get url documents if some_urls: urls_df = urls_df.reset_index() for url_index, url_row in urls_df.iterrows(): url = url_row.get('urls', '') citation = url_row.get('citation string', '') if url: try: dataset.add( url, citation, citation, disable_check=True # True to accept Japanese letters ) except Exception as e: print(e) pass # dataset is pandas dataframe if some_files: for uploaded_files_row in uploaded_files_rows: key = uploaded_files_row.get('citation string') if ',' not in uploaded_files_row.get('citation string') else None dataset.add( uploaded_files_row.get('filepath'), uploaded_files_row.get('citation string'), key=key, disable_check=True # True to accept Japanese letters ) openai_embeddings = OpenAIEmbeddings( document_model_name="text-embedding-ada-002", query_model_name="text-embedding-ada-002", ) if index_id == 1: dataset._build_faiss_index(openai_embeddings) else: dataset._build_pinecone_index(openai_embeddings) st.session_state['dataset'] = dataset if OPERATING_MODE == "debug": print(f"time to collect dataset: {time.time() - start:.2f} [s]") def configure_streamlit_and_page(): # Configure Streamlit page and state st.set_page_config(**ST_CONFIG) # Force responsive layout for columns also on mobile st.write( """""", unsafe_allow_html=True, ) def get_answer(): query = st.session_state['user_input'] dataset = st.session_state['dataset'] type_id = app.params['type_id'] index_id = app.params['index_id'] max_sources = app.params['max_sources'] if query and dataset and type_id and index_id: chat_history = [(past, generated) for (past, generated) in zip(st.session_state['past'], st.session_state['generated'])] marginal_relevance = False if not index_id == 1 else True start = time.time() openai_embeddings = OpenAIEmbeddings( document_model_name="text-embedding-ada-002", query_model_name="text-embedding-ada-002", ) result = dataset.query( query, openai_embeddings, chat_history, marginal_relevance=marginal_relevance, # if pinecone is used it must be False ) if OPERATING_MODE == "debug": print(f"time to get answer: {time.time() - start:.2f} [s]") print("-" * 10) # response = {'generated_text': result.formatted_answer} # response = {'generated_text': f"test_{len(st.session_state['generated'])} by {query}"} # @debug return result else: return None def load_main_page(): """ Load the body of web. """ # Streamlit HTML Markdown # st.title

# # st.header

## # st.subheader

### st.markdown(f"## Augmented-Retrieval Q&A ChatGPT ({APP_VERSION})") validate_status() st.markdown(f"#### **Status**: {app.params['status']}") # hidden div with anchor st.markdown("
", unsafe_allow_html=True) col1, col2, col3 = st.columns(3) col1.button(label="clear index", type="primary", on_click=clear_index) col2.button(label="clear conversation", type="primary", on_click=clear_chat_history) col3.markdown("Link to bottom", unsafe_allow_html=True) if st.session_state["generated"]: for i in range(len(st.session_state["generated"])): message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') message(st.session_state['generated'][i], key=str(i)) with st.expander("See context"): st.write(st.session_state['contexts'][i]) with st.expander("See chunks"): st.write(st.session_state['chunks'][i]) with st.expander("See costs"): st.write(st.session_state['costs'][i]) dataset = st.session_state['dataset'] index_built = dataset.index_docstore if hasattr(dataset, "index_docstore") else False without_source = app.params['source_id'] == 4 enable_chat_button = index_built or without_source st.text_input("You:", key='user_input', on_change=on_enter, disabled=not enable_chat_button ) st.markdown("Link to top", unsafe_allow_html=True) # hidden div with anchor st.markdown("
", unsafe_allow_html=True) def load_sidebar_page(): st.sidebar.markdown("## Instructions") # ############ # # SOURCES TYPE # # ############ # st.sidebar.markdown("1. Select a source:") source_selected = st.sidebar.selectbox( "Choose the location of your info to give context to chatgpt", [key for key, value in SOURCES_IDS.items()]) app.params['source_id'] = SOURCES_IDS.get(source_selected, None) # ##### # # MODEL # # ##### # st.sidebar.markdown("2. Select a model (LLM):") combination_selected = st.sidebar.selectbox( "Choose type: MSF Azure OpenAI and model / OpenAI", [key for key, value in TYPE_IDS.items()]) app.params['type_id'] = TYPE_IDS.get(combination_selected, None) if app.params['type_id'] == 1: # with AzureOpenAI endpoint # https://docs.streamlit.io/library/api-reference/widgets/st.text_input os.environ['OPENAI_API_KEY'] = st.sidebar.text_input( label="Enter Azure OpenAI API Key", type="password" ).strip() app.params['api_base'] = st.sidebar.text_input( label="Enter Azure API base", placeholder="https://.openai.azure.com/", ).strip() app.params['deployment_id'] = st.sidebar.text_input( label="Enter Azure deployment_id", ).strip() elif app.params['type_id'] == 2: # with OpenAI endpoint os.environ['OPENAI_API_KEY'] = st.sidebar.text_input( label="Enter OpenAI API Key", placeholder="sk-...", type="password" ).strip() app.params['api_base'] = "https://api.openai.com/v1" app.params['deployment_id'] = None # ####### # # INDEXES # # ####### # st.sidebar.markdown("3. Select a index store:") index_selected = st.sidebar.selectbox( "Type of Index", [key for key, value in INDEX_IDS.items()]) app.params['index_id'] = INDEX_IDS.get(index_selected, None) if app.params['index_id'] == 2: # with pinecone os.environ['PINECONE_API_KEY'] = st.sidebar.text_input( label="Enter pinecone API Key", type="password" ).strip() os.environ['PINECONE_ENVIRONMENT'] = st.sidebar.text_input( label="Enter pinecone environment", placeholder="eu-west1-gcp", ).strip() # ############## # # CONFIGURATIONS # # ############## # st.sidebar.markdown("4. Choose configuration:") # https://docs.streamlit.io/library/api-reference/widgets/st.number_input max_sources = st.sidebar.number_input( label="Top-k: Number of chunks/sections (1-5)", step=1, format="%d", value=5 ) app.params['max_sources'] = max_sources temperature = st.sidebar.number_input( label="Temperature (0.0 – 1.0)", step=0.1, format="%f", value=0.0, min_value=0.0, max_value=1.0 ) app.params['temperature'] = round(temperature, 1) # ############## # # UPLOAD SOURCES # # ############## # app.params['uploaded_files_rows'] = [] if app.params['source_id'] == 1: # https://docs.streamlit.io/library/api-reference/widgets/st.file_uploader # https://towardsdatascience.com/make-dataframes-interactive-in-streamlit-c3d0c4f84ccb st.sidebar.markdown("""5. Upload your local documents and modify citation strings (optional)""") uploaded_files = st.sidebar.file_uploader( "Choose files", accept_multiple_files=True, type=['pdf', 'PDF', 'txt', 'TXT', 'html', 'docx', 'DOCX', 'pptx', 'PPTX', ], ) uploaded_files_dataset = request_pathname(uploaded_files) uploaded_files_df = pd.DataFrame( uploaded_files_dataset, columns=['filepath', 'citation string']) uploaded_files_grid_options_builder = GridOptionsBuilder.from_dataframe(uploaded_files_df) uploaded_files_grid_options_builder.configure_selection( selection_mode='multiple', pre_selected_rows=list(range(uploaded_files_df.shape[0])) if uploaded_files_df.iloc[-1, 0] != "" else [], use_checkbox=True, ) uploaded_files_grid_options_builder.configure_column("citation string", editable=True) uploaded_files_grid_options_builder.configure_auto_height() uploaded_files_grid_options = uploaded_files_grid_options_builder.build() with st.sidebar: uploaded_files_ag_grid = AgGrid( uploaded_files_df, gridOptions=uploaded_files_grid_options, update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED, ) app.params['uploaded_files_rows'] = uploaded_files_ag_grid["selected_rows"] app.params['urls_df'] = pd.DataFrame() if app.params['source_id'] == 3: st.sidebar.markdown("""5. Write some urls and modify citation strings if you want (to look prettier)""") # option 1: with streamlit version 1.20.0+ # app.params['urls_df'] = st.sidebar.experimental_data_editor( # pd.DataFrame([["", ""]], columns=['urls', 'citation string']), # use_container_width=True, # num_rows="dynamic", # ) # option 2: with streamlit version 1.19.0 urls_dataset = [["", ""], ["", ""], ["", ""], ["", ""], ["", ""]] urls_df = pd.DataFrame( urls_dataset, columns=['urls', 'citation string']) urls_grid_options_builder = GridOptionsBuilder.from_dataframe(urls_df) urls_grid_options_builder.configure_columns(['urls', 'citation string'], editable=True) urls_grid_options_builder.configure_auto_height() urls_grid_options = urls_grid_options_builder.build() with st.sidebar: urls_ag_grid = AgGrid( urls_df, gridOptions=urls_grid_options, update_mode=GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED, ) df = urls_ag_grid.data df = df[df.urls != ""] app.params['urls_df'] = df if app.params['source_id'] in (1, 2, 3): st.sidebar.markdown("""6. Build an index where you can ask""") api_keys_ready = check_api_keys() source_ready = check_sources() enable_index_button = api_keys_ready and source_ready if st.sidebar.button("Build index", disabled=not enable_index_button): collect_dataset_and_built_index() def main(): configure_streamlit_and_page() load_sidebar_page() load_main_page() def on_enter(): output = get_answer() if output: st.session_state.past.append(st.session_state['user_input']) st.session_state.generated.append(output.answer) st.session_state.contexts.append(output.context) st.session_state.chunks.append(output.chunks) st.session_state.costs.append(output.cost_str) st.session_state['user_input'] = "" def request_pathname(files): if not files: return [["", ""]] # check if temporal directory exist, if not create it if not Path.exists(TEMP_DIR): TEMP_DIR.mkdir( parents=True, exist_ok=True, ) file_paths = [] for file in files: # # absolut path # file_path = str(TEMP_DIR / file.name) # relative path file_path = str((TEMP_DIR / file.name).relative_to(ROOT_DIR)) file_paths.append(file_path) with open(file_path, "wb") as f: f.write(file.getbuffer()) return [[filepath, filename.name] for filepath, filename in zip(file_paths, files)] def validate_status(): source_point_ready = check_source_point() combination_point_ready = check_combination_point() index_point_ready = check_index_point() params_point_ready = check_params_point() sources_ready = check_sources() index_ready = check_index() if source_point_ready and combination_point_ready and index_point_ready and params_point_ready and sources_ready and index_ready: app.params['status'] = "✨Ready✨" elif not source_point_ready: app.params['status'] = "⚠️Review step 1 on the sidebar." elif not combination_point_ready: app.params['status'] = "⚠️Review step 2 on the sidebar. API Keys or endpoint, ..." elif not index_point_ready: app.params['status'] = "⚠️Review step 3 on the sidebar. Index API Key or environment." elif not params_point_ready: app.params['status'] = "⚠️Review step 4 on the sidebar" elif not sources_ready: app.params['status'] = "⚠️Review step 5 on the sidebar. Waiting for some source..." elif not index_ready: app.params['status'] = "⚠️Review step 6 on the sidebar. Waiting for press button to create index ..." else: app.params['status'] = "⚠️Something is not ready..." class StreamlitLangchainChatApp(): def __init__(self) -> None: """Use __init__ to define instance variables. It cannot have any arguments.""" self.params = dict() def run(self, **state) -> None: """Define here all logic required by your application.""" main() if __name__ == "__main__": app = StreamlitLangchainChatApp() app.run()