minjibi commited on
Commit
52d9b6e
·
1 Parent(s): 6471fde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -24
app.py CHANGED
@@ -16,32 +16,29 @@ tokenizer = MT5TokenizerFast.from_pretrained(
16
  )
17
 
18
  def predict(text):
19
- with torch.no_grad():
20
- input_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- input_ids = input_ids.cuda()
23
-
24
- generated_ids = model.generate(
25
- input_ids=input_ids,
26
- num_beams=5,
27
- max_length=1000,
28
- repetition_penalty=3.0, #default = 2.5
29
- length_penalty=1.0,
30
- early_stopping=True,
31
- top_p=50, #default 50
32
- top_k=20, #default 20
33
- num_return_sequences=3,
34
  )
35
-
36
- preds = [
37
- tokenizer.decode(
38
- g,
39
- skip_special_tokens=True,
40
- clean_up_tokenization_spaces=True,
41
- )
42
- for g in generated_ids
43
- ]
44
- return preds
45
 
46
  # text_to_predict = predict(text)
47
  # predicted = ['Q: ' + text for text in predict(text_to_predict)]
 
16
  )
17
 
18
  def predict(text):
19
+ # with torch.no_grad():
20
+ input_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
21
+ generated_ids = model.generate(
22
+ input_ids=input_ids,
23
+ num_beams=5,
24
+ max_length=1000,
25
+ repetition_penalty=3.0, #default = 2.5
26
+ length_penalty=1.0,
27
+ early_stopping=True,
28
+ top_p=50, #default 50
29
+ top_k=20, #default 20
30
+ num_return_sequences=3,
31
+ )
32
 
33
+ preds = [
34
+ tokenizer.decode(
35
+ g,
36
+ skip_special_tokens=True,
37
+ clean_up_tokenization_spaces=True,
 
 
 
 
 
 
 
38
  )
39
+ for g in generated_ids
40
+ ]
41
+ return preds
 
 
 
 
 
 
 
42
 
43
  # text_to_predict = predict(text)
44
  # predicted = ['Q: ' + text for text in predict(text_to_predict)]