|
import os |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
from sentence_transformers import SentenceTransformer |
|
from sagemaker_inference import content_types, decoder, default_inference_handler, encoder |
|
|
|
def model_fn(model_dir): |
|
model = SentenceTransformer(model_dir) |
|
return model |
|
|
|
def input_fn(request_body, request_content_type): |
|
if request_content_type == content_types.JSON: |
|
input_data = decoder.decode(request_body, content_types.JSON) |
|
return input_data |
|
else: |
|
raise ValueError(f"Requested unsupported ContentType in content_type: {request_content_type}") |
|
|
|
def predict_fn(input_data, model): |
|
embeddings = model.encode(input_data) |
|
return embeddings |
|
|
|
def output_fn(prediction, accept): |
|
if accept == content_types.JSON: |
|
output = encoder.encode(prediction, content_types.JSON) |
|
return output |
|
else: |
|
raise ValueError(f"Requested unsupported ContentType in Accept: {accept}") |
|
|
|
|