sridhar-cd commited on
Commit
8595e5b
1 Parent(s): 3afb3ae

Custom inference code for SageMaker deployment (#9)

Browse files

- Custom inference code for aws (a211030719303b038c551e8378aeb0c21ea4c79f)

Files changed (2) hide show
  1. code/inference.py +24 -0
  2. code/requirements.txt +0 -0
code/inference.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
2
+
3
+
4
+ def model_fn(model_dir):
5
+ """
6
+ Load the model and tokenizer from the specified paths
7
+ :param model_dir:
8
+ :return:
9
+ """
10
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
11
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
12
+ return model, tokenizer
13
+
14
+
15
+ def predict_fn(data, model_and_tokenizer):
16
+ # destruct model and tokenizer
17
+ model, tokenizer = model_and_tokenizer
18
+
19
+ bert_pipe = pipeline("text-classification", model=model, tokenizer=tokenizer,
20
+ truncation=True, max_length=512, return_all_scores=True)
21
+ # Tokenize the input, pick up first 512 tokens before passing it further
22
+ tokens = tokenizer.encode(data['inputs'], add_special_tokens=False, max_length=512, truncation=True)
23
+ input_data = tokenizer.decode(tokens)
24
+ return bert_pipe(input_data)
code/requirements.txt ADDED
File without changes