from typing import Dict, List, Any from torchvision.models import resnet18, ResNet18_Weights from torchvision.io import read_image from PIL import Image import io import requests import torchvision.transforms.functional as transform from torch2trt import torch2trt from torchvision.models.alexnet import alexnet import torch # create some regular pytorch model... model = alexnet(pretrained=True).eval().cuda() # create example data x = torch.ones((1, 3, 224, 224)).cuda() # convert to TensorRT feeding sample data as input model_trt = torch2trt(model, [x]) class EndpointHandler(): def __init__(self, path=""): weights = ResNet18_Weights.DEFAULT # create some regular pytorch model... model = resnet18(weights=weights).eval().cuda() # create example data x = torch.ones((1, 3, 224, 224)).cuda() # convert to TensorRT feeding sample data as input self.pipeline = torch2trt(model, [x]) self.preprocess = weights.transforms() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs inputs = data.pop("inputs",data) if inputs.startswith("http") or inputs.startswith("www"): response = requests.get(inputs).content img = transform.to_tensor(Image.open(io.BytesIO(response))) else: img = read_image(inputs) batch = self.preprocess(img).unsqueeze(0) prediction = self.pipeline(batch).squeeze(0).softmax(0) return prediction.tolist()