Spaces:
Build error
Build error
modified args
Browse files
app.py
CHANGED
@@ -23,12 +23,12 @@ model.eval()
|
|
23 |
pipe = pipeline('text-generation', model=model, tokenizer=tokenizer, eos_token_id=tokenizer.eos_token_id)
|
24 |
|
25 |
def predict(text):
|
26 |
-
stopping_cond = StoppingCriteriaList([tokenizer.encode('<|endoftext|>')])
|
27 |
with torch.no_grad():
|
28 |
tokens = tokenizer(text, return_tensors="pt").input_ids
|
|
|
29 |
gen_tokens = model.generate(
|
30 |
tokens, do_sample=True, temperature=0.8, max_new_tokens=64, top_k=50, top_p=0.8,
|
31 |
-
no_repeat_ngram_size=3, repetition_penalty=1.2
|
32 |
)
|
33 |
generated = tokenizer.batch_decode(gen_tokens)[0]
|
34 |
return generated
|
|
|
23 |
pipe = pipeline('text-generation', model=model, tokenizer=tokenizer, eos_token_id=tokenizer.eos_token_id)
|
24 |
|
25 |
def predict(text):
|
|
|
26 |
with torch.no_grad():
|
27 |
tokens = tokenizer(text, return_tensors="pt").input_ids
|
28 |
+
# generate and end generate if <|endoftext|> is not in text
|
29 |
gen_tokens = model.generate(
|
30 |
tokens, do_sample=True, temperature=0.8, max_new_tokens=64, top_k=50, top_p=0.8,
|
31 |
+
no_repeat_ngram_size=3, repetition_penalty=1.2, bad_word_ids=[[11066]], eos_token_id=tokenizer.eos_token_id
|
32 |
)
|
33 |
generated = tokenizer.batch_decode(gen_tokens)[0]
|
34 |
return generated
|