Text Classification
Transformers
PyTorch
bert
Inference Endpoints
File size: 1,050 Bytes
320e6ef
362cab8
 
320e6ef
362cab8
 
 
 
 
 
 
 
 
b827343
4d90ca7
 
362cab8
 
 
 
 
 
 
 
 
 
 
 
 
4d90ca7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import List
import torch
from transformers import BertTokenizer
from foodybert import FoodyBertForSequenceClassification

class PreTrainedPipeline():

    def __init__(self, path=""):

        """
        Initialize model
        """
        self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.model = FoodyBertForSequenceClassification.from_pretrained(".")
    #def __call__(self, inputs: str) -> List[float]:
    def __call__(self, inputs: str) -> str:
        """
        Args:
            inputs (:obj:`str`):
                a string to get the features of.
        Return:
            A :obj:`list` of floats: The features computed by the model.
        """
        input_ids = self.bert_tokenizer.encode(inputs, add_special_tokens=True)
        X = torch.tensor([input_ids])
        with torch.no_grad():
            predicted_class_id = self.model(X).logits.argmax().item()
            labels = ['positive','neutral','negative']
            reps   = labels[predicted_class_id]
        return reps