ML2_HW4_Kalinin / inference.py
Alexey Kalinin
update inference
5a69625
raw
history blame contribute delete
945 Bytes
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import pipeline, Pipeline
from joblib import load
def load_model(path2chkpt: str, path2mapping: str):
model = AutoModelForSequenceClassification.from_pretrained(path2chkpt)
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-cased")
pipe = pipeline("text-classification",
model=model,
tokenizer=tokenizer)
class2name = load(path2mapping)
return pipe, class2name
def top_95_labels(pipe: Pipeline, class2name: dict[str, str], title: str, abstract: str):
inputs = ".".join([title, abstract])
result = pipe(inputs, top_k=20)
proba = 0
labels = []
i = 0
while proba < 0.95:
proba += result[i]["score"]
labels.append(result[i]["label"])
i += 1
return [class2name[label] for label in labels]