|
import os |
|
import logging |
|
import torch |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from transformers import pipeline |
|
|
|
sentiment_model_path = 'cardiffnlp/twitter-xlm-roberta-base-sentiment' |
|
|
|
@st.cache_resource |
|
def load_desirability_model(): |
|
try: |
|
model_path = os.getenv('model_path', 'magnolia-psychometrics/item-desirability') |
|
auth_token = os.environ.get('item_desirability') or True |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_path, |
|
use_fast=True, |
|
use_auth_token=auth_token |
|
) |
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
model_path, |
|
num_labels=1, |
|
ignore_mismatched_sizes=True, |
|
use_auth_token=auth_token |
|
) |
|
|
|
logging.info('Loaded desirability model and tokenizer!') |
|
return tokenizer, model |
|
|
|
except Exception as e: |
|
logging.error(f'Error loading desirability model/tokenizer: {e}') |
|
return None, None |
|
|
|
@st.cache_resource |
|
def load_sentiment_classifier(): |
|
try: |
|
classifier = pipeline( |
|
'sentiment-analysis', |
|
model=sentiment_model_path, |
|
tokenizer=sentiment_model_path, |
|
use_fast=False, |
|
top_k=3 |
|
) |
|
logging.info('Loaded sentiment classifier!') |
|
return classifier |
|
except Exception as e: |
|
logging.error(f'Error loading sentiment classifier: {e}') |
|
return None |
|
|
|
def z_score(y, mean=0.04853076, sd=0.9409466): |
|
return (y - mean) / sd |
|
|
|
def score_text(input_text): |
|
with st.spinner('Predicting...'): |
|
classifier = load_sentiment_classifier() |
|
if not classifier: |
|
st.error('Error loading sentiment classifier.') |
|
return None, None |
|
|
|
logging.info('Sentiment classifier loaded successfully.') |
|
|
|
classifier_output = classifier(input_text) |
|
classifier_output_dict = {x['label']: x['score'] for x in classifier_output[0]} |
|
sentiment = classifier_output_dict.get('positive', 0) - classifier_output_dict.get('negative', 0) |
|
|
|
tokenizer, model = load_desirability_model() |
|
if not tokenizer or not model: |
|
st.error('Error loading desirability model.') |
|
return None, None |
|
|
|
logging.info('Desirability model and tokenizer loaded successfully.') |
|
|
|
inputs = tokenizer(input_text, padding=True, return_tensors='pt') |
|
with torch.no_grad(): |
|
score = model(**inputs).logits.squeeze().tolist() |
|
desirability = z_score(score) |
|
|
|
logging.info(f'Sentiment: {sentiment}, Desirability: {desirability}') |
|
return sentiment, desirability |