item-desirability-demo / modeling.py
bjorn-hommel's picture
updated token ref
f3e1a6f
import os
import logging
import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
sentiment_model_path = 'cardiffnlp/twitter-xlm-roberta-base-sentiment'
@st.cache_resource
def load_desirability_model():
try:
model_path = os.getenv('model_path', 'magnolia-psychometrics/item-desirability')
auth_token = os.environ.get('item_desirability') or True
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=True,
use_auth_token=auth_token
)
model = AutoModelForSequenceClassification.from_pretrained(
model_path,
num_labels=1,
ignore_mismatched_sizes=True,
use_auth_token=auth_token
)
logging.info('Loaded desirability model and tokenizer!')
return tokenizer, model
except Exception as e:
logging.error(f'Error loading desirability model/tokenizer: {e}')
return None, None
@st.cache_resource
def load_sentiment_classifier():
try:
classifier = pipeline(
'sentiment-analysis',
model=sentiment_model_path,
tokenizer=sentiment_model_path,
use_fast=False,
top_k=3
)
logging.info('Loaded sentiment classifier!')
return classifier
except Exception as e:
logging.error(f'Error loading sentiment classifier: {e}')
return None
def z_score(y, mean=0.04853076, sd=0.9409466):
return (y - mean) / sd
def score_text(input_text):
with st.spinner('Predicting...'):
classifier = load_sentiment_classifier()
if not classifier:
st.error('Error loading sentiment classifier.')
return None, None
logging.info('Sentiment classifier loaded successfully.')
classifier_output = classifier(input_text)
classifier_output_dict = {x['label']: x['score'] for x in classifier_output[0]}
sentiment = classifier_output_dict.get('positive', 0) - classifier_output_dict.get('negative', 0)
tokenizer, model = load_desirability_model()
if not tokenizer or not model:
st.error('Error loading desirability model.')
return None, None
logging.info('Desirability model and tokenizer loaded successfully.')
inputs = tokenizer(input_text, padding=True, return_tensors='pt')
with torch.no_grad():
score = model(**inputs).logits.squeeze().tolist()
desirability = z_score(score)
logging.info(f'Sentiment: {sentiment}, Desirability: {desirability}')
return sentiment, desirability