wyluilipe's picture
Create README.md
46f724b verified
|
raw
history blame
872 Bytes

from transformers import BertTokenizer, BertForSequenceClassification
import torch
import numpy as np
import json

class Prehibition:
    def __init__(self):
        model_name = 'wyluilipe/prehibiton-themes-clf'
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertForSequenceClassification.from_pretrained(model_name)

    def predict(self, text):
        tokenized = self.tokenizer.batch_encode_plus(
            [text],
            max_length = 512,
            pad_to_max_length=True,
            truncation=True,
            return_token_type_ids=False
        )
        tokens_ids, mask = torch.tensor(tokenized['input_ids']), torch.tensor(tokenized['attention_mask'])
        with torch.no_grad():
            model_output = self.model(tokens_ids, mask)
        return np.argmax(model_output['logits']).item()