pebblo-classifier-v2 / code /code_inference.py
nishan-dx's picture
Model V0 Release (#1)
cb9a2b6 verified
raw
history blame
925 Bytes
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
def model_fn(model_dir):
"""
Load the model and tokenizer from the specified paths
:param model_dir:
:return:
"""
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
return model, tokenizer
def predict_fn(data, model_and_tokenizer):
# destruct model and tokenizer
model, tokenizer = model_and_tokenizer
bert_pipe = pipeline("text-classification", model=model, tokenizer=tokenizer,
truncation=True, max_length=512, return_all_scores=True)
# Tokenize the input, pick up first 512 tokens before passing it further
tokens = tokenizer.encode(data['inputs'], add_special_tokens=False, max_length=512, truncation=True)
input_data = tokenizer.decode(tokens)
return bert_pipe(input_data)