akhaliq HF staff commited on
Commit
07a12d1
·
1 Parent(s): 8377aee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+ import numpy as np
4
+ import mxnet as mx
5
+ import os
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+
10
+ from scipy import misc
11
+ import random
12
+ import sklearn
13
+ from sklearn.decomposition import PCA
14
+ from time import sleep
15
+ from easydict import EasyDict as edict
16
+ from mtcnn_detector import MtcnnDetector
17
+ from skimage import transform as trans
18
+ import matplotlib.pyplot as plt
19
+ from mxnet.contrib.onnx.onnx2mx.import_model import import_model
20
+
21
+
22
+ def get_model(ctx, model):
23
+ image_size = (112,112)
24
+ # Import ONNX model
25
+ sym, arg_params, aux_params = import_model(model)
26
+ # Define and binds parameters to the network
27
+ model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
28
+ model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
29
+ model.set_params(arg_params, aux_params)
30
+ return model
31
+
32
+ for i in range(4):
33
+ mx.test_utils.download(dirname='mtcnn-model', url='https://s3.amazonaws.com/onnx-model-zoo/arcface/mtcnn-model/det{}-0001.params'.format(i+1))
34
+ mx.test_utils.download(dirname='mtcnn-model', url='https://s3.amazonaws.com/onnx-model-zoo/arcface/mtcnn-model/det{}-symbol.json'.format(i+1))
35
+ mx.test_utils.download(dirname='mtcnn-model', url='https://s3.amazonaws.com/onnx-model-zoo/arcface/mtcnn-model/det{}.caffemodel'.format(i+1))
36
+ mx.test_utils.download(dirname='mtcnn-model', url='https://s3.amazonaws.com/onnx-model-zoo/arcface/mtcnn-model/det{}.prototxt'.format(i+1))
37
+
38
+ # Determine and set context
39
+ if len(mx.test_utils.list_gpus())==0:
40
+ ctx = mx.cpu()
41
+ else:
42
+ ctx = mx.gpu(0)
43
+ # Configure face detector
44
+ det_threshold = [0.6,0.7,0.8]
45
+ mtcnn_path = os.path.join(os.path.dirname('__file__'), 'mtcnn-model')
46
+ detector = MtcnnDetector(model_folder=mtcnn_path, ctx=ctx, num_worker=1, accurate_landmark = True, threshold=det_threshold)
47
+
48
+ def preprocess(img, bbox=None, landmark=None, **kwargs):
49
+ M = None
50
+ image_size = []
51
+ str_image_size = kwargs.get('image_size', '')
52
+ # Assert input shape
53
+ if len(str_image_size)>0:
54
+ image_size = [int(x) for x in str_image_size.split(',')]
55
+ if len(image_size)==1:
56
+ image_size = [image_size[0], image_size[0]]
57
+ assert len(image_size)==2
58
+ assert image_size[0]==112
59
+ assert image_size[0]==112 or image_size[1]==96
60
+
61
+ # Do alignment using landmark points
62
+ if landmark is not None:
63
+ assert len(image_size)==2
64
+ src = np.array([
65
+ [30.2946, 51.6963],
66
+ [65.5318, 51.5014],
67
+ [48.0252, 71.7366],
68
+ [33.5493, 92.3655],
69
+ [62.7299, 92.2041] ], dtype=np.float32 )
70
+ if image_size[1]==112:
71
+ src[:,0] += 8.0
72
+ dst = landmark.astype(np.float32)
73
+ tform = trans.SimilarityTransform()
74
+ tform.estimate(dst, src)
75
+ M = tform.params[0:2,:]
76
+ assert len(image_size)==2
77
+ warped = cv2.warpAffine(img,M,(image_size[1],image_size[0]), borderValue = 0.0)
78
+ return warped
79
+
80
+ # If no landmark points available, do alignment using bounding box. If no bounding box available use center crop
81
+ if M is None:
82
+ if bbox is None:
83
+ det = np.zeros(4, dtype=np.int32)
84
+ det[0] = int(img.shape[1]*0.0625)
85
+ det[1] = int(img.shape[0]*0.0625)
86
+ det[2] = img.shape[1] - det[0]
87
+ det[3] = img.shape[0] - det[1]
88
+ else:
89
+ det = bbox
90
+ margin = kwargs.get('margin', 44)
91
+ bb = np.zeros(4, dtype=np.int32)
92
+ bb[0] = np.maximum(det[0]-margin/2, 0)
93
+ bb[1] = np.maximum(det[1]-margin/2, 0)
94
+ bb[2] = np.minimum(det[2]+margin/2, img.shape[1])
95
+ bb[3] = np.minimum(det[3]+margin/2, img.shape[0])
96
+ ret = img[bb[1]:bb[3],bb[0]:bb[2],:]
97
+ if len(image_size)>0:
98
+ ret = cv2.resize(ret, (image_size[1], image_size[0]))
99
+ return ret
100
+
101
+ def get_input(detector,face_img):
102
+ # Pass input images through face detector
103
+ ret = detector.detect_face(face_img, det_type = 0)
104
+ if ret is None:
105
+ return None
106
+ bbox, points = ret
107
+ if bbox.shape[0]==0:
108
+ return None
109
+ bbox = bbox[0,0:4]
110
+ points = points[0,:].reshape((2,5)).T
111
+ # Call preprocess() to generate aligned images
112
+ nimg = preprocess(face_img, bbox, points, image_size='112,112')
113
+ nimg = cv2.cvtColor(nimg, cv2.COLOR_BGR2RGB)
114
+ aligned = np.transpose(nimg, (2,0,1))
115
+ return aligned
116
+
117
+ def get_feature(model,aligned):
118
+ input_blob = np.expand_dims(aligned, axis=0)
119
+ data = mx.nd.array(input_blob)
120
+ db = mx.io.DataBatch(data=(data,))
121
+ model.forward(db, is_train=False)
122
+ embedding = model.get_outputs()[0].asnumpy()
123
+ embedding = sklearn.preprocessing.normalize(embedding).flatten()
124
+ return embedding
125
+
126
+ # Download first image
127
+ mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/arcface/player1.jpg')
128
+ # Download second image
129
+ mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/arcface/player2.jpg')
130
+ # Download onnx model
131
+ mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/arcface/resnet100.onnx')
132
+ # Path to ONNX model
133
+ model_name = 'resnet100.onnx'
134
+
135
+ # Load ONNX model
136
+ model = get_model(ctx , model_name)
137
+
138
+ def inference(img1,img2):
139
+ # Load first image
140
+ img1 = cv2.imread(img1)
141
+
142
+ # Preprocess first image
143
+ pre1 = get_input(detector,img1)
144
+
145
+ # Get embedding of first image
146
+ out1 = get_feature(model,pre1)
147
+
148
+ # Load second image
149
+ img2 = cv2.imread('player2.jpg')
150
+
151
+ # Preprocess second image
152
+ pre2 = get_input(detector,img2)
153
+
154
+ # Get embedding of second image
155
+ out2 = get_feature(model,pre2)
156
+
157
+ # Compute squared distance between embeddings
158
+ dist = np.sum(np.square(out1-out2))
159
+ # Compute cosine similarity between embedddings
160
+ sim = np.dot(out1, out2.T)
161
+ # Print predictions
162
+ return 'Distance = %f' %(dist),'Similarity = %f' %(sim)
163
+
164
+ gr.Interface(inference,[gr.inputs.Image(type="file"),gr.inputs.Image(type="file")],["text","text"]).launch()