File size: 854 Bytes
9537917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

def model_fn(model_dir):
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir)
    return model, tokenizer

def predict_fn(data, model_and_tokenizer):
    model, tokenizer = model_and_tokenizer
    
    # Assuming 'inputs' is the key in the input data
    inputs = data.pop("inputs", data)
    
    # Tokenize the input
    tokenized = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True)
    
    # Make the prediction
    with torch.no_grad():
        output = model(**tokenized)
    
    # Get the predicted class (assuming it's a classification task)
    predicted_class = torch.argmax(output.logits, dim=1).item()
    
    return {"predicted_class": predicted_class}