from argparse import Namespace from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer def get_markers_for_model(): special_tokens_constants = Namespace() special_tokens_constants.separator_different_qa = "&&&" special_tokens_constants.separator_output_question_answer = "SSEEPP" special_tokens_constants.source_prefix = "qa: " return special_tokens_constants def load_trained_model(name_or_path): tokenizer = AutoTokenizer.from_pretrained(name_or_path) model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path) return model, tokenizer class QADiscourse_Pipeline(Text2TextGenerationPipeline): def __init__(self, model_repo: str, **kwargs): model, tokenizer = load_trained_model(model_repo) super().__init__(model, tokenizer, framework="pt") self.special_tokens = get_markers_for_model() def preprocess(self, inputs): if isinstance(inputs, str): processed_inputs = self._preprocess_string(inputs) elif hasattr(inputs, "__iter__"): processed_inputs = [self._preprocess_string(s) for s in inputs] else: raise ValueError("inputs must be str or Iterable[str]") # Now pass to super.preprocess for tokenization return super().preprocess(processed_inputs) def _preprocess_string(self, seq: str) -> str: seq = self.special_tokens.source_prefix + seq print(seq) return seq def _forward(self, *args, **kwargs): outputs = super()._forward(*args, **kwargs) return outputs def postprocess(self, model_outputs): predictions = self.tokenizer.decode(model_outputs["output_ids"].squeeze(), skip_special_tokens=True, clean_up_tokenization_spaces=False) seperated_qas = self._split_to_list(predictions) qas = [] for qa_pair in seperated_qas: post_process = self._postrocess_qa(qa_pair) # if the prediction isn't a valid QA if post_process is not None: qas.append(post_process) return qas def _split_to_list(self, output_seq: str) -> list: return output_seq.split(self.special_tokens.separator_different_qa) def _postrocess_qa(self, seq: str) -> str: # split question and answers if self.special_tokens.separator_output_question_answer in seq: question, answer = seq.split(self.special_tokens.separator_output_question_answer) else: return None return {"question": question, "answer": answer} if __name__ == "__main__": pipe = QADiscourse_Pipeline("RonEliav/QA_discourse_v2") res1 = pipe("I don't like chocolate, but I like cookies.") res2 = pipe(["I don't like chocolate, but I like cookies.", "I dived in the sea easily"], num_beams=10) print(res1) print(res2)