YassineB commited on
Commit
9f2c502
1 Parent(s): 74475f6

Test cv model

Browse files
Files changed (2) hide show
  1. handler.py +33 -0
  2. requirements.txt +3 -0
handler.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from torchvision.models import resnet18, ResNet18_Weights
3
+ from torchvision.io import read_image
4
+ from PIL import Image
5
+ import io
6
+ import requests
7
+ import torchvision.transforms.functional as transform
8
+
9
+ class EndpointHandler():
10
+ def __init__(self, path=""):
11
+ weights = ResNet18_Weights.DEFAULT
12
+ self.pipeline = resnet18(weights=weights)
13
+ self.preprocess = weights.transforms()
14
+ self.pipeline.eval()
15
+
16
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
17
+ """
18
+ data args:
19
+ inputs (:obj: `str`)
20
+ Return:
21
+ A :obj:`list` | `dict`: will be serialized and returned
22
+ """
23
+ # get inputs
24
+ inputs = data.pop("inputs",data)
25
+ if inputs.startswith("http") or inputs.startswith("www"):
26
+ response = requests.get(inputs).content
27
+ img = transform.to_tensor(Image.open(io.BytesIO(response)))
28
+ else:
29
+ img = read_image(inputs)
30
+
31
+ batch = self.preprocess(img).unsqueeze(0)
32
+ prediction = self.pipeline(batch).squeeze(0).softmax(0)
33
+ return prediction
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ holidays
2
+ torch
3
+ torchvision