needs / app.py
ryanrwatkins's picture
Update app.py
96637c9
raw
history blame
6.96 kB
import gradio as gr
import openai
import requests
import csv
import os
import langchain
import chromadb
import glob
import pickle
from PyPDF2 import PdfReader
from PyPDF2 import PdfWriter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader
from langchain.chains.question_answering import load_qa_chain
openai.api_key = os.environ['openai_key']
os.environ["OPENAI_API_KEY"] = os.environ['openai_key']
prompt_templates = {"All Needs Experts": "Respond as if you are combination of all needs assessment experts."}
actor_description = {"All Needs Experts": "<div style='float: left;margin: 0px 5px 0px 5px;'><img src='https://na.weshareresearch.com/wp-content/uploads/2023/04/experts2.jpg' alt='needs expert image' style='width:70px;align:top;'></div>A combiation of all needs assessment experts."}
def get_empty_state():
return {"total_tokens": 0, "messages": []}
def download_prompt_templates():
url = "https://huggingface.co/spaces/ryanrwatkins/needs/raw/main/gurus.txt"
try:
response = requests.get(url)
reader = csv.reader(response.text.splitlines())
next(reader) # skip the header row
for row in reader:
if len(row) >= 2:
act = row[0].strip('"')
prompt = row[1].strip('"')
description = row[2].strip('"')
prompt_templates[act] = prompt
actor_description[act] = description
except requests.exceptions.RequestException as e:
print(f"An error occurred while downloading prompt templates: {e}")
return
choices = list(prompt_templates.keys())
choices = choices[:1] + sorted(choices[1:])
return gr.update(value=choices[0], choices=choices)
def on_prompt_template_change(prompt_template):
if not isinstance(prompt_template, str): return
return prompt_templates[prompt_template]
def on_prompt_template_change_description(prompt_template):
if not isinstance(prompt_template, str): return
return actor_description[prompt_template]
def submit_message(prompt, prompt_template, temperature, max_tokens, context_length, state):
history = state['messages']
if not prompt:
return gr.update(value=''), [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)], f"Total tokens used: {state['total_tokens']}", state
prompt_template = prompt_templates[prompt_template]
system_prompt = []
if prompt_template:
system_prompt = [{ "role": "system", "content": prompt_template }]
prompt_msg = { "role": "user", "content": prompt }
#try:
with open("embeddings.pkl", 'rb') as f:
new_docsearch = pickle.load(f)
query = str(system_prompt + history + [prompt_msg])
docs = new_docsearch.similarity_search(query)
chain = load_qa_chain(ChatOpenAI(temperature=temperature, max_tokens=max_tokens, model_name="gpt-3.5-turbo"), chain_type="stuff")
completion = chain.run(input_documents=docs, question=query)
completion = { "content": completion.content }
get_empty_state()
state.append(completion.copy())
state['total_tokens'] += completion['usage']['total_tokens']
#except Exception as e:
# history.append(prompt_msg.copy())
# error = {
# "role": "system",
# "content": f"Error: {e}"
# }
# history.append(error.copy())
total_tokens_used_msg = f"Total tokens used: {state['total_tokens']}"
chat_messages = [(prompt_msg['content'], completion['content'])]
return '', chat_messages, total_tokens_used_msg, state
def clear_conversation():
return gr.update(value=None, visible=True), None, "", get_empty_state()
css = """
#col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
#chatbox {min-height: 400px;}
#header {text-align: center;}
#prompt_template_preview {padding: 1em; border-width: 1px; border-style: solid; border-color: #e0e0e0; border-radius: 4px; min-height: 150px;}
#total_tokens_str {text-align: right; font-size: 0.8em; color: #666;}
#label {font-size: 0.8em; padding: 0.5em; margin: 0;}
.message { font-size: 1.2em; }
"""
with gr.Blocks(css=css) as demo:
state = gr.State(get_empty_state())
with gr.Column(elem_id="col-container"):
gr.Markdown("""## Ask questions of *needs assessment* experts,
## get responses from a *needs assessment experts* version of ChatGPT.
Ask questions of all of them, or pick your expert below.""" ,
elem_id="header")
with gr.Row():
with gr.Column():
chatbot = gr.Chatbot(elem_id="chatbox")
input_message = gr.Textbox(show_label=False, placeholder="Enter your needs assessment question", visible=True).style(container=False)
btn_submit = gr.Button("Submit")
total_tokens_str = gr.Markdown(elem_id="total_tokens_str")
btn_clear_conversation = gr.Button("Start New Conversation")
with gr.Column():
prompt_template = gr.Dropdown(label="Choose an Expert:", choices=list(prompt_templates.keys()))
prompt_template_preview = gr.Markdown(elem_id="prompt_template_preview")
with gr.Accordion("Advanced parameters", open=False):
temperature = gr.Slider(minimum=0, maximum=2.0, value=0.7, step=0.1, label="Flexibility", info="Higher = More AI, Lower = More Expert")
max_tokens = gr.Slider(minimum=100, maximum=400, value=200, step=1, label="Length of Response.")
context_length = gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Context Length", info="Number of previous questions you have asked.")
btn_submit.click(submit_message, [ input_message, prompt_template, temperature, max_tokens, context_length, state], [input_message, chatbot, total_tokens_str, state])
input_message.submit(submit_message, [ input_message, prompt_template, temperature, max_tokens, context_length, state], [input_message, chatbot, total_tokens_str, state])
btn_clear_conversation.click(clear_conversation, [], [input_message, chatbot, total_tokens_str, state])
prompt_template.change(on_prompt_template_change_description, inputs=[prompt_template], outputs=[prompt_template_preview])
demo.load(download_prompt_templates, inputs=None, outputs=[prompt_template], queur=False)
demo.queue(concurrency_count=10)
demo.launch(height='800px')