shamaayan commited on
Commit
c70152c
1 Parent(s): c7c1092
Files changed (1) hide show
  1. code/inference.py +6 -0
code/inference.py CHANGED
@@ -1,7 +1,13 @@
 
1
  import numpy as np
2
  from typing import List, Union
3
 
4
 
 
 
 
 
 
5
  def predict_fn(data: Union[List[str], str], model):
6
  outputs = model(data, padding=False, truncation=True)
7
  embeddings = [np.array(r[0]).mean(axis=0).tolist() for r in outputs]
 
1
+ import json
2
  import numpy as np
3
  from typing import List, Union
4
 
5
 
6
+ def input_fn(input_data, content_type):
7
+ data = json.loads(input_data)
8
+ return data['inputs']
9
+
10
+
11
  def predict_fn(data: Union[List[str], str], model):
12
  outputs = model(data, padding=False, truncation=True)
13
  embeddings = [np.array(r[0]).mean(axis=0).tolist() for r in outputs]