RAG-Test / query_data.py
satendra4u2022's picture
Update query_data.py
ab3b5dc verified
from langchain.vectorstores.chroma import Chroma
from langchain.embeddings import OpenAIEmbeddings
from llamaapi import LlamaAPI
from langchain.prompts import ChatPromptTemplate
import os
# Constants
CHROMA_PATH = "chroma"
PROMPT_TEMPLATE = """
Answer the question based only on the following context:
{context}
---
Answer the question based on the above context: {question}
"""
def generate_reworded_question(prompt):
try:
llama = LlamaAPI('LL-0tVJ5OwMLdglnL5Okd94ScFHyT6FMPP33oClu8i5cXWPScRswldmqXI7VH1JaT3x')
# API Request
api_request_json = {
"model": "llama-13b-chat",
"messages": [
{"role": "user", "content": prompt},
],
"max_tokens": 250, # Set max_tokens to control the length of the generated question
"temperature": 0.1, # Adjust temperature to control the creativity of the generated question
"top_p": 0.9 # Adjust top_p to control the diversity of the generated question
}
# Run llama
response = llama.run(api_request_json)
response_json = response.json()
reworded_questions = [choice['message']['content'] for choice in response_json['choices']]
return reworded_questions
except Exception as e:
print(f"Error generating reworded questions: {e}")
return [] # Return an empty list if there's an error
def main(query_text):
# Prepare the DB.
embedding_function = OpenAIEmbeddings()
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
print('Query Text:', query_text)
# Search the DB.
results = db.similarity_search_with_relevance_scores(query_text, k=3)
if len(results) == 0 or results[0][1] < 0.7:
print('No results found or low score')
print (results[0][1])
response_text = generate_reworded_question(query_text)
else:
print('Inside the high score condition')
print (results[0][1])
print('Results:', results)
context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
prompt = prompt_template.format(context=context_text, question=query_text)
#print (prompt)
response_text = generate_reworded_question(prompt)
sources = [doc.metadata.get("source", None) for doc, _score in results]
# Response Formatting
formatted_response = "\n".join(response_text)
formatted_response = formatted_response.replace('\n', '\n') # Double newline for \n
formatted_response = formatted_response.replace('\n*', '\n*') # Replace \n* with \n
final_out = f"{formatted_response}\n\nSources: {sources}"
#print(final_out)
return final_out
# Call the main function
if __name__ == "__main__":
query_text = "What are the specifications of Advanced Energy's LCM300."
main(query_text)