distil-bert / inference.py
Gideonah's picture
Upload folder using huggingface_hub
9537917 verified
raw
history blame contribute delete
854 Bytes
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
def model_fn(model_dir):
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
return model, tokenizer
def predict_fn(data, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
# Assuming 'inputs' is the key in the input data
inputs = data.pop("inputs", data)
# Tokenize the input
tokenized = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True)
# Make the prediction
with torch.no_grad():
output = model(**tokenized)
# Get the predicted class (assuming it's a classification task)
predicted_class = torch.argmax(output.logits, dim=1).item()
return {"predicted_class": predicted_class}