needs / app.py
ryanrwatkins's picture
Update app.py
57ff103 verified
raw
history blame
No virus
10.5 kB
import gradio as gr
import openai
import requests
import csv
import os
import langchain
import chromadb
import glob
import pickle
import huggingface_hub
from huggingface_hub import Repository
from datetime import datetime
from PyPDF2 import PdfReader
from PyPDF2 import PdfWriter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader
from langchain.chains.question_answering import load_qa_chain
from langchain_google_genai import ChatGoogleGenerativeAI
# turned off due to people using it unethical ways
openai.api_key = os.environ['openai_key']
os.environ["OPENAI_API_KEY"] = os.environ['openai_key']
gemini.api_key = os.environ['gemini_key']
os.environ["GEMINI_API_KEY"] = os.environ['gemini_key']
prompt_templates = {"All Needs Experts": "Respond as if you are combination of all needs assessment experts."}
actor_description = {"All Needs Experts": "<div style='float: left;margin: 0px 5px 0px 5px;'><img src='https://na.weshareresearch.com/wp-content/uploads/2023/04/experts2.jpg' alt='needs expert image' style='width:70px;align:top;'></div>A combiation of all needs assessment experts."}
#repo_url = create_repo(repo_id="prompts_archive")
#prompts_archive_url = "https://huggingface.co/datasets/ryanrwatkins/prompts_archive"
#prompts_archive_file_name = "prompts_archive.csv"
#prompts_archive_file = os.path.join("prompts_archive", prompts_archive_file_name)
#print(prompts_archive_file)
#HF_TOKEN = os.environ.get("HF_token_write")
#repo = Repository(
# local_dir="prompts_archive", clone_from=repo_url, use_auth_token=HF_TOKEN, git_user="ryanrwatkins", git_email="rwatkins@gwu.edu"
#)
def get_empty_state():
return { "messages": []}
def download_prompt_templates():
url = "https://huggingface.co/spaces/ryanrwatkins/needs/raw/main/gurus.txt"
try:
response = requests.get(url)
reader = csv.reader(response.text.splitlines())
next(reader) # skip the header row
for row in reader:
if len(row) >= 2:
act = row[0].strip('"')
prompt = row[1].strip('"')
description = row[2].strip('"')
prompt_templates[act] = prompt
actor_description[act] = description
except requests.exceptions.RequestException as e:
print(f"An error occurred while downloading prompt templates: {e}")
return
choices = list(prompt_templates.keys())
choices = choices[:1] + sorted(choices[1:])
return gr.update(value=choices[0], choices=choices)
def on_prompt_template_change(prompt_template):
if not isinstance(prompt_template, str): return
return prompt_templates[prompt_template]
def on_prompt_template_change_description(prompt_template):
if not isinstance(prompt_template, str): return
return actor_description[prompt_template]
def submit_message(prompt, prompt_template, temperature, max_tokens, context_length, state):
history = state['messages']
if not prompt:
return gr.update(value=''), [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)], state
prompt_template = prompt_templates[prompt_template]
with open("prompts_archive.csv", "a") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["prompt", "time"])
writer.writerow(
{"prompt": str(prompt), "time": str(datetime.now())}
)
system_prompt = []
if prompt_template:
system_prompt = [{ "role": "system", "content": prompt_template }]
prompt_msg = { "role": "user", "content": prompt }
#try:
with open("embeddings.pkl", 'rb') as f:
new_docsearch = pickle.load(f)
query = str(system_prompt + history + [prompt_msg])
docs = new_docsearch.similarity_search(query)
gen_ai = GenerativeAI.get_client(model="gemini-1.0-pro")
response = gen_ai.start_chat(
messages=query, # Pass both history and current prompt
max_tokens=max_tokens,
temperature=temperature # Adjust temperature as needed
)
completion = response.messages[-1] # Extract the completion message
get_empty_state()
state['content'] = completion
#state.append(completion.copy())
completion = { "content": completion }
#state['total_tokens'] += completion['usage']['total_tokens']
#except Exception as e:
# history.append(prompt_msg.copy())
# error = {
# "role": "system",
# "content": f"Error: {e}"
# }
# history.append(error.copy())
#total_tokens_used_msg = f"Total tokens used: {state['total_tokens']}"
chat_messages = [(prompt_msg['content'], completion['content'])]
return '', chat_messages, state # total_tokens_used_msg,
def submit_message_OLD(prompt, prompt_template, temperature, max_tokens, context_length, state):
history = state['messages']
if not prompt:
return gr.update(value=''), [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)], state
prompt_template = prompt_templates[prompt_template]
with open("prompts_archive.csv", "a") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["prompt", "time"])
writer.writerow(
{"prompt": str(prompt), "time": str(datetime.now())}
)
# with open(prompts_archive_file, "a") as csvfile:
# writer = csv.DictWriter(csvfile, fieldnames=["prompt", "time"])
# writer.writerow(
# {"prompt": str(prompt), "time": str(datetime.now())}
# )
# commit_url = repo.push_to_hub()
# print(commit_url)
system_prompt = []
if prompt_template:
system_prompt = [{ "role": "system", "content": prompt_template }]
prompt_msg = { "role": "user", "content": prompt }
#try:
with open("embeddings.pkl", 'rb') as f:
new_docsearch = pickle.load(f)
query = str(system_prompt + history + [prompt_msg])
docs = new_docsearch.similarity_search(query)
chain = load_qa_chain(ChatOpenAI(temperature=temperature, max_tokens=max_tokens, model_name="gpt-3.5-turbo"), chain_type="stuff")
#completion = chain.run(input_documents=docs, question=query)
get_empty_state()
state['content'] = completion
#state.append(completion.copy())
completion = { "content": completion }
#state['total_tokens'] += completion['usage']['total_tokens']
#except Exception as e:
# history.append(prompt_msg.copy())
# error = {
# "role": "system",
# "content": f"Error: {e}"
# }
# history.append(error.copy())
#total_tokens_used_msg = f"Total tokens used: {state['total_tokens']}"
chat_messages = [(prompt_msg['content'], completion['content'])]
return '', chat_messages, state # total_tokens_used_msg,
def clear_conversation():
return gr.update(value=None, visible=True), None, "", get_empty_state()
css = """
#col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
#chatbox {min-height: 400px;}
#header {text-align: center;}
#prompt_template_preview {padding: 1em; border-width: 1px; border-style: solid; border-color: #e0e0e0; border-radius: 4px; min-height: 150px;}
#total_tokens_str {text-align: right; font-size: 0.8em; color: #666;}
#label {font-size: 0.8em; padding: 0.5em; margin: 0;}
.message { font-size: 1.2em; }
"""
with gr.Blocks(css=css) as demo:
state = gr.State(get_empty_state())
with gr.Column(elem_id="col-container"):
gr.Markdown("""## Ask questions of *needs assessment* experts,
## get responses from a *needs assessment experts* version of ChatGPT.
Ask questions of all of them, or pick your expert below.
This is a free resource but it does cost us money to run. Unfortunately someone has been abusing this approach.
In response, we have had to temporarily turn it off until we can put improve the monitoring. Sorry for the inconvenience.""" ,
elem_id="header")
with gr.Row():
with gr.Column():
chatbot = gr.Chatbot(elem_id="chatbox")
input_message = gr.Textbox(show_label=False, placeholder="Enter your needs assessment question", visible=True).style(container=False)
btn_submit = gr.Button("Submit")
#total_tokens_str = gr.Markdown(elem_id="total_tokens_str")
btn_clear_conversation = gr.Button("Start New Conversation")
with gr.Column():
prompt_template = gr.Dropdown(label="Choose an Expert:", choices=list(prompt_templates.keys()))
prompt_template_preview = gr.Markdown(elem_id="prompt_template_preview")
with gr.Accordion("Advanced parameters", open=False):
temperature = gr.Slider(minimum=0, maximum=2.0, value=0.7, step=0.1, label="Flexibility", info="Higher = More AI, Lower = More Expert")
max_tokens = gr.Slider(minimum=100, maximum=400, value=200, step=1, label="Length of Response.")
context_length = gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Context Length", info="Number of previous questions you have asked.")
btn_submit.click(submit_message, [ input_message, prompt_template, temperature, max_tokens, context_length, state], [input_message, chatbot, state])
input_message.submit(submit_message, [ input_message, prompt_template, temperature, max_tokens, context_length, state], [input_message, chatbot, state])
btn_clear_conversation.click(clear_conversation, [], [input_message, chatbot, state])
prompt_template.change(on_prompt_template_change_description, inputs=[prompt_template], outputs=[prompt_template_preview])
demo.load(download_prompt_templates, inputs=None, outputs=[prompt_template], queur=False)
demo.queue(concurrency_count=10)
demo.launch(height='800px')