needs / app.py
ryanrwatkins's picture
Update app.py
8313f09
raw
history blame
No virus
7.76 kB
!git config --global user.email "rwatkins@gwu.edu"
!git config --global user.name "Ryan Watkins"
import gradio as gr
import openai
import requests
import csv
import os
import langchain
import chromadb
import glob
import pickle
import huggingface_hub
from huggingface_hub import Repository
from datetime import datetime
from PyPDF2 import PdfReader
from PyPDF2 import PdfWriter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader
from langchain.chains.question_answering import load_qa_chain
openai.api_key = os.environ['openai_key']
os.environ["OPENAI_API_KEY"] = os.environ['openai_key']
prompt_templates = {"All Needs Experts": "Respond as if you are combination of all needs assessment experts."}
actor_description = {"All Needs Experts": "<div style='float: left;margin: 0px 5px 0px 5px;'><img src='https://na.weshareresearch.com/wp-content/uploads/2023/04/experts2.jpg' alt='needs expert image' style='width:70px;align:top;'></div>A combiation of all needs assessment experts."}
prompts_archive_url = "https://huggingface.co/datasets/ryanrwatkins/na_prompts_archive"
prompts_archive_file_name = "prompts_archives.csv"
prompts_archive_file = os.path.join("prompts_archive", prompts_archive_file_name)
HF_TOKEN = os.environ.get("HF_token_write")
repo = Repository(
local_dir="prompts_archive", clone_from=prompts_archive_url, use_auth_token=HF_TOKEN
)
def get_empty_state():
return { "messages": []}
def download_prompt_templates():
url = "https://huggingface.co/spaces/ryanrwatkins/needs/raw/main/gurus.txt"
try:
response = requests.get(url)
reader = csv.reader(response.text.splitlines())
next(reader) # skip the header row
for row in reader:
if len(row) >= 2:
act = row[0].strip('"')
prompt = row[1].strip('"')
description = row[2].strip('"')
prompt_templates[act] = prompt
actor_description[act] = description
except requests.exceptions.RequestException as e:
print(f"An error occurred while downloading prompt templates: {e}")
return
choices = list(prompt_templates.keys())
choices = choices[:1] + sorted(choices[1:])
return gr.update(value=choices[0], choices=choices)
def on_prompt_template_change(prompt_template):
if not isinstance(prompt_template, str): return
return prompt_templates[prompt_template]
def on_prompt_template_change_description(prompt_template):
if not isinstance(prompt_template, str): return
return actor_description[prompt_template]
def submit_message(prompt, prompt_template, temperature, max_tokens, context_length, state):
history = state['messages']
if not prompt:
return gr.update(value=''), [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)], state
prompt_template = prompt_templates[prompt_template]
with open(prompts_archive_file, "a") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["prompt", "time"])
writer.writerow(
{"prompt": str(prompt), "time": str(datetime.now())}
)
commit_url = repo.push_to_hub()
print(commit_url)
system_prompt = []
if prompt_template:
system_prompt = [{ "role": "system", "content": prompt_template }]
prompt_msg = { "role": "user", "content": prompt }
#try:
with open("embeddings.pkl", 'rb') as f:
new_docsearch = pickle.load(f)
query = str(system_prompt + history + [prompt_msg])
docs = new_docsearch.similarity_search(query)
chain = load_qa_chain(ChatOpenAI(temperature=temperature, max_tokens=max_tokens, model_name="gpt-3.5-turbo"), chain_type="stuff")
completion = chain.run(input_documents=docs, question=query)
get_empty_state()
state['content'] = completion
#state.append(completion.copy())
completion = { "content": completion }
#state['total_tokens'] += completion['usage']['total_tokens']
#except Exception as e:
# history.append(prompt_msg.copy())
# error = {
# "role": "system",
# "content": f"Error: {e}"
# }
# history.append(error.copy())
#total_tokens_used_msg = f"Total tokens used: {state['total_tokens']}"
chat_messages = [(prompt_msg['content'], completion['content'])]
return '', chat_messages, state # total_tokens_used_msg,
def clear_conversation():
return gr.update(value=None, visible=True), None, "", get_empty_state()
css = """
#col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
#chatbox {min-height: 400px;}
#header {text-align: center;}
#prompt_template_preview {padding: 1em; border-width: 1px; border-style: solid; border-color: #e0e0e0; border-radius: 4px; min-height: 150px;}
#total_tokens_str {text-align: right; font-size: 0.8em; color: #666;}
#label {font-size: 0.8em; padding: 0.5em; margin: 0;}
.message { font-size: 1.2em; }
"""
with gr.Blocks(css=css) as demo:
state = gr.State(get_empty_state())
with gr.Column(elem_id="col-container"):
gr.Markdown("""## Ask questions of *needs assessment* experts,
## get responses from a *needs assessment experts* version of ChatGPT.
Ask questions of all of them, or pick your expert below.""" ,
elem_id="header")
with gr.Row():
with gr.Column():
chatbot = gr.Chatbot(elem_id="chatbox")
input_message = gr.Textbox(show_label=False, placeholder="Enter your needs assessment question", visible=True).style(container=False)
btn_submit = gr.Button("Submit")
#total_tokens_str = gr.Markdown(elem_id="total_tokens_str")
btn_clear_conversation = gr.Button("Start New Conversation")
with gr.Column():
prompt_template = gr.Dropdown(label="Choose an Expert:", choices=list(prompt_templates.keys()))
prompt_template_preview = gr.Markdown(elem_id="prompt_template_preview")
with gr.Accordion("Advanced parameters", open=False):
temperature = gr.Slider(minimum=0, maximum=2.0, value=0.7, step=0.1, label="Flexibility", info="Higher = More AI, Lower = More Expert")
max_tokens = gr.Slider(minimum=100, maximum=400, value=200, step=1, label="Length of Response.")
context_length = gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Context Length", info="Number of previous questions you have asked.")
btn_submit.click(submit_message, [ input_message, prompt_template, temperature, max_tokens, context_length, state], [input_message, chatbot, state])
input_message.submit(submit_message, [ input_message, prompt_template, temperature, max_tokens, context_length, state], [input_message, chatbot, state])
btn_clear_conversation.click(clear_conversation, [], [input_message, chatbot, state])
prompt_template.change(on_prompt_template_change_description, inputs=[prompt_template], outputs=[prompt_template_preview])
demo.load(download_prompt_templates, inputs=None, outputs=[prompt_template], queur=False)
demo.queue(concurrency_count=10)
demo.launch(height='800px')