Valeriy Sinyukov
Remove model wrappers, use dict and model input
82ec9f7
raw
history blame contribute delete
1.89 kB
import typing as tp
from collections import namedtuple
import torch
from transformers import Pipeline, AutoModelForSequenceClassification
from transformers.pipelines import PIPELINE_REGISTRY
class PapersClassificationPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def preprocess(self, inputs):
if (
not isinstance(inputs, tp.Iterable)
or isinstance(inputs, tp.Dict)
or isinstance(inputs, str)
):
inputs = [inputs]
title = "title"
authors = "authors"
abstract = "abstract"
texts = [
(
f"AUTHORS: {' '.join(paper[title]) if isinstance(paper[authors], list) else paper[authors]} "
f"TITLE: {paper[title]} ABSTRACT: {paper[abstract]}"
if not isinstance(paper, str)
else paper
)
for paper in inputs
]
inputs = self.tokenizer(
texts, truncation=True, padding=True, max_length=256, return_tensors="pt"
).to(self.device)
return inputs
def _forward(self, model_inputs):
with torch.no_grad():
outputs = self.model(**model_inputs)
return outputs
def postprocess(self, model_outputs):
probs = torch.nn.functional.softmax(model_outputs.logits, dim=-1)
results = []
for prob in probs:
result = [
{"label": self.model.config.id2label[label_idx], "score": score.item()}
for label_idx, score in enumerate(prob)
]
results.append(result)
if 1 == len(results):
return results[0]
return results
PIPELINE_REGISTRY.register_pipeline(
"paper-classification",
pipeline_class=PapersClassificationPipeline,
pt_model=AutoModelForSequenceClassification,
)