Spaces:
Build error
Build error
import json | |
import openai | |
import tiktoken | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
import pickle | |
COMPLETIONS_MODEL = "text-davinci-003" | |
EMBEDDING_MODEL = "text-embedding-ada-002" | |
with open("document_embeddings_clean_code.pickle", "rb") as document_embeddings_clean_code: | |
document_embeddings = pickle.load(document_embeddings_clean_code) | |
df = pd.read_csv('./clean_code_processed.csv') | |
df = df.set_index(["title", "section"]) | |
df = df[df.tokens > 40] | |
df_for_embeddings = df | |
def get_embedding(text, model=EMBEDDING_MODEL): | |
result = openai.Embedding.create( | |
model=model, | |
input=text | |
) | |
return result["data"][0]["embedding"] | |
def vector_similarity(x, y): | |
""" | |
Returns the similarity between two vectors. | |
Because OpenAI Embeddings are normalized to length 1, the cosine similarity is the same as the dot product. | |
""" | |
return np.dot(np.array(x), np.array(y)) | |
def order_document_sections_by_query_similarity(query, contexts): | |
""" | |
Find the query embedding for the supplied query, and compare it against all of the pre-calculated document embeddings | |
to find the most relevant sections. | |
Return the list of document sections, sorted by relevance in descending order. | |
""" | |
query_embedding = get_embedding(query) | |
document_similarities = sorted([ | |
(vector_similarity(query_embedding, doc_embedding), doc_index) for doc_index, doc_embedding in contexts.items() | |
], reverse=True) | |
return document_similarities | |
def construct_prompt(question, context_embeddings, df): | |
""" | |
Fetch relevant | |
""" | |
most_relevant_document_sections = order_document_sections_by_query_similarity(question, context_embeddings) | |
chosen_sections = [] | |
chosen_sections_len = 0 | |
chosen_sections_indexes = [] | |
for _, section_index in most_relevant_document_sections: | |
# Add contexts until we run out of space. | |
tokens = df._get_value(section_index, "tokens") | |
if type(tokens) != np.int64: | |
continue | |
chosen_sections_len += df._get_value(section_index, "tokens") + separator_len | |
if chosen_sections_len > MAX_SECTION_LEN: | |
break | |
chosen_sections.append(SEPARATOR + df._get_value(section_index, "content").replace("\n", " ")) | |
chosen_sections_indexes.append(section_index) | |
# Useful diagnostic information | |
print(f"Selected {len(chosen_sections)} document sections:") | |
print("\n".join(str(index) for index in chosen_sections_indexes)) | |
header = """Answer the question as truthfully as possible using the provided context, and if the answer is not contained within the text below, say "I don't know."\n\nContext:\n""" | |
return (header + "".join(chosen_sections) + "\n\n Q: " + question + "\n A:", chosen_sections_indexes) | |
MAX_SECTION_LEN = 2000 | |
SEPARATOR = "\n* " | |
ENCODING = "gpt2" # encoding for text-davinci-003 | |
encoding = tiktoken.get_encoding(ENCODING) | |
separator_len = len(encoding.encode(SEPARATOR)) | |
COMPLETIONS_API_PARAMS = { | |
# We use temperature of 0.0 because it gives the most predictable, factual answer. | |
"temperature": 0.0, | |
"max_tokens": 1500, | |
"model": COMPLETIONS_MODEL, | |
} | |
def answer_query_with_context( | |
query, | |
df, | |
document_embeddings | |
): | |
prompt, chosen_sections_indexes = construct_prompt( | |
query, | |
document_embeddings, | |
df | |
) | |
for i in range(len(chosen_sections_indexes)): | |
chosen_sections_indexes[i] = chosen_sections_indexes[i][0] | |
response = openai.Completion.create( | |
prompt=prompt, | |
**COMPLETIONS_API_PARAMS | |
) | |
return (response["choices"][0]["text"].strip(" \n"), chosen_sections_indexes) | |
def handle(question): | |
answer, related_documents = answer_query_with_context(question, df_for_embeddings, document_embeddings) | |
return answer + "\n\nRelated chapters:\n" + "\n".join(related_documents) | |
demo = gr.Interface( | |
fn=handle, | |
inputs="text", | |
outputs="text", | |
cache_examples=False, | |
examples=[ | |
"How to properly name a variable?", | |
"How to write a good comment?", | |
"What are best practices of unit testing?", | |
] | |
) | |
demo.launch() |