Simon Tang
commit files to HF hub
fb4253c
from huggingface_hub import Repository
from typing import List, Union
from transformers import pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import AutoModelForSequenceClassification, Pipeline
import torch
# from loguru import logger
class MyPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "hypothesis" in kwargs:
preprocess_kwargs["hypothesis"] = kwargs["hypothesis"]
return preprocess_kwargs, {}, {}
def __call__(
self,
sequences: Union[str, List[str]],
*args,
**kwargs,
):
if len(args) == 0:
pass
elif len(args) == 1 and "hypothesis" not in kwargs:
kwargs["hypothesis"] = args[0]
else:
raise ValueError(f"Unable to understand extra arguments {args}")
return super().__call__(sequences, **kwargs)
def preprocess(self, premise, hypothesis=None):
encode_inputs = self.tokenizer(
premise,
hypothesis,
# max_length=self.toke,
# return_token_type_ids=True,
truncation=True,
return_tensors="pt"
)
return {"input_ids": encode_inputs['input_ids']}
def _forward(self, input_ids):
outputs = self.model(input_ids['input_ids'])
return outputs
def postprocess(self, model_outputs):
prediction = torch.softmax(model_outputs["logits"][0], -1).tolist()
print(prediction)
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 1)
for pred, name in zip(prediction, label_names)}
return prediction
# PIPELINE_REGISTRY.register_pipeline(
# "test",
# pipeline_class=MyPipeline,
# pt_model=AutoModelForSequenceClassification,
# # default={"pt": ("MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli", "retina")},
# # type="text",
# )
# classifier = pipeline("test",
# model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
# # tokenizer="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
# )
# output = classifier(
# "Angela Merkel is a politician in Germany and leader of the CDU",
# hypothesis="this is a test"
# )
# # logger.info(output)
# # repo = Repository("entailment-classifier",
# # clone_from="Tverous/entailment-classifier")
# classifier.save_pretrained("entailment-classifier")
# # repo.push_to_hub()
# logger.info("Finished")