Spaces:
Sleeping
Sleeping
File size: 6,520 Bytes
03287bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
#Import the libraries we know we'll need for the Generator.
import pandas as pd, spacy, nltk, numpy as np
from spacy.matcher import Matcher
#!python -m spacy download en_core_web_md #Not sure if we need this so I'm going to keep it just in case
nlp = spacy.load("en_core_web_lg")
lemmatizer = nlp.get_pipe("lemmatizer")
#Import the libraries to support the model and predictions.
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
import lime
import torch
import torch.nn.functional as F
from lime.lime_text import LimeTextExplainer
#Import the libraries for human interaction and visualization.
import altair as alt
import streamlit as st
from annotated_text import annotated_text as ant
#Import functions needed to build dataframes of keywords from WordNet
from WNgen import *
from NLselector import *
@st.experimental_singleton
def set_up_explainer():
class_names = ['negative', 'positive']
explainer = LimeTextExplainer(class_names=class_names)
return explainer
@st.experimental_singleton
def prepare_model():
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)
return tokenizer, model, pipe
@st.experimental_singleton
def prepare_lists():
countries = pd.read_csv("Assets/Countries/combined-countries.csv")
professions = pd.read_csv("Assets/Professions/soc-professions-2018.csv")
word_lists = [list(countries.Words),list(professions.Words)]
return countries, professions, word_lists
#Provide all the functions necessary to run the app
#get definitions for control flow in Streamlit
def get_def(word, POS=False):
pos_options = ['NOUN','VERB','ADJ','ADV']
m_word = word.replace(" ", "_")
if POS in pos_options:
seed_definitions = [syn.definition() for syn in wordnet.synsets(m_word, pos=getattr(wordnet, POS))]
else:
seed_definitions = [syn.definition() for syn in wordnet.synsets(m_word)]
seed_definition = col1.selectbox("Which definition is most relevant?", seed_definitions, key= "WN_definition")
if col1.button("Choose Definition"):
col1.write("You've chosen a definition.")
st.session_state.definition = seed_definition
return seed_definition
else:
col1.write("Please choose a definition.")
###Start coding the actual app###
st.set_page_config(layout="wide", page_title="VizNLC Generator Test")
st.title('VizNLC Generator Test')
st.write('This is a test of the pipeline Nathan built to generate counterfactuals for the STP-3 research project. Here we test the Nathan\'s elaboration for comparing the Natural Language Explanation and a visual display against the original input from a person.')
#Prepare the model
tokenizer, model, pipe = prepare_model()
countries, professions, word_lists = prepare_lists()
explainer = set_up_explainer()
text2 = ""
text3 = ""
cf_df = pd.DataFrame()
if 'definition' not in st.session_state:
st.session_state.definition = None
if 'option' not in st.session_state:
st.session_state.option = None
proceed = False
#Get the user to input a sentence
st.write('This first iteration only allows you to evaluate countries.')
col1, col2, col3 = st.columns(3)
with col1:
text = st.text_input('Provide a sentence you want to evaluate.', placeholder = "I like you. I love you.", key="input")
#Use spaCy to make the sentence into a doc so we can do NLP.
doc = nlp(st.session_state.input)
#Evaluate the provided sentence for sentiment and probability.
if st.session_state.input != "":
probability, sentiment = eval_pred(text, return_all=True)
options, lime = critical_words(st.session_state.input,options=True)
nat_lang_explanation = construct_nlexp(text,sentiment,probability)
st.altair_chart(lime_viz(lime))
#Allow the user to pick an option to generate counterfactuals from.
option = st.radio('Which word would you like to use to generate alternatives?', options, key = "option")
if (any(option in sublist for sublist in word_lists)):
st.write(f'You selected {option}. It matches a list.')
elif option:
st.write(f'You selected {option}. It does not match a list.')
definition = get_def(option)
else:
st.write('Awaiting your selection.')
if st.button('Generate Alternatives'):
if option in list(countries.Words):
cf_df = gen_cf_country(countries, doc, option)
col1.write('Alternatives created.')
elif option in list(professions.Words):
cf_df = gen_cf_country(professions, doc, option)
col1.write('Alternatives created.')
else:
ant("Generating alternatives for",(option,"opt","#E0FBFB"), "with a definition of: ",(st.session_state.definition,"def","#E0FBFB"),".")
cf_df = cf_from_wordnet_df(option,text,seed_definition=st.session_state.definition)
col1.write('Alternatives created.')
if len(cf_df) != 0:
text2, text3 = get_min_max(cf_df, option)
with col2:
if text2 != "":
sim2 = cf_df.loc[cf_df['text'] == text2, 'similarity'].iloc[0]
st.write(f"This alternate example is similar to {option}.")
st.write(f" Similarity Score: {np.round(sim2, 2)}, Num Checked: {len(cf_df)}") #for QA purposes
st.write(text2)
exp2 = explainer.explain_instance(text2, predictor, num_features=15, num_samples=2000)
lime_results2 = exp2.as_list()
probability2, sentiment2 = eval_pred(text2, return_all=True)
nat_lang_explanation = construct_nlexp(text2,sentiment2,probability2)
st.altair_chart(lime_viz(lime_results2))
with col3:
if not cf_df.empty:
single_nearest = alt.selection_single(on='mouseover', nearest=True)
full = alt.Chart(cf_df).encode(
alt.X('similarity:Q', scale=alt.Scale(zero=False)),
alt.Y('pred:Q'),
color=alt.Color('Categories:N', legend=alt.Legend(title="Color of Categories")),
size=alt.Size('seed:O'),
tooltip=('Categories','text','pred')
).mark_circle(opacity=.5).properties(width=450, height=450).add_selection(single_nearest)
st.altair_chart(full) |