osanseviero's picture
osanseviero HF staff
Update pipeline.py
6293397
raw history blame
No virus
1.58 kB
from typing import Dict, List, Any
from PIL import Image
import os
import json
import numpy as np
from fastai.learner import load_learner
from helpers import is_cat
class PreTrainedPipeline():
def __init__(self, path=""):
# IMPLEMENT_THIS
# Preload all the elements you are going to need at inference.
# For instance your model, processors, tokenizer that might be needed.
# This function is only called once, so do all the heavy processing I/O here"""
self.model = load_learner(os.path.join(path, "model.pkl"))
with open(os.path.join(path, "config.json")) as config:
config = json.load(config)
self.id2label = config["id2label"]
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
"""
Args:
inputs (:obj:`PIL.Image`):
The raw image representation as PIL.
No transformation made whatsoever from the input. Make all necessary transformations here.
Return:
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
It is preferred if the returned list is in decreasing `score` order
"""
# IMPLEMENT_THIS
# FastAI expects a np array, not a PIL Image.
_, _, preds = self.model.predict(np.array(inputs))
preds = preds.tolist()
labels = [
{"label": str(self.id2label["0"]), "score": preds[0]},
{"label": str(self.id2label["1"]), "score": preds[1]},
]
return labels