Mattral's picture
Update app.py
4c6bffd verified
raw
history blame
4.02 kB
import gradio as gr
from qdrant_client import models, QdrantClient
from sentence_transformers import SentenceTransformer
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from ctransformers import AutoModelForCausalLM
# Load the embedding model
encoder = SentenceTransformer('jinaai/jina-embedding-b-en-v1')
print("Embedding model loaded...")
# Load the LLM
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
llm = AutoModelForCausalLM.from_pretrained(
"TheBloke/Llama-2-7B-Chat-GGUF",
model_file="llama-2-7b-chat.Q3_K_S.gguf",
model_type="llama",
temperature=0.2,
repetition_penalty=1.5,
max_new_tokens=300,
)
print("LLM loaded...")
# Initialize QdrantClient
client = QdrantClient(path="./db")
print("DB created...")
def get_chunks(text):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=250,
chunk_overlap=50,
length_function=len,
)
return text_splitter.split_text(text)
def setup_database(files):
all_chunks = []
for file in files:
reader = PdfReader(file)
text = "".join(page.extract_text() for page in reader.pages)
chunks = get_chunks(text)
all_chunks.extend(chunks)
print(f"Total chunks: {len(all_chunks)}")
print("Chunks are ready...")
client.recreate_collection(
collection_name="my_facts",
vectors_config=models.VectorParams(
size=encoder.get_sentence_embedding_dimension(),
distance=models.Distance.COSINE,
),
)
print("Collection created...")
records = [
models.Record(
id=idx,
vector=encoder.encode(chunk).tolist(),
payload={"text": chunk}
) for idx, chunk in enumerate(all_chunks)
]
client.upload_records(
collection_name="my_facts",
records=records,
)
print("Records uploaded...")
def answer(question):
hits = client.search(
collection_name="my_facts",
query_vector=encoder.encode(question).tolist(),
limit=3
)
context = " ".join(hit.payload["text"] for hit in hits)
system_prompt = """You are a helpful co-worker, you will use the provided context to answer user questions.
Read the given context before answering questions and think step by step. If you cannot answer a user question based on
the provided context, inform the user. Do not use any other information for answering user. Provide a detailed answer to the question."""
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
instruction = f"Context: {context}\nUser: {question}"
prompt_template = f"{B_INST}{B_SYS}{system_prompt}{E_SYS}{instruction}{E_INST}"
print(prompt_template)
result = llm(prompt_template)
return result
def chat(messages, files):
if files:
setup_database(files)
if not messages:
return "Please upload PDF documents to initialize the database."
last_message = messages[-1]["content"]
response = answer(last_message)
messages.append({"role": "assistant", "content": response})
return messages
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
file_input = gr.File(label="Upload PDFs", file_count="multiple")
with gr.Row():
with gr.Column(scale=0.85):
txt = gr.Textbox(show_label=False, placeholder="Enter your question here...").style(container=False)
with gr.Column(scale=0.15, min_width=0):
send_btn = gr.Button("Send")
def respond(messages, files, txt):
messages = chat(messages, files)
return messages, None, ""
send_btn.click(respond, [chatbot, file_input, txt], [chatbot, file_input, txt])
txt.submit(respond, [chatbot, file_input, txt], [chatbot, file_input, txt])
demo.launch()