pair-classification-tool / pair_classification_tool.py
sgugger's picture
Upload model and tool
04b6c3e unverified
import torch
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
from transformers.tools import PipelineTool
class TextPairClassificationTool(PipelineTool):
default_checkpoint = "sgugger/bert-finetuned-mrpc"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSequenceClassification
description = (
"classifies if two texts in English are similar or not using the labels {labels}. It takes two inputs named "
"`text` and `second_text` which should be in English and returns a dictionary with two keys named 'label' "
"(the predicted label ) and 'score' (the probability associated to it)."
)
def post_init(self):
if isinstance(self.model, str):
config = AutoConfig.from_pretrained(self.model)
else:
config = self.model.config
labels = list(config.label2id.keys())
if len(labels) > 1:
labels = [f"'{label}'" for label in labels]
labels_string = ", ".join(labels[:-1])
labels_string += f", and {labels[-1]}"
else:
raise ValueError("Not enough labels.")
self.description = self.description.replace("{labels}", labels_string)
def encode(self, text, second_text):
return self.pre_processor(text, second_text, return_tensors="pt")
def decode(self, outputs):
logits = outputs.logits
scores = torch.nn.functional.softmax(logits, dim=-1)
label_id = torch.argmax(logits[0]).item()
label = self.model.config.id2label[label_id]
return {"label": label, "score": scores[0][label_id].item()}