File size: 8,893 Bytes
4ac4212 276d772 4ac4212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
#!/usr/bin/python
from transformers import Pipeline, pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.tokenization_utils_base import TruncationStrategy
from torch import Tensor
import html.parser
import unicodedata
import sys, os
import re
from tqdm.auto import tqdm
import operator
######## PredTitrage pipeline #########
class PredTitragesPipeline(Pipeline):
def __init__(self, beam_size=5, batch_size=32, **kwargs):
self.beam_size = beam_size
super().__init__(**kwargs)
def _sanitize_parameters(self, clean_up_tokenisation_spaces=None, truncation=None, **generate_kwargs):
preprocess_params = {}
if truncation is not None:
preprocess_params["truncation"] = truncation
forward_params = generate_kwargs
postprocess_params = {}
if clean_up_tokenisation_spaces is not None:
postprocess_params["clean_up_tokenisation_spaces"] = clean_up_tokenisation_spaces
return preprocess_params, forward_params, postprocess_params
def check_inputs(self, input_length: int, min_length: int, max_length: int):
"""
Checks whether there might be something wrong with given input with regard to the model.
"""
return True
def make_printable(self, s):
'''Replace non-printable characters in a string.'''
return s.translate(NOPRINT_TRANS_TABLE)
def normalise(self, line):
line = unicodedata.normalize('NFKC', line)
line = self.make_printable(line)
for before, after in [('[«»\“\”]', '"'),
('[‘’]', "'"),
(' +', ' '),
('\"+', '"'),
("'+", "'"),
('^ *', ''),
(' *$', '')]:
line = re.sub(before, after, line)
return line.strip() + ' </s>'
def _parse_and_tokenise(self, *args, truncation):
prefix = ""
if isinstance(args[0], list):
if self.tokenizer.pad_token_id is None:
raise ValueError("Please make sure that the tokeniser has a pad_token_id when using a batch input")
args = ([prefix + arg for arg in args[0]],)
padding = True
elif isinstance(args[0], str):
args = (prefix + args[0],)
padding = False
else:
raise ValueError(
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
)
inputs = [self.normalise(x) for x in args]
inputs = self.tokenizer(inputs, padding=padding, truncation=truncation, return_tensors=self.framework)
toks = []
for tok_ids in inputs.input_ids:
toks.append(" ".join(self.tokenizer.convert_ids_to_tokens(tok_ids)))
# This is produced by tokenisers but is an invalid generate kwargs
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
return inputs
def preprocess(self, inputs, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs):
inputs = self._parse_and_tokenise(inputs, truncation=truncation, **kwargs)
return inputs
def _forward(self, model_inputs, **generate_kwargs):
in_b, input_length = model_inputs["input_ids"].shape
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
generate_kwargs['num_beams'] = self.beam_size
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
out_b = output_ids.shape[0]
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
return {"output_ids": output_ids}
def postprocess(self, model_outputs, clean_up_tokenisation_spaces=False):
records = []
for output_ids in model_outputs["output_ids"][0]:
record = {
"text": self.tokenizer.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenisation_spaces=clean_up_tokenisation_spaces,
)
}
records.append(record)
return records
def correct_hallunications(self, orig, output):
# align the original and output tokens
# check that the correspondences are legitimate and correct if not
# replace <EMOJI> symbols by the original ones
return output
def __call__(self, *args, **kwargs):
r"""
Generate the output text(s) using text(s) given as inputs.
Args:
args (`str` or `List[str]`):
Input text for the encoder.
return_tensors (`bool`, *optional*, defaults to `False`):
Whether or not to include the tensors of predictions (as token indices) in the outputs.
return_text (`bool`, *optional*, defaults to `True`):
Whether or not to include the decoded texts in the outputs.
clean_up_tokenisation_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the potential extra spaces in the text output.
truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
The truncation strategy for the tokenisation within the pipeline. `TruncationStrategy.DO_NOT_TRUNCATE`
(default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's
max_length instead of throwing an error down the line.
generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework [here](./model#generative-models)).
Return:
A list or a list of list of `dict`: Each result comes as a dictionary with the following keys:
- **generated_text** (`str`, present when `return_text=True`) -- The generated text.
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
ids of the generated text.
"""
result = super().__call__(*args, **kwargs)
if (isinstance(args[0], list)
and all(isinstance(el, str) for el in args[0])
and all(len(res) == 1 for res in result)):
return result
else:
return result[0] # check this
def predict_titrages(list_sents, batch_size=32, beam_size=5):
tokeniser = AutoTokenizer.from_pretrained("rbawden/CCASS-pred-titrages-base", use_auth_token=True)
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/CCASS-pred-titrages-base", use_auth_token=True)
pipeline = PredTitragesPipeline(model=model,
tokenizer=tokeniser,
batch_size=batch_size,
beam_size=beam_size)
outputs = pipeline(list_sents)
return outputs
def predict_from_stdin(batch_size=32, beam_size=5):
tokeniser = AutoTokenizer.from_pretrained("rbawden/CCASS-pred-titrages-base", use_auth_token=True)
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/CCASS-pred-titrages-base", use_auth_token=True)
pipeline = PredTitragesPipeline(model=model,
tokenizer=tokeniser,
batch_size=batch_size,
beam_size=beam_size)
list_sents = []
for sent in sys.stdin:
list_sents.append(sent.strip())
outputs = pipeline(list_sents)
for s, sent in enumerate(outputs):
print(sent)
return outputs
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-k', '--batch_size', type=int, default=32, help='Set the batch size for decoding')
parser.add_argument('-b', '--beam_size', type=int, default=5, help='Set the beam size for decoding')
parser.add_argument('-i', '--input_file', type=str, default=None, help='Input file. If None, read from STDIN')
args = parser.parse_args()
if args.input_file is None:
predict_from_stdin(batch_size=args.batch_size, beam_size=args.beam_size)
else:
list_sents = []
with open(args.input_file) as fp:
for line in fp:
list_sents.append(line.strip())
output_sents = predict_text(list_sents, batch_size=args.batch_size, beam_size=args.beam_size)
for output_sent in output_sents:
print(output_sent)
|