Spaces:
Runtime error
Runtime error
from typing import List, Tuple, Dict, Generator | |
from langchain.chat_models import ChatOpenAI | |
import gradio as gr | |
import json | |
import os | |
import openai | |
parent = os.path.dirname(__file__) | |
data_path = os.path.join(parent, 'data.json') | |
def create_formatted_messages( messages: List[dict]) -> List[Tuple[str, str]]: | |
''' | |
create formatted history | |
= | |
Create formatted history for gradio chatbot | |
''' | |
user_messages = [] | |
formatted_messages = [] | |
assistant_messages = [] | |
for message in messages: | |
if message['role'] == 'user': | |
user_messages.append(message['content']) | |
elif message['role'] == 'assistant': | |
assistant_messages.append(message['content']) | |
if user_messages and assistant_messages: | |
formatted_messages.append( | |
(''.join(user_messages), ''.join(assistant_messages)) | |
) | |
user_messages = [] | |
assistant_messages = [] | |
# append any remaining messages | |
if user_messages: | |
formatted_messages.append((''.join(user_messages), None)) | |
elif assistant_messages: | |
formatted_messages.append((None, ''.join(assistant_messages))) | |
return formatted_messages | |
class ChatGPT: | |
def chat(self, message: str, messages: List[Dict[str, str]], message_length:int) \ | |
-> Generator[Tuple[List[Tuple[str, str]], List[Dict[str, str]]], None, None]: | |
if messages == None: | |
messages = [] | |
messages.append({'role': 'system', 'content': 'ChatDefense is available to assist you with your legal questions.'}) | |
messages.append({'role': 'user', 'content': message}) | |
# We have no content for the assistant's response yet but we will update this: | |
messages.append({'role': 'assistant', 'content': ''}) | |
response_message = '' | |
try: | |
chat_generator = self.openai.client.create( | |
messages=messages[-2*int(message_length):], | |
model=self.openai.model_name, | |
stream=True, | |
) | |
for chunk in chat_generator: | |
if 'choices' in chunk: | |
for choice in chunk['choices']: | |
if 'delta' in choice and 'content' in choice['delta']: | |
new_token = choice['delta']['content'] | |
# Add the latest token: | |
response_message += new_token | |
# Update the assistant's response in our model: | |
messages[-1]['content'] = response_message | |
if 'finish_reason' in choice and choice['finish_reason'] == 'stop': | |
break | |
formatted_messages = create_formatted_messages(messages) | |
# TODO database | |
json.dump(messages, open(data_path, 'w')) | |
yield '', messages, formatted_messages | |
except Exception as e: | |
print(str(e)) | |
def clear(self): | |
json.dump([], open(data_path, 'w')) | |
return [], [] | |
def init_agent(self, openai_api_key): | |
self.openai = ChatOpenAI(openai_api_key=openai_api_key) | |
return gr.update(visible = True) | |
chatgpt = ChatGPT() | |
with gr.Blocks() as iface: | |
with gr.Row(): | |
openai_api_key_textbox = gr.Textbox( | |
placeholder='Paste your OpenAI API key here to start ChatGPT(sk-...) and press Enter', | |
label='OpenAI API key', | |
type='password', | |
) | |
# larger lenght will generate a more accuarte response but exponetially larger token size | |
message_length = gr.Number(value=5, label='History message length') | |
data = json.load(open(data_path)) if os.path.exists(data_path) else {} | |
chatbot = gr.Chatbot(value=create_formatted_messages(data)).style() | |
with gr.Row(visible=False) as input_raws: | |
with gr.Column(scale=.9): | |
txt = gr.Textbox(placeholder='Enter message here', show_label=False).style(container=False) | |
with gr.Column(scale=.1, min_width=0): | |
clear = gr.Button(value='Clear') | |
messages = gr.State(data or []) | |
txt.submit(chatgpt.chat, [txt, messages, message_length], [txt, messages, chatbot]) | |
clear.click(chatgpt.clear, [], [chatbot, messages]) | |
openai_api_key_textbox.submit(chatgpt.init_agent, [openai_api_key_textbox], [input_raws]) | |
iface.queue().launch() |