sappho192 commited on
Commit
137b040
1 Parent(s): 6b1a623

Set max_length=500 when calling generate()

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -12,7 +12,7 @@ def translate(text_src):
12
  embeddings = src_tokenizer(text_src, return_attention_mask=False, return_token_type_ids=False, return_tensors='pt')
13
  embeddings = {k: v for k, v in embeddings.items()}
14
  # using default generation method: GreedySearch, No LogitsProcessor
15
- output = model.generate(**embeddings)[0, 1:-1]
16
  text_trg = trg_tokenizer.decode(output.cpu())
17
  return text_trg
18
 
 
12
  embeddings = src_tokenizer(text_src, return_attention_mask=False, return_token_type_ids=False, return_tensors='pt')
13
  embeddings = {k: v for k, v in embeddings.items()}
14
  # using default generation method: GreedySearch, No LogitsProcessor
15
+ output = model.generate(**embeddings, max_length=500)[0, 1:-1]
16
  text_trg = trg_tokenizer.decode(output.cpu())
17
  return text_trg
18