|
import gradio as gr |
|
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration |
|
import fitz |
|
from datasets import load_dataset |
|
from llama_index.core import Document, VectorStoreIndex, StorageContext, load_index_from_storage, Settings |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
from llama_index.llms.ollama import Ollama |
|
|
|
|
|
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") |
|
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom", passages_path="my_knowledge_base.faiss") |
|
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) |
|
|
|
|
|
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") |
|
|
|
|
|
llm = Ollama(model="llama3:instruct", request_timeout=60.0) |
|
|
|
|
|
Settings.llm = llm |
|
Settings.chunk_size = 512 |
|
Settings.embed_model = embed_model |
|
|
|
|
|
def extract_text_from_pdf(pdf_files): |
|
texts = [] |
|
for pdf in pdf_files: |
|
doc = fitz.open(pdf.name) |
|
text = "" |
|
for page in doc: |
|
text += page.get_text() |
|
texts.append(text) |
|
return texts |
|
|
|
|
|
def rag_answer(question, pdf_files): |
|
texts = extract_text_from_pdf(pdf_files) |
|
context = " ".join(texts) |
|
inputs = tokenizer(question, return_tensors="pt") |
|
outputs = model.generate(**inputs, context_input=context) |
|
return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
|
|
|
def create_vector_store_index(documents): |
|
index = VectorStoreIndex.from_documents(documents) |
|
index.storage_context.persist(persist_dir="pdf_docs") |
|
return index |
|
|
|
|
|
pdf_docs = load_dataset('your-dataset-name', split='train') |
|
documents = [Document(text=row['text'], metadata={'title': row['title']}) for index, row in pdf_docs.iterrows()] |
|
|
|
|
|
try: |
|
storage_context = StorageContext.from_defaults(persist_dir="pdf_docs") |
|
vector_index = load_index_from_storage(storage_context) |
|
except: |
|
vector_index = create_vector_store_index(documents) |
|
|
|
|
|
query_engine = vector_index.as_query_engine(similarity_top_k=10) |
|
|
|
|
|
def query(text): |
|
z = query_engine.query(text) |
|
return z |
|
|
|
def interface(text): |
|
z = query(text) |
|
response = z.response |
|
return response |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Glass().set(block_title_text_color="black", body_background_fill="black", input_background_fill="black", body_text_color="white")) as demo: |
|
gr.Markdown("h1 {text-align: center;display: block;}Information Custodian Chat Agent") |
|
with gr.Row(): |
|
output_text = gr.Textbox(lines=20) |
|
with gr.Row(): |
|
input_text = gr.Textbox(label='Enter your query here') |
|
input_text.submit(fn=interface, inputs=input_text, outputs=output_text) |
|
|
|
demo.launch(share=True) |
|
|