testrepo / MyPipe.py
not-lain's picture
Upload MnistPipe
cd530e9 verified
raw
history blame contribute delete
No virus
2.17 kB
from transformers import Pipeline
import requests
from PIL import Image
import torchvision.transforms as transforms
import torch
class MnistPipe(Pipeline):
def __init__(self,**kwargs):
# self.tokenizer = (...) # code if you want to instantiate more parameters
Pipeline.__init__(self,**kwargs) # self.model automatically instantiated here
self.transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize((28,28), antialias=True)
])
def _sanitize_parameters(self, **kwargs):
# will make sure where each parameter goes
preprocess_kwargs = {}
postprocess_kwargs = {}
if "download" in kwargs:
preprocess_kwargs["download"] = kwargs["download"]
if "clean_output" in kwargs :
postprocess_kwargs["clean_output"] = kwargs["clean_output"]
return preprocess_kwargs, {}, postprocess_kwargs
def preprocess(self, inputs, download=False):
if download == True :
# call download_img method and name image as "image.png"
self.download_img(inputs)
inputs = "image.png"
# we open and process the image
img = Image.open(inputs)
gray = img.convert('L')
tensor = self.transform(gray)
tensor = tensor.unsqueeze(0)
return tensor
def _forward(self, tensor):
with torch.no_grad():
# the model has been automatically instantiated
# in the __init__ method
out = self.model(tensor)
return out
def postprocess(self, out, clean_output=True):
if clean_output ==True :
label = torch.argmax(out,axis=-1) # get class
label = label.tolist()[0]
return label
else :
return out
def download_img(self,url):
# if download = True download image and name it image.png
response = requests.get(url, stream=True)
with open("image.png", "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("image saved as image.png")