pablorocg commited on
Commit
ee8e6af
1 Parent(s): bdd2583

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -3
app.py CHANGED
@@ -336,8 +336,8 @@ def answer_query(query_text, index, documents, llm_model, llm_tokenizer, embeddi
336
  retrieved_info = get_retrieved_info(documents, I, D)
337
  formatted_info = format_retrieved_info(retrieved_info)
338
  prompt = generate_prompt(query_text, formatted_info)
339
- answer = answer_using_gemma(prompt, llm_model, llm_tokenizer)
340
- return answer
341
 
342
 
343
 
@@ -393,7 +393,30 @@ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", token=HF_TOKE
393
 
394
 
395
  def make_inference(query, hist):
396
- return answer_query(query, index, documents, model, tokenizer, CFG.embedding_model, CFG.n_samples, CFG.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  demo = gr.ChatInterface(fn = make_inference,
399
  examples = ["What is diabetes?", "Is ginseng good for diabetes?", "What are the symptoms of diabetes?", "What is Celiac disease?"],
 
336
  retrieved_info = get_retrieved_info(documents, I, D)
337
  formatted_info = format_retrieved_info(retrieved_info)
338
  prompt = generate_prompt(query_text, formatted_info)
339
+ # answer = answer_using_gemma(prompt, llm_model, llm_tokenizer)
340
+ return prompt
341
 
342
 
343
 
 
393
 
394
 
395
  def make_inference(query, hist):
396
+ prompt = answer_query(query, index, documents, model, tokenizer, CFG.embedding_model, CFG.n_samples, CFG.device)
397
+ # answer = answer_using_gemma(prompt, llm_model, llm_tokenizer)
398
+ model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
399
+ count_tokens = lambda text: len(tokenizer.tokenize(text))
400
+
401
+ streamer = TextIteratorStreamer(tokenizer, timeout=540., skip_prompt=True, skip_special_tokens=True)
402
+
403
+ generate_kwargs = dict(
404
+ model_inputs,
405
+ streamer=streamer,
406
+ max_new_tokens=6000 - count_tokens(prompt),
407
+ top_p=0.2,
408
+ top_k=20,
409
+ temperature=0.1,
410
+ repetition_penalty=2.0,
411
+ length_penalty=-0.5,
412
+ num_beams=1
413
+ )
414
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
415
+ t.start() # Starting the generation in a separate thread.
416
+ partial_message = ""
417
+ for new_token in streamer:
418
+ partial_message += new_token
419
+ yield partial_message
420
 
421
  demo = gr.ChatInterface(fn = make_inference,
422
  examples = ["What is diabetes?", "Is ginseng good for diabetes?", "What are the symptoms of diabetes?", "What is Celiac disease?"],