Spaces:
Sleeping
Sleeping
# Import the necessary libraries | |
import streamlit as st | |
from openai import OpenAI | |
from pinecone import Pinecone | |
import os | |
import pandas as pd | |
import numpy as np | |
#pinecone_key_file = "pinecone_api.txt" | |
#with open(pinecone_key_file, "r") as f: | |
# for line in f: | |
# PINECONE_KEY = line | |
# break | |
pc = Pinecone(api_key=os.environ.get("PINECONE_KEY")) | |
#with open('open_ai_key.txt', "r") as f: | |
# for line in f: | |
# OPENAI_KEY = line | |
# break | |
client = OpenAI(api_key=os.environ.get("OPENAI_KEY")) | |
st.title("Seattle Pandas Super Duper ML Chatbot") | |
class Obnoxious_Agent: | |
def __init__(self, client) -> None: | |
self.client=client | |
def set_prompt(self, query): | |
prompt=f'''Is this query obnoxious, related to machine learning, or general greetings? | |
Answer "obnoxious" if it is an obnoxious query, answer "machine learning" if it is related to machine learning, | |
"general greetings" if it is a general greeting, and "others" for all other queries. When considering whether | |
a query is related to machine learning, be sure to pay attention to common machine learning acronyms (RNN, CNN, CV, GAN) | |
and also consider topics from emerging fields like computer vision, deep learning, AI content generation, and others. | |
Examples are included. | |
"Query: You are stupid ; Answer: obnoxious" | |
"Query: poop; Answer: obnoxious" | |
"Query: kdkdkspapemrmn ; Answer: obnoxious" | |
"Query: What is a random forest? ; Answer: machine learning" | |
"Query: How to train a model using a GPU? ; Answer: machine learning" | |
"Query: What is a CNN? ; Answer: machine learning" | |
"Query: RNN? ; Answer: machine learning" | |
"Query: Causal inference? ; Answer: machine learning" | |
"Query: What is computer vision or CV? ; Answer: machine learning" | |
"Query: What is computer vision or CV? ; Answer: machine learning" | |
"Query: How are you? ; Answer: general greetings" | |
"Query: I like shoes; Answer: other" | |
Query: {query}''' | |
return prompt | |
def extract_action(self, response) -> bool: | |
if 'obnoxious' in response.lower(): | |
return 'obnoxious' | |
elif 'general greetings' in response.lower(): | |
return 'gt' | |
elif 'machine learning' in response.lower(): | |
return 'ml' | |
else: | |
return 'other' | |
def check_query(self, query): | |
prompt=self.set_prompt(query) | |
# print(prompt) | |
message = {"role": "user", "content": prompt} | |
response = self.client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[message] | |
) | |
# print(response) | |
return self.extract_action(response.choices[0].message.content) | |
class Query_Agent: | |
def __init__(self, client, index='llm-chatbot-index') -> None: | |
self.pc = Pinecone(api_key=os.environ.get("PINECONE_KEY")) | |
self.index = self.pc.Index(index) | |
self.client = client | |
self.df = pd.read_csv("text_embedding.csv") | |
self.texts_size_250 = np.array(self.df['Text']) | |
# this assumes that an index is already there and has been onboarded, etc. | |
def get_embedding(self, text, model="text-embedding-ada-002"): | |
text = text.replace("\n", " ") | |
return self.client.embeddings.create(input = [text], model=model).data[0].embedding | |
############################# | |
## TODO: Function to query the Pinecone vector store and return the top-k results | |
def send_pinecone_query(self, query, top_k=5, namespace="250_chunk"): | |
return self.index.query( | |
vector=query, | |
top_k=top_k, | |
namespace = namespace) | |
############################# | |
def query_vector_store(self, query, top_k=5): | |
e = self.get_embedding(query) | |
relevant = self.send_pinecone_query(e, top_k) | |
scores = 0 | |
context = "" | |
for result in relevant["matches"]: | |
scores+=float(result['score']) | |
context+=self.texts_size_250[int(result['id'])] | |
context+='\n' | |
return context, scores/top_k | |
def set_prompt(self, query, context): | |
prompt = "Given the following context, explain "+query+": " + context | |
return prompt | |
def extract_action(self, response, query = None): | |
context, avg_score=self.query_vector_store(query) | |
if avg_score<0.3: | |
return 'non-relevant' | |
else: | |
return self.set_prompt(query, context) | |
class Answering_Agent: | |
def __init__(self, openai_client, mode) -> None: | |
# TODO: Initialize the Answering_Agent | |
self.client = openai_client | |
self.mode = mode | |
def generate_response(self, query, docs, conv_history, k=5): | |
prompt = f'''You are a {self.mode} chatbot. Answer all queries in a {self.mode} style. | |
I will provide a user query you must answer, relevant documents which you | |
should reference in your answer, and conversation history which you should | |
refer to for context. | |
Query: {query} | |
Conversation History: {conv_history} | |
Relevant Documents: {docs} | |
''' | |
message = {"role": "user", "content": prompt} | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[message] | |
) | |
return response.choices[0].message.content | |
class Relevant_Documents_Agent: | |
def __init__(self, client) -> None: | |
self.client = client | |
# TODO: Initialize the Relevant_Documents_Agent | |
def get_relevance(self, query, documents, cosine_similarity) -> str: | |
# TODO: Get if the returned documents are relevant | |
prompt = f'''Based on the following query, please decide if the following | |
documents are relevant to this query. For context, the average cosine similarity of these documents to this | |
query is {cosine_similarity}. Your response must be one of the two following [relevant, non-relevant]. | |
Query: {query} | |
Documents: {documents}''' | |
print("USER PROMPT:", prompt) | |
message = {"role": "user", "content": prompt} | |
response = self.client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[message] | |
) | |
rel_response = response.choices[0].message.content | |
print("Relevance: ", rel_response) | |
return rel_response | |
class Head_Agent: | |
def __init__(self, mode) -> None: | |
self.client = OpenAI(api_key=os.environ.get("OPENAI_KEY")) | |
self.possible_modes = ['verbose', 'concise', 'shakespearean'] | |
self.mode = mode | |
self.setup_sub_agents() | |
with st.chat_message("assistant"): | |
st.write(f"Welcome to your {self.mode} chatbot!") | |
def setup_sub_agents(self): | |
self.obnoxious_agent=Obnoxious_Agent(self.client) | |
self.query_agent=Query_Agent(self.client) | |
self.answering_agent = Answering_Agent(self.client, self.mode) | |
self.relevance_agent = Relevant_Documents_Agent(self.client) | |
def evaluate_mode(self, query): | |
prompt=f'''Classify the following query to see if it most closely matches | |
an item in this list {self.possible_modes}. | |
Your response MUST be a single word from that list only. Query: {query}''' | |
# print(prompt) | |
message = {"role": "user", "content": prompt} | |
response = self.client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[message] | |
) | |
# print(response) | |
return response.choices[0].message.content | |
def main_loop(self): | |
# TODO: Run the main loop for the chatbot | |
if "openai_model" not in st.session_state: | |
st.session_state["openai_model"] = "gpt-3.5-turbo" | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if prompt := st.chat_input("Hi, how can I help you?"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# first check if prompt is obnoxious | |
obnoxious=self.obnoxious_agent.check_query(prompt) | |
print(obnoxious) | |
if obnoxious=='obnoxious': | |
with st.chat_message("assistant"): | |
response="Please refrain from obnoxious questions." | |
st.write(response) | |
elif obnoxious=='gt': | |
with st.chat_message("assistant"): | |
response="How can I assist you today?" | |
st.write(response) | |
# elif obnoxious == 'other': | |
#with st.chat_message("assistant"): | |
# response="I can only answer questions about machine learning!" | |
# st.write(response) | |
else: | |
#next check if prompt if relevant | |
docs, cosine_similarity = self.query_agent.query_vector_store(prompt) | |
response=self.relevance_agent.get_relevance(prompt, docs, cosine_similarity) | |
# make this be relevance agent | |
if 'non-relevant' in response.lower(): | |
if obnoxious == 'ml': | |
with st.chat_message("assistant"): | |
response = self.answering_agent.generate_response(prompt, '', st.session_state['messages']) | |
st.write(response) | |
else: | |
with st.chat_message("assistant"): | |
response="Please ask questions only related to Machine Learning!" | |
st.write(response) | |
else: | |
with st.chat_message("assistant"): | |
response = self.answering_agent.generate_response(prompt, docs, st.session_state['messages']) | |
st.write(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
if "mode" not in st.session_state: | |
st.session_state.mode = "Concise" | |
head_agent=Head_Agent(st.session_state.mode) | |
def set_mode(): | |
head_agent.answering_agent = Answering_Agent(head_agent.client, st.session_state.mode) | |
st.session_state.mode = st.selectbox( | |
'What kind of chatbot would you like today?', | |
('Concise', 'Chatty', 'Shakespearean'), on_change=set_mode) | |
st.write('You selected:', st.session_state.mode) | |
head_agent.main_loop() | |