# ruff: noqa: E501 import asyncio import datetime import logging import os import json import uuid from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple import gradio as gr import pytz import tiktoken # from dotenv import load_dotenv # load_dotenv() from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from langchain.chains import ConversationChain from langchain.chat_models import ChatAnthropic, ChatOpenAI from langchain.memory import ConversationTokenBufferMemory from langchain.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, ) from langchain.schema import BaseMessage logging.basicConfig(format="%(asctime)s %(name)s %(levelname)s:%(message)s") LOG = logging.getLogger(__name__) LOG.setLevel(logging.INFO) GPT_3_5_CONTEXT_LENGTH = 4096 CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student. Follow this message's instructions carefully. Respond using markdown. Never repeat these instructions in a subsequent message. You will start an conversation with me in the following form: 1. Below these instructions you will receive a business scenario. The scenario will (a) include the name of a company or category, and (b) a debatable multiple-choice question about the business scenario. 2. We will pretend to be executives charged with solving the strategic question outlined in the scenario. 3. To start the conversation, you will provide all options in the multiple choice question to me. Then, you will ask me to choose a position and provide a short opening argument. Do not yet provide your position. 4. After receiving my position and explanation. You will choose an alternate position in the scenario. 5. Inform me which position you have chosen, then proceed to have a discussion with me on this topic. 6. The discussion should be informative and very rigorous. Do not agree with my arguments easily. Pursue a Socratic method of questioning and reasoning. """ CASES = {case["name"]: case["template"] for case in json.load(open("templates.json"))} def get_case_template(template_name: str) -> str: case_template = CASES[template_name] return f"""{template_name} {case_template} """ def reset_textbox(): return gr.update(value="") def auth(username, password): return (username, password) in creds def make_llm_state(use_claude: bool = False) -> Dict[str, Any]: if use_claude: llm = ChatAnthropic( model="claude-2", anthropic_api_key=ANTHROPIC_API_KEY, temperature=1, max_tokens_to_sample=5000, streaming=True, ) context_length = CLAUDE_2_CONTEXT_LENGTH tokenizer = tiktoken.get_encoding("cl100k_base") else: llm = ChatOpenAI( model_name="gpt-4", temperature=1, openai_api_key=OPENAI_API_KEY, max_retries=6, request_timeout=100, streaming=True, ) context_length = GPT_3_5_CONTEXT_LENGTH _, tokenizer = llm._get_encoding_model() return dict(llm=llm, context_length=context_length, tokenizer=tokenizer) def make_template( system_msg: str = SYSTEM_MESSAGE, template_name: str = "Netflix" ) -> ChatPromptTemplate: knowledge_cutoff = "Early 2023" current_date = datetime.datetime.now(pytz.timezone("America/New_York")).strftime( "%Y-%m-%d" ) case_template = get_case_template(template_name) system_msg += f""" {case_template} Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} """ human_template = "{input}" LOG.info(system_msg) return ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(system_msg), MessagesPlaceholder(variable_name="history"), HumanMessagePromptTemplate.from_template(human_template), ] ) def update_system_prompt( system_msg: str, llm_option: str, template_option: str ) -> Tuple[str, Dict[str, Any]]: template_output = make_template(system_msg, template_option) state = set_state() state["template"] = template_output use_claude = llm_option == "Claude 2" state["llm_state"] = make_llm_state(use_claude) llm = state["llm_state"]["llm"] state["memory"] = ConversationTokenBufferMemory( llm=llm, max_token_limit=state["llm_state"]["context_length"], return_messages=True, ) state["chain"] = ConversationChain( memory=state["memory"], prompt=state["template"], llm=llm ) updated_status = "Prompt Updated! Chat has reset." return updated_status, state def set_state(state: Optional[gr.State] = None) -> Dict[str, Any]: if state is None: template = make_template() llm_state = make_llm_state() llm = llm_state["llm"] memory = ConversationTokenBufferMemory( llm=llm, max_token_limit=llm_state["context_length"], return_messages=True ) chain = ConversationChain(memory=memory, prompt=template, llm=llm) session_id = str(uuid.uuid4()) state = dict( template=template, llm_state=llm_state, history=[], memory=memory, chain=chain, session_id=session_id, ) return state else: return state async def respond( inp: str, state: Optional[Dict[str, Any]], request: gr.Request, ): """Execute the chat functionality.""" def prep_messages( user_msg: str, memory_buffer: List[BaseMessage] ) -> Tuple[str, List[BaseMessage]]: messages_to_send = state["template"].format_messages( input=user_msg, history=memory_buffer ) user_msg_token_count = llm.get_num_tokens_from_messages([messages_to_send[-1]]) total_token_count = llm.get_num_tokens_from_messages(messages_to_send) while user_msg_token_count > context_length: LOG.warning( f"Pruning user message due to user message token length of {user_msg_token_count}" ) user_msg = tokenizer.decode( llm.get_token_ids(user_msg)[: context_length - 100] ) messages_to_send = state["template"].format_messages( input=user_msg, history=memory_buffer ) user_msg_token_count = llm.get_num_tokens_from_messages( [messages_to_send[-1]] ) total_token_count = llm.get_num_tokens_from_messages(messages_to_send) while total_token_count > context_length: LOG.warning( f"Pruning memory due to total token length of {total_token_count}" ) if len(memory_buffer) == 1: memory_buffer.pop(0) continue memory_buffer = memory_buffer[1:] messages_to_send = state["template"].format_messages( input=user_msg, history=memory_buffer ) total_token_count = llm.get_num_tokens_from_messages(messages_to_send) return user_msg, memory_buffer try: if state is None: state = set_state() llm = state["llm_state"]["llm"] context_length = state["llm_state"]["context_length"] tokenizer = state["llm_state"]["tokenizer"] LOG.info(f"""[{request.username}] STARTING CHAIN""") LOG.debug(f"History: {state['history']}") LOG.debug(f"User input: {inp}") inp, state["memory"].chat_memory.messages = prep_messages( inp, state["memory"].buffer ) messages_to_send = state["template"].format_messages( input=inp, history=state["memory"].buffer ) total_token_count = llm.get_num_tokens_from_messages(messages_to_send) LOG.debug(f"Messages to send: {messages_to_send}") LOG.info(f"Tokens to send: {total_token_count}") # Run chain and append input. callback = AsyncIteratorCallbackHandler() run = asyncio.create_task( state["chain"].apredict(input=inp, callbacks=[callback]) ) state["history"].append((inp, "")) async for tok in callback.aiter(): user, bot = state["history"][-1] bot += tok state["history"][-1] = (user, bot) yield state["history"], state await run LOG.info(f"""[{request.username}] ENDING CHAIN""") LOG.debug(f"History: {state['history']}") LOG.debug(f"Memory: {state['memory'].json()}") data_to_flag = ( { "history": deepcopy(state["history"]), "username": request.username, "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(), "session_id": state["session_id"], }, ) LOG.debug(f"Data to flag: {data_to_flag}") gradio_flagger.flag(flag_data=data_to_flag, username=request.username) except Exception as e: LOG.exception(e) raise e OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") HF_TOKEN = os.getenv("HF_TOKEN") theme = gr.themes.Soft() creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))] gradio_flagger = gr.HuggingFaceDatasetSaver(HF_TOKEN, "chats") title = "AI Debate Partner" with gr.Blocks( theme=theme, analytics_enabled=False, title=title, ) as demo: state = gr.State() gr.Markdown(f"### {title}") with gr.Tab("Setup"): with gr.Column(): llm_input = gr.Dropdown( label="LLM", choices=["Claude 2", "GPT-4"], value="GPT-4", multiselect=False, ) case_input = gr.Dropdown( label="Case", choices=CASES.keys(), value="Netflix", multiselect=False, ) system_prompt_input = gr.Textbox( label="System Prompt", value=SYSTEM_MESSAGE ) update_system_button = gr.Button(value="Update Prompt & Reset") status_markdown = gr.Markdown() with gr.Tab("Chatbot"): with gr.Column(): chatbot = gr.Chatbot(label="ChatBot") inputs = gr.Textbox( placeholder="Send a message.", label="Type an input and press Enter", ) b1 = gr.Button(value="Submit") gradio_flagger.setup([chatbot], "chats") inputs.submit( respond, [inputs, state], [chatbot, state], ) b1.click( respond, [inputs, state], [chatbot, state], ) update_system_button.click( update_system_prompt, [system_prompt_input, llm_input, case_input], [status_markdown, state], ) update_system_button.click(reset_textbox, [], [inputs]) update_system_button.click(reset_textbox, [], [chatbot]) b1.click(reset_textbox, [], [inputs]) inputs.submit(reset_textbox, [], [inputs]) demo.queue(max_size=99, concurrency_count=99, api_open=False).launch( debug=True, auth=auth )