generative-qa-model / new_task.py
blapuma's picture
Upload MyTestPipeline
18bc817 verified
raw
history blame
2.15 kB
from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM, pipeline, TextGenerationPipeline
import torch
import tensorflow as tf
import numpy as np
import math
class MyTestPipeline(TextGenerationPipeline):
def preprocess(self, text, **kwargs):
prompt = 'Answer the following question/statement in English without any explanation, do not abbreviate names.'
txt = f"<|user|>\n{prompt} {text}\n<|end|>\n<|assistant|>"
return self.tokenizer(txt, return_tensors=self.framework)
def _forward(self, model_inputs, **generate_kwargs):
if self.framework == "pt":
in_b, input_length = model_inputs["input_ids"].shape
elif self.framework == "tf":
in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy()
outputs = self.model.generate(**model_inputs, **generate_kwargs, return_dict_in_generate=True, output_scores=True)
output_ids = outputs.sequences
out_b = output_ids.shape[0]
if self.framework == "pt":
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
elif self.framework == "tf":
output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:]))
output_sequences = outputs.sequences
output_scores = outputs.scores
return {"input_ids": output_ids.flatten().flatten(), "generated_sequence": [output_sequences], "output_scores": output_scores, 'prompt_text' : ''}
def postprocess(self, model_outputs, **kwargs):
guess_text = super().postprocess(model_outputs)[0]['generated_text'].split('\n')[-1].strip()
transition_scores = self.model.compute_transition_scores(model_outputs['generated_sequence'][0], model_outputs['output_scores'], normalize_logits=True)
log_probs = np.round(np.exp(transition_scores.cpu().numpy()), 3)[0]
guess_prob = np.product(log_probs)
confidence = (math.exp(12*(guess_prob - 0.5))) / (1 + math.exp(12 * (guess_prob - 0.5)))
return {'guess': guess_text, 'confidence': confidence}