dfsgs / pipeline.py
dgergherherherhererher's picture
Create pipeline.py
3184f38
raw
history blame
579 Bytes
import json
from typing import Any, Dict, List
import sklearn
import os
import joblib
import numpy as np
import whatlies
class PreTrainedPipeline():
def __init__(self, path: str):
# load the model
self.model = joblib.load(os.path.join(path, "model.pkl"))
def __call__(self, inputs):
predictions = self.model.predict_proba([inputs])
labels = []
for cls in predictions[0]:
labels.append({
"label": f"LABEL_{cls}",
"score": predictions[0][cls],
})
return labels