|
from typing import Dict, List, Any |
|
from PIL import Image |
|
from io import BytesIO |
|
from optimum.nvidia.pipelines import pipeline |
|
import base64 |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.pipeline = pipeline("zero-shot-image-classification", model=path, device=0) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
inputs = data.pop("inputs", {}) |
|
parameters = data.pop("parameters", {}) |
|
candidate_labels = parameters.pop("candidate_labels", []) |
|
print("data: ", data) |
|
|
|
|
|
image = Image.open(BytesIO(base64.b64decode(inputs))) |
|
|
|
|
|
prediction = self.pipeline(images=[image], candidate_labels=candidate_labels) |
|
return prediction[0] |