from transformers import TextGenerationPipeline from transformers.pipelines.text_generation import ReturnType STYLE = "<|prompt|>{instruction}<|endoftext|><|answer|>" class H2OTextGenerationPipeline(TextGenerationPipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.prompt = STYLE def preprocess( self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs ): prompt_text = self.prompt.format(instruction=prompt_text) return super().preprocess( prompt_text, prefix=prefix, handle_long_generation=handle_long_generation, **generate_kwargs, ) def postprocess( self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True, ): records = super().postprocess( model_outputs, return_type=return_type, clean_up_tokenization_spaces=clean_up_tokenization_spaces, ) for rec in records: rec["generated_text"] = ( rec["generated_text"] .split("<|answer|>")[1] .strip() .split("<|prompt|>")[0] .strip() ) return records