tmskss commited on
Commit
22b1013
1 Parent(s): 33e5cf5

Update output format, batch inference

Browse files
Files changed (1) hide show
  1. app.py +123 -49
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import time
3
  import pinecone
@@ -10,36 +11,47 @@ from transformers.generation.stopping_criteria import StoppingCriteria, Stopping
10
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
11
  from torch import nn
12
  from sentence_transformers.cross_encoder import CrossEncoder
13
- from peft import PeftModel
14
  from sentence_transformers import SentenceTransformer
 
15
  from bs4 import BeautifulSoup
16
  import requests
17
 
18
  headers = {
19
- "User-Agent":"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_5) AppleWebKit 537.36 (KHTML, like Gecko) Chrome",
20
- "Accept":"text/html,application/xhtml+xml,application/xml; q=0.9,image/webp,*/*;q=0.8",
21
- 'Cookie':'CONSENT=YES+cb.20210418-17-p0.it+FX+917; '
22
  }
23
 
 
24
  def google_search(text):
25
  print(f"Google search on: {text}")
26
  try:
27
- site = requests.get(f'https://www.google.com/search?hl=en&q={text}', headers=headers)
28
- main = BeautifulSoup(site.text, features="html.parser").select_one('#main').select('.VwiC3b.lyLwlc.yDYNvb.W8l4ac')
29
- res = '\n\n'.join([m.get_text() for m in main])
 
 
 
 
 
 
 
 
 
 
30
  except Exception as ex:
31
  print(f"Error: {ex}")
32
- res = ""
33
 
34
- print(f"The result of the google search is: {res}")
35
 
36
- return res
37
 
38
  PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
39
 
40
- sentencetransformer_model = SentenceTransformer('sentence-transformers/multi-qa-mpnet-base-cos-v1')
41
  pinecone.init(api_key=PINECONE_API_KEY, environment="gcp-starter")
42
 
 
43
 
44
  CACHE_DIR = "./.cache"
45
  INDEX_NAME = "k8s-semantic-search"
@@ -79,6 +91,7 @@ def create_embedding(text: str):
79
 
80
  return embed_text.tolist()
81
 
 
82
  index = pinecone.Index(INDEX_NAME)
83
 
84
 
@@ -190,6 +203,23 @@ start_template = "### Answer:"
190
  command_template = "# Command:"
191
  end_template = "#End"
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  def text_to_text_generation(verbose, prompt):
195
  prompt = prompt.strip()
@@ -205,6 +235,7 @@ def text_to_text_generation(verbose, prompt):
205
  )
206
 
207
  model_input = tokenizer(is_kubectl_prompt, return_tensors="pt").to("cuda")
 
208
  with torch.no_grad():
209
  response = tokenizer.decode(
210
  model.generate(
@@ -223,7 +254,7 @@ def text_to_text_generation(verbose, prompt):
223
 
224
  response_num = 0 if "0" in response else (1 if "1" in response else 2)
225
 
226
- def generate(response_num, prompt, retriever, verbose):
227
  match response_num:
228
  case 0:
229
  prompt = f"[INST] {prompt}\n Lets think step by step. [/INST] {start_template}"
@@ -241,56 +272,104 @@ def text_to_text_generation(verbose, prompt):
241
  case _:
242
  prompt = f"[INST] {prompt} [/INST]"
243
 
244
- print("Query:")
245
- print(prompt)
 
 
246
 
247
- # Generate output
248
- model_input = tokenizer(prompt, return_tensors="pt").to("cuda")
249
  with torch.no_grad():
250
- response = tokenizer.decode(
251
  model.generate(
252
- **model_input,
253
  max_new_tokens=256,
254
  pad_token_id=tokenizer.eos_token_id,
255
  repetition_penalty=1.15,
256
  stopping_criteria=StoppingCriteriaList([eval_stop_criterion]),
257
- )[0],
258
  skip_special_tokens=True,
259
  )
260
 
261
- decoded_prompt = tokenizer.decode(tokenizer(prompt).input_ids, skip_special_tokens=True)
262
 
263
- start = (
264
- response.index(start_template) + len(start_template) if start_template in response else len(decoded_prompt)
265
- )
266
- start = response.index(command_template) + len(command_template) if command_template in response else start
267
- end = response.index(end_template) if end_template in response else len(response)
268
 
269
- return response if verbose else response[start:end].strip()
 
 
 
270
 
271
- true_response = generate(response_num, prompt, False, verbose)
272
- true_response_semantic_search = generate(response_num, prompt, "semantic_search", verbose)
273
- true_response_google_search = generate(response_num, prompt, "google_search", verbose)
274
 
 
 
 
275
 
276
- print("Returned: " + true_response)
277
- print(f'{" QUERY END ":-^40}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
- match response_num:
280
- case 0:
281
- mode = "Kubectl"
282
- case 1:
283
- mode = "Kubernetes"
284
- case _:
285
- mode = "Normal"
286
 
287
- return (
288
- f"*Mode*: {mode}",
289
- f"# Answer\n\n {true_response}",
290
- f"# Answer with RAG\n\n {true_response_semantic_search}",
291
- f"# Answer with Google search\n\n {true_response_google_search}"
292
  )
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
  iface = gr.Interface(
296
  fn=text_to_text_generation,
@@ -298,12 +377,7 @@ iface = gr.Interface(
298
  gr.components.Checkbox(label="Verbose"),
299
  gr.components.Text(placeholder="prompt here ...", label="Prompt"),
300
  ],
301
- outputs=[
302
- gr.components.Markdown(label="Mode"),
303
- gr.components.Markdown(label="Answer Without Retriever"),
304
- gr.components.Markdown(label="Answer With Retriever"),
305
- gr.components.Markdown(label="Answer With Google search"),
306
- ],
307
  allow_flagging="never",
308
  )
309
 
 
1
+ import re
2
  import torch
3
  import time
4
  import pinecone
 
11
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
12
  from torch import nn
13
  from sentence_transformers.cross_encoder import CrossEncoder
 
14
  from sentence_transformers import SentenceTransformer
15
+ from peft import PeftModel
16
  from bs4 import BeautifulSoup
17
  import requests
18
 
19
  headers = {
20
+ "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_5) AppleWebKit 537.36 (KHTML, like Gecko) Chrome",
21
+ "Accept": "text/html,application/xhtml+xml,application/xml; q=0.9,image/webp,*/*;q=0.8",
22
+ "Cookie": "CONSENT=YES+cb.20210418-17-p0.it+FX+917; ",
23
  }
24
 
25
+
26
  def google_search(text):
27
  print(f"Google search on: {text}")
28
  try:
29
+ site = requests.get(f"https://www.google.com/search?hl=en&q={text}", headers=headers)
30
+ main = (
31
+ BeautifulSoup(site.text, features="html.parser").select_one("#main").select(".VwiC3b.lyLwlc.yDYNvb.W8l4ac")
32
+ )
33
+ res = []
34
+ for m in main:
35
+ t = m.get_text()
36
+ if "—" in t:
37
+ t = t[len("—") + t.index("—") :].strip()
38
+
39
+ res.append(t)
40
+
41
+ ans = "\n".join(res)
42
  except Exception as ex:
43
  print(f"Error: {ex}")
44
+ ans = ""
45
 
46
+ print(f"The result of the google search is: {ans}")
47
 
48
+ return ans
49
 
50
  PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
51
 
 
52
  pinecone.init(api_key=PINECONE_API_KEY, environment="gcp-starter")
53
 
54
+ sentencetransformer_model = SentenceTransformer('sentence-transformers/multi-qa-mpnet-base-cos-v1')
55
 
56
  CACHE_DIR = "./.cache"
57
  INDEX_NAME = "k8s-semantic-search"
 
91
 
92
  return embed_text.tolist()
93
 
94
+
95
  index = pinecone.Index(INDEX_NAME)
96
 
97
 
 
203
  command_template = "# Command:"
204
  end_template = "#End"
205
 
206
+ def str_to_md(text):
207
+ def escape_hash(line):
208
+ i = 0
209
+ while i < len(line) and line[i] == ' ':
210
+ i+=1
211
+
212
+ if i == len(line):
213
+ return line
214
+
215
+ if line[i] == '#':
216
+ line = line[:i] + '\\' + line[i:]
217
+
218
+ return line
219
+
220
+ lines = text.split('\n')
221
+ lines = [escape_hash(line) for line in lines]
222
+ return ' \n'.join(lines)
223
 
224
  def text_to_text_generation(verbose, prompt):
225
  prompt = prompt.strip()
 
235
  )
236
 
237
  model_input = tokenizer(is_kubectl_prompt, return_tensors="pt").to("cuda")
238
+
239
  with torch.no_grad():
240
  response = tokenizer.decode(
241
  model.generate(
 
254
 
255
  response_num = 0 if "0" in response else (1 if "1" in response else 2)
256
 
257
+ def create_generation_prompt(response_num, prompt, retriever):
258
  match response_num:
259
  case 0:
260
  prompt = f"[INST] {prompt}\n Lets think step by step. [/INST] {start_template}"
 
272
  case _:
273
  prompt = f"[INST] {prompt} [/INST]"
274
 
275
+ return prompt
276
+
277
+ def generate_batch(*prompts):
278
+ tokenized_inputs = tokenizer(prompts, return_tensors="pt", padding=True).to("cuda")
279
 
 
 
280
  with torch.no_grad():
281
+ responses = tokenizer.batch_decode(
282
  model.generate(
283
+ **tokenized_inputs,
284
  max_new_tokens=256,
285
  pad_token_id=tokenizer.eos_token_id,
286
  repetition_penalty=1.15,
287
  stopping_criteria=StoppingCriteriaList([eval_stop_criterion]),
288
+ ),
289
  skip_special_tokens=True,
290
  )
291
 
292
+ decoded_prompts = tokenizer.batch_decode(tokenized_inputs.input_ids, skip_special_tokens=True)
293
 
294
+ return [(prompt, answer) for prompt, answer in zip(decoded_prompts, responses)]
 
 
 
 
295
 
296
+ def cleanup(prompt, answer):
297
+ start = answer.index(start_template) + len(start_template) if start_template in answer else len(prompt)
298
+ start = answer.index(command_template) + len(command_template) if command_template in answer else start
299
+ end = answer.index(end_template) if end_template in answer else len(answer)
300
 
301
+ return (prompt, answer[start:end].strip())
 
 
302
 
303
+ modes = ["Kubectl command", "Kubernetes definition", "Normal"]
304
+ modes[response_num] = f"**{modes[response_num]}**"
305
+ modes = " / ".join(modes)
306
 
307
+ if response_num == 2:
308
+ prompt = create_generation_prompt(response_num, prompt, False)
309
+ original, new = generate_batch(prompt)[0]
310
+ prompt, response = cleanup(original, new)
311
+ if verbose:
312
+ return f"{modes}\n\n" f"# Prompt given to the model:\n" f"{str_to_md(prompt)}\n" f"# Model's answer:\n" f"{str_to_md(response)}\n"
313
+ else:
314
+ return f"{modes}\n\n" f"# Answer:\n" f"{str_to_md(response)}"
315
+
316
+ if response_num == 0:
317
+ prompt = create_generation_prompt(response_num, prompt, False)
318
+ original, new = generate_batch(prompt)[0]
319
+ prompt, response = cleanup(original, new)
320
+ model_response = new[len(original):].strip()
321
+ if verbose:
322
+ return (
323
+ f"{modes}\n\n"
324
+ f"# Prompt given to the model:\n"
325
+ f"{str_to_md(prompt)}\n"
326
+ f"# Model's answer:\n"
327
+ f"{str_to_md(model_response)}\n"
328
+ f"# Processed answer:\n"
329
+ f"```bash\n{str_to_md(response)}\n```\n"
330
+ )
331
+ else:
332
+ return f"{modes}\n\n" f"# Answer:\n" f"```bash\n{str_to_md(response)}\n```\n"
333
 
334
+ res_prompt = create_generation_prompt(response_num, prompt, False)
335
+ res_semantic_search_prompt = create_generation_prompt(response_num, prompt, "semantic_search")
336
+ res_google_search_prompt = create_generation_prompt(response_num, prompt, "google_search")
 
 
 
 
337
 
338
+ gen_normal, gen_semantic_search, gen_google_search = generate_batch(
339
+ res_prompt, res_semantic_search_prompt, res_google_search_prompt
 
 
 
340
  )
341
 
342
+ res_prompt, res_normal = cleanup(*gen_normal)
343
+ res_semantic_search_prompt, res_semantic_search = cleanup(*gen_semantic_search)
344
+ res_google_search_prompt, res_google_search = cleanup(*gen_google_search)
345
+
346
+ if verbose:
347
+ return (
348
+ f"{modes}\n\n"
349
+ f"# Answer with finetuned model\n"
350
+ f"## Prompt given to the model:\n"
351
+ f"{str_to_md(res_prompt)}\n\n"
352
+ f"## Model's answer:\n"
353
+ f"{str_to_md(res_normal)}\n\n"
354
+ f"# Answer with RAG\n"
355
+ f"## Prompt given to the model:\n"
356
+ f"{str_to_md(res_semantic_search_prompt)}\n\n"
357
+ f"## Model's answer:\n"
358
+ f"{str_to_md(res_semantic_search)}\n\n"
359
+ f"# Answer with Google search\n"
360
+ f"## Prompt given to the model:\n"
361
+ f"{str_to_md(res_google_search_prompt)}\n\n"
362
+ f"## Model's answer:\n"
363
+ f"{str_to_md(res_google_search)}\n\n"
364
+ )
365
+ else:
366
+ return (
367
+ f"{modes}\n\n"
368
+ f"# Answer with finetuned model\n\n {str_to_md(res_normal)}\n"
369
+ f"# Answer with RAG\n\n {str_to_md(res_semantic_search)}\n"
370
+ f"# Answer with Google search\n\n {str_to_md(res_google_search)}"
371
+ )
372
+
373
 
374
  iface = gr.Interface(
375
  fn=text_to_text_generation,
 
377
  gr.components.Checkbox(label="Verbose"),
378
  gr.components.Text(placeholder="prompt here ...", label="Prompt"),
379
  ],
380
+ outputs=gr.components.Markdown(label="Answer"),
 
 
 
 
 
381
  allow_flagging="never",
382
  )
383