from huggingface_hub import Repository from typing import List, Union from transformers import pipeline from transformers.pipelines import PIPELINE_REGISTRY from transformers import AutoModelForSequenceClassification, Pipeline import torch # from loguru import logger class MyPipeline(Pipeline): def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} if "hypothesis" in kwargs: preprocess_kwargs["hypothesis"] = kwargs["hypothesis"] return preprocess_kwargs, {}, {} def __call__( self, sequences: Union[str, List[str]], *args, **kwargs, ): if len(args) == 0: pass elif len(args) == 1 and "hypothesis" not in kwargs: kwargs["hypothesis"] = args[0] else: raise ValueError(f"Unable to understand extra arguments {args}") return super().__call__(sequences, **kwargs) def preprocess(self, premise, hypothesis=None): encode_inputs = self.tokenizer( premise, hypothesis, # max_length=self.toke, # return_token_type_ids=True, truncation=True, return_tensors="pt" ) return {"input_ids": encode_inputs['input_ids']} def _forward(self, input_ids): outputs = self.model(input_ids['input_ids']) return outputs def postprocess(self, model_outputs): prediction = torch.softmax(model_outputs["logits"][0], -1).tolist() print(prediction) label_names = ["entailment", "neutral", "contradiction"] prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)} return prediction # PIPELINE_REGISTRY.register_pipeline( # "test", # pipeline_class=MyPipeline, # pt_model=AutoModelForSequenceClassification, # # default={"pt": ("MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli", "retina")}, # # type="text", # ) # classifier = pipeline("test", # model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli", # # tokenizer="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli" # ) # output = classifier( # "Angela Merkel is a politician in Germany and leader of the CDU", # hypothesis="this is a test" # ) # # logger.info(output) # # repo = Repository("entailment-classifier", # # clone_from="Tverous/entailment-classifier") # classifier.save_pretrained("entailment-classifier") # # repo.push_to_hub() # logger.info("Finished")