from langchain.vectorstores.chroma import Chroma from langchain.embeddings import OpenAIEmbeddings from llamaapi import LlamaAPI from langchain.prompts import ChatPromptTemplate import os # Constants CHROMA_PATH = "chroma" PROMPT_TEMPLATE = """ Answer the question based only on the following context: {context} --- Answer the question based on the above context: {question} """ def generate_reworded_question(prompt): try: llama = LlamaAPI('LL-0tVJ5OwMLdglnL5Okd94ScFHyT6FMPP33oClu8i5cXWPScRswldmqXI7VH1JaT3x') # API Request api_request_json = { "model": "llama-13b-chat", "messages": [ {"role": "user", "content": prompt}, ], "max_tokens": 250, # Set max_tokens to control the length of the generated question "temperature": 0.1, # Adjust temperature to control the creativity of the generated question "top_p": 0.9 # Adjust top_p to control the diversity of the generated question } # Run llama response = llama.run(api_request_json) response_json = response.json() reworded_questions = [choice['message']['content'] for choice in response_json['choices']] return reworded_questions except Exception as e: print(f"Error generating reworded questions: {e}") return [] # Return an empty list if there's an error def main(query_text): # Prepare the DB. embedding_function = OpenAIEmbeddings() db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function) print('Query Text:', query_text) # Search the DB. results = db.similarity_search_with_relevance_scores(query_text, k=3) if len(results) == 0 or results[0][1] < 0.7: print('No results found or low score') print (results[0][1]) response_text = generate_reworded_question(query_text) else: print('Inside the high score condition') print (results[0][1]) print('Results:', results) context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE) prompt = prompt_template.format(context=context_text, question=query_text) #print (prompt) response_text = generate_reworded_question(prompt) sources = [doc.metadata.get("source", None) for doc, _score in results] # Response Formatting formatted_response = "\n".join(response_text) formatted_response = formatted_response.replace('\n', '\n') # Double newline for \n formatted_response = formatted_response.replace('\n*', '\n*') # Replace \n* with \n final_out = f"{formatted_response}\n\nSources: {sources}" #print(final_out) return final_out # Call the main function if __name__ == "__main__": query_text = "What are the specifications of Advanced Energy's LCM300." main(query_text)