samyakshrestha commited on
Commit
0f7b282
Β·
1 Parent(s): bf951a0

First commit

Browse files
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. Dockerfile +16 -0
  3. app.py +181 -0
  4. requirements.txt +12 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # System deps for git-lfs (model pulls) and faster tokenization wheels
4
+ RUN apt-get update && apt-get install -y git-lfs && git lfs install
5
+
6
+ WORKDIR /app
7
+
8
+ # Install Python deps first for cache efficiency
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ # Copy application code
13
+ COPY . .
14
+
15
+ EXPOSE 7860
16
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------
2
+ # app.py
3
+ # FastAPI + Gradio hybrid RAG service
4
+ # (c) Samyak Shrestha β€” 2025
5
+ # ------------------------------------------------------------------
6
+
7
+ import os, json, time
8
+ from pathlib import Path
9
+ from typing import List
10
+
11
+ import torch
12
+ from fastapi import FastAPI
13
+ from pydantic import BaseModel
14
+ import gradio as gr
15
+
16
+ from transformers import (
17
+ AutoTokenizer,
18
+ AutoModelForCausalLM,
19
+ BitsAndBytesConfig,
20
+ )
21
+
22
+ from huggingface_hub import hf_hub_download
23
+ import faiss
24
+ from sentence_transformers import SentenceTransformer
25
+ import numpy as np
26
+
27
+ # ------------------------------------------------------------------
28
+ # Configuration
29
+ # ------------------------------------------------------------------
30
+ HF_MODEL_ID = "samyakshrestha/merged-finetuned-mistral" # weights + FAISS live here
31
+ EMBED_MODEL = "BAAI/bge-base-en-v1.5"
32
+
33
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
+ TOP_K = 5
35
+ CTX_TOKEN_LIMIT = 2048
36
+ MAX_NEW_TOKENS = 256
37
+
38
+ DATA_DIR = Path("data") # cached at runtime
39
+ DATA_DIR.mkdir(exist_ok=True)
40
+
41
+ FAISS_BIN_NAME = "data/faiss_index/faiss_index.bin"
42
+ META_JSON_NAME = "data/faiss_index/chunk_metadata.json"
43
+ INDEX_PATH = DATA_DIR / "faiss_index.bin"
44
+ META_PATH = DATA_DIR / "chunk_metadata.json"
45
+
46
+ # ------------------------------------------------------------------
47
+ # 1) Embedding model
48
+ # ------------------------------------------------------------------
49
+ print("Loading embedding model …")
50
+ embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
51
+ embed_dim = embedder.get_sentence_embedding_dimension()
52
+ print(f"{EMBED_MODEL} ({embed_dim}-d vectors)")
53
+
54
+ # ------------------------------------------------------------------
55
+ # 2) Download / load FAISS index + metadata
56
+ # ------------------------------------------------------------------
57
+ def download_assets():
58
+ if not INDEX_PATH.exists():
59
+ print("Downloading FAISS index from Hub …")
60
+ hf_hub_download(
61
+ repo_id = HF_MODEL_ID,
62
+ filename = FAISS_BIN_NAME,
63
+ local_dir = DATA_DIR,
64
+ local_dir_use_symlinks=False,
65
+ )
66
+ if not META_PATH.exists():
67
+ print("Downloading metadata …")
68
+ hf_hub_download(
69
+ repo_id = HF_MODEL_ID,
70
+ filename = META_JSON_NAME,
71
+ local_dir = DATA_DIR,
72
+ local_dir_use_symlinks=False,
73
+ )
74
+
75
+ download_assets()
76
+
77
+ print("Loading FAISS index …")
78
+ index = faiss.read_index(str(INDEX_PATH))
79
+ with open(META_PATH) as f:
80
+ chunk_metadata = json.load(f)
81
+ assert index.ntotal == len(chunk_metadata), "Index / metadata size mismatch"
82
+ print(f"vectors = {index.ntotal}")
83
+
84
+ # ------------------------------------------------------------------
85
+ # 3) Load language model (4-bit if bitsandbytes is available)
86
+ # ------------------------------------------------------------------
87
+ print("Loading LoRA-fine-tuned Mistral …")
88
+
89
+ bnb_cfg = None
90
+ try:
91
+ import bitsandbytes # noqa: F401
92
+ bnb_cfg = BitsAndBytesConfig(
93
+ load_in_4bit=True,
94
+ bnb_4bit_compute_dtype=torch.float16,
95
+ bnb_4bit_use_double_quant=True,
96
+ bnb_4bit_quant_type="nf4",
97
+ )
98
+ print("bitsandbytes detected β†’ 4-bit quant enabled")
99
+ except ImportError:
100
+ print("bitsandbytes not found β†’ loading in fp16 / fp32")
101
+
102
+ tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, use_fast=True)
103
+ model = AutoModelForCausalLM.from_pretrained(
104
+ HF_MODEL_ID,
105
+ device_map="auto" if DEVICE == "cuda" else None,
106
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
107
+ quantization_config=bnb_cfg,
108
+ )
109
+ model.eval()
110
+ print("model ready")
111
+
112
+ # ------------------------------------------------------------------
113
+ # 4) Retrieval & Generation helpers
114
+ # ------------------------------------------------------------------
115
+ def retrieve_chunks(query: str, k: int = TOP_K) -> List[dict]:
116
+ emb = embedder.encode([query], normalize_embeddings=True)
117
+ _, idxs = index.search(emb, k)
118
+ return [chunk_metadata[int(i)] for i in idxs[0]]
119
+
120
+ def build_prompt(query: str, chunks: List[dict]) -> str:
121
+ ctx_blocks, total_tokens = [], 0
122
+ for ch in chunks:
123
+ block = f"[{ch['title']}]\n{ch['text']}\n"
124
+ toks = len(tokenizer.tokenize(block))
125
+ if total_tokens + toks <= CTX_TOKEN_LIMIT:
126
+ ctx_blocks.append(block)
127
+ total_tokens += toks
128
+ context = "\n\n".join(ctx_blocks)
129
+ return (
130
+ "You are an expert scientific assistant. "
131
+ "Use the excerpts to answer.\n\n"
132
+ f"Excerpts:\n{context}\n\n"
133
+ f"Question: {query}\nAnswer:"
134
+ )
135
+
136
+ @torch.inference_mode()
137
+ def generate_answer(query: str) -> str:
138
+ prompt = build_prompt(query, retrieve_chunks(query))
139
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
140
+ output = model.generate(
141
+ **inputs,
142
+ max_new_tokens=MAX_NEW_TOKENS,
143
+ do_sample=False,
144
+ top_p=1.0,
145
+ )
146
+ return (
147
+ tokenizer.decode(output[0], skip_special_tokens=True)
148
+ .split("Answer:")[-1]
149
+ .strip()
150
+ )
151
+
152
+ # ------------------------------------------------------------------
153
+ # 5) FastAPI backend
154
+ # ------------------------------------------------------------------
155
+ api = FastAPI(title="Finetuned Mistral RAG API")
156
+
157
+ class Question(BaseModel):
158
+ question: str
159
+
160
+ class Answer(BaseModel):
161
+ answer: str
162
+
163
+ @api.post("/rag", response_model=Answer)
164
+ def rag_endpoint(item: Question):
165
+ return Answer(answer=generate_answer(item.question))
166
+
167
+ # ------------------------------------------------------------------
168
+ # 6) Gradio chat UI
169
+ # ------------------------------------------------------------------
170
+ demo = gr.Interface(
171
+ fn = generate_answer,
172
+ inputs = gr.Textbox(label="Ask a question about LLM fine-tuning"),
173
+ outputs = gr.Textbox(label="Answer"),
174
+ title = "Finetuned Mistral-7B β€” Retrieval-Augmented QA",
175
+ )
176
+
177
+ # ------------------------------------------------------------------
178
+ # 7) Launch (Spaces exposes port 7860)
179
+ # ------------------------------------------------------------------
180
+ if __name__ == "__main__":
181
+ demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.110.1
2
+ uvicorn[standard]==0.29.0
3
+ transformers==4.40.1
4
+ huggingface_hub==0.23.0
5
+ sentence-transformers==2.7.0
6
+ faiss-cpu==1.7.4
7
+ torch==2.2.2
8
+ gradio==4.24.0
9
+ pydantic>=2.6
10
+ numpy
11
+ bitsandbytes ; sys_platform == 'linux' # only installs on Linux/GPU
12
+ accelerate # optional, speeds HF model I/O