Rachel Bawden commited on
Commit
4ac4212
·
1 Parent(s): 744a36a

added beginning of pipeline

Browse files
Files changed (1) hide show
  1. pipeline.py +197 -0
pipeline.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ from transformers import Pipeline, pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
3
+ from transformers.tokenization_utils_base import TruncationStrategy
4
+ from torch import Tensor
5
+ import html.parser
6
+ import unicodedata
7
+ import sys, os
8
+ import re
9
+ from tqdm.auto import tqdm
10
+ import operator
11
+
12
+
13
+ ######## PredTitrage pipeline #########
14
+ class PredTitragePipeline(Pipeline):
15
+
16
+ def __init__(self, beam_size=5, batch_size=32, **kwargs):
17
+ self.beam_size = beam_size
18
+ super().__init__(**kwargs)
19
+
20
+
21
+ def _sanitize_parameters(self, clean_up_tokenisation_spaces=None, truncation=None, **generate_kwargs):
22
+ preprocess_params = {}
23
+ if truncation is not None:
24
+ preprocess_params["truncation"] = truncation
25
+ forward_params = generate_kwargs
26
+ postprocess_params = {}
27
+ if clean_up_tokenisation_spaces is not None:
28
+ postprocess_params["clean_up_tokenisation_spaces"] = clean_up_tokenisation_spaces
29
+ return preprocess_params, forward_params, postprocess_params
30
+
31
+
32
+ def check_inputs(self, input_length: int, min_length: int, max_length: int):
33
+ """
34
+ Checks whether there might be something wrong with given input with regard to the model.
35
+ """
36
+ return True
37
+
38
+ def make_printable(self, s):
39
+ '''Replace non-printable characters in a string.'''
40
+ return s.translate(NOPRINT_TRANS_TABLE)
41
+
42
+
43
+ def normalise(self, line):
44
+ line = unicodedata.normalize('NFKC', line)
45
+ line = self.make_printable(line)
46
+ for before, after in [('[«»\“\”]', '"'),
47
+ ('[‘’]', "'"),
48
+ (' +', ' '),
49
+ ('\"+', '"'),
50
+ ("'+", "'"),
51
+ ('^ *', ''),
52
+ (' *$', '')]:
53
+ line = re.sub(before, after, line)
54
+ return line.strip() + ' </s>'
55
+
56
+ def _parse_and_tokenise(self, *args, truncation):
57
+ prefix = ""
58
+ if isinstance(args[0], list):
59
+ if self.tokenizer.pad_token_id is None:
60
+ raise ValueError("Please make sure that the tokeniser has a pad_token_id when using a batch input")
61
+ args = ([prefix + arg for arg in args[0]],)
62
+ padding = True
63
+
64
+ elif isinstance(args[0], str):
65
+ args = (prefix + args[0],)
66
+ padding = False
67
+ else:
68
+ raise ValueError(
69
+ f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
70
+ )
71
+ inputs = [self.normalise(x) for x in args]
72
+ inputs = self.tokenizer(inputs, padding=padding, truncation=truncation, return_tensors=self.framework)
73
+ toks = []
74
+ for tok_ids in inputs.input_ids:
75
+ toks.append(" ".join(self.tokenizer.convert_ids_to_tokens(tok_ids)))
76
+ # This is produced by tokenisers but is an invalid generate kwargs
77
+ if "token_type_ids" in inputs:
78
+ del inputs["token_type_ids"]
79
+ return inputs
80
+
81
+ def preprocess(self, inputs, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs):
82
+ inputs = self._parse_and_tokenise(inputs, truncation=truncation, **kwargs)
83
+ return inputs
84
+
85
+ def _forward(self, model_inputs, **generate_kwargs):
86
+ in_b, input_length = model_inputs["input_ids"].shape
87
+
88
+ generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
89
+ generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
90
+ generate_kwargs['num_beams'] = self.beam_size
91
+ self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
92
+ output_ids = self.model.generate(**model_inputs, **generate_kwargs)
93
+ out_b = output_ids.shape[0]
94
+ output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
95
+ return {"output_ids": output_ids}
96
+
97
+ def postprocess(self, model_outputs, clean_up_tokenisation_spaces=False):
98
+ records = []
99
+ for output_ids in model_outputs["output_ids"][0]:
100
+ record = {
101
+ "text": self.tokenizer.decode(
102
+ output_ids,
103
+ skip_special_tokens=True,
104
+ clean_up_tokenisation_spaces=clean_up_tokenisation_spaces,
105
+ )
106
+ }
107
+ records.append(record)
108
+ return records
109
+
110
+ def correct_hallunications(self, orig, output):
111
+ # align the original and output tokens
112
+
113
+ # check that the correspondences are legitimate and correct if not
114
+
115
+ # replace <EMOJI> symbols by the original ones
116
+ return output
117
+
118
+ def __call__(self, *args, **kwargs):
119
+ r"""
120
+ Generate the output text(s) using text(s) given as inputs.
121
+ Args:
122
+ args (`str` or `List[str]`):
123
+ Input text for the encoder.
124
+ return_tensors (`bool`, *optional*, defaults to `False`):
125
+ Whether or not to include the tensors of predictions (as token indices) in the outputs.
126
+ return_text (`bool`, *optional*, defaults to `True`):
127
+ Whether or not to include the decoded texts in the outputs.
128
+ clean_up_tokenisation_spaces (`bool`, *optional*, defaults to `False`):
129
+ Whether or not to clean up the potential extra spaces in the text output.
130
+ truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
131
+ The truncation strategy for the tokenisation within the pipeline. `TruncationStrategy.DO_NOT_TRUNCATE`
132
+ (default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's
133
+ max_length instead of throwing an error down the line.
134
+ generate_kwargs:
135
+ Additional keyword arguments to pass along to the generate method of the model (see the generate method
136
+ corresponding to your framework [here](./model#generative-models)).
137
+ Return:
138
+ A list or a list of list of `dict`: Each result comes as a dictionary with the following keys:
139
+ - **generated_text** (`str`, present when `return_text=True`) -- The generated text.
140
+ - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
141
+ ids of the generated text.
142
+ """
143
+
144
+ result = super().__call__(*args, **kwargs)
145
+ if (isinstance(args[0], list)
146
+ and all(isinstance(el, str) for el in args[0])
147
+ and all(len(res) == 1 for res in result)):
148
+ return result
149
+ else:
150
+ return result[0] # check this
151
+
152
+
153
+ def predict_titrages(list_sents, batch_size=32, beam_size=5):
154
+ tokeniser = AutoTokenizer.from_pretrained("rbawden/CCASS-pred-titrages-base", use_auth_token=True)
155
+ model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/CCASS-pred-titrages-base", use_auth_token=True)
156
+ pipeline = PredTitragesPipeline(model=model,
157
+ tokenizer=tokeniser,
158
+ batch_size=batch_size,
159
+ beam_size=beam_size)
160
+ outputs = pipeline(list_sents)
161
+ return outputs
162
+
163
+ def predict_from_stdin(batch_size=32, beam_size=5):
164
+ tokeniser = AutoTokenizer.from_pretrained("rbawden/CCASS-pred-titrages-base", use_auth_token=True)
165
+ model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/CCASS-pred-titrages-base", use_auth_token=True)
166
+ pipeline = PredTitragesPipeline(model=model,
167
+ tokenizer=tokeniser,
168
+ batch_size=batch_size,
169
+ beam_size=beam_size)
170
+ list_sents = []
171
+ for sent in sys.stdin:
172
+ list_sents.append(sent.strip())
173
+ outputs = pipeline(list_sents)
174
+ for s, sent in enumerate(outputs):
175
+ print(sent)
176
+ return outputs
177
+
178
+
179
+ if __name__ == '__main__':
180
+
181
+ import argparse
182
+ parser = argparse.ArgumentParser()
183
+ parser.add_argument('-k', '--batch_size', type=int, default=32, help='Set the batch size for decoding')
184
+ parser.add_argument('-b', '--beam_size', type=int, default=5, help='Set the beam size for decoding')
185
+ parser.add_argument('-i', '--input_file', type=str, default=None, help='Input file. If None, read from STDIN')
186
+ args = parser.parse_args()
187
+
188
+ if args.input_file is None:
189
+ predict_from_stdin(batch_size=args.batch_size, beam_size=args.beam_size)
190
+ else:
191
+ list_sents = []
192
+ with open(args.input_file) as fp:
193
+ for line in fp:
194
+ list_sents.append(line.strip())
195
+ output_sents = predict_text(list_sents, batch_size=args.batch_size, beam_size=args.beam_size)
196
+ for output_sent in output_sents:
197
+ print(output_sent)