File size: 897 Bytes
17ad6cf 297b879 17ad6cf 297b879 17ad6cf 297b879 17ad6cf 297b879 17ad6cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from transformers import Pipeline
import torch
class TBCP(Pipeline):
def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {}
if "text_pair" in kwargs:
postprocess_kwargs["top_k"] = kwargs["top_k"]
return {}, {}, postprocess_kwargs
def preprocess(self, text):
return self.tokenizer(text, return_tensors="pt")
def _forward(self, model_inputs):
return self.model(**model_inputs)
def postprocess(self, model_outputs,top_k = None):
logits = model_outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
best_class = probabilities.argmax().item()
label = self.model.config.id2label[best_class]
score = probabilities.squeeze()[best_class].item()
logits = logits.squeeze().tolist()
return {"label": label, "score": score, "logits": logits} |