Spaces:
Sleeping
Sleeping
from typing import Optional | |
import json | |
from argparse import Namespace | |
from pathlib import Path | |
from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer | |
def get_markers_for_model(is_t5_model: bool) -> Namespace: | |
special_tokens_constants = Namespace() | |
if is_t5_model: | |
# T5 model have 100 special tokens by default | |
special_tokens_constants.separator_input_question_predicate = "<extra_id_1>" | |
special_tokens_constants.separator_output_answers = "<extra_id_3>" | |
special_tokens_constants.separator_output_questions = "<extra_id_5>" # if using only questions | |
special_tokens_constants.separator_output_question_answer = "<extra_id_7>" | |
special_tokens_constants.separator_output_pairs = "<extra_id_9>" | |
special_tokens_constants.predicate_generic_marker = "<extra_id_10>" | |
special_tokens_constants.predicate_verb_marker = "<extra_id_11>" | |
special_tokens_constants.predicate_nominalization_marker = "<extra_id_12>" | |
else: | |
special_tokens_constants.separator_input_question_predicate = "<question_predicate_sep>" | |
special_tokens_constants.separator_output_answers = "<answers_sep>" | |
special_tokens_constants.separator_output_questions = "<question_sep>" # if using only questions | |
special_tokens_constants.separator_output_question_answer = "<question_answer_sep>" | |
special_tokens_constants.separator_output_pairs = "<qa_pairs_sep>" | |
special_tokens_constants.predicate_generic_marker = "<predicate_marker>" | |
special_tokens_constants.predicate_verb_marker = "<verbal_predicate_marker>" | |
special_tokens_constants.predicate_nominalization_marker = "<nominalization_predicate_marker>" | |
return special_tokens_constants | |
def load_trained_model(name_or_path): | |
import huggingface_hub as HFhub | |
tokenizer = AutoTokenizer.from_pretrained(name_or_path) | |
model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path) | |
# load preprocessing_kwargs from the model repo on HF hub, or from the local model directory | |
kwargs_filename = None | |
if name_or_path.startswith("kleinay/"): # and 'preprocessing_kwargs.json' in HFhub.list_repo_files(name_or_path): # the supported version of HFhub doesn't support list_repo_files | |
kwargs_filename = HFhub.hf_hub_download(repo_id=name_or_path, filename="preprocessing_kwargs.json") | |
elif Path(name_or_path).is_dir() and (Path(name_or_path) / "experiment_kwargs.json").exists(): | |
kwargs_filename = Path(name_or_path) / "experiment_kwargs.json" | |
if kwargs_filename: | |
preprocessing_kwargs = json.load(open(kwargs_filename)) | |
# integrate into model.config (for decoding args, e.g. "num_beams"), and save also as standalone object for preprocessing | |
model.config.preprocessing_kwargs = Namespace(**preprocessing_kwargs) | |
model.config.update(preprocessing_kwargs) | |
return model, tokenizer | |
class QASRL_Pipeline(Text2TextGenerationPipeline): | |
def __init__(self, model_repo: str, **kwargs): | |
model, tokenizer = load_trained_model(model_repo) | |
super().__init__(model, tokenizer, framework="pt") | |
self.is_t5_model = "t5" in model.config.model_type | |
self.special_tokens = get_markers_for_model(self.is_t5_model) | |
self.data_args = model.config.preprocessing_kwargs | |
# backward compatibility - default keyword values implemeted in `run_summarization`, thus not saved in `preprocessing_kwargs` | |
if "predicate_marker_type" not in vars(self.data_args): | |
self.data_args.predicate_marker_type = "generic" | |
if "use_bilateral_predicate_marker" not in vars(self.data_args): | |
self.data_args.use_bilateral_predicate_marker = True | |
if "append_verb_form" not in vars(self.data_args): | |
self.data_args.append_verb_form = True | |
self._update_config(**kwargs) | |
def _update_config(self, **kwargs): | |
" Update self.model.config with initialization parameters and necessary defaults. " | |
# set default values that will always override model.config, but can overriden by __init__ kwargs | |
kwargs["max_length"] = kwargs.get("max_length", 80) | |
# override model.config with kwargs | |
for k,v in kwargs.items(): | |
self.model.config.__dict__[k] = v | |
def _sanitize_parameters(self, **kwargs): | |
preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {} | |
if "predicate_marker" in kwargs: | |
preprocess_kwargs["predicate_marker"] = kwargs["predicate_marker"] | |
if "predicate_type" in kwargs: | |
preprocess_kwargs["predicate_type"] = kwargs["predicate_type"] | |
if "verb_form" in kwargs: | |
preprocess_kwargs["verb_form"] = kwargs["verb_form"] | |
return preprocess_kwargs, forward_kwargs, postprocess_kwargs | |
def preprocess(self, inputs, predicate_marker="<predicate>", predicate_type=None, verb_form=None): | |
# Here, inputs is string or list of strings; apply string postprocessing | |
if isinstance(inputs, str): | |
processed_inputs = self._preprocess_string(inputs, predicate_marker, predicate_type, verb_form) | |
elif hasattr(inputs, "__iter__"): | |
processed_inputs = [self._preprocess_string(s, predicate_marker, predicate_type, verb_form) for s in inputs] | |
else: | |
raise ValueError("inputs must be str or Iterable[str]") | |
# Now pass to super.preprocess for tokenization | |
return super().preprocess(processed_inputs) | |
def _preprocess_string(self, seq: str, predicate_marker: str, predicate_type: Optional[str], verb_form: Optional[str]) -> str: | |
sent_tokens = seq.split(" ") | |
assert predicate_marker in sent_tokens, f"Input sentence must include a predicate-marker token ('{predicate_marker}') before the target predicate word" | |
predicate_idx = sent_tokens.index(predicate_marker) | |
sent_tokens.remove(predicate_marker) | |
sentence_before_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx)]) | |
predicate = sent_tokens[predicate_idx] | |
sentence_after_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx+1, len(sent_tokens))]) | |
if self.data_args.predicate_marker_type == "generic": | |
predicate_marker = self.special_tokens.predicate_generic_marker | |
# In case we want special marker for each predicate type: """ | |
elif self.data_args.predicate_marker_type == "pred_type": | |
assert predicate_type is not None, "For this model, you must provide the `predicate_type` either when initializing QASRL_Pipeline(...) or when applying __call__(...) on it" | |
assert predicate_type in ("verbal", "nominal"), f"`predicate_type` must be either 'verbal' or 'nominal'; got '{predicate_type}'" | |
predicate_marker = {"verbal": self.special_tokens.predicate_verb_marker , | |
"nominal": self.special_tokens.predicate_nominalization_marker | |
}[predicate_type] | |
if self.data_args.use_bilateral_predicate_marker: | |
seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {predicate_marker} {sentence_after_predicate}" | |
else: | |
seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {sentence_after_predicate}" | |
# embed also verb_form | |
if self.data_args.append_verb_form and verb_form is None: | |
raise ValueError(f"For this model, you must provide the `verb_form` of the predicate when applying __call__(...)") | |
elif self.data_args.append_verb_form: | |
seq = f"{seq} {self.special_tokens.separator_input_question_predicate} {verb_form} " | |
else: | |
seq = f"{seq} " | |
# append source prefix (for t5 models) | |
prefix = self._get_source_prefix(predicate_type) | |
return prefix + seq | |
def _get_source_prefix(self, predicate_type: Optional[str]): | |
if not self.is_t5_model or self.data_args.source_prefix is None: | |
return '' | |
if not self.data_args.source_prefix.startswith("<"): # Regular prefix - not dependent on input row x | |
return self.data_args.source_prefix | |
if self.data_args.source_prefix == "<predicate-type>": | |
if predicate_type is None: | |
raise ValueError("source_prefix is '<predicate-type>' but input no `predicate_type`.") | |
else: | |
return f"Generate QAs for {predicate_type} QASRL: " | |
def _forward(self, *args, **kwargs): | |
outputs = super()._forward(*args, **kwargs) | |
return outputs | |
def postprocess(self, model_outputs): | |
output_seq = self.tokenizer.decode( | |
model_outputs["output_ids"].squeeze(), | |
skip_special_tokens=False, | |
clean_up_tokenization_spaces=False, | |
) | |
output_seq = output_seq.strip(self.tokenizer.pad_token).strip(self.tokenizer.eos_token).strip() | |
qa_subseqs = output_seq.split(self.special_tokens.separator_output_pairs) | |
qas = [self._postrocess_qa(qa_subseq) for qa_subseq in qa_subseqs] | |
return {"generated_text": output_seq, | |
"QAs": qas} | |
def _postrocess_qa(self, seq: str) -> str: | |
# split question and answers | |
if self.special_tokens.separator_output_question_answer in seq: | |
question, answer = seq.split(self.special_tokens.separator_output_question_answer)[:2] | |
else: | |
print("invalid format: no separator between question and answer found...") | |
return None | |
# question, answer = seq, '' # Or: backoff to only question | |
# skip "_" slots in questions | |
question = ' '.join(t for t in question.split(' ') if t != '_') | |
answers = [a.strip() for a in answer.split(self.special_tokens.separator_output_answers)] | |
return {"question": question, "answers": answers} | |
if __name__ == "__main__": | |
pipe = QASRL_Pipeline("kleinay/qanom-seq2seq-model-baseline") | |
res1 = pipe("The student was interested in Luke 's <predicate> research about sea animals .", verb_form="research", predicate_type="nominal") | |
res2 = pipe(["The doctor was interested in Luke 's <predicate> treatment .", | |
"The Veterinary student was interested in Luke 's <predicate> treatment of sea animals ."], verb_form="treat", predicate_type="nominal", num_beams=10) | |
res3 = pipe("A number of professions have <predicate> developed that specialize in the treatment of mental disorders .", verb_form="develop", predicate_type="verbal") | |
print(res1) | |
print(res2) | |
print(res3) | |