File size: 925 Bytes
8595e5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline


def model_fn(model_dir):
    """
    Load the model and tokenizer from the specified paths
    :param model_dir:
    :return:
    """
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir)
    return model, tokenizer


def predict_fn(data, model_and_tokenizer):
    # destruct model and tokenizer
    model, tokenizer = model_and_tokenizer

    bert_pipe = pipeline("text-classification", model=model, tokenizer=tokenizer,
                         truncation=True, max_length=512, return_all_scores=True)
    # Tokenize the input, pick up first 512 tokens before passing it further
    tokens = tokenizer.encode(data['inputs'], add_special_tokens=False, max_length=512, truncation=True)
    input_data = tokenizer.decode(tokens)
    return bert_pipe(input_data)