Spaces:
Sleeping
Sleeping
from langchain.vectorstores.chroma import Chroma | |
from langchain.embeddings import OpenAIEmbeddings | |
from llamaapi import LlamaAPI | |
from langchain.prompts import ChatPromptTemplate | |
import os | |
# Constants | |
CHROMA_PATH = "chroma" | |
PROMPT_TEMPLATE = """ | |
Answer the question based only on the following context: | |
{context} | |
--- | |
Answer the question based on the above context: {question} | |
""" | |
def generate_reworded_question(prompt): | |
try: | |
llama = LlamaAPI('LL-0tVJ5OwMLdglnL5Okd94ScFHyT6FMPP33oClu8i5cXWPScRswldmqXI7VH1JaT3x') | |
# API Request | |
api_request_json = { | |
"model": "llama-13b-chat", | |
"messages": [ | |
{"role": "user", "content": prompt}, | |
], | |
"max_tokens": 250, # Set max_tokens to control the length of the generated question | |
"temperature": 0.1, # Adjust temperature to control the creativity of the generated question | |
"top_p": 0.9 # Adjust top_p to control the diversity of the generated question | |
} | |
# Run llama | |
response = llama.run(api_request_json) | |
response_json = response.json() | |
reworded_questions = [choice['message']['content'] for choice in response_json['choices']] | |
return reworded_questions | |
except Exception as e: | |
print(f"Error generating reworded questions: {e}") | |
return [] # Return an empty list if there's an error | |
def main(query_text): | |
# Prepare the DB. | |
embedding_function = OpenAIEmbeddings() | |
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function) | |
print('Query Text:', query_text) | |
# Search the DB. | |
results = db.similarity_search_with_relevance_scores(query_text, k=3) | |
if len(results) == 0 or results[0][1] < 0.7: | |
print('No results found or low score') | |
print (results[0][1]) | |
response_text = generate_reworded_question(query_text) | |
else: | |
print('Inside the high score condition') | |
print (results[0][1]) | |
print('Results:', results) | |
context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) | |
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE) | |
prompt = prompt_template.format(context=context_text, question=query_text) | |
#print (prompt) | |
response_text = generate_reworded_question(prompt) | |
sources = [doc.metadata.get("source", None) for doc, _score in results] | |
# Response Formatting | |
formatted_response = "\n".join(response_text) | |
formatted_response = formatted_response.replace('\n', '\n') # Double newline for \n | |
formatted_response = formatted_response.replace('\n*', '\n*') # Replace \n* with \n | |
final_out = f"{formatted_response}\n\nSources: {sources}" | |
#print(final_out) | |
return final_out | |
# Call the main function | |
if __name__ == "__main__": | |
query_text = "What are the specifications of Advanced Energy's LCM300." | |
main(query_text) |