fine-tune-rag / app.py
soureesh1211's picture
Update app.py
76ca765 verified
import os
import torch
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import chromadb
from chromadb.utils import embedding_functions
from huggingface_hub import login
# Set Hugging Face token and authenticate
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN)
# Configure cache directory (important for Hugging Face Spaces)
CACHE_DIR = "/tmp/huggingface_cache"
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HUB_CACHE"] = CACHE_DIR
os.makedirs(CACHE_DIR, exist_ok=True)
# Define your fine-tuned model repository
MODEL_REPO = "soureesh1211/fine-tuned-gemma-2b"
CHROMA_DB_PATH = "./chroma_db"
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
MODEL_REPO,
torch_dtype=torch.float16,
device_map="auto",
token=HF_TOKEN
)
return model, tokenizer
model, tokenizer = load_model()
@st.cache_resource
def load_chroma_db():
if not os.path.exists(CHROMA_DB_PATH):
st.error(f"ChromaDB directory {CHROMA_DB_PATH} not found. Please upload your database.")
return None
chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
collection = chroma_client.get_or_create_collection(
name="rag_collection",
embedding_function=embedding_functions.DefaultEmbeddingFunction()
)
return collection
collection = load_chroma_db()
# Streamlit Interface
st.title("πŸ” RAG-Powered Code Assistant")
st.subheader("Fine-tuned Gemma-2B + ChromaDB")
query = st.text_input("Enter your query:")
num_results = st.slider("Number of retrieved documents:", min_value=1, max_value=5, value=1)
if st.button("Search"):
if not query:
st.warning("Please enter a query.")
else:
with st.spinner("πŸ”Ž Retrieving relevant context from ChromaDB..."):
results = collection.query(query_texts=[query], n_results=num_results)
retrieved_docs = results['documents'][0] if results else []
context = "\n".join(retrieved_docs)
prompt = f"### Context:\n{context}\n\n### Query:\n{query}\n\n### Answer:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(**inputs, max_length=512, temperature=0.3, top_p=0.9)
response = tokenizer.decode(output[0], skip_special_tokens=True)
st.success("βœ… Generated Response:")
st.markdown(response)