|
import typing as tp |
|
from collections import namedtuple |
|
|
|
import torch |
|
|
|
from transformers import Pipeline, AutoModelForSequenceClassification |
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
|
|
|
|
class PapersClassificationPipeline(Pipeline): |
|
def _sanitize_parameters(self, **kwargs): |
|
return {}, {}, {} |
|
|
|
def preprocess(self, inputs): |
|
if ( |
|
not isinstance(inputs, tp.Iterable) |
|
or isinstance(inputs, tp.Dict) |
|
or isinstance(inputs, str) |
|
): |
|
inputs = [inputs] |
|
title = "title" |
|
authors = "authors" |
|
abstract = "abstract" |
|
texts = [ |
|
( |
|
f"AUTHORS: {' '.join(paper[title]) if isinstance(paper[authors], list) else paper[authors]} " |
|
f"TITLE: {paper[title]} ABSTRACT: {paper[abstract]}" |
|
if not isinstance(paper, str) |
|
else paper |
|
) |
|
for paper in inputs |
|
] |
|
inputs = self.tokenizer( |
|
texts, truncation=True, padding=True, max_length=256, return_tensors="pt" |
|
).to(self.device) |
|
return inputs |
|
|
|
def _forward(self, model_inputs): |
|
with torch.no_grad(): |
|
outputs = self.model(**model_inputs) |
|
return outputs |
|
|
|
def postprocess(self, model_outputs): |
|
probs = torch.nn.functional.softmax(model_outputs.logits, dim=-1) |
|
results = [] |
|
for prob in probs: |
|
result = [ |
|
{"label": self.model.config.id2label[label_idx], "score": score.item()} |
|
for label_idx, score in enumerate(prob) |
|
] |
|
results.append(result) |
|
if 1 == len(results): |
|
return results[0] |
|
return results |
|
|
|
|
|
PIPELINE_REGISTRY.register_pipeline( |
|
"paper-classification", |
|
pipeline_class=PapersClassificationPipeline, |
|
pt_model=AutoModelForSequenceClassification, |
|
) |
|
|