File size: 2,173 Bytes
387a2e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

from transformers import Pipeline
import requests
from PIL import Image
import torchvision.transforms as transforms
import torch

class MnistPipe(Pipeline):
    def __init__(self,**kwargs):

      # self.tokenizer = (...) # code if you want to instantiate more parameters

      Pipeline.__init__(self,**kwargs) # self.model automatically instantiated here

      self.transform = transforms.Compose(
                              [transforms.ToTensor(),
                              transforms.Resize((28,28), antialias=True)
                              ])

    def _sanitize_parameters(self, **kwargs):
        # will make sure where each parameter goes
        preprocess_kwargs = {}
        postprocess_kwargs = {}
        if "download" in kwargs:
            preprocess_kwargs["download"] = kwargs["download"]
        if "clean_output" in kwargs :
          postprocess_kwargs["clean_output"] = kwargs["clean_output"]
        return preprocess_kwargs, {}, postprocess_kwargs

    def preprocess(self, inputs, download=False):
        if download == True :
          # call download_img method and name image as "image.png"
          self.download_img(inputs)
          inputs = "image.png"

        # we open and process the image
        img = Image.open(inputs)
        gray = img.convert('L')
        tensor = self.transform(gray)
        tensor = tensor.unsqueeze(0)
        return tensor

    def _forward(self, tensor):
        with torch.no_grad():
            # the model has been automatically instantiated
            # in the __init__ method
            out = self.model(tensor)
        return out

    def postprocess(self, out, clean_output=True):
        if clean_output ==True :
          label = torch.argmax(out,axis=-1) # get class
          label = label.tolist()[0]
          return label
        else :
          return out

    def download_img(self,url):
      # if download = True download image and name it image.png
      response = requests.get(url, stream=True)

      with open("image.png", "wb") as f:
          for chunk in response.iter_content(chunk_size=8192):
              f.write(chunk)
      print("image saved as image.png")