cadige commited on
Commit
50e4490
1 Parent(s): 35ba67d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +51 -0
  2. qasrl_model_pipeline.py +183 -0
  3. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from qasrl_model_pipeline import QASRL_Pipeline
3
+
4
+ models = ["kleinay/qanom-seq2seq-model-baseline",
5
+ "kleinay/qanom-seq2seq-model-joint"]
6
+ pipelines = {model: QASRL_Pipeline(model) for model in models}
7
+
8
+
9
+ description = f"""Using Seq2Seq T5 model which takes a sequence of items and outputs another sequence this model generates Questions and Answers (QA) with focus on Semantic Role Labeling (SRL)"""
10
+ title="Seq2Seq T5 Questions and Answers (QA) with Semantic Role Labeling (SRL)"
11
+ examples = [[models[0], "In March and April the patient <p> had two falls. One was related to asthma, heart palpitations. The second was due to syncope and post covid vaccination dizziness during exercise. The patient is now getting an EKG. Former EKG had shown that there was a bundle branch block. Patient had some uncontrolled immune system reactions like anaphylaxis and shortness of breath.", True, "fall"],
12
+ [models[1], "In March and April the patient had two falls. One was related to asthma, heart palpitations. The second was due to syncope and post covid vaccination dizziness during exercise. The patient is now getting an EKG. Former EKG had shown that there was a bundle branch block. Patient had some uncontrolled immune system reactions <p> like anaphylaxis and shortness of breath.", True, "reactions"],
13
+ [models[0], "In March and April the patient had two falls. One was related <p> to asthma, heart palpitations. The second was due to syncope and post covid vaccination dizziness during exercise. The patient is now getting an EKG. Former EKG had shown that there was a bundle branch block. Patient had some uncontrolled immune system reactions like anaphylaxis and shortness of breath.", True, "relate"],
14
+ [models[1], "In March and April the patient <p> had two falls. One was related to asthma, heart palpitations. The second was due to syncope and post covid vaccination dizziness during exercise. The patient is now getting an EKG. Former EKG had shown that there was a bundle branch block. Patient had some uncontrolled immune system reactions like anaphylaxis and shortness of breath.", False, "fall"]]
15
+
16
+ input_sent_box_label = "Insert sentence here. Mark the predicate by adding the token '<p>' before it."
17
+ verb_form_inp_placeholder = "e.g. 'decide' for the nominalization 'decision', 'teach' for 'teacher', etc."
18
+ links = """<p style='text-align: center'>
19
+ <a href='https://www.qasrl.org' target='_blank'>QASRL Website</a> | <a href='https://huggingface.co/kleinay/qanom-seq2seq-model-baseline' target='_blank'>Model Repo at Huggingface Hub</a>
20
+ </p>"""
21
+ def call(model_name, sentence, is_nominal, verb_form):
22
+ predicate_marker="<p>"
23
+ if predicate_marker not in sentence:
24
+ raise ValueError("You must highlight one word of the sentence as a predicate using preceding '<p>'.")
25
+
26
+ if not verb_form:
27
+ if is_nominal:
28
+ raise ValueError("You should provide the verbal form of the nominalization")
29
+
30
+ toks = sentence.split(" ")
31
+ pred_idx = toks.index(predicate_marker)
32
+ predicate = toks(pred_idx+1)
33
+ verb_form=predicate
34
+ pipeline = pipelines[model_name]
35
+ pipe_out = pipeline([sentence],
36
+ predicate_marker=predicate_marker,
37
+ predicate_type="nominal" if is_nominal else "verbal",
38
+ verb_form=verb_form)[0]
39
+ return pipe_out["QAs"], pipe_out["generated_text"]
40
+ iface = gr.Interface(fn=call,
41
+ inputs=[gr.inputs.Radio(choices=models, default=models[0], label="Model"),
42
+ gr.inputs.Textbox(placeholder=input_sent_box_label, label="Sentence", lines=4),
43
+ gr.inputs.Checkbox(default=True, label="Is Nominalization?"),
44
+ gr.inputs.Textbox(placeholder=verb_form_inp_placeholder, label="Verbal form (for nominalizations)", default='')],
45
+ outputs=[gr.outputs.JSON(label="Model Output - QASRL"), gr.outputs.Textbox(label="Raw output sequence")],
46
+ title=title,
47
+ description=description,
48
+ article=links,
49
+ examples=examples )
50
+
51
+ iface.launch()
qasrl_model_pipeline.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import json
3
+ from argparse import Namespace
4
+ from pathlib import Path
5
+ from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer
6
+
7
+ def get_markers_for_model(is_t5_model: bool) -> Namespace:
8
+ special_tokens_constants = Namespace()
9
+ if is_t5_model:
10
+ # T5 model have 100 special tokens by default
11
+ special_tokens_constants.separator_input_question_predicate = "<extra_id_1>"
12
+ special_tokens_constants.separator_output_answers = "<extra_id_3>"
13
+ special_tokens_constants.separator_output_questions = "<extra_id_5>" # if using only questions
14
+ special_tokens_constants.separator_output_question_answer = "<extra_id_7>"
15
+ special_tokens_constants.separator_output_pairs = "<extra_id_9>"
16
+ special_tokens_constants.predicate_generic_marker = "<extra_id_10>"
17
+ special_tokens_constants.predicate_verb_marker = "<extra_id_11>"
18
+ special_tokens_constants.predicate_nominalization_marker = "<extra_id_12>"
19
+
20
+ else:
21
+ special_tokens_constants.separator_input_question_predicate = "<question_predicate_sep>"
22
+ special_tokens_constants.separator_output_answers = "<answers_sep>"
23
+ special_tokens_constants.separator_output_questions = "<question_sep>" # if using only questions
24
+ special_tokens_constants.separator_output_question_answer = "<question_answer_sep>"
25
+ special_tokens_constants.separator_output_pairs = "<qa_pairs_sep>"
26
+ special_tokens_constants.predicate_generic_marker = "<predicate_marker>"
27
+ special_tokens_constants.predicate_verb_marker = "<verbal_predicate_marker>"
28
+ special_tokens_constants.predicate_nominalization_marker = "<nominalization_predicate_marker>"
29
+ return special_tokens_constants
30
+
31
+ def load_trained_model(name_or_path):
32
+ import huggingface_hub as HFhub
33
+ tokenizer = AutoTokenizer.from_pretrained(name_or_path)
34
+ model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path)
35
+ # load preprocessing_kwargs from the model repo on HF hub, or from the local model directory
36
+ kwargs_filename = None
37
+ if name_or_path.startswith("kleinay/"): # and 'preprocessing_kwargs.json' in HFhub.list_repo_files(name_or_path): # the supported version of HFhub doesn't support list_repo_files
38
+ kwargs_filename = HFhub.hf_hub_download(repo_id=name_or_path, filename="preprocessing_kwargs.json")
39
+ elif Path(name_or_path).is_dir() and (Path(name_or_path) / "experiment_kwargs.json").exists():
40
+ kwargs_filename = Path(name_or_path) / "experiment_kwargs.json"
41
+
42
+ if kwargs_filename:
43
+ preprocessing_kwargs = json.load(open(kwargs_filename))
44
+ # integrate into model.config (for decoding args, e.g. "num_beams"), and save also as standalone object for preprocessing
45
+ model.config.preprocessing_kwargs = Namespace(**preprocessing_kwargs)
46
+ model.config.update(preprocessing_kwargs)
47
+ return model, tokenizer
48
+
49
+
50
+ class QASRL_Pipeline(Text2TextGenerationPipeline):
51
+ def __init__(self, model_repo: str, **kwargs):
52
+ model, tokenizer = load_trained_model(model_repo)
53
+ super().__init__(model, tokenizer, framework="pt")
54
+ self.is_t5_model = "t5" in model.config.model_type
55
+ self.special_tokens = get_markers_for_model(self.is_t5_model)
56
+ self.data_args = model.config.preprocessing_kwargs
57
+ # backward compatibility - default keyword values implemeted in `run_summarization`, thus not saved in `preprocessing_kwargs`
58
+ if "predicate_marker_type" not in vars(self.data_args):
59
+ self.data_args.predicate_marker_type = "generic"
60
+ if "use_bilateral_predicate_marker" not in vars(self.data_args):
61
+ self.data_args.use_bilateral_predicate_marker = True
62
+ if "append_verb_form" not in vars(self.data_args):
63
+ self.data_args.append_verb_form = True
64
+ self._update_config(**kwargs)
65
+
66
+ def _update_config(self, **kwargs):
67
+ " Update self.model.config with initialization parameters and necessary defaults. "
68
+ # set default values that will always override model.config, but can overriden by __init__ kwargs
69
+ kwargs["max_length"] = kwargs.get("max_length", 80)
70
+ # override model.config with kwargs
71
+ for k,v in kwargs.items():
72
+ self.model.config.__dict__[k] = v
73
+
74
+ def _sanitize_parameters(self, **kwargs):
75
+ preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {}
76
+ if "predicate_marker" in kwargs:
77
+ preprocess_kwargs["predicate_marker"] = kwargs["predicate_marker"]
78
+ if "predicate_type" in kwargs:
79
+ preprocess_kwargs["predicate_type"] = kwargs["predicate_type"]
80
+ if "verb_form" in kwargs:
81
+ preprocess_kwargs["verb_form"] = kwargs["verb_form"]
82
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
83
+
84
+ def preprocess(self, inputs, predicate_marker="<predicate>", predicate_type=None, verb_form=None):
85
+ # Here, inputs is string or list of strings; apply string postprocessing
86
+ if isinstance(inputs, str):
87
+ processed_inputs = self._preprocess_string(inputs, predicate_marker, predicate_type, verb_form)
88
+ elif hasattr(inputs, "__iter__"):
89
+ processed_inputs = [self._preprocess_string(s, predicate_marker, predicate_type, verb_form) for s in inputs]
90
+ else:
91
+ raise ValueError("inputs must be str or Iterable[str]")
92
+ # Now pass to super.preprocess for tokenization
93
+ return super().preprocess(processed_inputs)
94
+
95
+ def _preprocess_string(self, seq: str, predicate_marker: str, predicate_type: Optional[str], verb_form: Optional[str]) -> str:
96
+ sent_tokens = seq.split(" ")
97
+ assert predicate_marker in sent_tokens, f"Input sentence must include a predicate-marker token ('{predicate_marker}') before the target predicate word"
98
+ predicate_idx = sent_tokens.index(predicate_marker)
99
+ sent_tokens.remove(predicate_marker)
100
+ sentence_before_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx)])
101
+ predicate = sent_tokens[predicate_idx]
102
+ sentence_after_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx+1, len(sent_tokens))])
103
+
104
+ if self.data_args.predicate_marker_type == "generic":
105
+ predicate_marker = self.special_tokens.predicate_generic_marker
106
+ # In case we want special marker for each predicate type: """
107
+ elif self.data_args.predicate_marker_type == "pred_type":
108
+ assert predicate_type is not None, "For this model, you must provide the `predicate_type` either when initializing QASRL_Pipeline(...) or when applying __call__(...) on it"
109
+ assert predicate_type in ("verbal", "nominal"), f"`predicate_type` must be either 'verbal' or 'nominal'; got '{predicate_type}'"
110
+ predicate_marker = {"verbal": self.special_tokens.predicate_verb_marker ,
111
+ "nominal": self.special_tokens.predicate_nominalization_marker
112
+ }[predicate_type]
113
+
114
+ if self.data_args.use_bilateral_predicate_marker:
115
+ seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {predicate_marker} {sentence_after_predicate}"
116
+ else:
117
+ seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {sentence_after_predicate}"
118
+
119
+ # embed also verb_form
120
+ if self.data_args.append_verb_form and verb_form is None:
121
+ raise ValueError(f"For this model, you must provide the `verb_form` of the predicate when applying __call__(...)")
122
+ elif self.data_args.append_verb_form:
123
+ seq = f"{seq} {self.special_tokens.separator_input_question_predicate} {verb_form} "
124
+ else:
125
+ seq = f"{seq} "
126
+
127
+ # append source prefix (for t5 models)
128
+ prefix = self._get_source_prefix(predicate_type)
129
+
130
+ return prefix + seq
131
+
132
+ def _get_source_prefix(self, predicate_type: Optional[str]):
133
+ if not self.is_t5_model or self.data_args.source_prefix is None:
134
+ return ''
135
+ if not self.data_args.source_prefix.startswith("<"): # Regular prefix - not dependent on input row x
136
+ return self.data_args.source_prefix
137
+ if self.data_args.source_prefix == "<predicate-type>":
138
+ if predicate_type is None:
139
+ raise ValueError("source_prefix is '<predicate-type>' but input no `predicate_type`.")
140
+ else:
141
+ return f"Generate QAs for {predicate_type} QASRL: "
142
+
143
+ def _forward(self, *args, **kwargs):
144
+ outputs = super()._forward(*args, **kwargs)
145
+ return outputs
146
+
147
+
148
+ def postprocess(self, model_outputs):
149
+ output_seq = self.tokenizer.decode(
150
+ model_outputs["output_ids"].squeeze(),
151
+ skip_special_tokens=False,
152
+ clean_up_tokenization_spaces=False,
153
+ )
154
+ output_seq = output_seq.strip(self.tokenizer.pad_token).strip(self.tokenizer.eos_token).strip()
155
+ qa_subseqs = output_seq.split(self.special_tokens.separator_output_pairs)
156
+ qas = [self._postrocess_qa(qa_subseq) for qa_subseq in qa_subseqs]
157
+ return {"generated_text": output_seq,
158
+ "QAs": qas}
159
+
160
+ def _postrocess_qa(self, seq: str) -> str:
161
+ # split question and answers
162
+ if self.special_tokens.separator_output_question_answer in seq:
163
+ question, answer = seq.split(self.special_tokens.separator_output_question_answer)[:2]
164
+ else:
165
+ print("invalid format: no separator between question and answer found...")
166
+ return None
167
+ # question, answer = seq, '' # Or: backoff to only question
168
+ # skip "_" slots in questions
169
+ question = ' '.join(t for t in question.split(' ') if t != '_')
170
+ answers = [a.strip() for a in answer.split(self.special_tokens.separator_output_answers)]
171
+ return {"question": question, "answers": answers}
172
+
173
+
174
+ if __name__ == "__main__":
175
+ pipe = QASRL_Pipeline("kleinay/qanom-seq2seq-model-baseline")
176
+ res1 = pipe("The student was interested in Luke 's <predicate> research about sea animals .", verb_form="research", predicate_type="nominal")
177
+ res2 = pipe(["The doctor was interested in Luke 's <predicate> treatment .",
178
+ "The Veterinary student was interested in Luke 's <predicate> treatment of sea animals ."], verb_form="treat", predicate_type="nominal", num_beams=10)
179
+ res3 = pipe("A number of professions have <predicate> developed that specialize in the treatment of mental disorders .", verb_form="develop", predicate_type="verbal")
180
+ print(res1)
181
+ print(res2)
182
+ print(res3)
183
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers==4.15.0
2
+ torch