Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext | |
from llama_index.llms.huggingface import HuggingFaceLLM | |
from llama_index.core.prompts.prompts import SimpleInputPrompt | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from llama_index.embeddings.langchain import LangchainEmbedding | |
import torch | |
# Set the environment variable for HuggingFace token | |
os.environ["HF_TOKEN"] = Secret | |
# Streamlit app | |
st.title("PDF Data Extractor") | |
# File uploader for PDFs | |
uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True) | |
# Text input for the query prompt | |
user_query = st.text_input("Enter your query") | |
# Button to trigger processing | |
if st.button("Extract Data"): | |
if uploaded_files and user_query: | |
# Save uploaded files | |
for uploaded_file in uploaded_files: | |
with open(os.path.join("./docs", uploaded_file.name), "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
# Load documents from the specified directory | |
documents = SimpleDirectoryReader("./docs").load_data() | |
# Define system prompt for the LLM | |
system_prompt = """ | |
You are a data extractor. Your goal is to analyze the given PDF document and extract the table containing information relevant to the user query. | |
""" | |
# Define query wrapper prompt | |
query_wrapper_prompt = SimpleInputPrompt("{query_str}") | |
# Initialize the LLM | |
llm = HuggingFaceLLM( | |
context_window=4096, | |
max_new_tokens=256, | |
generate_kwargs={"temperature": 0.0, "do_sample": False}, | |
system_prompt=system_prompt, | |
query_wrapper_prompt=query_wrapper_prompt, | |
tokenizer_name="gemma-1.1-2b-it", | |
model_name="gemma-1.1-2b-it", | |
device_map="auto", | |
model_kwargs={"torch_dtype": torch.float16} | |
) | |
st.write("LLM download successful") | |
# Initialize the embedding model | |
embed_model = LangchainEmbedding( | |
HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
) | |
# Create service context with appropriate configurations | |
service_context = ServiceContext.from_defaults( | |
chunk_size=1024, | |
llm=llm, | |
embed_model=embed_model | |
) | |
st.write("Before Vector Index") | |
index = VectorStoreIndex.from_documents(documents, service_context=service_context) | |
st.write("After Vector Index") | |
# Create query engine from the index | |
query_engine = index.as_query_engine() | |
# Execute query | |
response = query_engine.query(user_query) | |
# Display response | |
st.write("Generated Response:") | |
st.write(response) | |
else: | |
st.error("Please upload PDF files and enter a query.") | |