dev6696 commited on
Commit
f788fb6
Β·
verified Β·
1 Parent(s): 9b06b81

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch, faiss, pickle, numpy as np
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
+ from sentence_transformers import SentenceTransformer
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ MODEL_REPO = "dev6696/edu-llm-llama3" # πŸ‘ˆ your repo
8
+
9
+ # ── Load RAG ───────────────────────────────────────────
10
+ index_path = hf_hub_download(MODEL_REPO, "faiss_index.bin")
11
+ chunks_path = hf_hub_download(MODEL_REPO, "chunks_meta.pkl")
12
+
13
+ index = faiss.read_index(index_path)
14
+ with open(chunks_path, "rb") as f:
15
+ store = pickle.load(f)
16
+ all_chunks = store["chunks"]
17
+ metadata = store["metadata"]
18
+
19
+ embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
20
+
21
+ def retrieve(query, top_k=3):
22
+ q_emb = embedder.encode([query], normalize_embeddings=True).astype("float32")
23
+ scores, indices = index.search(q_emb, top_k)
24
+ results = []
25
+ for score, idx in zip(scores[0], indices[0]):
26
+ if idx >= 0 and score > 0.3:
27
+ results.append(f"[{metadata[idx]['source']}]\n{all_chunks[idx]}")
28
+ return "\n\n---\n\n".join(results) if results else ""
29
+
30
+ # ── Load Model ─────────────────────────────────────────
31
+ bnb_config = BitsAndBytesConfig(
32
+ load_in_4bit=True,
33
+ bnb_4bit_quant_type="nf4",
34
+ bnb_4bit_compute_dtype=torch.bfloat16,
35
+ bnb_4bit_use_double_quant=True,
36
+ )
37
+
38
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
39
+ tokenizer.pad_token = tokenizer.eos_token
40
+
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ MODEL_REPO,
43
+ quantization_config=bnb_config,
44
+ device_map="auto",
45
+ torch_dtype=torch.bfloat16,
46
+ low_cpu_mem_usage=True,
47
+ )
48
+ model.eval()
49
+
50
+ # ── Inference ──────────────────────────────────────────
51
+ def answer(query, history):
52
+ context = retrieve(query)
53
+ system_msg = "You are an expert educational assistant."
54
+ if context:
55
+ system_msg += f"\n\nContext:\n{context}"
56
+
57
+ prompt = (
58
+ f"<|begin_of_text|>"
59
+ f"<|start_header_id|>system<|end_header_id|>\n{system_msg}\n<|eot_id|>"
60
+ f"<|start_header_id|>user<|end_header_id|>\n{query}\n<|eot_id|>"
61
+ f"<|start_header_id|>assistant<|end_header_id|>\n"
62
+ )
63
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to("cuda")
64
+ with torch.no_grad():
65
+ out = model.generate(
66
+ **inputs,
67
+ max_new_tokens=512,
68
+ temperature=0.7,
69
+ top_p=0.9,
70
+ do_sample=True,
71
+ repetition_penalty=1.1,
72
+ pad_token_id=tokenizer.eos_token_id,
73
+ )
74
+ decoded = tokenizer.decode(out[0], skip_special_tokens=True)
75
+ return decoded.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
76
+
77
+ # ── Gradio UI ──────────────────────────────────────────
78
+ with gr.Blocks(theme=gr.themes.Soft(), title="EduLLM") as demo:
79
+ gr.Markdown("# πŸ“š EduLLM β€” AI Educational Assistant")
80
+ gr.Markdown("Powered by Llama-3.1-1B + QLoRA + RAG")
81
+ chatbot = gr.ChatInterface(
82
+ fn=answer,
83
+ examples=["Explain Newton's second law", "What is photosynthesis?"],
84
+ cache_examples=False,
85
+ )
86
+
87
+ demo.launch()