|
import torch |
|
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer |
|
from transformers.tools import PipelineTool |
|
|
|
|
|
class TextPairClassificationTool(PipelineTool): |
|
default_checkpoint = "sgugger/bert-finetuned-mrpc" |
|
pre_processor_class = AutoTokenizer |
|
model_class = AutoModelForSequenceClassification |
|
|
|
description = ( |
|
"classifies if two texts in English are similar or not using the labels {labels}. It takes two inputs named " |
|
"`text` and `second_text` which should be in English and returns a dictionary with two keys named 'label' " |
|
"(the predicted label ) and 'score' (the probability associated to it)." |
|
) |
|
|
|
def post_init(self): |
|
if isinstance(self.model, str): |
|
config = AutoConfig.from_pretrained(self.model) |
|
else: |
|
config = self.model.config |
|
|
|
labels = list(config.label2id.keys()) |
|
|
|
if len(labels) > 1: |
|
labels = [f"'{label}'" for label in labels] |
|
labels_string = ", ".join(labels[:-1]) |
|
labels_string += f", and {labels[-1]}" |
|
else: |
|
raise ValueError("Not enough labels.") |
|
|
|
self.description = self.description.replace("{labels}", labels_string) |
|
|
|
def encode(self, text, second_text): |
|
return self.pre_processor(text, second_text, return_tensors="pt") |
|
|
|
def decode(self, outputs): |
|
logits = outputs.logits |
|
scores = torch.nn.functional.softmax(logits, dim=-1) |
|
label_id = torch.argmax(logits[0]).item() |
|
label = self.model.config.id2label[label_id] |
|
return {"label": label, "score": scores[0][label_id].item()} |
|
|