Spaces:
Runtime error
Runtime error
import gradio as gr | |
from gradio_pdf import PDF | |
from qdrant_client import models, QdrantClient | |
from sentence_transformers import SentenceTransformer | |
from PyPDF2 import PdfReader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from ctransformers import AutoModelForCausalLM | |
# Load the embedding model | |
encoder = SentenceTransformer('jinaai/jina-embedding-b-en-v1') | |
print("Embedding model loaded...") | |
# Load the LLM | |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | |
llm = AutoModelForCausalLM.from_pretrained( | |
"TheBloke/Llama-2-7B-Chat-GGUF", | |
model_file="llama-2-7b-chat.Q3_K_S.gguf", | |
model_type="llama", | |
temperature=0.2, | |
repetition_penalty=1.5, | |
max_new_tokens=300, | |
) | |
print("LLM loaded...") | |
def get_chunks(text): | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=250, | |
chunk_overlap=50, | |
length_function=len, | |
) | |
return text_splitter.split_text(text) | |
def setup_database(files): | |
all_chunks = [] | |
for file in files: | |
reader = PdfReader(file) | |
text = "".join(page.extract_text() for page in reader.pages) | |
chunks = get_chunks(text) | |
all_chunks.extend(chunks) | |
client = QdrantClient(path="./db") | |
client.recreate_collection( | |
collection_name="my_facts", | |
vectors_config=models.VectorParams( | |
size=encoder.get_sentence_embedding_dimension(), | |
distance=models.Distance.COSINE, | |
), | |
) | |
records = [ | |
models.Record( | |
id=idx, | |
vector=encoder.encode(chunk).tolist(), | |
payload={f"chunk_{idx}": chunk} | |
) for idx, chunk in enumerate(all_chunks) | |
] | |
client.upload_records( | |
collection_name="my_facts", | |
records=records, | |
) | |
def answer_question(question): | |
client = QdrantClient(path="./db") | |
hits = client.search( | |
collection_name="my_facts", | |
query_vector=encoder.encode(question).tolist(), | |
limit=3 | |
) | |
context = " ".join(hit.payload[f"chunk_{hit.id}"] for hit in hits) | |
system_prompt = """You are a helpful co-worker, you will use the provided context to answer user questions. | |
Read the given context before answering questions and think step by step. If you cannot answer a user question based on | |
the provided context, inform the user. Do not use any other information for answering user. Provide a detailed answer to the question.""" | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
instruction = f"Context: {context}\nUser: {question}" | |
prompt_template = f"{B_INST}{B_SYS}{system_prompt}{E_SYS}{instruction}{E_INST}" | |
response = llm(prompt_template) | |
return response | |
def chat(messages, files): | |
if files: | |
setup_database(files) | |
if messages: | |
question = messages[-1]["text"] | |
answer = answer_question(question) | |
messages.append({"text": answer, "is_user": False}) | |
return messages | |
interface = gr.Interface( | |
fn=chat, | |
inputs=[ | |
gr.Chatbot(label="Chat"), | |
gr.File(label="Upload PDFs", file_count="multiple") | |
], | |
outputs=gr.Chatbot(label="Chat"), | |
title="Q&A with PDFs π©π»βπ»πβπ»π‘", | |
description="This app facilitates a conversation with PDFs uploadedπ‘", | |
theme="soft", | |
live=True, | |
) | |
interface.launch() | |