not-lain commited on
Commit
dff2014
1 Parent(s): 8fd08bb

Upload MnistPipe

Browse files
Files changed (2) hide show
  1. MyPipe.py +65 -0
  2. config.json +13 -2
MyPipe.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import Pipeline
3
+ import requests
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ import torch
7
+
8
+ class MnistPipe(Pipeline):
9
+ def __init__(self,**kwargs):
10
+
11
+ # self.tokenizer = (...) # code if you want to instantiate more parameters
12
+
13
+ Pipeline.__init__(self,**kwargs) # self.model automatically instantiated here
14
+
15
+ self.transform = transforms.Compose(
16
+ [transforms.ToTensor(),
17
+ transforms.Resize((28,28), antialias=True)
18
+ ])
19
+
20
+ def _sanitize_parameters(self, **kwargs):
21
+ # will make sure where each parameter goes
22
+ preprocess_kwargs = {}
23
+ postprocess_kwargs = {}
24
+ if "download" in kwargs:
25
+ preprocess_kwargs["download"] = kwargs["download"]
26
+ if "clean_output" in kwargs :
27
+ postprocess_kwargs["clean_output"] = kwargs["clean_output"]
28
+ return preprocess_kwargs, {}, postprocess_kwargs
29
+
30
+ def preprocess(self, inputs, download=False):
31
+ if download == True :
32
+ # call download_img method and name image as "image.png"
33
+ self.download_img(inputs)
34
+ inputs = "image.png"
35
+
36
+ # we open and process the image
37
+ img = Image.open(inputs)
38
+ gray = img.convert('L')
39
+ tensor = self.transform(gray)
40
+ tensor = tensor.unsqueeze(0)
41
+ return tensor
42
+
43
+ def _forward(self, tensor):
44
+ with torch.no_grad():
45
+ # the model has been automatically instantiated
46
+ # in the __init__ method
47
+ out = self.model(tensor)
48
+ return out
49
+
50
+ def postprocess(self, out, clean_output=True):
51
+ if clean_output ==True :
52
+ label = torch.argmax(out,axis=-1) # get class
53
+ label = label.tolist()[0]
54
+ return label
55
+ else :
56
+ return out
57
+
58
+ def download_img(self,url):
59
+ # if download = True download image and name it image.png
60
+ response = requests.get(url, stream=True)
61
+
62
+ with open("image.png", "wb") as f:
63
+ for chunk in response.iter_content(chunk_size=8192):
64
+ f.write(chunk)
65
+ print("image saved as image.png")
config.json CHANGED
@@ -1,13 +1,24 @@
1
  {
 
2
  "architectures": [
3
  "MnistModel"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "MyConfig.MnistConfig",
7
- "AutoModelForImageClassification": "MyModel.MnistModel"
8
  },
9
  "conv1": 10,
10
  "conv2": 20,
 
 
 
 
 
 
 
 
 
 
11
  "model_type": "MobileNetV1",
12
  "torch_dtype": "float32",
13
  "transformers_version": "4.39.0.dev0"
 
1
  {
2
+ "_name_or_path": "not-lain/mycustomrepo",
3
  "architectures": [
4
  "MnistModel"
5
  ],
6
  "auto_map": {
7
+ "AutoConfig": "not-lain/mycustomrepo--MyConfig.MnistConfig",
8
+ "AutoModelForImageClassification": "not-lain/mycustomrepo--MyModel.MnistModel"
9
  },
10
  "conv1": 10,
11
  "conv2": 20,
12
+ "custom_pipelines": {
13
+ "image-classification": {
14
+ "impl": "MyPipe.MnistPipe",
15
+ "pt": [
16
+ "AutoModelForImageClassification"
17
+ ],
18
+ "tf": [],
19
+ "type": "image"
20
+ }
21
+ },
22
  "model_type": "MobileNetV1",
23
  "torch_dtype": "float32",
24
  "transformers_version": "4.39.0.dev0"