from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
def model_fn(model_dir): | |
tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
model = AutoModelForSequenceClassification.from_pretrained(model_dir) | |
return model, tokenizer | |
def predict_fn(data, model_and_tokenizer): | |
model, tokenizer = model_and_tokenizer | |
# Assuming 'inputs' is the key in the input data | |
inputs = data.pop("inputs", data) | |
# Tokenize the input | |
tokenized = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True) | |
# Make the prediction | |
with torch.no_grad(): | |
output = model(**tokenized) | |
# Get the predicted class (assuming it's a classification task) | |
predicted_class = torch.argmax(output.logits, dim=1).item() | |
return {"predicted_class": predicted_class} |