|
import os |
|
import logging |
|
import torch |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from transformers import pipeline |
|
|
|
sentiment_model_path = 'cardiffnlp/twitter-xlm-roberta-base-sentiment' |
|
|
|
def load_model(): |
|
|
|
keys = ['tokenizer', 'model', 'classifier'] |
|
|
|
if any(st.session_state.get(key) is None for key in keys): |
|
with st.spinner('Loading the model might take a couple of seconds...'): |
|
try: |
|
if os.environ.get('item-desirability'): |
|
model_path = 'magnolia-psychometrics/item-desirability' |
|
else: |
|
model_path = os.getenv('model_path') |
|
|
|
auth_token = os.environ.get('item-desirability') or True |
|
|
|
st.session_state.tokenizer = AutoTokenizer.from_pretrained( |
|
pretrained_model_name_or_path=model_path, |
|
use_fast=True, |
|
use_auth_token=auth_token |
|
) |
|
|
|
st.session_state.model = AutoModelForSequenceClassification.from_pretrained( |
|
pretrained_model_name_or_path=model_path, |
|
num_labels=1, |
|
ignore_mismatched_sizes=True, |
|
use_auth_token=auth_token |
|
) |
|
|
|
st.session_state.classifier = pipeline( |
|
task='sentiment-analysis', |
|
model=sentiment_model_path, |
|
tokenizer=sentiment_model_path, |
|
use_fast=False, |
|
top_k=3 |
|
) |
|
|
|
logging.info('Loaded models and tokenizer!') |
|
|
|
except Exception as e: |
|
logging.error(f'Error while loading models/tokenizer: {e}') |
|
|
|
def z_score(y, mean=.04853076, sd=.9409466): |
|
return (y - mean) / sd |
|
|
|
def score_text(input_text): |
|
with st.spinner('Predicting...'): |
|
classifier_output = st.session_state.classifier(input_text) |
|
classifier_output_dict = {x['label']: x['score'] for x in classifier_output[0]} |
|
sentiment = classifier_output_dict['positive'] - classifier_output_dict['negative'] |
|
|
|
inputs = st.session_state.tokenizer(text=input_text, padding=True, return_tensors='pt') |
|
|
|
with torch.no_grad(): |
|
score = st.session_state.model(**inputs).logits.squeeze().tolist() |
|
desirability = z_score(score) |
|
|
|
return sentiment, desirability |