item-desirability-demo / modeling.py
bjorn-hommel's picture
refactor
28183db
raw
history blame
2.46 kB
import os
import logging
import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
sentiment_model_path = 'cardiffnlp/twitter-xlm-roberta-base-sentiment'
def load_model():
keys = ['tokenizer', 'model', 'classifier']
if any(st.session_state.get(key) is None for key in keys):
with st.spinner('Loading the model might take a couple of seconds...'):
try:
if os.environ.get('item-desirability'):
model_path = 'magnolia-psychometrics/item-desirability'
else:
model_path = os.getenv('model_path')
auth_token = os.environ.get('item-desirability') or True
st.session_state.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=model_path,
use_fast=True,
use_auth_token=auth_token
)
st.session_state.model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=model_path,
num_labels=1,
ignore_mismatched_sizes=True,
use_auth_token=auth_token
)
st.session_state.classifier = pipeline(
task='sentiment-analysis',
model=sentiment_model_path,
tokenizer=sentiment_model_path,
use_fast=False,
top_k=3
)
logging.info('Loaded models and tokenizer!')
except Exception as e:
logging.error(f'Error while loading models/tokenizer: {e}')
def z_score(y, mean=.04853076, sd=.9409466):
return (y - mean) / sd
def score_text(input_text):
with st.spinner('Predicting...'):
classifier_output = st.session_state.classifier(input_text)
classifier_output_dict = {x['label']: x['score'] for x in classifier_output[0]}
sentiment = classifier_output_dict['positive'] - classifier_output_dict['negative']
inputs = st.session_state.tokenizer(text=input_text, padding=True, return_tensors='pt')
with torch.no_grad():
score = st.session_state.model(**inputs).logits.squeeze().tolist()
desirability = z_score(score)
return sentiment, desirability