sklearn-transformers / pipeline.py
merve's picture
merve HF staff
Update pipeline.py
8c353a4
raw history blame
No virus
587 Bytes
import json
from typing import Any, Dict, List
import sklearn
import os
import joblib
import numpy as np
import whatlies
class PreTrainedPipeline():
def __init__(self, path: str):
# load the model
self.model = joblib.load(os.path.join(path, "pipeline.pkl"))
def __call__(self, inputs: str):
predictions = self.model.predict_proba([inputs])
labels = []
for cls in predictions[0]:
labels.append({
"label": f"LABEL_{cls}",
"score": predictions[0][cls],
})
return labels