from typing import List import torch from transformers import BertTokenizer from foodybert import FoodyBertForSequenceClassification class PreTrainedPipeline(): def __init__(self, path=""): """ Initialize model """ self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") self.model = FoodyBertForSequenceClassification.from_pretrained(".") #def __call__(self, inputs: str) -> List[float]: def __call__(self, inputs: str) -> str: """ Args: inputs (:obj:`str`): a string to get the features of. Return: A :obj:`list` of floats: The features computed by the model. """ input_ids = self.bert_tokenizer.encode(inputs, add_special_tokens=True) X = torch.tensor([input_ids]) with torch.no_grad(): predicted_class_id = self.model(X).logits.argmax().item() labels = ['positive','neutral','negative'] reps = labels[predicted_class_id] return reps