import sklearn import sqlite3 import numpy as np from sklearn.metrics.pairwise import cosine_similarity import openai import os import gradio as gr # Set OpenAI API key from environment variable openai.api_key = os.environ["Secret"] def find_closest_neighbors(vector1, dictionary_of_vectors): vector = openai.Embedding.create( input=vector1, engine="text-embedding-ada-002" )['data'][0]['embedding'] vector = np.array(vector) cosine_similarities = {} for key, value in dictionary_of_vectors.items(): cosine_similarities[key] = cosine_similarity(vector.reshape(1, -1), value.reshape(1, -1))[0][0] sorted_cosine_similarities = sorted(cosine_similarities.items(), key=lambda x: x[1], reverse=True) return sorted_cosine_similarities[0:4] def predict(message, history): # Connect to the database conn = sqlite3.connect('text_chunks_with_embeddings.db') # Update the database name cursor = conn.cursor() cursor.execute("SELECT text, embedding FROM chunks") rows = cursor.fetchall() dictionary_of_vectors = {} for row in rows: text = row[0] embedding_str = row[1] embedding = np.fromstring(embedding_str, sep=' ') dictionary_of_vectors[text] = embedding conn.close() match_list = find_closest_neighbors(message, dictionary_of_vectors) context = '' for match in match_list: context += str(match[0]) context = context[:1500] # Limit context to 1500 characters prep = f"This is an OpenAI model designed to answer questions specific to grant-making applications for an aquarium. Here is some question-specific context: {context}. Q: {message} A: " history_openai_format = [] for human, assistant in history: history_openai_format.append({"role": "user", "content": human}) history_openai_format.append({"role": "assistant", "content": assistant}) history_openai_format.append({"role": "user", "content": prep}) response = openai.ChatCompletion.create( model='gpt-4', messages=history_openai_format, temperature=1.0, stream=True ) partial_message = "" for chunk in response: if len(chunk['choices'][0]['delta']) != 0: partial_message += chunk['choices'][0]['delta']['content'] yield partial_message gr.ChatInterface(predict).queue().launch()