financIA / src /predict.py
Frorozcol's picture
Load the app
9ee675e
raw
history blame
939 Bytes
from pathlib import Path
import torch
from .tokenizer import load_tokenizer, preprocessing_text
from .model import load_model
# CONFIG
NUM_VARAIBLES = 3
NUM_LABELS = 3
num_labels = NUM_LABELS * NUM_VARAIBLES
divice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_labels = NUM_LABELS * NUM_VARAIBLES
model_name = "pysentimiento/robertuito-sentiment-analysis"
checkpoint_path = Path(__file__).parent.parent / "checkpoints" / "model.ckpt"
tokenizer = load_tokenizer(model_name)
model = load_model(checkpoint_path, model_name, num_labels, divice)
def get_predict(text):
inputs = preprocessing_text(text, tokenizer)
input_ids = inputs["input_ids"].to(divice)
attention_mask = inputs["attention_mask"].to(divice)
token_type_ids = inputs["token_type_ids"].to(divice)
outputs = model(input_ids, attention_mask, token_type_ids)
preds = torch.sigmoid(outputs).detach().cpu().numpy()
return preds