alexkueck commited on
Commit
8f16105
1 Parent(s): 6f10600

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +10 -2
utils.py CHANGED
@@ -388,9 +388,17 @@ def query(api_llm, payload):
388
  def llm_chain2(prompt, context):
389
  full_prompt = RAG_CHAIN_PROMPT.format(context=context, question=prompt)
390
  inputs = tokenizer_rag(full_prompt, return_tensors="pt", max_length=1024, truncation=True)
391
-
392
  #Generiere die Antwort
393
- outputs = modell_rag.generate(inputs['input_ids'], max_new_tokens=1024, num_beams=2, early_stopping=True)
 
 
 
 
 
 
 
 
394
  answer = tokenizer_rag.decode(outputs[0], skip_special_tokens=True)
395
 
396
  return answer
 
388
  def llm_chain2(prompt, context):
389
  full_prompt = RAG_CHAIN_PROMPT.format(context=context, question=prompt)
390
  inputs = tokenizer_rag(full_prompt, return_tensors="pt", max_length=1024, truncation=True)
391
+ attention_mask = (inputs != tokenizer_rag.pad_token_id).long()
392
  #Generiere die Antwort
393
+ outputs = modell_rag.generate(
394
+ inputs,
395
+ attention_mask=attention_mask,
396
+ max_new_tokens=1024,
397
+ do_sample=True,
398
+ temperature=0.9,
399
+ pad_token_id=tokenizer.eos_token_id
400
+ )
401
+ #outputs = modell_rag.generate(inputs['input_ids'], max_new_tokens=1024, num_beams=2, early_stopping=True)
402
  answer = tokenizer_rag.decode(outputs[0], skip_special_tokens=True)
403
 
404
  return answer