import torch from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer from transformers.tools import PipelineTool class TextPairClassificationTool(PipelineTool): default_checkpoint = "sgugger/bert-finetuned-mrpc" pre_processor_class = AutoTokenizer model_class = AutoModelForSequenceClassification description = ( "classifies if two texts in English are similar or not using the labels {labels}. It takes two inputs named " "`text` and `second_text` which should be in English and returns a dictionary with two keys named 'label' " "(the predicted label ) and 'score' (the probability associated to it)." ) def post_init(self): if isinstance(self.model, str): config = AutoConfig.from_pretrained(self.model) else: config = self.model.config labels = list(config.label2id.keys()) if len(labels) > 1: labels = [f"'{label}'" for label in labels] labels_string = ", ".join(labels[:-1]) labels_string += f", and {labels[-1]}" else: raise ValueError("Not enough labels.") self.description = self.description.replace("{labels}", labels_string) def encode(self, text, second_text): return self.pre_processor(text, second_text, return_tensors="pt") def decode(self, outputs): logits = outputs.logits scores = torch.nn.functional.softmax(logits, dim=-1) label_id = torch.argmax(logits[0]).item() label = self.model.config.id2label[label_id] return {"label": label, "score": scores[0][label_id].item()}