from transformers import BertTokenizer, BertForSequenceClassification
import torch
import numpy as np
import json
class Prehibition:
def __init__(self):
model_name = 'wyluilipe/prehibiton-themes-clf'
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertForSequenceClassification.from_pretrained(model_name)
def predict(self, text):
tokenized = self.tokenizer.batch_encode_plus(
[text],
max_length = 512,
pad_to_max_length=True,
truncation=True,
return_token_type_ids=False
)
tokens_ids, mask = torch.tensor(tokenized['input_ids']), torch.tensor(tokenized['attention_mask'])
with torch.no_grad():
model_output = self.model(tokens_ids, mask)
return np.argmax(model_output['logits']).item()