|
from huggingface_hub import Repository |
|
from typing import List, Union |
|
from transformers import pipeline |
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
from transformers import AutoModelForSequenceClassification, Pipeline |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
class MyPipeline(Pipeline): |
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
if "hypothesis" in kwargs: |
|
preprocess_kwargs["hypothesis"] = kwargs["hypothesis"] |
|
return preprocess_kwargs, {}, {} |
|
|
|
def __call__( |
|
self, |
|
sequences: Union[str, List[str]], |
|
*args, |
|
**kwargs, |
|
): |
|
if len(args) == 0: |
|
pass |
|
elif len(args) == 1 and "hypothesis" not in kwargs: |
|
kwargs["hypothesis"] = args[0] |
|
else: |
|
raise ValueError(f"Unable to understand extra arguments {args}") |
|
|
|
return super().__call__(sequences, **kwargs) |
|
|
|
|
|
def preprocess(self, premise, hypothesis=None): |
|
encode_inputs = self.tokenizer( |
|
premise, |
|
hypothesis, |
|
|
|
|
|
truncation=True, |
|
return_tensors="pt" |
|
) |
|
|
|
return {"input_ids": encode_inputs['input_ids']} |
|
|
|
def _forward(self, input_ids): |
|
outputs = self.model(input_ids['input_ids']) |
|
return outputs |
|
|
|
def postprocess(self, model_outputs): |
|
prediction = torch.softmax(model_outputs["logits"][0], -1).tolist() |
|
print(prediction) |
|
label_names = ["entailment", "neutral", "contradiction"] |
|
prediction = {name: round(float(pred) * 100, 1) |
|
for pred, name in zip(prediction, label_names)} |
|
return prediction |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|