djsull's picture
Upload 3 files
649aa1c verified
raw
history blame
968 Bytes
import os
import torch
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
from sagemaker_inference import content_types, decoder, default_inference_handler, encoder
def model_fn(model_dir):
model = SentenceTransformer(model_dir)
return model
def input_fn(request_body, request_content_type):
if request_content_type == content_types.JSON:
input_data = decoder.decode(request_body, content_types.JSON)
return input_data
else:
raise ValueError(f"Requested unsupported ContentType in content_type: {request_content_type}")
def predict_fn(input_data, model):
embeddings = model.encode(input_data)
return embeddings
def output_fn(prediction, accept):
if accept == content_types.JSON:
output = encoder.encode(prediction, content_types.JSON)
return output
else:
raise ValueError(f"Requested unsupported ContentType in Accept: {accept}")