|
import gradio as gr |
|
import os |
|
import time |
|
from datetime import datetime |
|
from zoneinfo import ZoneInfo |
|
from langchain_community.llms import HuggingFaceHub |
|
from langchain_community.retrievers import BM25Retriever |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_huggingface import ChatHuggingFace |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain_huggingface import HuggingFaceEndpoint |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import Chroma |
|
from langchain.retrievers import EnsembleRetriever |
|
from langchain_core.prompts import MessagesPlaceholder |
|
from langchain.chains import create_history_aware_retriever |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain.chains import create_retrieval_chain |
|
from langchain_core.messages import HumanMessage,AIMessage |
|
from langchain.tools.retriever import create_retriever_tool |
|
from langchain_groq import ChatGroq |
|
from transformers import pipeline |
|
|
|
token = os.getenv('gr_tkn') |
|
os.environ["GROQ_API_KEY"] = token |
|
def build_rag_chain(): |
|
pdfloader = PyPDFLoader('kuDoc.pdf') |
|
docs = pdfloader.load() |
|
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=0) |
|
texts = splitter.split_documents(docs) |
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") |
|
db = Chroma.from_documents(texts,embedding=embeddings) |
|
|
|
vector_retriever = db.as_retriever(search_type='similarity',search_kwargs = {'k':5}) |
|
keyword_retriever = BM25Retriever.from_documents(documents=texts,k=5) |
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
retrievers=[keyword_retriever, vector_retriever], |
|
weights=[0.5, 0.5] |
|
) |
|
|
|
llm = ChatGroq( |
|
model = "llama-3.1-8b-instant", |
|
temperature = 0, |
|
max_tokens = None, |
|
timeout = None, |
|
max_retries = 2, |
|
streaming = True |
|
) |
|
|
|
chat_model = llm |
|
|
|
memory_system_prompt = ( |
|
"You are ChatKU, an AI assistant that helps users learn more about Kenyatta University. " |
|
"Greet the user by saying: 'Hello my name is ChatKU, I can help you to get to know more about Kenyatta University, so how can I help you dear?' " |
|
"You should help reformulate follow-up questions into standalone questions. " |
|
"Given the chat history and the latest user message, rewrite the user’s message as a clear, self-contained question that incorporates all relevant context. " |
|
"Do not invent new information. " |
|
"If the user introduces themselves (e.g., 'Hello, I am Steve' or 'My name is Joy'), remember their name for the rest of the conversation. " |
|
"If the user later asks 'What is my name?' or similar, respond using the name they previously provided." |
|
"And don't keep saying their name while answering questions,you may only say it at the beginning and end of conversation" |
|
) |
|
memory_prompt = ChatPromptTemplate.from_messages([ |
|
('system',memory_system_prompt), |
|
MessagesPlaceholder('chat_history'), |
|
('human','{input}') |
|
]) |
|
|
|
system_prompt = ( |
|
"You are ChatKU, an AI assistant that helps users learn more about Kenyatta University." |
|
"Greet the user by saying: 'Hello my name is ChatKU, I can help you to get to know more about Kenyatta University, so how can I help you dear?' " |
|
"Remember the user's name when they introduce it (e.g., 'Hello, am Steve')" |
|
"And don't keep saying their name while answering questions,you may only say it at the beginning and end of conversation" |
|
"Use only the information provided in the context below. " |
|
"think like an agent before answering the question and give the correct answer" |
|
"Do not make up information or add external knowledge. " |
|
"If the answer cannot be found in the context, say so clearly. " |
|
"Keep your answers concise, natural, and friendly. " |
|
"Feel free to address the user by name if mentioned in the chat history. " |
|
"Avoid repeating long context word-for-word. " |
|
"Never start your answer with phrases like 'Based on the provided context'...or 'According to the information i have...' " |
|
"and don't include them anywhere in your answer" |
|
"You are free to use emojis" |
|
"Consider bolded or stylized text in the context as important keywords.\n\n" |
|
"Context:\n{context}\n\n" |
|
) |
|
|
|
qa_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", system_prompt), |
|
MessagesPlaceholder("chat_history"), |
|
("human", "{input}"), |
|
] |
|
) |
|
|
|
history_aware_retriever = create_history_aware_retriever(chat_model,ensemble_retriever,memory_prompt) |
|
|
|
question_answer_chain = create_stuff_documents_chain(chat_model, qa_prompt) |
|
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) |
|
|
|
return rag_chain |
|
|
|
rag_chain = build_rag_chain() |
|
chat_history = [] |
|
nairobi_time = datetime.now(ZoneInfo("Africa/Nairobi")) |
|
def get_greeting(): |
|
hour = nairobi_time.hour |
|
if 5 <= hour < 12: |
|
return "Good morning" |
|
elif 12 <= hour < 17: |
|
return "Good afternoon" |
|
elif 17<= hour < 22: |
|
return "Good evening" |
|
else: |
|
return "Good night" |
|
|
|
greeting = get_greeting() |
|
|
|
def chatku_fn(message, history): |
|
|
|
chat_history = [] |
|
for human, ai in history: |
|
chat_history.append(HumanMessage(content=human)) |
|
chat_history.append(AIMessage(content=ai)) |
|
|
|
|
|
response_stream = rag_chain.stream({ |
|
"input": message, |
|
"chat_history": chat_history |
|
}) |
|
|
|
partial_answer = "" |
|
for chunk in response_stream: |
|
delta = chunk.get("answer", "") |
|
partial_answer += delta |
|
|
|
time.sleep(0.06) |
|
yield partial_answer |
|
|
|
|
|
with gr.Blocks(fill_height = True) as demo: |
|
gr.Markdown( |
|
f"<h2 style='text-align: center;'>{greeting}!,any Queries about Kenyatta University comrade!?</h2>" |
|
) |
|
|
|
gr.ChatInterface( |
|
fn=chatku_fn, |
|
chatbot=gr.Chatbot(label="💬 ChatKU"), |
|
autoscroll=True |
|
) |
|
|
|
gr.Markdown( |
|
"⚠️ **ChatKU can make mistakes, verify important information.**", |
|
elem_id="footer" |
|
) |
|
if __name__ == "__main__": |
|
port = int(os.environ.get("PORT", 7860)) |
|
demo.launch(server_name="0.0.0.0", server_port=port) |