Spaces:
Build error
Build error
| import torch | |
| from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer | |
| from transformers.tools import PipelineTool | |
| class TextPairClassificationTool(PipelineTool): | |
| default_checkpoint = "sgugger/bert-finetuned-mrpc" | |
| pre_processor_class = AutoTokenizer | |
| model_class = AutoModelForSequenceClassification | |
| inputs = ["text", "text"] | |
| outputs = ["text"] | |
| description = ( | |
| "This is a tool that classifies if two texts in English are similar or not using the labels 'equivalent' and " | |
| "'not_equivalent'. It takes two inputs named `text` and `second_text` which should be in English and returns " | |
| "the predicted label." | |
| ) | |
| def encode(self, text, second_text): | |
| return self.pre_processor(text, second_text, return_tensors="pt") | |
| def decode(self, outputs): | |
| logits = outputs.logits | |
| label_id = torch.argmax(logits[0]).item() | |
| return self.model.config.id2label[label_id] | |