RAG_QnA / app.py
akkaim's picture
Update app.py
d32b4e0 verified
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core.prompts.prompts import SimpleInputPrompt
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import gradio as gr
import torch
import os
import pickle
# Example setup for cache directories
os.environ['TRANSFORMERS_CACHE'] = './transformers_cache'
os.environ['TORCH_HOME'] = './torch_cache'
# Load or process data
processed_data_path = 'processed_data.pkl'
if os.path.exists(processed_data_path):
with open(processed_data_path, 'rb') as f:
documents = pickle.load(f)
else:
documents = SimpleDirectoryReader("Data").load_data()
with open(processed_data_path, 'wb') as f:
pickle.dump(documents, f)
# Define system and query wrapper prompts
system_prompt = "You are a Q&A assistant. Your goal is to answer questions as accurately as possible based on the instructions and context provided."
query_wrapper_prompt = SimpleInputPrompt("{query_str}")
# Initialize models and services without using Settings
llm = HuggingFaceLLM(
context_window=4096,
max_new_tokens=256,
generate_kwargs={"temperature": 0.0, "do_sample": False},
system_prompt=system_prompt,
query_wrapper_prompt=query_wrapper_prompt,
tokenizer_name="microsoft/phi-2",
model_name="microsoft/phi-2",
#device_map="cuda",
model_kwargs={"torch_dtype": torch.bfloat16}
)
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
# Assuming settings is no longer used, and VectorStoreIndex initialization doesn't require it
index_path = 'vector_store_index.pkl'
if os.path.exists(index_path):
with open(index_path, 'rb') as f:
index = pickle.load(f)
else:
# Note: Assuming that from_documents doesn't strictly require settings
# This needs to be aligned with actual llama_index API
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model, llm=llm)
with open(index_path, 'wb') as f:
pickle.dump(index, f)
query_engine = index.as_query_engine()
def predict(input_text):
response = query_engine.query(input_text)
return response
# Set up Gradio interface
iface = gr.Interface(fn=predict, inputs="text", outputs="text", title="Q&A Assistant")
iface.launch(share=True)