|
|
|
from transformers import Pipeline |
|
import requests |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
import torch |
|
|
|
class MnistPipe(Pipeline): |
|
def __init__(self,**kwargs): |
|
|
|
|
|
|
|
Pipeline.__init__(self,**kwargs) |
|
|
|
self.transform = transforms.Compose( |
|
[transforms.ToTensor(), |
|
transforms.Resize((28,28), antialias=True) |
|
]) |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
|
|
preprocess_kwargs = {} |
|
postprocess_kwargs = {} |
|
if "download" in kwargs: |
|
preprocess_kwargs["download"] = kwargs["download"] |
|
if "clean_output" in kwargs : |
|
postprocess_kwargs["clean_output"] = kwargs["clean_output"] |
|
return preprocess_kwargs, {}, postprocess_kwargs |
|
|
|
def preprocess(self, inputs, download=False): |
|
if download == True : |
|
|
|
self.download_img(inputs) |
|
inputs = "image.png" |
|
|
|
|
|
img = Image.open(inputs) |
|
gray = img.convert('L') |
|
tensor = self.transform(gray) |
|
tensor = tensor.unsqueeze(0) |
|
return tensor |
|
|
|
def _forward(self, tensor): |
|
with torch.no_grad(): |
|
|
|
|
|
out = self.model(tensor) |
|
return out |
|
|
|
def postprocess(self, out, clean_output=True): |
|
if clean_output ==True : |
|
label = torch.argmax(out,axis=-1) |
|
label = label.tolist()[0] |
|
return label |
|
else : |
|
return out |
|
|
|
def download_img(self,url): |
|
|
|
response = requests.get(url, stream=True) |
|
|
|
with open("image.png", "wb") as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
print("image saved as image.png") |
|
|