Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,145 +1,92 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import re
|
| 3 |
-
import zipfile
|
| 4 |
-
import gradio as gr
|
| 5 |
-
from langchain_openai import ChatOpenAI
|
| 6 |
-
from langchain.embeddings import HuggingFaceEmbeddings
|
| 7 |
-
from langchain_chroma import Chroma
|
| 8 |
-
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
|
| 9 |
-
from langchain.chains import LLMChain
|
| 10 |
-
|
| 11 |
-
# Unzip vector DB if not already extracted
|
| 12 |
-
if not os.path.exists("geometry_chroma"):
|
| 13 |
-
with zipfile.ZipFile("geometry_chroma.zip", 'r') as zip_ref:
|
| 14 |
-
zip_ref.extractall(".")
|
| 15 |
-
|
| 16 |
-
# Load vector DB
|
| 17 |
-
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 18 |
-
vectordb = Chroma(persist_directory="geometry_chroma", embedding_function=embedding_model)
|
| 19 |
-
retriever = vectordb.as_retriever()
|
| 20 |
-
|
| 21 |
-
# Set OpenAI key (use Secrets or .env later)
|
| 22 |
-
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
| 23 |
-
|
| 24 |
-
llm = ChatOpenAI(model_name="gpt-4.1", temperature=0.2)
|
| 25 |
-
|
| 26 |
-
# ✅ Prompt templates
|
| 27 |
-
templates = {
|
| 28 |
-
"flashcard": PromptTemplate(
|
| 29 |
-
input_variables=["context", "query"],
|
| 30 |
-
template="""
|
| 31 |
-
{context}
|
| 32 |
-
|
| 33 |
-
Create 5 flashcards based on the topic: "{query}"
|
| 34 |
-
Each flashcard should include:
|
| 35 |
-
- A clear question
|
| 36 |
-
- A short answer
|
| 37 |
-
Focus on high school geometry understanding.
|
| 38 |
-
"""
|
| 39 |
-
),
|
| 40 |
-
"lesson plan": PromptTemplate(
|
| 41 |
-
input_variables=["context", "query"],
|
| 42 |
-
template="""
|
| 43 |
-
Given the following retrieved SOL text:
|
| 44 |
-
{context}
|
| 45 |
-
|
| 46 |
-
Generate a Geometry lesson plan based on: "{query}"
|
| 47 |
-
Include:
|
| 48 |
-
1. Simple explanation of the concept.
|
| 49 |
-
2. Real-world example.
|
| 50 |
-
3. Engaging class activity.
|
| 51 |
-
Be concise and curriculum-aligned for high school.
|
| 52 |
-
"""
|
| 53 |
-
),
|
| 54 |
-
"worksheet": PromptTemplate(
|
| 55 |
-
input_variables=["context", "query"],
|
| 56 |
-
template="""
|
| 57 |
-
{context}
|
| 58 |
-
|
| 59 |
-
Create a student worksheet for: "{query}"
|
| 60 |
-
Include:
|
| 61 |
-
- Concept summary
|
| 62 |
-
- A worked example
|
| 63 |
-
- 3 practice problems
|
| 64 |
-
"""
|
| 65 |
-
),
|
| 66 |
-
"proofs": PromptTemplate(
|
| 67 |
-
input_variables=["context", "query"],
|
| 68 |
-
template="""
|
| 69 |
-
{context}
|
| 70 |
-
|
| 71 |
-
Generate a proof-focused geometry lesson plan for: "{query}"
|
| 72 |
-
Include:
|
| 73 |
-
- Student-friendly explanation
|
| 74 |
-
- Real-world connection
|
| 75 |
-
- One short class activity
|
| 76 |
-
"""
|
| 77 |
-
),
|
| 78 |
-
"general question": ChatPromptTemplate.from_messages([
|
| 79 |
-
HumanMessagePromptTemplate.from_template(
|
| 80 |
-
"""
|
| 81 |
-
You are a Virginia Geometry SOL assistant.
|
| 82 |
-
|
| 83 |
-
From the following SOL context:
|
| 84 |
-
{context}
|
| 85 |
-
|
| 86 |
-
Identify the SOL standard (e.g., G.RLT.1) that best matches this query: "{query}"
|
| 87 |
-
|
| 88 |
-
Respond with:
|
| 89 |
-
1. The exact SOL code (e.g., G.RLT.1)
|
| 90 |
-
2. The exact description line from the SOL guide
|
| 91 |
-
|
| 92 |
-
Do not summarize. Only copy from the context.
|
| 93 |
-
"""
|
| 94 |
-
)
|
| 95 |
-
])
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
def generate_prompt_output(prompt_type, query, retriever, llm):
|
| 102 |
-
try:
|
| 103 |
-
import re
|
| 104 |
-
sol_match = re.search(r"\bG\.[A-Z]+\.\d+\b", query)
|
| 105 |
-
matched_code = sol_match.group(0) if sol_match else None
|
| 106 |
-
|
| 107 |
-
if matched_code:
|
| 108 |
-
all_docs = retriever.vectorstore._collection.get(include=['documents', 'metadatas'])
|
| 109 |
-
filtered = []
|
| 110 |
-
for doc_text, metadata in zip(all_docs['documents'], all_docs['metadatas']):
|
| 111 |
-
if metadata.get('standard') == matched_code:
|
| 112 |
-
filtered.append(doc_text)
|
| 113 |
-
|
| 114 |
-
context = "\n\n".join(filtered)
|
| 115 |
-
else:
|
| 116 |
-
docs = retriever.get_relevant_documents(query)
|
| 117 |
-
context = "\n\n".join([doc.page_content for doc in docs])
|
| 118 |
-
|
| 119 |
-
chain = LLMChain(llm=llm, prompt=templates[prompt_type])
|
| 120 |
-
return chain.run({"context": context, "query": query}).strip()
|
| 121 |
-
|
| 122 |
-
except Exception as e:
|
| 123 |
-
return f"❌ Error: {str(e)}"
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
# ✅ Gradio UI
|
| 129 |
-
with gr.Blocks() as demo:
|
| 130 |
-
gr.Markdown("# 📐 Geometry Teaching Assistant")
|
| 131 |
-
|
| 132 |
-
with gr.Row():
|
| 133 |
-
query = gr.Textbox(label="Enter a geometry topic")
|
| 134 |
-
prompt_type = gr.Dropdown(
|
| 135 |
-
["general question", "lesson plan", "worksheet", "proofs", "flashcard"],
|
| 136 |
-
value="general question",
|
| 137 |
-
label="Prompt Type"
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
output = gr.Textbox(label="Generated Output", lines=12, interactive=True)
|
| 141 |
-
btn = gr.Button("Generate")
|
| 142 |
-
|
| 143 |
-
btn.click(fn=generate_prompt_output, inputs=[prompt_type, query], outputs=output)
|
| 144 |
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import gradio as gr
|
| 3 |
import os
|
| 4 |
+
from transformers import pipeline
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
import faiss
|
| 7 |
+
import numpy as np
|
| 8 |
+
import json
|
| 9 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
# --- Load necessary components for the RAG system ---
|
| 12 |
+
# These paths are relative to the Space's root directory
|
| 13 |
+
FAISS_INDEX_PATH = "sol_faiss_index.bin"
|
| 14 |
+
DOCUMENT_IDS_PATH = "sol_document_ids.json"
|
| 15 |
+
|
| 16 |
+
# Load SentenceTransformer model
|
| 17 |
+
# Ensure this model is downloaded or available in the environment
|
| 18 |
+
# For Spaces, you might need to add it to requirements.txt or directly download if space has internet
|
| 19 |
+
# It's better to declare it globally or as a shared resource.
|
| 20 |
+
try:
|
| 21 |
+
model = SentenceTransformer('all-mpnet-base-v2')
|
| 22 |
+
except Exception as e:
|
| 23 |
+
print(f"Error loading SentenceTransformer model: {e}")
|
| 24 |
+
print("Attempting to load from local cache or download on first use.")
|
| 25 |
+
# If running in a Space, the model will be downloaded to cache if not present.
|
| 26 |
+
# Ensure you have internet access in your Space settings.
|
| 27 |
+
|
| 28 |
+
# Load FAISS index
|
| 29 |
+
try:
|
| 30 |
+
index = faiss.read_index(FAISS_INDEX_PATH)
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"Error loading FAISS index: {e}")
|
| 33 |
+
# Handle error, maybe create a dummy index or exit
|
| 34 |
+
index = None # Placeholder if loading fails
|
| 35 |
+
|
| 36 |
+
# Load document IDs
|
| 37 |
+
try:
|
| 38 |
+
with open(DOCUMENT_IDS_PATH, "r") as f:
|
| 39 |
+
document_ids = json.load(f)
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"Error loading document IDs: {e}")
|
| 42 |
+
document_ids = [] # Placeholder if loading fails
|
| 43 |
+
|
| 44 |
+
# Placeholder for the actual content of "10 Geometry Mathematics Instructional Guide.pdf"
|
| 45 |
+
# In a real deployed scenario, this content would be loaded from a file
|
| 46 |
+
# that you upload to your Hugging Face Space or fetched at runtime.
|
| 47 |
+
# For now, we'll assume it's available or that 'documents' are pre-processed and loaded.
|
| 48 |
+
# You would typically load the 'documents' list created in Step 2 here.
|
| 49 |
+
# For deployment, it's best to save the `documents` list (sol_data) as a JSON
|
| 50 |
+
# and load it back. Let's add that.
|
| 51 |
+
|
| 52 |
+
# Assuming you've saved sol_data as 'sol_documents.json'
|
| 53 |
+
SOL_DOCUMENTS_PATH = "sol_documents.json"
|
| 54 |
+
try:
|
| 55 |
+
with open(SOL_DOCUMENTS_PATH, "r") as f:
|
| 56 |
+
documents = json.load(f)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"Error loading sol documents: {e}")
|
| 59 |
+
documents = [] # Placeholder
|
| 60 |
+
|
| 61 |
+
# Load LLM for generation
|
| 62 |
+
# For a Hugging Face Space, you need to ensure the model is available.
|
| 63 |
+
# 'google/gemma-2b-it' is a good option.
|
| 64 |
+
# Ensure you set up environment variables or secrets for API keys if using paid models.
|
| 65 |
+
try:
|
| 66 |
+
llm_pipeline = pipeline("text-generation", model="google/gemma-2b-it")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(f"Error loading LLM pipeline: {e}")
|
| 69 |
+
llm_pipeline = None # Placeholder
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def retrieve_and_generate_app(query, top_k=3):
|
| 73 |
+
if not model or not index or not document_ids or not documents or not llm_pipeline:
|
| 74 |
+
return "System not fully initialized. Please check logs for missing components."
|
| 75 |
+
|
| 76 |
+
# 1. Query Embedding
|
| 77 |
+
query_embedding = model.encode([query])
|
| 78 |
+
|
| 79 |
+
# 2. Retrieval using FAISS
|
| 80 |
+
D, I = index.search(query_embedding, top_k)
|
| 81 |
+
|
| 82 |
+
retrieved_docs = []
|
| 83 |
+
for i in I[0]:
|
| 84 |
+
sol_id = document_ids[i]
|
| 85 |
+
# Find the full content of the retrieved SOL
|
| 86 |
+
# This relies on the 'documents' list being correctly loaded and matching by ID
|
| 87 |
+
retrieved_content = next((doc["content"] for doc in documents if doc["id"] == sol_id), "Content not found.")
|
| 88 |
+
retrieved_docs.append({"id": sol_id, "content": retrieved_content})
|
| 89 |
+
|
| 90 |
+
context = "\n\n".join([f"SOL {doc['id']}: {doc['content']}" for doc in retrieved_docs])
|
| 91 |
+
|
| 92 |
+
prompt = f"""
|