svjack commited on
Commit
ec67560
·
1 Parent(s): b67a9e4

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +60 -0
predict.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def batch_as_list(a, batch_size = int(100000)):
4
+ req = []
5
+ for ele in a:
6
+ if not req:
7
+ req.append([])
8
+ if len(req[-1]) < batch_size:
9
+ req[-1].append(ele)
10
+ else:
11
+ req.append([])
12
+ req[-1].append(ele)
13
+ return req
14
+
15
+ class Obj:
16
+ def __init__(self, model, tokenizer, device = "cpu"):
17
+ self.model = model
18
+ self.tokenizer = tokenizer
19
+ self.device = device
20
+ self.model = self.model.to(self.device)
21
+
22
+ def predict(
23
+ self,
24
+ source_text: str,
25
+ max_length: int = 512,
26
+ num_return_sequences: int = 1,
27
+ num_beams: int = 2,
28
+ top_k: int = 50,
29
+ top_p: float = 0.95,
30
+ do_sample: bool = True,
31
+ repetition_penalty: float = 2.5,
32
+ length_penalty: float = 1.0,
33
+ early_stopping: bool = True,
34
+ skip_special_tokens: bool = True,
35
+ clean_up_tokenization_spaces: bool = True,
36
+ ):
37
+ input_ids = self.tokenizer.encode(
38
+ source_text, return_tensors="pt", add_special_tokens=True
39
+ )
40
+ input_ids = input_ids.to(self.device)
41
+ generated_ids = self.model.generate(
42
+ input_ids=input_ids,
43
+ num_beams=num_beams,
44
+ max_length=max_length,
45
+ repetition_penalty=repetition_penalty,
46
+ length_penalty=length_penalty,
47
+ early_stopping=early_stopping,
48
+ top_p=top_p,
49
+ top_k=top_k,
50
+ num_return_sequences=num_return_sequences,
51
+ )
52
+ preds = [
53
+ self.tokenizer.decode(
54
+ g,
55
+ skip_special_tokens=skip_special_tokens,
56
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
57
+ )
58
+ for g in generated_ids
59
+ ]
60
+ return preds