djsull commited on
Commit
649aa1c
·
verified ·
1 Parent(s): 95d256f

Upload 3 files

Browse files
code/inference.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoModel, AutoTokenizer
4
+ from sentence_transformers import SentenceTransformer
5
+ from sagemaker_inference import content_types, decoder, default_inference_handler, encoder
6
+
7
+ def model_fn(model_dir):
8
+ model = SentenceTransformer(model_dir)
9
+ return model
10
+
11
+ def input_fn(request_body, request_content_type):
12
+ if request_content_type == content_types.JSON:
13
+ input_data = decoder.decode(request_body, content_types.JSON)
14
+ return input_data
15
+ else:
16
+ raise ValueError(f"Requested unsupported ContentType in content_type: {request_content_type}")
17
+
18
+ def predict_fn(input_data, model):
19
+ embeddings = model.encode(input_data)
20
+ return embeddings
21
+
22
+ def output_fn(prediction, accept):
23
+ if accept == content_types.JSON:
24
+ output = encoder.encode(prediction, content_types.JSON)
25
+ return output
26
+ else:
27
+ raise ValueError(f"Requested unsupported ContentType in Accept: {accept}")
28
+
code/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sentence-transformers==2.4.0
2
+ torch==1.7.1
eval/similarity_evaluation_sts-dev_results.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ epoch,steps,cosine_pearson,cosine_spearman,euclidean_pearson,euclidean_spearman,manhattan_pearson,manhattan_spearman,dot_pearson,dot_spearman
2
+ 0,-1,0.8698176470151903,0.8703809106165765,0.8613456349850496,0.8713671870317011,0.8612671961630662,0.8711084727834425,0.8616164593092901,0.8605345981877489
3
+ 1,-1,0.8734464873783272,0.8736683657907032,0.8675749215906989,0.8757654192103659,0.8672802412614803,0.8752541800842721,0.8657575423252234,0.8644435734300635
4
+ 2,-1,0.8710017129066742,0.872910314428584,0.8663801605893812,0.8740412659842393,0.8659927634570299,0.873583159340305,0.8628811327109139,0.8624866717237128
5
+ 3,-1,0.8739040673272463,0.8741376943020817,0.8682234999427393,0.8756945433821925,0.8677010320744731,0.875246831324674,0.8656950557273269,0.8641310893502131
6
+ 4,-1,0.8742169848606813,0.8745454590505839,0.868420878937881,0.8760332005360969,0.8678115978073914,0.8753381392178419,0.8657302849379587,0.8640112492882958