from transformers import Pipeline import requests from PIL import Image import torchvision.transforms as transforms import torch class MnistPipe(Pipeline): def __init__(self,**kwargs): # self.tokenizer = (...) # code if you want to instantiate more parameters Pipeline.__init__(self,**kwargs) # self.model automatically instantiated here self.transform = transforms.Compose( [transforms.ToTensor(), transforms.Resize((28,28), antialias=True) ]) def _sanitize_parameters(self, **kwargs): # will make sure where each parameter goes preprocess_kwargs = {} postprocess_kwargs = {} if "download" in kwargs: preprocess_kwargs["download"] = kwargs["download"] if "clean_output" in kwargs : postprocess_kwargs["clean_output"] = kwargs["clean_output"] return preprocess_kwargs, {}, postprocess_kwargs def preprocess(self, inputs, download=False): if download == True : # call download_img method and name image as "image.png" self.download_img(inputs) inputs = "image.png" # we open and process the image img = Image.open(inputs) gray = img.convert('L') tensor = self.transform(gray) tensor = tensor.unsqueeze(0) return tensor def _forward(self, tensor): with torch.no_grad(): # the model has been automatically instantiated # in the __init__ method out = self.model(tensor) return out def postprocess(self, out, clean_output=True): if clean_output ==True : label = torch.argmax(out,axis=-1) # get class label = label.tolist()[0] return label else : return out def download_img(self,url): # if download = True download image and name it image.png response = requests.get(url, stream=True) with open("image.png", "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print("image saved as image.png")