Heng2004's picture
Update model_utils.py
b781c31 verified
# model_utils.py
from typing import List, Optional
import re
import os
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import qa_store
from loader import (
load_curriculum,
load_manual_qa,
rebuild_combined_qa,
load_glossary,
sync_download_manual_qa, # <--- Import it
sync_download_cache, # <--- Add this import
sync_upload_cache, # <--- Add this import
CACHE_PATH # <--- Add this import
)
# -----------------------------
# Base chat model
# -----------------------------
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CACHE_FILE = os.path.join(BASE_DIR, "data", "cached_embeddings.pt")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Use float16 on GPU to save memory, float32 on CPU
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype)
model.to(device)
model.eval()
embed_model = SentenceTransformer(EMBED_MODEL_NAME)
embed_model = embed_model.to(device)
# Number of textbook entries to include in the RAG context
MAX_CONTEXT_ENTRIES = 4
# -----------------------------
# Embedding builders
# -----------------------------
# 👇👇👇 ADD THIS NEW FUNCTION 👇👇👇
def admin_force_rebuild_cache() -> str:
"""
Forcedly re-calculate all embeddings and upload to cloud.
Triggered by Teacher Panel button.
"""
status_msg = []
# 1. Compute Textbook
print("[ADMIN] Rebuilding Textbook Embeddings...")
texts = []
for e in qa_store.ENTRIES:
chapter = e.get("chapter_title", "") or ""
section = e.get("section_title", "") or ""
text = e.get("text", "") or ""
texts.append(f"{chapter}\n{section}\n{text}")
if texts:
qa_store.TEXT_EMBEDDINGS = embed_model.encode(texts, convert_to_tensor=True)
status_msg.append(f"✅ Textbook ({len(texts)})")
# 2. Compute Glossary
print("[ADMIN] Rebuilding Glossary Embeddings...")
gloss_texts = [f"{i.get('term')} :: {i.get('definition')}" for i in qa_store.GLOSSARY]
if gloss_texts:
qa_store.GLOSSARY_EMBEDDINGS = embed_model.encode(
gloss_texts, convert_to_numpy=True, normalize_embeddings=True
)
status_msg.append(f"✅ Glossary ({len(gloss_texts)})")
# 3. Save to Disk
print("[ADMIN] Saving to disk...")
torch.save({
"textbook": qa_store.TEXT_EMBEDDINGS,
"glossary": qa_store.GLOSSARY_EMBEDDINGS
}, CACHE_PATH)
# 4. Upload to Cloud
upload_status = sync_upload_cache()
return f"Rebuild Complete: {', '.join(status_msg)} | {upload_status}"
def _build_entry_embeddings() -> None:
"""
Load pre-computed embeddings if available, otherwise build them.
"""
if not getattr(qa_store, "ENTRIES", None):
qa_store.TEXT_EMBEDDINGS = None
return
# 1. Try Loading from Cache
if os.path.exists(CACHE_FILE):
try:
print(f"[INFO] Loading cached embeddings from {CACHE_FILE}...")
cache = torch.load(CACHE_FILE, map_location=device)
if "textbook" in cache and cache["textbook"] is not None:
# Validate size matches
if len(cache["textbook"]) == len(qa_store.ENTRIES):
qa_store.TEXT_EMBEDDINGS = cache["textbook"].to(device)
print("[INFO] Textbook embeddings loaded successfully.")
return
else:
print("[WARN] Cache size mismatch (Data changed?). Re-computing...")
except Exception as e:
print(f"[WARN] Failed to load cache: {e}")
# 2. Fallback: Compute from scratch (The old slow way)
print("[INFO] Computing textbook embeddings from scratch...")
texts: List[str] = []
for e in qa_store.ENTRIES:
chapter = e.get("chapter_title", "") or e.get("chapter", "") or ""
section = e.get("section_title", "") or e.get("section", "") or ""
text = e.get("text", "") or ""
combined = f"{chapter}\n{section}\n{text}"
texts.append(combined)
qa_store.TEXT_EMBEDDINGS = embed_model.encode(
texts,
convert_to_tensor=True,
show_progress_bar=False,
)
def _build_glossary_embeddings() -> None:
"""Create embeddings for glossary terms + definitions."""
if not getattr(qa_store, "GLOSSARY", None):
qa_store.GLOSSARY_EMBEDDINGS = None
print("[INFO] No glossary terms to embed.")
return
# Embed term + definition together
texts = [
f"{item.get('term', '')} :: {item.get('definition', '')}"
for item in qa_store.GLOSSARY
]
embeddings = embed_model.encode(
texts,
convert_to_numpy=True,
normalize_embeddings=True,
)
qa_store.GLOSSARY_EMBEDDINGS = embeddings
print(f"[INFO] Built glossary embeddings for {len(texts)} terms.")
# -----------------------------
# Load data once at import time
# -----------------------------
sync_download_manual_qa()
sync_download_cache() # <--- Add this line!
load_curriculum()
load_manual_qa()
load_glossary()
rebuild_combined_qa()
_build_entry_embeddings()
_build_glossary_embeddings()
# -----------------------------
# System prompt (Natural Science)
# -----------------------------
SYSTEM_PROMPT = (
"ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານວິທະຍາສາດທໍາມະຊາດ "
"ສໍາລັບນັກຮຽນຊັ້ນ ມ.1-ມ.4. "
"ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. "
"ໃຫ້ອີງຈາກຂໍ້ມູນອ້າງອີງຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
"ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
)
# -----------------------------
# Helper: history formatting
# -----------------------------
def _format_history(history: Optional[List]) -> str:
"""
Convert last few chat turns into a Lao conversation snippet
to give the model context for follow-up questions.
Gradio history format: [[user_msg, bot_msg], [user_msg, bot_msg], ...]
"""
if not history:
return ""
# keep only the last 3 turns to avoid very long prompts
recent = history[-3:]
lines: List[str] = []
for turn in recent:
if not isinstance(turn, (list, tuple)) or len(turn) != 2:
continue
user_msg, bot_msg = turn
lines.append(f"ນັກຮຽນ: {user_msg}")
lines.append(f"ອາຈານ AI: {bot_msg}")
if not lines:
return ""
joined = "\n".join(lines) + "\n\n"
return joined
# -----------------------------
# RAG: retrieve textbook context
# -----------------------------
def retrieve_context(question: str, max_entries: int = MAX_CONTEXT_ENTRIES) -> str:
"""
Embedding-based retrieval over textbook entries.
Falls back to concatenated raw knowledge if embeddings are missing.
"""
if not getattr(qa_store, "ENTRIES", None):
# Fallback: raw knowledge (if available) or empty string
return getattr(qa_store, "RAW_KNOWLEDGE", "")
if qa_store.TEXT_EMBEDDINGS is None:
top_entries = qa_store.ENTRIES[:max_entries]
else:
# 1) Encode the question
q_vec = embed_model.encode(
question,
convert_to_tensor=True,
show_progress_bar=False,
)
# 2) Cosine similarity with all entry embeddings
sims = cos_sim(q_vec, qa_store.TEXT_EMBEDDINGS)[0] # shape [N]
# 3) Take top-k
top_indices = torch.topk(sims, k=min(max_entries, sims.shape[0])).indices
top_entries = [qa_store.ENTRIES[i] for i in top_indices.tolist()]
# Build context string for the prompt
context_blocks: List[str] = []
for e in top_entries:
header = (
f"[ຊັ້ນ {e.get('grade','')}, "
f"ໜ່ວຍ {e.get('unit','')}, "
f"ບົດ {e.get('chapter_title','')}, "
f"ຫົວຂໍ້ {e.get('section_title','')}]"
)
context_blocks.append(f"{header}\n{e.get('text','')}")
return "\n\n".join(context_blocks)
# -----------------------------
# Glossary-based answering
# -----------------------------
def answer_from_glossary(message: str) -> Optional[str]:
"""
Try to answer using the glossary index.
Priority 1: Exact string match of the Term inside the user's message.
Priority 2: Vector embedding match (if confidence is high).
"""
if not getattr(qa_store, "GLOSSARY", None):
return None
# --- FIX START: Check for EXACT term match first ---
# This fixes the issue where "What is Science" matches "Pollution"
# just because "Pollution" definition contains the word "Science".
normalized_msg = message.lower().strip()
for item in qa_store.GLOSSARY:
term = item.get("term", "").lower().strip()
# If the specific term appears in the message (e.g. "Science" in "What is Science?")
if term and term in normalized_msg:
# Optional: Check if the message is SHORT (so we don't trigger on long sentences accidentally)
if len(normalized_msg) < len(term) + 20:
definition = item.get("definition", "").strip()
example = item.get("example", "").strip()
if example:
return f"{definition} ຕົວຢ່າງ: {example}"
return definition
# --- FIX END ---
# If no exact text match, proceed to Vector Similarity (the old code)
if qa_store.GLOSSARY_EMBEDDINGS is None:
return None
q_emb = embed_model.encode(
[message],
convert_to_numpy=True,
normalize_embeddings=True,
)[0]
sims = np.dot(qa_store.GLOSSARY_EMBEDDINGS, q_emb)
best_idx = int(np.argmax(sims))
best_sim = float(sims[best_idx])
# INCREASE THRESHOLD:
# Raised from 0.55 to 0.65 to prevent weak matches (like Science matching Pollution)
if best_sim < 0.65:
return None
item = qa_store.GLOSSARY[best_idx]
definition = item.get("definition", "").strip()
example = item.get("example", "").strip()
if example:
return f"{definition} ຕົວຢ່າງ: {example}"
else:
return definition
# -----------------------------
# Prompt + LLM generation
# -----------------------------
def build_prompt(question: str, history: Optional[List] = None) -> str:
context = retrieve_context(question, max_entries=MAX_CONTEXT_ENTRIES)
history_block = _format_history(history)
return f"""{SYSTEM_PROMPT}
{history_block}ຂໍ້ມູນອ້າງອີງ:
{context}
ຄຳຖາມ: {question}
ຄຳຕອບດ້ວຍພາສາລາວ:"""
def generate_answer(question: str, history: Optional[List] = None) -> str:
prompt = build_prompt(question, history)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=160,
do_sample=False,
)
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
# Enforce 2–3 sentence answers for students
sentences = re.split(r"(?<=[\.?!…])\s+", answer)
short_answer = " ".join(sentences[:3]).strip()
return short_answer if short_answer else answer
# -----------------------------
# QA lookup (exact + fuzzy)
# -----------------------------
def answer_from_qa(question: str) -> Optional[str]:
"""
1) Exact match in QA_INDEX
2) Fuzzy match via word overlap with ALL_QA_KNOWLEDGE
"""
norm_q = qa_store.normalize_question(question)
if not norm_q:
return None
# Exact match
if norm_q in qa_store.QA_INDEX:
return qa_store.QA_INDEX[norm_q]
# Fuzzy match
q_terms = [t for t in norm_q.split(" ") if len(t) > 1]
if not q_terms:
return None
best_score = 0
best_answer: Optional[str] = None
for item in qa_store.ALL_QA_KNOWLEDGE:
stored_terms = [t for t in item["norm_q"].split(" ") if len(t) > 1]
overlap = sum(1 for t in q_terms if t in stored_terms)
if overlap > best_score:
best_score = overlap
best_answer = item["a"]
# require at least 2 overlapping words to accept fuzzy match
if best_score >= 2 and best_answer is not None:
# optional: log when fuzzy match is used
print(f"[FUZZY MATCH] score={best_score} -> {best_answer[:50]!r}")
return best_answer
return None
# -----------------------------
# Main chatbot entry
# -----------------------------
def laos_science_bot(message: str, history: List) -> str:
"""
Main chatbot function for Student tab (Gradio ChatInterface).
"""
if not message.strip():
return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."
# 0) Try glossary first for key concepts
gloss = answer_from_glossary(message)
if gloss:
return gloss
# 1) Try exact / fuzzy Q&A first
direct = answer_from_qa(message)
if direct:
return direct
# 2) Fall back to LLM + retrieved context
try:
answer = generate_answer(message, history)
except Exception as e: # noqa: BLE001
return f"ລະບົບມີບັນຫາ: {e}"
return answer