Spaces:
Sleeping
Sleeping
| """UTN Student Chatbot — Gradio app with CRAG pipeline.""" | |
| import logging | |
| import re | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from prompt import REWRITE_PROMPT, build_chat_messages | |
| from retriever import Retriever | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| MODEL_ID = "saeedbenadeeb/UTN-Qwen3-0.6B-LoRA-merged" | |
| logger.info("Initializing retriever...") | |
| retriever = Retriever( | |
| faiss_index_path="faiss.index", | |
| chunks_meta_path="chunks_meta.jsonl", | |
| embedding_model="BAAI/bge-small-en-v1.5", | |
| top_k=5, | |
| ) | |
| logger.info("Loading model: %s", MODEL_ID) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=dtype, | |
| trust_remote_code=True, | |
| ).to(device) | |
| model.eval() | |
| logger.info("Model loaded.") | |
| def _generate(messages: list[dict], max_tokens: int = 512) -> str: | |
| prompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=0.3, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() | |
| def _grade_relevance(question: str, sources: list[dict]) -> bool: | |
| if not sources: | |
| return False | |
| top_score = sources[0].get("score", 0) | |
| q_tokens = set(re.findall(r"\w+", question.lower())) | |
| doc_tokens = set(re.findall(r"\w+", sources[0].get("text", "").lower())) | |
| stopwords = { | |
| "i", "a", "the", "is", "it", "to", "do", "if", "my", "can", "in", "of", | |
| "for", "and", "or", "at", "on", "no", "not", "what", "how", "when", "where", | |
| "who", "which", "this", "that", "be", "are", "was", "have", "has", | |
| } | |
| q_content = q_tokens - stopwords | |
| overlap = len(q_content & doc_tokens) / max(len(q_content), 1) | |
| return top_score >= 0.02 or overlap >= 0.35 | |
| def crag_answer(message: str, history: list[dict]) -> str: | |
| question = message.strip() | |
| if not question: | |
| return "Please ask a question about UTN." | |
| sources = retriever.retrieve(question) | |
| relevant = _grade_relevance(question, sources) | |
| if not relevant: | |
| rewrite_msgs = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": REWRITE_PROMPT.format(question=question)}, | |
| ] | |
| rewritten = _generate(rewrite_msgs, max_tokens=100) | |
| rewritten = rewritten.split("\n")[0].strip() | |
| if rewritten and rewritten != question: | |
| sources = retriever.retrieve(rewritten) | |
| context = retriever.format_context(sources) | |
| messages = build_chat_messages(question, context) | |
| answer = _generate(messages) | |
| return answer | |
| demo = gr.ChatInterface( | |
| fn=crag_answer, | |
| type="messages", | |
| title="UTN Student Chatbot", | |
| description="Ask questions about the University of Technology Nuremberg (UTN) — admissions, programs, courses, deadlines, and more. Powered by a finetuned Qwen3-0.6B with Corrective RAG.", | |
| examples=[ | |
| "What are the admission requirements for AI & Robotics?", | |
| "Are there tuition fees?", | |
| "What courses are in the first semester?", | |
| "Is there a Welcome Week?", | |
| "What TOEFL score do I need?", | |
| ], | |
| theme=gr.themes.Soft(), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |