from typing import Dict, List, Any from torchvision.models import resnet18, ResNet18_Weights from torchvision.io import read_image from PIL import Image import io import requests import torchvision.transforms.functional as transform class EndpointHandler(): def __init__(self, path=""): weights = ResNet18_Weights.DEFAULT self.pipeline = resnet18(weights=weights) self.preprocess = weights.transforms() self.pipeline.eval() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs inputs = data.pop("inputs",data) if inputs.startswith("http") or inputs.startswith("www"): response = requests.get(inputs).content img = transform.to_tensor(Image.open(io.BytesIO(response))) else: img = read_image(inputs) batch = self.preprocess(img).unsqueeze(0) prediction = self.pipeline(batch).squeeze(0).softmax(0) return prediction.tolist()