svjack commited on
Commit
9ea124a
1 Parent(s): 7162e60

Upload 3 files

Browse files
Files changed (3) hide show
  1. predict.py +59 -0
  2. reconstructor.py +39 -0
  3. requirements.txt +3 -0
predict.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def batch_as_list(a, batch_size = int(100000)):
4
+ req = []
5
+ for ele in a:
6
+ if not req:
7
+ req.append([])
8
+ if len(req[-1]) < batch_size:
9
+ req[-1].append(ele)
10
+ else:
11
+ req.append([])
12
+ req[-1].append(ele)
13
+ return req
14
+
15
+ class Obj:
16
+ def __init__(self, model, tokenizer, device = "cpu"):
17
+ self.model = model
18
+ self.tokenizer = tokenizer
19
+ self.device = "cpu"
20
+
21
+ def predict(
22
+ self,
23
+ source_text: str,
24
+ max_length: int = 512,
25
+ num_return_sequences: int = 1,
26
+ num_beams: int = 2,
27
+ top_k: int = 50,
28
+ top_p: float = 0.95,
29
+ do_sample: bool = True,
30
+ repetition_penalty: float = 2.5,
31
+ length_penalty: float = 1.0,
32
+ early_stopping: bool = True,
33
+ skip_special_tokens: bool = True,
34
+ clean_up_tokenization_spaces: bool = True,
35
+ ):
36
+ input_ids = self.tokenizer.encode(
37
+ source_text, return_tensors="pt", add_special_tokens=True
38
+ )
39
+ input_ids = input_ids.to(self.device)
40
+ generated_ids = self.model.generate(
41
+ input_ids=input_ids,
42
+ num_beams=num_beams,
43
+ max_length=max_length,
44
+ repetition_penalty=repetition_penalty,
45
+ length_penalty=length_penalty,
46
+ early_stopping=early_stopping,
47
+ top_p=top_p,
48
+ top_k=top_k,
49
+ num_return_sequences=num_return_sequences,
50
+ )
51
+ preds = [
52
+ self.tokenizer.decode(
53
+ g,
54
+ skip_special_tokens=skip_special_tokens,
55
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
56
+ )
57
+ for g in generated_ids
58
+ ]
59
+ return preds
reconstructor.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from predict import *
2
+ from transformers import (
3
+ T5ForConditionalGeneration,
4
+ T5TokenizerFast as T5Tokenizer,
5
+ )
6
+ import jieba.posseg as posseg
7
+
8
+ model_path = "svjack/T5-dialogue-collect-v5"
9
+ tokenizer = T5Tokenizer.from_pretrained(model_path)
10
+ model = T5ForConditionalGeneration.from_pretrained(model_path)
11
+
12
+ rec_obj = Obj(model, tokenizer)
13
+
14
+ def process_one_sent(input_):
15
+ assert type(input_) == type("")
16
+ input_ = " ".join(map(lambda y: y.word.strip() ,filter(lambda x: x.flag != "x" ,
17
+ posseg.lcut(input_))))
18
+ return input_
19
+
20
+ def predict_split(sp_list, cut_tokens = True):
21
+ assert type(sp_list) == type([])
22
+ if cut_tokens:
23
+ src_text = '''
24
+ 根据下面的上下文进行分段:
25
+ 上下文:{}
26
+ 答案:
27
+ '''.format(" ".join(
28
+ map(process_one_sent ,sp_list)
29
+ ))
30
+ else:
31
+ src_text = '''
32
+ 根据下面的上下文进行分段:
33
+ 上下文:{}
34
+ 答案:
35
+ '''.format("".join(sp_list))
36
+ print(src_text)
37
+ pred = rec_obj.predict(src_text)[0]
38
+ pred = list(filter(lambda y: y ,map(lambda x: x.strip() ,pred.split("分段:"))))
39
+ return pred
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers==4.20.1
2
+ jieba
3
+ gradio