kleinay commited on
Commit
5bc50e8
1 Parent(s): f5d506d

Adding Pipeline support for easy usage

Browse files
Files changed (1) hide show
  1. pipeline.py +180 -0
pipeline.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import json
3
+ from argparse import Namespace
4
+ from pathlib import Path
5
+ from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer
6
+ import preprocessing
7
+
8
+ def get_markers_for_model(is_t5_model: bool) -> Namespace:
9
+ special_tokens_constants = Namespace()
10
+ if is_t5_model:
11
+ # T5 model have 100 special tokens by default
12
+ special_tokens_constants.separator_input_question_predicate = "<extra_id_1>"
13
+ special_tokens_constants.separator_output_answers = "<extra_id_3>"
14
+ special_tokens_constants.separator_output_questions = "<extra_id_5>" # if using only questions
15
+ special_tokens_constants.separator_output_question_answer = "<extra_id_7>"
16
+ special_tokens_constants.separator_output_pairs = "<extra_id_9>"
17
+ special_tokens_constants.predicate_generic_marker = "<extra_id_10>"
18
+ special_tokens_constants.predicate_verb_marker = "<extra_id_11>"
19
+ special_tokens_constants.predicate_nominalization_marker = "<extra_id_12>"
20
+
21
+ else:
22
+ special_tokens_constants.separator_input_question_predicate = "<question_predicate_sep>"
23
+ special_tokens_constants.separator_output_answers = "<answers_sep>"
24
+ special_tokens_constants.separator_output_questions = "<question_sep>" # if using only questions
25
+ special_tokens_constants.separator_output_question_answer = "<question_answer_sep>"
26
+ special_tokens_constants.separator_output_pairs = "<qa_pairs_sep>"
27
+ special_tokens_constants.predicate_generic_marker = "<predicate_marker>"
28
+ special_tokens_constants.predicate_verb_marker = "<verbal_predicate_marker>"
29
+ special_tokens_constants.predicate_nominalization_marker = "<nominalization_predicate_marker>"
30
+ return special_tokens_constants
31
+
32
+ def load_trained_model(name_or_path):
33
+ import huggingface_hub as HFhub
34
+ tokenizer = AutoTokenizer.from_pretrained(name_or_path)
35
+ model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path)
36
+ # load preprocessing_kwargs from the model repo on HF hub, or from the local model directory
37
+ kwargs_filename = None
38
+ if name_or_path.startswith("kleinay/") and 'preprocessing_kwargs.json' in HFhub.list_repo_files(name_or_path):
39
+ kwargs_filename = HFhub.hf_hub_download(repo_id=name_or_path, filename="preprocessing_kwargs.json")
40
+ elif Path(name_or_path).is_dir() and (Path(name_or_path) / "experiment_kwargs.json").exists():
41
+ kwargs_filename = Path(name_or_path) / "experiment_kwargs.json"
42
+
43
+ if kwargs_filename:
44
+ preprocessing_kwargs = json.load(open(kwargs_filename))
45
+ # integrate into model.config (for decoding args, e.g. "num_beams"), and save also as standalone object for preprocessing
46
+ model.config.preprocessing_kwargs = Namespace(**preprocessing_kwargs)
47
+ model.config.update(preprocessing_kwargs)
48
+ return model, tokenizer
49
+
50
+
51
+ class QASRL_Pipeline(Text2TextGenerationPipeline):
52
+ def __init__(self, model_repo: str, **kwargs):
53
+ model, tokenizer = load_trained_model(model_repo)
54
+ super().__init__(model, tokenizer, framework="pt")
55
+ self.is_t5_model = "t5" in model.config.model_type
56
+ self.special_tokens = get_markers_for_model(self.is_t5_model)
57
+ # self.preprocessor = preprocessing.Preprocessor(model.config.preprocessing_kwargs, self.special_tokens)
58
+ self.data_args = model.config.preprocessing_kwargs
59
+ # backward compatibility - default keyword values implemeted in `run_summarization`, thus not saved in `preprocessing_kwargs`
60
+ if "predicate_marker_type" not in vars(self.data_args):
61
+ self.data_args.predicate_marker_type = "generic"
62
+ if "use_bilateral_predicate_marker" not in vars(self.data_args):
63
+ self.data_args.use_bilateral_predicate_marker = True
64
+ if "append_verb_form" not in vars(self.data_args):
65
+ self.data_args.append_verb_form = True
66
+
67
+
68
+ def _sanitize_parameters(self, **kwargs):
69
+ preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {} # super()._sanitize_parameters(**kwargs)
70
+ if "predicate_marker" in kwargs:
71
+ preprocess_kwargs["predicate_marker"] = kwargs["predicate_marker"]
72
+ if "predicate_type" in kwargs:
73
+ preprocess_kwargs["predicate_type"] = kwargs["predicate_type"]
74
+ if "verb_form" in kwargs:
75
+ preprocess_kwargs["verb_form"] = kwargs["verb_form"]
76
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
77
+
78
+ def preprocess(self, inputs, predicate_marker="<predicate>", predicate_type=None, verb_form=None):
79
+ # Here, inputs is string or list of strings; apply string postprocessing
80
+ if isinstance(inputs, str):
81
+ processed_inputs = self._preprocess_string(inputs, predicate_marker, predicate_type, verb_form)
82
+ elif hasattr(inputs, "__iter__"):
83
+ processed_inputs = [self._preprocess_string(s, predicate_marker, predicate_type, verb_form) for s in inputs]
84
+ else:
85
+ raise ValueError("inputs must be str or Iterable[str]")
86
+ # Now pass to super.preprocess for tokenization
87
+ return super().preprocess(processed_inputs)
88
+
89
+ def _preprocess_string(self, seq: str, predicate_marker: str, predicate_type: Optional[str], verb_form: Optional[str]) -> str:
90
+ sent_tokens = seq.split(" ")
91
+ assert predicate_marker in sent_tokens, f"Input sentence must include a predicate-marker token ('{predicate_marker}') before the target predicate word"
92
+ predicate_idx = sent_tokens.index(predicate_marker)
93
+ sent_tokens.remove(predicate_marker)
94
+ sentence_before_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx)])
95
+ predicate = sent_tokens[predicate_idx]
96
+ sentence_after_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx+1, len(sent_tokens))])
97
+
98
+ if self.data_args.predicate_marker_type == "generic":
99
+ predicate_marker = self.special_tokens.predicate_generic_marker
100
+ # In case we want special marker for each predicate type: """
101
+ elif self.data_args.predicate_marker_type == "pred_type":
102
+ assert predicate_type is not None, "For this model, you must provide the `predicate_type` either when initializing QASRL_Pipeline(...) or when applying __call__(...) on it"
103
+ assert predicate_type in ("verbal", "nominal"), f"`predicate_type` must be either 'verbal' or 'nominal'; got '{predicate_type}'"
104
+ predicate_marker = {"verbal": self.special_tokens.predicate_verb_marker ,
105
+ "nominal": self.special_tokens.predicate_nominalization_marker
106
+ }[predicate_type]
107
+
108
+ if self.data_args.use_bilateral_predicate_marker:
109
+ seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {predicate_marker} {sentence_after_predicate}"
110
+ else:
111
+ seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {sentence_after_predicate}"
112
+
113
+ # embed also verb_form
114
+ if self.data_args.append_verb_form and verb_form is None:
115
+ raise ValueError(f"For this model, you must provide the `verb_form` of the predicate when applying __call__(...)")
116
+ elif self.data_args.append_verb_form:
117
+ seq = f"{seq} {self.special_tokens.separator_input_question_predicate} {verb_form} "
118
+ else:
119
+ seq = f"{seq} "
120
+
121
+ # append source prefix (for t5 models)
122
+ prefix = self._get_source_prefix(predicate_type)
123
+
124
+ return prefix + seq
125
+
126
+ def _get_source_prefix(self, predicate_type: Optional[str]):
127
+ if not self.is_t5_model or self.data_args.source_prefix is None:
128
+ return ''
129
+ if "Generate QAs for <predicate_type> QASRL: " in self.data_args.source_prefix:
130
+ if predicate_type is None:
131
+ raise ValueError("source_prefix includes 'Generate QAs for <predicate_type> QASRL: ' but input has no `predicate_type`.")
132
+ if self.data_args.source_prefix == "Generate QAs for <predicate_type> QASRL: ": # backwrad compatibility - "Generate QAs for <predicate_type> QASRL: " alone was a sign for a longer prefix
133
+ return f"Generate QAs for {predicate_type} QASRL: "
134
+ else:
135
+ return self.data_args.source_prefix.replace("Generate QAs for <predicate_type> QASRL: ", predicate_type)
136
+ else:
137
+ return self.data_args.source_prefix
138
+
139
+
140
+ def _forward(self, *args, **kwargs):
141
+ outputs = super()._forward(*args, **kwargs)
142
+ return outputs
143
+
144
+
145
+ def postprocess(self, model_outputs):
146
+ output_seq = self.tokenizer.decode(
147
+ model_outputs["output_ids"][0],
148
+ skip_special_tokens=False,
149
+ clean_up_tokenization_spaces=False,
150
+ )
151
+ output_seq = output_seq.strip(self.tokenizer.pad_token).strip(self.tokenizer.eos_token).strip()
152
+ qa_subseqs = output_seq.split(self.special_tokens.separator_output_pairs)
153
+ qas = [self._postrocess_qa(qa_subseq) for qa_subseq in qa_subseqs]
154
+ return {"generated_text": output_seq,
155
+ "QAs": qas}
156
+
157
+ def _postrocess_qa(self, seq: str) -> str:
158
+ # split question and answers
159
+ if self.special_tokens.separator_output_question_answer in seq:
160
+ question, answer = seq.split(self.special_tokens.separator_output_question_answer)[:2]
161
+ else:
162
+ print("invalid format: no separator between question and answer found...")
163
+ return None
164
+ # question, answer = seq, '' # Or: backoff to only question
165
+ # skip "_" slots in questions
166
+ question = ' '.join(t for t in question.split(' ') if t != '_')
167
+ answers = [a.strip() for a in answer.split(self.special_tokens.separator_output_answers)]
168
+ return {"question": question, "answers": answers}
169
+
170
+
171
+ if __name__ == "__main__":
172
+ pipe = QASRL_Pipeline("kleinay/qanom-seq2seq-model-baseline")
173
+ res1 = pipe("The student was interested in Luke 's <predicate> research about see animals .", verb_form="research", predicate_type="nominal")
174
+ res2 = pipe(["The doctor was interested in Luke 's <predicate> treatment .",
175
+ "The Veterinary student was interested in Luke 's <predicate> treatment of sea animals ."], verb_form="treat", predicate_type="nominal", num_beams=10)
176
+ res3 = pipe("A number of professions have <predicate> developed that specialize in the treatment of mental disorders .", verb_form="develop", predicate_type="verbal")
177
+ print(res1)
178
+ print(res2)
179
+ print(res3)
180
+