Spaces:
Sleeping
Sleeping
#Import the libraries we know we'll need for the Generator. | |
import pandas as pd, spacy, nltk, numpy as np | |
from spacy.matcher import Matcher | |
#!python -m spacy download en_core_web_md #Not sure if we need this so I'm going to keep it just in case | |
nlp = spacy.load("en_core_web_lg") | |
lemmatizer = nlp.get_pipe("lemmatizer") | |
#Import the libraries to support the model and predictions. | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline | |
import lime | |
import torch | |
import torch.nn.functional as F | |
from lime.lime_text import LimeTextExplainer | |
#Import the libraries for human interaction and visualization. | |
import altair as alt | |
import streamlit as st | |
from annotated_text import annotated_text as ant | |
#Import functions needed to build dataframes of keywords from WordNet | |
from WNgen import * | |
from NLselector import * | |
def set_up_explainer(): | |
class_names = ['negative', 'positive'] | |
explainer = LimeTextExplainer(class_names=class_names) | |
return explainer | |
def prepare_model(): | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True) | |
return tokenizer, model, pipe | |
def prepare_lists(): | |
countries = pd.read_csv("Assets/Countries/combined-countries.csv") | |
professions = pd.read_csv("Assets/Professions/soc-professions-2018.csv") | |
word_lists = [list(countries.Words),list(professions.Words)] | |
return countries, professions, word_lists | |
#Provide all the functions necessary to run the app | |
#get definitions for control flow in Streamlit | |
def get_def(word, POS=False): | |
pos_options = ['NOUN','VERB','ADJ','ADV'] | |
m_word = word.replace(" ", "_") | |
if POS in pos_options: | |
seed_definitions = [syn.definition() for syn in wordnet.synsets(m_word, pos=getattr(wordnet, POS))] | |
else: | |
seed_definitions = [syn.definition() for syn in wordnet.synsets(m_word)] | |
seed_definition = col1.selectbox("Which definition is most relevant?", seed_definitions, key= "WN_definition") | |
if col1.button("Choose Definition"): | |
col1.write("You've chosen a definition.") | |
st.session_state.definition = seed_definition | |
return seed_definition | |
else: | |
col1.write("Please choose a definition.") | |
###Start coding the actual app### | |
st.set_page_config(layout="wide", page_title="MultiNLC Generator Test") | |
st.title('MultiNLC Generator Test') | |
st.write('This is a test of the pipeline Nathan built to generate counterfactuals for the STP-3 research project. Here we test the initial propsoal from Ana for comparing the Natural Language Explanation of multiple alternatives against the original input from a person.') | |
#Prepare the model | |
tokenizer, model, pipe = prepare_model() | |
countries, professions, word_lists = prepare_lists() | |
explainer = set_up_explainer() | |
text2 = "" | |
text3 = "" | |
cf_df = pd.DataFrame() | |
if 'definition' not in st.session_state: | |
st.session_state.definition = None | |
#if 'option' not in st.session_state: | |
# st.session_state.option = None | |
#Get the user to input a sentence | |
st.write('This first iteration only allows you to evaluate countries.') | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
text = st.text_input('Provide a sentence you want to evaluate.', placeholder = "I like you. I love you.", key="input") | |
#Use spaCy to make the sentence into a doc so we can do NLP. | |
doc = nlp(st.session_state.input) | |
#Evaluate the provided sentence for sentiment and probability. | |
if st.session_state.input != "": | |
probability, sentiment = eval_pred(text, return_all=True) | |
options, lime = critical_words(st.session_state.input,options=True) | |
nat_lang_explanation = construct_nlexp(text,sentiment,probability) | |
st.altair_chart(lime_viz(lime)) | |
#Allow the user to pick an option to generate counterfactuals from. | |
option = st.radio('Which word would you like to use to generate alternatives?', options, key = "option") | |
if (any(option in sublist for sublist in word_lists)): | |
st.write(f'You selected {option}. It matches a list.') | |
elif option: | |
st.write(f'You selected {option}. It does not match a list.') | |
definition = get_def(option) | |
else: | |
st.write('Awaiting your selection.') | |
if st.button('Generate Alternatives'): | |
if option in list(countries.Words): | |
cf_df = gen_cf_country(countries, doc, option) | |
col1.write('Alternatives created.') | |
elif option in list(professions.Words): | |
cf_df = gen_cf_profession(professions, doc, option) | |
col1.write('Alternatives created.') | |
else: | |
ant("Generating alternatives for",(option,"opt","#E0FBFB"), "with a definition of: ",(st.session_state.definition,"def","#E0FBFB"),".") | |
cf_df = cf_from_wordnet_df(option,text,seed_definition=st.session_state.definition) | |
col1.write('Alternatives created.') | |
if len(cf_df) != 0: | |
text2, text3 = get_min_max(cf_df, option) | |
with col2: | |
if text2 != "": | |
sim2 = cf_df.loc[cf_df['text'] == text2, 'similarity'].iloc[0] | |
st.write(f"This alternate example is similar to {option}.") | |
st.write(f" Similarity Score: {np.round(sim2, 2)}, Num Checked: {len(cf_df)}") #for QA purposes | |
st.write(text2) | |
exp2 = explainer.explain_instance(text2, predictor, num_features=15, num_samples=2000) | |
lime_results2 = exp2.as_list() | |
probability2, sentiment2 = eval_pred(text2, return_all=True) | |
nat_lang_explanation = construct_nlexp(text2,sentiment2,probability2) | |
st.altair_chart(lime_viz(lime_results2)) | |
with col3: | |
if text3 != "": | |
sim3 = cf_df.loc[cf_df['text'] == text3, 'similarity'].iloc[0] | |
st.write(f"This alternate example is not similar to {option}.") | |
st.write(f"Similarity Score: {np.round(sim3, 2)}, Num Checked: {len(cf_df)}") #for QA purposes | |
st.write(text3) | |
exp3 = explainer.explain_instance(text3, predictor, num_features=15, num_samples=2000) | |
lime_results3 = exp3.as_list() | |
probability3, sentiment3 = eval_pred(text3, return_all=True) | |
nat_lang_explanation = construct_nlexp(text3,sentiment3,probability3) | |
st.altair_chart(lime_viz(lime_results3)) |