File size: 3,816 Bytes
fd5f784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_community.vectorstores import FAISS
from langchain.chains import create_history_aware_retriever
from langchain.chains import create_retrieval_chain
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from dotenv import load_dotenv
import os
CHROMA_PATH = "chroma"

load_dotenv()
API_KEY = os.getenv("OPEN_AI_KEY")
def build_rag_chain(api_key):
    embed = OpenAIEmbeddings(
        api_key=api_key,
        model="text-embedding-3-large"
    )
    db = Chroma(
            collection_name="linux_funds",
            embedding_function=embed,
            persist_directory=CHROMA_PATH
        )
    retriever = db.as_retriever(
        search_type="similarity_score_threshold",
        search_kwargs={"k": 4, "score_threshold": 0.3},
    )
    model = ChatOpenAI(api_key=api_key, model="gpt-4o")
    # docs = retriever.invoke(test_query)
    # print("\n--- RELEVANT DOCUMENTS ---")
    # for i, doc in enumerate(docs, 1):
    #     print(f"Document {i}:\n{doc.page_content}\n")
    #     if doc.metadata:
    #         print(f"Source: {doc.metadata.get('source', 'Unknown')}\n")
    context = (
        "Given a chat history and the latest user question "
        "which might reference context in the chat history, "
        "formulate a standalone question which can be understood "
        "without the chat history. Do NOT answer the question, just "
        "reformulate it if needed and otherwise return it as is."
    )
    context_with_history = ChatPromptTemplate(
        [
            ("system", context),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ]
    )
    history_aware_retriever = create_history_aware_retriever(
        model, retriever, context_with_history
    )
    main_query = (
        "You are an assistant for question-answering tasks. Use"
        "the following pieces of retrieved context to answer the "
        "question. If you don't know the answer, just say "
        "you don't know. Use 10 sentences maximum and keep the answer "
        "concise. You will most likely have to write bash scripts, so make"
        " this presentable on HuggingFace in markdown if needed."
        "\n\n"
        "{context}"
    )
    prompt = ChatPromptTemplate(
        [
            ("system", main_query),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ]
    )
    qna_chain = create_stuff_documents_chain(model, prompt)
    rag_chain = create_retrieval_chain(history_aware_retriever, qna_chain)
    return rag_chain

def chat():
    print("Start asking about the Theory of Computation. Type 'exit' to end the conversation.")
    chat_history = []

    while True:
        query = input("You: ")
        if query.lower() == "exit":
            break
        rag_chain = build_rag_chain(API_KEY)
        result = rag_chain.invoke({"input": query, "chat_history": chat_history})
        print(f"AI: {result['answer']}")
        chat_history.append(HumanMessage(content=query))
        chat_history.append(SystemMessage(content=result["answer"]))

# ABOVE IS FOR LOCAL TESTING ONLY ^ ONLY KEEPING IT FOR FUTURE USE


    # messages = [
    #     SystemMessage(content="You are a helpful assistant."),
    #     HumanMessage(content=query_input),
    # ]
    #
    # result = model.invoke(messages)
    #
    # print("\n--- Generated Response ---")
    # print("Result:")
    # print(result)
    # print("Content only:")
    # print(result.content)

if __name__ == "__main__":
    chat()