dgergherherherhererher commited on
Commit
3184f38
1 Parent(s): ae9dfaf

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +25 -0
pipeline.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any, Dict, List
3
+
4
+ import sklearn
5
+ import os
6
+ import joblib
7
+ import numpy as np
8
+ import whatlies
9
+
10
+
11
+
12
+ class PreTrainedPipeline():
13
+ def __init__(self, path: str):
14
+ # load the model
15
+ self.model = joblib.load(os.path.join(path, "model.pkl"))
16
+
17
+ def __call__(self, inputs):
18
+ predictions = self.model.predict_proba([inputs])
19
+ labels = []
20
+ for cls in predictions[0]:
21
+ labels.append({
22
+ "label": f"LABEL_{cls}",
23
+ "score": predictions[0][cls],
24
+ })
25
+ return labels