jason9693 commited on
Commit
4a26f8c
1 Parent(s): c8fceef

modified args

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -23,12 +23,12 @@ model.eval()
23
  pipe = pipeline('text-generation', model=model, tokenizer=tokenizer, eos_token_id=tokenizer.eos_token_id)
24
 
25
  def predict(text):
26
- stopping_cond = StoppingCriteriaList([tokenizer.encode('<|endoftext|>')])
27
  with torch.no_grad():
28
  tokens = tokenizer(text, return_tensors="pt").input_ids
 
29
  gen_tokens = model.generate(
30
  tokens, do_sample=True, temperature=0.8, max_new_tokens=64, top_k=50, top_p=0.8,
31
- no_repeat_ngram_size=3, repetition_penalty=1.2
32
  )
33
  generated = tokenizer.batch_decode(gen_tokens)[0]
34
  return generated
 
23
  pipe = pipeline('text-generation', model=model, tokenizer=tokenizer, eos_token_id=tokenizer.eos_token_id)
24
 
25
  def predict(text):
 
26
  with torch.no_grad():
27
  tokens = tokenizer(text, return_tensors="pt").input_ids
28
+ # generate and end generate if <|endoftext|> is not in text
29
  gen_tokens = model.generate(
30
  tokens, do_sample=True, temperature=0.8, max_new_tokens=64, top_k=50, top_p=0.8,
31
+ no_repeat_ngram_size=3, repetition_penalty=1.2, bad_word_ids=[[11066]], eos_token_id=tokenizer.eos_token_id
32
  )
33
  generated = tokenizer.batch_decode(gen_tokens)[0]
34
  return generated