fsojni commited on
Commit
fef5c81
·
verified ·
1 Parent(s): 4872cd0

REEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE

Browse files
Files changed (1) hide show
  1. app.py +39 -10
app.py CHANGED
@@ -17,6 +17,7 @@ from collections import defaultdict
17
  HF_TOKEN = os.getenv("HF_token")
18
  CHAT_MODEL_ID = "QWen/Qwen1.5-7B-Chat"
19
  EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"
 
20
 
21
  # --- lazy loaders (unchanged) -------------------------------------------------
22
  tokenizer, chat_model = None, None
@@ -42,11 +43,11 @@ def load_embedder():
42
 
43
  @torch.no_grad()
44
  def embed(text:str)->torch.Tensor:
45
- """Return L2-normalised embedding vector."""
46
  load_embedder()
47
- inputs = emb_tokenizer(text, return_tensors="pt", truncation=True).to(emb_model.device)
48
- vec = emb_model(**inputs).last_hidden_state[:, 0] # CLS pooling
49
- return F.normalize(vec, dim=-1).squeeze(0)
 
50
 
51
  # ---------- 2. tiny in-memory KB shared by Gradio & API ----------------------
52
  # ---------- 2. Tiny in-memory knowledge-base -------------------------------
@@ -67,7 +68,7 @@ def add_docs(user_id: str, docs: list[str]) -> int:
67
  return 0
68
 
69
  load_embedder() # lazy-load once
70
- new_vecs = torch.stack([embed(t) for t in docs])
71
  store = kb[user_id] # auto-creates via defaultdict
72
  store["texts"].extend(docs)
73
  store["vecs"] = (
@@ -119,7 +120,7 @@ def answer(system: str, context: str, question: str, user_id="demo", history="No
119
  context_list = [context]
120
  # 1. Retrieve top-k similar passages
121
  if history == "Some":
122
- q_vec = embed(question)
123
  store = kb[user_id]
124
  sims = torch.matmul(store["vecs"], q_vec) # [N]
125
  k = min(4, sims.numel())
@@ -134,8 +135,22 @@ def answer(system: str, context: str, question: str, user_id="demo", history="No
134
 
135
  # 3. Generate and strip everything before the assistant tag
136
  load_chat()
137
- inputs = tokenizer(prompt, return_tensors="pt").to(chat_model.device)
138
- output = chat_model.generate(**inputs, max_new_tokens=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  full = tokenizer.decode(output[0], skip_special_tokens=True)
140
  reply = full.split("<|im_start|>assistant")[-1].strip()
141
  return reply
@@ -220,8 +235,22 @@ def rag(req:QueryReq):
220
  prompt = build_qwen_prompt(context, req.question)
221
 
222
  load_chat()
223
- inputs = tokenizer(prompt, return_tensors="pt").to(chat_model.device)
224
- out = chat_model.generate(**inputs, max_new_tokens=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  full = tokenizer.decode(out[0], skip_special_tokens=True)
226
  ans = full.split("<|im_start|>assistant")[-1].strip()
227
  return {"answer": ans}
 
17
  HF_TOKEN = os.getenv("HF_token")
18
  CHAT_MODEL_ID = "QWen/Qwen1.5-7B-Chat"
19
  EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"
20
+ MAX_PROMPT_TOKENS = 8192
21
 
22
  # --- lazy loaders (unchanged) -------------------------------------------------
23
  tokenizer, chat_model = None, None
 
43
 
44
  @torch.no_grad()
45
  def embed(text:str)->torch.Tensor:
 
46
  load_embedder()
47
+ with torch.no_grad():
48
+ inputs = emb_tokenizer(text, return_tensors="pt", truncation=True).to(emb_model.device)
49
+ vec = emb_model(**inputs).last_hidden_state[:, 0]
50
+ return F.normalize(vec, dim=-1).cpu()
51
 
52
  # ---------- 2. tiny in-memory KB shared by Gradio & API ----------------------
53
  # ---------- 2. Tiny in-memory knowledge-base -------------------------------
 
68
  return 0
69
 
70
  load_embedder() # lazy-load once
71
+ new_vecs = torch.stack([embed(t) for t in docs]).cpu()
72
  store = kb[user_id] # auto-creates via defaultdict
73
  store["texts"].extend(docs)
74
  store["vecs"] = (
 
120
  context_list = [context]
121
  # 1. Retrieve top-k similar passages
122
  if history == "Some":
123
+ q_vec = embed(question).cpu()
124
  store = kb[user_id]
125
  sims = torch.matmul(store["vecs"], q_vec) # [N]
126
  k = min(4, sims.numel())
 
135
 
136
  # 3. Generate and strip everything before the assistant tag
137
  load_chat()
138
+ tokens = tokenizer(
139
+ prompt,
140
+ return_tensors="pt",
141
+ add_special_tokens=False, # important – we already built chat template
142
+ )
143
+ if tokens.input_ids.size(1) > MAX_PROMPT_TOKENS:
144
+ tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()}
145
+
146
+ tokens = {k: v.to(chat_model.device) for k, v in tokens.items()}
147
+
148
+ # --- generate ------------------------------------------------------
149
+ output = chat_model.generate(
150
+ **tokens,
151
+ max_new_tokens=512,
152
+ max_length=MAX_PROMPT_TOKENS + 512,
153
+ )
154
  full = tokenizer.decode(output[0], skip_special_tokens=True)
155
  reply = full.split("<|im_start|>assistant")[-1].strip()
156
  return reply
 
235
  prompt = build_qwen_prompt(context, req.question)
236
 
237
  load_chat()
238
+ tokens = tokenizer(
239
+ prompt,
240
+ return_tensors="pt",
241
+ add_special_tokens=False,
242
+ )
243
+ if tokens.input_ids.size(1) > MAX_PROMPT_TOKENS:
244
+ tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()}
245
+
246
+ tokens = {k: v.to(chat_model.device) for k, v in tokens.items()}
247
+
248
+ out = chat_model.generate(
249
+ **tokens,
250
+ max_new_tokens=512,
251
+ max_length=MAX_PROMPT_TOKENS + 512,
252
+ )
253
+
254
  full = tokenizer.decode(out[0], skip_special_tokens=True)
255
  ans = full.split("<|im_start|>assistant")[-1].strip()
256
  return {"answer": ans}