import os import time from pathlib import Path import openai import pandas as pd import streamlit as st from streamlit.elements.utils import _shown_default_value_warning _shown_default_value_warning = True # https://discuss.streamlit.io/t/why-do-default-values-cause-a-session-state-warning/15485/21 st.set_page_config(page_title="ChatGPT", page_icon="🌐") @st.cache_resource def init_openai_settings(): openai.api_key = os.getenv("OPENAI_API_KEY") if os.getenv("OPENAI_PROXY"): openai.proxy = os.getenv("OPENAI_PROXY") def init_session(): if not st.session_state.get("params"): st.session_state["params"] = dict() if not st.session_state.get("chats"): st.session_state["chats"] = {} if "input" not in st.session_state: st.session_state["input"] = "Hello, how are you?" def new_chat(chat_name): if not st.session_state["chats"].get(chat_name): st.session_state["chats"][chat_name] = { "answer": [], "question": [], "messages": [ {"role": "system", "content": st.session_state["params"]["prompt"]} ], "is_delete": False, "display_name": chat_name, } return chat_name def switch_chat(chat_name): if st.session_state.get("current_chat") != chat_name: st.session_state["current_chat"] = chat_name render_chat(chat_name) st.stop() def switch_chat_name(chat_name): if st.session_state.get("current_chat") != chat_name: st.session_state["current_chat"] = chat_name render_sidebar() render_chat(chat_name) st.stop() def delete_chat(chat_name): if chat_name in st.session_state['chats']: st.session_state['chats'][chat_name]['is_delete'] = True current_chats = [chat for chat, value in st.session_state['chats'].items() if not value['is_delete']] if len(current_chats) == 0: switch_chat(new_chat(f"Chat{len(st.session_state['chats'])}")) st.stop() if st.session_state["current_chat"] == chat_name: del st.session_state["current_chat"] switch_chat_name(current_chats[0]) def edit_chat(chat_name, zone): def edit(): if not st.session_state['edited_name']: print('name is empty!') return None if (st.session_state['edited_name'] != chat_name and st.session_state['edited_name'] in st.session_state['chats']): print('name is duplicated!') return None if st.session_state['edited_name'] == chat_name: print('name is not modified!') return None st.session_state['chats'][chat_name]['display_name'] = st.session_state['edited_name'] edit_zone = zone.empty() time.sleep(0.1) with edit_zone.container(): st.text_input('New Name', st.session_state['chats'][chat_name]['display_name'], key='edited_name') column1, _, column2 = st.columns([1, 5, 1]) column1.button('βœ…', on_click=edit) column2.button('❌') def render_sidebar_chat_management(zone): new_chat_button = zone.button(label="βž• New Chat", use_container_width=True) if new_chat_button: new_chat_name = f"Chat{len(st.session_state['chats'])}" st.session_state["current_chat"] = new_chat_name new_chat(new_chat_name) with st.sidebar.container(): for chat_name in st.session_state["chats"].keys(): if st.session_state['chats'][chat_name]['is_delete']: continue if chat_name == st.session_state.get('current_chat'): column1, column2, column3 = zone.columns([7, 1, 1]) column1.button( label='πŸ’¬ ' + st.session_state['chats'][chat_name]['display_name'], on_click=switch_chat_name, key=chat_name, args=(chat_name,), type='primary', use_container_width=True, ) column2.button(label='πŸ“', key='edit', on_click=edit_chat, args=(chat_name, zone)) column3.button(label='πŸ—‘οΈ', key='remove', on_click=delete_chat, args=(chat_name,)) else: zone.button( label='πŸ’¬ ' + st.session_state['chats'][chat_name]['display_name'], on_click=switch_chat_name, key=chat_name, args=(chat_name,), use_container_width=True, ) if new_chat_button: switch_chat(new_chat_name) def render_sidebar_gpt_config_tab(zone): st.session_state["params"] = dict() st.session_state["params"]["model"] = zone.selectbox( "Please select a model", ["gpt-3.5-turbo"], # , "gpt-4" help="ID of the model to use", ) st.session_state["params"]["temperature"] = zone.slider( "Temperature", min_value=0.0, max_value=2.0, value=1.2, step=0.1, format="%0.2f", help="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.", ) st.session_state["params"]["max_tokens"] = zone.slider( "Max Tokens", value=2000, step=1, min_value=100, max_value=2000, help="The maximum number of tokens to generate in the completion", ) st.session_state["params"]["stream"] = zone.checkbox( "Streaming output", value=True, help="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message", ) zone.caption('Looking for help at https://platform.openai.com/docs/api-reference/chat') def render_sidebar_prompt_config_tab(zone): prompt_text = zone.empty() st.session_state["params"]["prompt"] = prompt_text.text_area( "System Prompt", "You are a helpful assistant that translates answer from English to Chinese.", help="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.", ) template = zone.selectbox('Loading From Prompt Template', load_prompt_templates()) if template: prompts_df = load_prompts(template) actor = zone.selectbox('Loading Prompts', prompts_df.index.tolist()) if actor: st.session_state["params"]["prompt"] = prompt_text.text_area( "System Prompt", prompts_df.loc[actor].prompt, help="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.", ) def render_download_zone(zone): from io import BytesIO, StringIO if not st.session_state.get('current_chat'): return chat = st.session_state['chats'][st.session_state['current_chat']] col1, col2 = zone.columns([1, 1]) chat_messages = ['# ' + chat['display_name']] if chat["question"]: for i in range(len(chat["question"])): chat_messages.append(f"""πŸ˜ƒ **YOU:** {chat["question"][i]}""") if i < len(chat["answer"]): chat_messages.append(f"""πŸ€– **AI:** {chat["answer"][i]}""") col1.download_button('πŸ“€ Markdown', '\n'.join(chat_messages).encode('utf-8'), file_name=f"{chat['display_name']}.md", help="Download messages to a markdown file", use_container_width=True) tables = [] for answer in chat["answer"]: filter_table_str = '\n'.join([m.strip() for m in answer.split('\n') if m.strip().startswith('| ') or m == '']) try: tables.extend([pd.read_table(StringIO(filter_table_str.replace(' ', '')), sep='|').dropna(axis=1, how='all').iloc[1:] for m in filter_table_str.split('\n\n')]) except Exception as e: print(e) if tables: buffer = BytesIO() with pd.ExcelWriter(buffer) as writer: for index, table in enumerate(tables): table.to_excel(writer, sheet_name=str(index + 1), index=False) col2.download_button('πŸ“‰ Excel', buffer.getvalue(), file_name=f"{chat['display_name']}.xlsx", help="Download tables to a excel file", use_container_width=True) def render_sidebar(): st.sidebar.title("ChatGPT") chat_name_container = st.sidebar.container() chat_config_expander = st.sidebar.expander('βš™οΈ Chat configuration', True) tab_gpt, tab_prompt = chat_config_expander.tabs(['🌐 ChatGPT', 'πŸ‘₯ Prompt']) download_zone = st.sidebar.empty() github_zone = st.sidebar.empty() render_sidebar_gpt_config_tab(tab_gpt) render_sidebar_prompt_config_tab(tab_prompt) render_sidebar_chat_management(chat_name_container) render_download_zone(download_zone) render_github_info(github_zone) def render_user_message(message, zone): col1, col2 = zone.columns([1,8]) col1.markdown("πŸ˜ƒ **YOU:**") col2.markdown(message) def render_ai_message(message, zone): col1, col2 = zone.columns([1,8]) col1.markdown("πŸ€– **AI:**") col2.markdown(message) def render_history_answer(chat, zone): zone.empty() time.sleep(0.1) # https://github.com/streamlit/streamlit/issues/5044 with zone.container(): if chat['messages']: st.caption(f"""ℹ️ Prompt: {chat["messages"][0]['content']}""") if chat["question"]: for i in range(len(chat["question"])): render_user_message(chat["question"][i], st) if i < len(chat["answer"]): render_ai_message(chat["answer"][i], st) def render_last_answer(question, chat, zone): answer_zone = zone.empty() chat["messages"].append({"role": "user", "content": question}) chat["question"].append(question) if st.session_state["params"]["stream"]: answer = "" chat["answer"].append(answer) chat["messages"].append({"role": "assistant", "content": answer}) for response in get_openai_response(chat["messages"]): answer += response["choices"][0]['delta'].get("content", '') chat["answer"][-1] = answer chat["messages"][-1] = {"role": "assistant", "content": answer} render_ai_message(answer, answer_zone) else: with st.spinner("Wait for responding..."): answer = get_openai_response(chat["messages"])["choices"][0]["message"]["content"] chat["answer"].append(answer) chat["messages"].append({"role": "assistant", "content": answer}) render_ai_message(answer, zone) def render_stop_generate_button(zone): def stop(): st.session_state['regenerate'] = False zone.columns((2, 1, 2))[1].button('β–‘ Stop', on_click=stop) def render_regenerate_button(chat, zone): def regenerate(): chat["messages"].pop(-1) chat["messages"].pop(-1) chat["answer"].pop(-1) st.session_state['regenerate'] = True st.session_state['last_question'] = chat["question"].pop(-1) zone.columns((2, 1, 2))[1].button('πŸ”„Regenerate', type='primary', on_click=regenerate) def render_chat(chat_name): def handle_ask(): if st.session_state['input']: re_generate_zone.empty() render_user_message(st.session_state['input'], last_question_zone) render_stop_generate_button(stop_generate_zone) render_last_answer(st.session_state['input'], chat, last_answer_zone) st.session_state['input'] = '' if chat_name not in st.session_state["chats"]: st.error(f'{chat_name} is not exist') return chat = st.session_state["chats"][chat_name] if chat['is_delete']: st.error(f"{chat_name} is deleted") st.stop() if len(chat['messages']) == 1 and st.session_state["params"]["prompt"]: chat["messages"][0]['content'] = st.session_state["params"]["prompt"] conversation_zone = st.container() history_zone = conversation_zone.empty() last_question_zone = conversation_zone.empty() last_answer_zone = conversation_zone.empty() ask_form_zone = st.empty() render_history_answer(chat, history_zone) ask_form = ask_form_zone.form(chat_name) col1, col2 = ask_form.columns([10, 1]) col1.text_area("πŸ˜ƒ You: ", key="input", max_chars=2000, label_visibility='collapsed') with col2.container(): for _ in range(2): st.write('\n') st.form_submit_button("πŸš€", on_click=handle_ask) stop_generate_zone = conversation_zone.empty() re_generate_zone = conversation_zone.empty() if st.session_state.get('regenerate'): render_user_message(st.session_state['last_question'], last_question_zone) render_stop_generate_button(stop_generate_zone) render_last_answer(st.session_state['last_question'], chat, last_answer_zone) st.session_state['regenerate'] = False if chat["answer"]: stop_generate_zone.empty() render_regenerate_button(chat, re_generate_zone) # render_footer() def render_footer(): st.markdown( "

Made with ❀️ by ChatGPT and StreamLit.

", unsafe_allow_html=True) st.markdown("", unsafe_allow_html=True) def render_github_info(zone): with zone.container(): for i in range(1): st.write("\n") st.markdown('' 'GitHub' '', unsafe_allow_html=True) def get_openai_response(messages): if st.session_state["params"]["model"] in {'gpt-3.5-turbo', 'gpt4'}: response = openai.ChatCompletion.create( model=st.session_state["params"]["model"], temperature=st.session_state["params"]["temperature"], messages=messages, stream=st.session_state["params"]["stream"], max_tokens=st.session_state["params"]["max_tokens"], ) else: raise NotImplementedError('Not implemented yet!') return response def load_prompt_templates(): path = Path(__file__).parent / "templates" return [f.name for f in path.glob("*.json")] def load_prompts(template_name): if template_name: path = Path(__file__).parent / "templates" / template_name return pd.read_json(path).drop_duplicates(subset='act').set_index('act') # act, prompt if __name__ == "__main__": print("---- page reloading ----") init_openai_settings() init_session() render_sidebar() if st.session_state.get("current_chat"): render_chat(st.session_state["current_chat"]) if len(st.session_state["chats"]) == 0: switch_chat(new_chat(f"Chat{len(st.session_state['chats'])}"))