Spaces:
Running
Running
File size: 4,433 Bytes
287a0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import argparse
import os
from typing import List
import google.generativeai as genai
import chromadb
from chromadb.utils import embedding_functions
model = genai.GenerativeModel("gemini-pro")
def build_prompt(query: str, context: List[str]) -> str:
"""
Builds a prompt for the LLM. #
This function builds a prompt for the LLM. It takes the original query,
and the returned context, and asks the model to answer the question based only
on what's in the context, not what's in its weights.
Args:
query (str): The original query.
context (List[str]): The context of the query, returned by embedding search.
Returns:
A prompt for the LLM (str).
"""
base_prompt = {
"content": "I am going to ask you a question, which I would like you to answer"
" based only on the provided context, and not any other information."
" If there is not enough information in the context to answer the question,"
' say "I am not sure", then try to make a guess.'
" Break your answer up into nicely readable paragraphs.",
}
user_prompt = {
"content": f" The question is '{query}'. Here is all the context you have:"
f'{(" ").join(context)}',
}
# combine the prompts to output a single prompt string
system = f"{base_prompt['content']} {user_prompt['content']}"
return system
def get_gemini_response(query: str, context: List[str]) -> str:
"""
Queries the Gemini API to get a response to the question.
Args:
query (str): The original query.
context (List[str]): The context of the query, returned by embedding search.
Returns:
A response to the question.
"""
response = model.generate_content(build_prompt(query, context))
return response.text
def main(
collection_name: str = "documents_collection", persist_directory: str = "."
) -> None:
# Check if the GOOGLE_API_KEY environment variable is set. Prompt the user to set it if not.
google_api_key = None
if "GOOGLE_API_KEY" not in os.environ:
gapikey = input("Please enter your Google API Key: ")
genai.configure(api_key=gapikey)
google_api_key = gapikey
else:
google_api_key = os.environ["GOOGLE_API_KEY"]
# Instantiate a persistent chroma client in the persist_directory.
# This will automatically load any previously saved collections.
# Learn more at docs.trychroma.com
client = chromadb.PersistentClient(path=persist_directory)
# create embedding function
embedding_function = embedding_functions.GoogleGenerativeAIEmbeddingFunction(api_key=google_api_key, task_type="RETRIEVAL_QUERY")
# Get the collection.
collection = client.get_collection(
name=collection_name, embedding_function=embedding_function
)
# We use a simple input loop.
while True:
# Get the user's query
query = input("Query: ")
if len(query) == 0:
print("Please enter a question. Ctrl+C to Quit.\n")
continue
print("\nThinking...\n")
# Query the collection to get the 5 most relevant results
results = collection.query(
query_texts=[query], n_results=5, include=["documents", "metadatas"]
)
sources = "\n".join(
[
f"{result['filename']}: line {result['line_number']}"
for result in results["metadatas"][0] # type: ignore
]
)
# Get the response from Gemini
response = get_gemini_response(query, results["documents"][0]) # type: ignore
# Output, with sources
print(response)
print("\n")
print(f"Source documents:\n{sources}")
print("\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Load documents from a directory into a Chroma collection"
)
parser.add_argument(
"--persist_directory",
type=str,
default="chroma_storage",
help="The directory where you want to store the Chroma collection",
)
parser.add_argument(
"--collection_name",
type=str,
default="documents_collection",
help="The name of the Chroma collection",
)
# Parse arguments
args = parser.parse_args()
main(
collection_name=args.collection_name,
persist_directory=args.persist_directory,
)
|