Upload 3 files
Browse files- code/inference.py +28 -0
- code/requirements.txt +2 -0
- eval/similarity_evaluation_sts-dev_results.csv +6 -0
code/inference.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModel, AutoTokenizer
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from sagemaker_inference import content_types, decoder, default_inference_handler, encoder
|
6 |
+
|
7 |
+
def model_fn(model_dir):
|
8 |
+
model = SentenceTransformer(model_dir)
|
9 |
+
return model
|
10 |
+
|
11 |
+
def input_fn(request_body, request_content_type):
|
12 |
+
if request_content_type == content_types.JSON:
|
13 |
+
input_data = decoder.decode(request_body, content_types.JSON)
|
14 |
+
return input_data
|
15 |
+
else:
|
16 |
+
raise ValueError(f"Requested unsupported ContentType in content_type: {request_content_type}")
|
17 |
+
|
18 |
+
def predict_fn(input_data, model):
|
19 |
+
embeddings = model.encode(input_data)
|
20 |
+
return embeddings
|
21 |
+
|
22 |
+
def output_fn(prediction, accept):
|
23 |
+
if accept == content_types.JSON:
|
24 |
+
output = encoder.encode(prediction, content_types.JSON)
|
25 |
+
return output
|
26 |
+
else:
|
27 |
+
raise ValueError(f"Requested unsupported ContentType in Accept: {accept}")
|
28 |
+
|
code/requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
sentence-transformers==2.4.0
|
2 |
+
torch==1.7.1
|
eval/similarity_evaluation_sts-dev_results.csv
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
epoch,steps,cosine_pearson,cosine_spearman,euclidean_pearson,euclidean_spearman,manhattan_pearson,manhattan_spearman,dot_pearson,dot_spearman
|
2 |
+
0,-1,0.8698176470151903,0.8703809106165765,0.8613456349850496,0.8713671870317011,0.8612671961630662,0.8711084727834425,0.8616164593092901,0.8605345981877489
|
3 |
+
1,-1,0.8734464873783272,0.8736683657907032,0.8675749215906989,0.8757654192103659,0.8672802412614803,0.8752541800842721,0.8657575423252234,0.8644435734300635
|
4 |
+
2,-1,0.8710017129066742,0.872910314428584,0.8663801605893812,0.8740412659842393,0.8659927634570299,0.873583159340305,0.8628811327109139,0.8624866717237128
|
5 |
+
3,-1,0.8739040673272463,0.8741376943020817,0.8682234999427393,0.8756945433821925,0.8677010320744731,0.875246831324674,0.8656950557273269,0.8641310893502131
|
6 |
+
4,-1,0.8742169848606813,0.8745454590505839,0.868420878937881,0.8760332005360969,0.8678115978073914,0.8753381392178419,0.8657302849379587,0.8640112492882958
|