hfcustomarch / MyPipe.py
not-lain's picture
Upload MnistPipe
4f694c9 verified
from transformers import Pipeline
import requests
from PIL import Image
import torchvision.transforms as transforms
import torch
class MnistPipe(Pipeline):
def __init__(self,**kwargs):
# self.tokenizer = (...) # code if you want to instantiate more parameters
Pipeline.__init__(self,**kwargs) # self.model automatically instantiated here
self.transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize((28,28), antialias=True)
])
def _sanitize_parameters(self, **kwargs):
# will make sure where each parameter goes
preprocess_kwargs = {}
postprocess_kwargs = {}
if "download" in kwargs:
preprocess_kwargs["download"] = kwargs["download"]
if "clean_output" in kwargs :
postprocess_kwargs["clean_output"] = kwargs["clean_output"]
return preprocess_kwargs, {}, postprocess_kwargs
def preprocess(self, inputs, download=False):
if download == True :
# call download_img method and name image as "image.png"
self.download_img(inputs)
inputs = "image.png"
# we open and process the image
img = Image.open(inputs)
gray = img.convert('L')
tensor = self.transform(gray)
tensor = tensor.unsqueeze(0)
return tensor
def _forward(self, tensor):
with torch.no_grad():
# the model has been automatically instantiated
# in the __init__ method
out = self.model(tensor)
return out
def postprocess(self, out, clean_output=True):
if clean_output ==True :
label = torch.argmax(out,axis=-1) # get class
label = label.tolist()[0]
return label
else :
return out
def download_img(self,url):
# if download = True download image and name it image.png
response = requests.get(url, stream=True)
with open("image.png", "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("image saved as image.png")