Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain_community.document_loaders import PyPDFLoader, UnstructuredFileLoader, CSVLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# Load Hugging Face API token | |
HF_API_TOKEN = os.getenv("HF_API_TOKEN") | |
if not HF_API_TOKEN: | |
raise ValueError("Hugging Face API token is not set in environment variables.") | |
# Initialize Zephyr client | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=HF_API_TOKEN) | |
# Load documents based on file type | |
def load_documents(file_path): | |
if file_path.endswith(".pdf"): | |
loader = PyPDFLoader(file_path) | |
elif file_path.endswith(".txt") or file_path.endswith(".docx"): | |
loader = UnstructuredFileLoader(file_path) | |
elif file_path.endswith(".csv"): | |
loader = CSVLoader(file_path) | |
else: | |
raise ValueError("Unsupported file format") | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
return text_splitter.split_documents(documents) | |
# Create vector store | |
def create_vector_store(documents, persist_dir="vector_db"): | |
embeddings = HuggingFaceBgeEmbeddings( | |
model_name="BAAI/bge-large-en", | |
model_kwargs={"device": "cpu"}, | |
) | |
vector_store = Chroma.from_documents(documents, embeddings, persist_directory=persist_dir) | |
return vector_store | |
# Initialize retriever and vector store | |
persist_dir = "vector_db" | |
retriever = None # Will be dynamically updated | |
# Handle queries and uploads | |
def handle_query(message, history, system_message, max_tokens, temperature, top_p, file=None): | |
global retriever | |
if file: # If a file is uploaded, process it | |
documents = load_documents(file.name) | |
vector_store = create_vector_store(documents, persist_dir) | |
retriever = vector_store.as_retriever() | |
if not retriever: | |
return "No documents uploaded yet. Please upload a file first." | |
# Retrieve relevant context | |
relevant_docs = retriever.get_relevant_documents(message) | |
context = "\n".join([doc.page_content for doc in relevant_docs]) | |
# Build the prompt | |
prompt = f""" | |
Use the following context to answer the user's question. | |
Context: | |
{context} | |
Question: | |
{message} | |
Answer:""" | |
response = "" | |
for msg in client.chat_completion( | |
messages=[{"role": "system", "content": system_message}, {"role": "user", "content": prompt}], | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = msg.choices[0].delta.content | |
response += token | |
yield response | |
# Gradio app setup | |
demo = gr.Interface( | |
fn=handle_query, | |
inputs=[ | |
gr.File(label="Upload Document"), | |
gr.Textbox(value="You are a knowledgeable assistant.", label="System Message"), | |
gr.Textbox(label="Enter Your Query", placeholder="Ask a question..."), | |
gr.Slider(1, 2048, step=1, value=512, label="Max Tokens"), | |
gr.Slider(0.1, 4.0, step=0.1, value=0.7, label="Temperature"), | |
gr.Slider(0.1, 1.0, step=0.05, value=0.95, label="Top-p"), | |
], | |
outputs="text", | |
title="RAG with Zephyr-7B", | |
description="Upload documents and ask questions using RAG.", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |