import typing as tp from collections import namedtuple import torch from transformers import Pipeline, AutoModelForSequenceClassification from transformers.pipelines import PIPELINE_REGISTRY class PapersClassificationPipeline(Pipeline): def _sanitize_parameters(self, **kwargs): return {}, {}, {} def preprocess(self, inputs): if ( not isinstance(inputs, tp.Iterable) or isinstance(inputs, tp.Dict) or isinstance(inputs, str) ): inputs = [inputs] title = "title" authors = "authors" abstract = "abstract" texts = [ ( f"AUTHORS: {' '.join(paper[title]) if isinstance(paper[authors], list) else paper[authors]} " f"TITLE: {paper[title]} ABSTRACT: {paper[abstract]}" if not isinstance(paper, str) else paper ) for paper in inputs ] inputs = self.tokenizer( texts, truncation=True, padding=True, max_length=256, return_tensors="pt" ).to(self.device) return inputs def _forward(self, model_inputs): with torch.no_grad(): outputs = self.model(**model_inputs) return outputs def postprocess(self, model_outputs): probs = torch.nn.functional.softmax(model_outputs.logits, dim=-1) results = [] for prob in probs: result = [ {"label": self.model.config.id2label[label_idx], "score": score.item()} for label_idx, score in enumerate(prob) ] results.append(result) if 1 == len(results): return results[0] return results PIPELINE_REGISTRY.register_pipeline( "paper-classification", pipeline_class=PapersClassificationPipeline, pt_model=AutoModelForSequenceClassification, )