File size: 5,004 Bytes
56623dc
 
 
 
 
 
 
 
bb39d84
 
56623dc
 
 
 
 
 
 
 
 
 
 
 
bb39d84
 
 
 
 
 
 
 
56623dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_ollama import OllamaLLM
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from fastapi.middleware.cors import CORSMiddleware


import traceback

# from langchain_core.output_parsers import StrOutputParser
# from langchain_core.runnables import RunnablePassthrough

import os

os.environ["HF_HOME"] = "/tmp/huggingface"

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

# Load and split documents
loader = TextLoader("knowledge_base.txt", encoding="utf-8")
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=50,
    separators=["\n\n", "\n", ".", "!", "?", "،", "؟", "!", ";", ","],
)
splits = text_splitter.split_documents(documents)

# Generate embeddings and store in FAISS
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
vectorstore = FAISS.from_documents(splits, embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5, "score_threshold": 0.4})

# Define improved prompt template
template = """
You are an AI assistant. You must ALWAYS respond in the EXACT SAME LANGUAGE as the user's question or message. This is crucial:
- If the user writes in English, you MUST respond in English
- If the user writes in Arabic, you MUST respond in Arabic (Modern Standard Arabic)
- Mixed language messages should get responses in the predominant language of the message
Conversation history:
{history}
Relevant information from knowledge base:
{context}
User's message: {question}
Key requirements:
1. MATCH THE LANGUAGE OF THE USER'S MESSAGE EXACTLY
2. Use the provided context and history to answer the question
3. Maintain your identity as an AI assistant
4. Never pretend to be the user or adopt their name
5. For greetings and casual conversation, respond naturally without using the knowledge base
6. Only use the knowledge base content when directly relevant to a specific question
Response:
"""

prompt = ChatPromptTemplate.from_template(template)

# Load model with adjusted parameters
model = OllamaLLM(
    model="mistral",
    temperature=0.1,
    num_ctx=8192,
    top_p=0.8,
)


def format_conversation_history(history):
    formatted = ""
    for entry in history:
        formatted += f"{entry}\n"
    return formatted


# Create RAG chain with properly handled input types
def generate_response(question, history, retriever):
    # Get relevant documents
    context = retriever.invoke(question)
    context_str = "\n".join(doc.page_content for doc in context)

    # Format the conversation history
    history_str = format_conversation_history(history)

    # Prepare the input for the prompt
    chain_input = {"context": context_str, "history": history_str, "question": question}

    # Generate response using the prompt template and model
    response = prompt.format(**chain_input)
    response = model.invoke(response)

    return response


def chatbot_conversation():
    print("Hello! I'm an AI assistant. Type 'exit' to quit.")

    conversation_history = []

    while True:
        user_input = input("You: ").strip()
        if user_input.lower() == 'exit':
            break

        try:
            # Generate response
            result = generate_response(user_input, conversation_history, retriever)

            print(f"Assistant: {result}")

            # Store the exchange in history
            conversation_history.append(f"User: {user_input}")
            conversation_history.append(f"Assistant: {result}")

        except Exception as e:
            print(f"An error occurred: {str(e)}")
            print(
                "Assistant: I apologize, but I encountered an error. Please try again."
            )


chat_histories = {}


class ChatRequest(BaseModel):
    user_id: str  # Unique ID for tracking history per user
    message: str


@app.post("/chat")
def chat(request: ChatRequest):
    try:
        # Retrieve the user's conversation history or create a new one
        if request.user_id not in chat_histories:
            chat_histories[request.user_id] = []

        # Get conversation history
        history = chat_histories[request.user_id]

        # Generate response
        response = generate_response(request.message, history, retriever)

        # Update history
        history.append(f"User: {request.message}")
        history.append(f"Assistant: {response}")

        return {"response": response}
    except Exception as e:
        print(traceback.format_exc())
        raise HTTPException(status_code=500, detail=str(e))