Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import chromadb | |
from chromadb.utils import embedding_functions | |
from huggingface_hub import login | |
# Set Hugging Face token and authenticate | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
login(token=HF_TOKEN) | |
# Configure cache directory (important for Hugging Face Spaces) | |
CACHE_DIR = "/tmp/huggingface_cache" | |
os.environ["HF_HOME"] = CACHE_DIR | |
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR | |
os.environ["HF_HUB_CACHE"] = CACHE_DIR | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# Define your fine-tuned model repository | |
MODEL_REPO = "soureesh1211/fine-tuned-gemma-2b" | |
CHROMA_DB_PATH = "./chroma_db" | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, token=HF_TOKEN) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_REPO, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
token=HF_TOKEN | |
) | |
return model, tokenizer | |
model, tokenizer = load_model() | |
def load_chroma_db(): | |
if not os.path.exists(CHROMA_DB_PATH): | |
st.error(f"ChromaDB directory {CHROMA_DB_PATH} not found. Please upload your database.") | |
return None | |
chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH) | |
collection = chroma_client.get_or_create_collection( | |
name="rag_collection", | |
embedding_function=embedding_functions.DefaultEmbeddingFunction() | |
) | |
return collection | |
collection = load_chroma_db() | |
# Streamlit Interface | |
st.title("π RAG-Powered Code Assistant") | |
st.subheader("Fine-tuned Gemma-2B + ChromaDB") | |
query = st.text_input("Enter your query:") | |
num_results = st.slider("Number of retrieved documents:", min_value=1, max_value=5, value=1) | |
if st.button("Search"): | |
if not query: | |
st.warning("Please enter a query.") | |
else: | |
with st.spinner("π Retrieving relevant context from ChromaDB..."): | |
results = collection.query(query_texts=[query], n_results=num_results) | |
retrieved_docs = results['documents'][0] if results else [] | |
context = "\n".join(retrieved_docs) | |
prompt = f"### Context:\n{context}\n\n### Query:\n{query}\n\n### Answer:" | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
output = model.generate(**inputs, max_length=512, temperature=0.3, top_p=0.9) | |
response = tokenizer.decode(output[0], skip_special_tokens=True) | |
st.success("β Generated Response:") | |
st.markdown(response) | |