{ "cells": [ { "cell_type": "markdown", "id": "8ea54fcd-ef4a-42cb-ae26-cbdc6f6ffc64", "metadata": { "tags": [] }, "source": [ "# Duct Tape Pipeline\n", "To explore how users may interact with interactive visualizations of counterfactuals for evolving the Interactive Model Card, we will need to first find a way to generate counterfactuals based on a given input. We want the user to be able to provide their input and direct the system to generate counterfactuals based on a part of speech that is significant to the model. The system should then provide a data frame of counterfactuals to be used in an interactive visualization. Below is an example wireframe of the experience based on previous research.\n", "\n", "![wireframe](Assets/VizNLC-Wireframe-example.png)\n", "\n", "## Goals of this notebook\n", "* Clean up the flow in the \"duct tape pipeline\".\n", "* See if I can extract the LIME list for visualization" ] }, { "cell_type": "markdown", "id": "736e6375-dd6d-4188-b8b1-92bded2bcd02", "metadata": {}, "source": [ "## Loading the libraries and models" ] }, { "cell_type": "code", "execution_count": 1, "id": "7f581785-e642-4f74-9f67-06a63820eaf2", "metadata": {}, "outputs": [], "source": [ "#Import the libraries we know we'll need for the Generator.\n", "import pandas as pd, spacy, nltk, numpy as np\n", "from spacy import displacy\n", "from spacy.matcher import Matcher\n", "#!python -m spacy download en_core_web_sm\n", "nlp = spacy.load(\"en_core_web_md\")\n", "lemmatizer = nlp.get_pipe(\"lemmatizer\")\n", "\n", "#Import the libraries to support the model, predictions, and LIME.\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline\n", "import lime\n", "import torch\n", "import torch.nn.functional as F\n", "from lime.lime_text import LimeTextExplainer\n", "\n", "#Import the libraries for generating interactive visualizations.\n", "import altair as alt" ] }, { "cell_type": "code", "execution_count": 2, "id": "cbe2b292-e33e-4915-8e61-bba5327fb643", "metadata": {}, "outputs": [], "source": [ "#Defining all necessary variables and instances.\n", "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")\n", "model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")\n", "class_names = ['negative', 'positive']\n", "explainer = LimeTextExplainer(class_names=class_names)" ] }, { "cell_type": "code", "execution_count": 3, "id": "197c3e26-0fdf-49c6-9135-57f1fd55d3e3", "metadata": {}, "outputs": [], "source": [ "#Defining a Predictor required for LIME to function.\n", "def predictor(texts):\n", " outputs = model(**tokenizer(texts, return_tensors=\"pt\", padding=True))\n", " probas = F.softmax(outputs.logits, dim=1).detach().numpy()\n", " return probas" ] }, { "cell_type": "code", "execution_count": 4, "id": "013af6ac-f7d1-41d2-a601-b0f9a4870815", "metadata": {}, "outputs": [], "source": [ "#Instantiate a matcher and use it to test some patterns.\n", "matcher = Matcher(nlp.vocab)\n", "pattern = [{\"ENT_TYPE\": {\"IN\":[\"NORP\",\"GPE\"]}}]\n", "matcher.add(\"proper_noun\", [pattern])\n", "pattern_test = [{\"DEP\": \"amod\"},{\"DEP\":\"attr\"},{\"TEXT\":\"-\"},{\"DEP\":\"attr\",\"OP\":\"+\"}]\n", "matcher.add(\"amod_attr\",[pattern_test])\n", "pattern_an = [{\"DEP\": \"amod\"},{\"POS\":{\"IN\":[\"NOUN\",\"PROPN\"]}},{\"DEP\":{\"NOT_IN\":[\"attr\"]}}]\n", "matcher.add(\"amod_noun\", [pattern_an])" ] }, { "cell_type": "code", "execution_count": 5, "id": "f6ac821d-7b56-446e-b9ca-42a5f5afd198", "metadata": {}, "outputs": [], "source": [ "def match_this(matcher, doc):\n", " matches = matcher(doc)\n", " for match_id, start, end in matches:\n", " matched_span = doc[start:end]\n", " print(f\"Mached {matched_span.text} by the rule {nlp.vocab.strings[match_id]}.\")\n", " return matches" ] }, { "cell_type": "markdown", "id": "c23d48c4-f5ab-4428-9244-0786e9903a8e", "metadata": { "tags": [] }, "source": [ "## Building the Duct-Tape Pipeline cell-by-cell" ] }, { "cell_type": "code", "execution_count": 6, "id": "a373fc00-401a-4def-9f09-de73d485ac13", "metadata": {}, "outputs": [], "source": [ "gender = [\"man\", \"woman\",\"girl\",\"boy\",\"male\",\"female\",\"husband\",\"wife\",\"girlfriend\",\"boyfriend\",\"brother\",\"sister\",\"aunt\",\"uncle\",\"grandma\",\"grandpa\",\"granny\",\"granps\",\"grandmother\",\"grandfather\",\"mama\",\"dada\",\"Ma\",\"Pa\",\"lady\",\"gentleman\"]" ] }, { "cell_type": "code", "execution_count": 7, "id": "8b02a5d4-8a6b-4e5e-8f15-4f9182fe341f", "metadata": {}, "outputs": [], "source": [ "def select_crit(document, options=False, limelist=False):\n", " '''This function is meant to select the critical part of a sentence. Critical, in this context means\n", " the part of the sentence that is either: A) a PROPN from the correct entity group; B) an ADJ associated with a NOUN;\n", " C) a NOUN that represents gender. It also checks this against what the model thinks is important if the user defines \"options\" as \"LIME\" or True.'''\n", " chunks = list(document.noun_chunks)\n", " pos_options = []\n", " lime_options = []\n", " \n", " #Identify what the model cares about.\n", " if options:\n", " exp = explainer.explain_instance(document.text, predictor, num_features=15, num_samples=2000)\n", " lime_results = exp.as_list()\n", " #prints the results from lime for QA.\n", " if limelist == True:\n", " print(lime_results)\n", " for feature in lime_results:\n", " lime_options.append(feature[0])\n", " lime_results = pd.DataFrame(lime_results, columns=[\"Word\",\"Weight\"])\n", " \n", " #Identify what we care about \"parts of speech\"\n", " for chunk in chunks:\n", " #The use of chunk[-1] is due to testing that it appears to always match the root\n", " root = chunk[-1]\n", " #This currently matches to a list I've created. I don't know the best way to deal with this so I'm leaving it as is for the moment.\n", " if root.text.lower() in gender:\n", " cur_values = [token.text for token in chunk if token.pos_ in [\"NOUN\",\"ADJ\"]]\n", " if (all(elem in lime_options for elem in cur_values) and ((options == \"LIME\") or (options == True))) or ((options != \"LIME\") and (options != True)):\n", " pos_options.extend(cur_values)\n", " #print(f\"From {chunk.text}, {cur_values} added to pos_options due to gender.\") #for QA\n", " #This is currently set to pick up entities in a particular set of groups (which I recently expanded). Should it just pick up all named entities?\n", " elif root.ent_type_ in [\"GPE\",\"NORP\",\"DATE\",\"EVENT\"]:\n", " cur_values = []\n", " if (len(chunk) > 1) and (chunk[-2].dep_ == \"compound\"):\n", " #creates the compound element of the noun\n", " compound = [x.text for x in chunk if x.dep_ == \"compound\"]\n", " print(f\"This is the contents of {compound} and it is {all(elem in lime_options for elem in compound)} that all elements are present in {lime_options}.\") #for QA\n", " #checks to see all elements in the compound are important to the model or use the compound if not checking importance.\n", " if (all(elem in lime_options for elem in compound) and ((options == \"LIME\") or (options == True))) or ((options != \"LIME\") and (options != True)):\n", " #creates a span for the entirety of the compound noun and adds it to the list.\n", " span = -1 * (1 + len(compound))\n", " pos_options.append(chunk[span:].text)\n", " cur_values + [token.text for token in chunk if token.pos_ == \"ADJ\"]\n", " else: \n", " cur_values = [token.text for token in chunk if (token.ent_type_ in [\"GPE\",\"NORP\",\"DATE\",\"EVENT\"]) or (token.pos_ == \"ADJ\")]\n", " if (all(elem in lime_options for elem in cur_values) and ((options == \"LIME\") or (options == True))) or ((options != \"LIME\") and (options != True)):\n", " pos_options.extend(cur_values)\n", " print(f\"From {chunk.text}, {cur_values} and {pos_options} added to pos_options due to entity recognition.\") #for QA\n", " elif len(chunk) > 1:\n", " cur_values = [token.text for token in chunk if token.pos_ in [\"NOUN\",\"ADJ\"]]\n", " if (all(elem in lime_options for elem in cur_values) and ((options == \"LIME\") or (options == True))) or ((options != \"LIME\") and (options != True)):\n", " pos_options.extend(cur_values)\n", " print(f\"From {chunk.text}, {cur_values} added to pos_options due to wildcard.\") #for QA\n", " else:\n", " print(f\"No options added for \\'{chunk.text}\\' \")\n", " \n", " \n", " #Return the correct set of options based on user input, defaults to POS for simplicity.\n", " if options == \"LIME\":\n", " return pos_options, lime_results\n", " else:\n", " return pos_options" ] }, { "cell_type": "code", "execution_count": 8, "id": "d43e202e-64b9-4cea-b117-82492c9ee5f4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "From This film, ['film'] added to pos_options due to wildcard.\n", "From Iraq, ['Iraq'] and ['film', 'Iraq'] added to pos_options due to entity recognition.\n" ] } ], "source": [ "#Test to make sure all three options work\n", "text4 = \"This film was filmed in Iraq.\"\n", "doc4 = nlp(text4)\n", "lime4, limedf = select_crit(doc4,options=\"LIME\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "a0e55a24-65df-429e-a0cd-8daf91a5d242", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.LayerChart(...)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "single_nearest = alt.selection_single(on='mouseover', nearest=True)\n", "viz = alt.Chart(limedf).encode(\n", " alt.X('Weight:Q', scale=alt.Scale(domain=(-1, 1))),\n", " alt.Y('Word:N', sort='x', axis=None),\n", " color=alt.Color(\"Weight\", scale=alt.Scale(scheme='blueorange', domain=[0], type=\"threshold\", range='diverging'), legend=None),\n", " tooltip = (\"Word\",\"Weight\")\n", ").mark_bar().properties(title =\"Importance of individual words\")\n", "\n", "text = viz.mark_text(\n", " fill=\"black\",\n", " align='right',\n", " baseline='middle'\n", ").encode(\n", " text='Word:N'\n", ")\n", "limeplot = alt.LayerChart(layer=[viz,text], width = 300).configure_axis(grid=False).configure_view(strokeWidth=0)\n", "limeplot" ] }, { "cell_type": "markdown", "id": "bf0512b6-336e-4842-9bde-34e03a1ca7c6", "metadata": {}, "source": [ "### Testing predictions and visualization\n", "Here I will attempt to import the model from huggingface, generate predictions for each of the sentences, and then visualize those predictions into a dot plot. If I can get this to work then I will move on to testing a full pipeline for letting the user pick which part of the sentence they wish to generate counterfactuals for." ] }, { "cell_type": "code", "execution_count": 10, "id": "74c639bb-e74a-4a46-8047-3552265ae6a4", "metadata": {}, "outputs": [], "source": [ "#Discovering that there's a pipeline specifically to provide scores. \n", "#I used it to get a list of lists of dictionaries that I can then manipulate to calculate the proper prediction score.\n", "pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)" ] }, { "cell_type": "code", "execution_count": 11, "id": "8726a284-99bd-47f1-9756-1c3ae603db10", "metadata": {}, "outputs": [], "source": [ "def eval_pred(text):\n", " '''A basic function for evaluating the prediction from the model and turning it into a visualization friendly number.'''\n", " preds = pipe(text)\n", " neg_score = preds[0][0]['score']\n", " pos_score = preds[0][1]['score']\n", " if pos_score >= neg_score:\n", " return pos_score\n", " if neg_score >= pos_score:\n", " return -1 * neg_score" ] }, { "cell_type": "code", "execution_count": 12, "id": "f38f5061-f30a-4c81-9465-37951c3ad9f4", "metadata": {}, "outputs": [], "source": [ "def eval_pred_test(text, return_all = False):\n", " '''A basic function for evaluating the prediction from the model and turning it into a visualization friendly number.'''\n", " preds = pipe(text)\n", " neg_score = -1 * preds[0][0]['score']\n", " sent_neg = preds[0][0]['label']\n", " pos_score = preds[0][1]['score']\n", " sent_pos = preds[0][1]['label']\n", " prediction = 0\n", " sentiment = ''\n", " if pos_score > abs(neg_score):\n", " prediction = pos_score\n", " sentiment = sent_pos\n", " elif abs(neg_score) > pos_score:\n", " prediction = neg_score\n", " sentiment = sent_neg\n", " \n", " if return_all:\n", " return prediction, sentiment\n", " else:\n", " return prediction" ] }, { "cell_type": "markdown", "id": "8b349a87-fe83-4045-a63a-d054489bb461", "metadata": {}, "source": [ "## Load the dummy countries I created to test generating counterfactuals\n", "I decided to test the pipeline with a known problem space. Taking the text from Aurélien Géron's observations in twitter, I built a built a small scale test using the learnings I had to prove that we can identify a particular part of speech, use it to generate counterfactuals, and then build a visualization off it." ] }, { "cell_type": "code", "execution_count": 13, "id": "46ab3332-964c-449f-8cef-a9ff7df397a4", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CountryContinent
0AlgeriaAfrica
1AngolaAfrica
2BeninAfrica
3BotswanaAfrica
4BurkinaAfrica
\n", "
" ], "text/plain": [ " Country Continent\n", "0 Algeria Africa\n", "1 Angola Africa\n", "2 Benin Africa\n", "3 Botswana Africa\n", "4 Burkina Africa" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#load my test data from https://github.com/dbouquin/IS_608/blob/master/NanosatDB_munging/Countries-Continents.csv\n", "df = pd.read_csv(\"Assets/Countries/countries.csv\")\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 14, "id": "51c75894-80af-4625-8ce8-660e500b496b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "From This film, ['film'] added to pos_options due to wildcard.\n", "From Iraq, ['Iraq'] and ['film', 'Iraq'] added to pos_options due to entity recognition.\n", "['film', 'Iraq']\n" ] }, { "data": { "text/plain": [ "'Iraq'" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#Note: we will need to build the function that lets the user choose from the options available. For now I have hard coded it as \"selection\", from \"user_options\".\n", "user_options = select_crit(doc4)\n", "print(user_options)\n", "selection = user_options[1]\n", "selection" ] }, { "cell_type": "code", "execution_count": 15, "id": "3d6419f1-bf7d-44bc-afb8-ac26ef9002df", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CountryContinenttextpredictionseed
0AlgeriaAfricaThis film was filmed in Algeria.0.806454alternative
1AngolaAfricaThis film was filmed in Angola.-0.775854alternative
2BeninAfricaThis film was filmed in Benin.0.962272alternative
3BotswanaAfricaThis film was filmed in Botswana.0.785837alternative
4BurkinaAfricaThis film was filmed in Burkina.0.872980alternative
\n", "
" ], "text/plain": [ " Country Continent text prediction \\\n", "0 Algeria Africa This film was filmed in Algeria. 0.806454 \n", "1 Angola Africa This film was filmed in Angola. -0.775854 \n", "2 Benin Africa This film was filmed in Benin. 0.962272 \n", "3 Botswana Africa This film was filmed in Botswana. 0.785837 \n", "4 Burkina Africa This film was filmed in Burkina. 0.872980 \n", "\n", " seed \n", "0 alternative \n", "1 alternative \n", "2 alternative \n", "3 alternative \n", "4 alternative " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#Create a function that generates the counterfactuals within a data frame.\n", "def gen_cf_country(df,document,selection):\n", " df['text'] = df.Country.apply(lambda x: document.text.replace(selection,x))\n", " df['prediction'] = df.text.apply(eval_pred_test)\n", " #added this because I think it will make the end results better if we ensure the seed is in the data we generate counterfactuals from.\n", " df['seed'] = df.Country.apply(lambda x: 'seed' if x == selection else 'alternative')\n", " return df\n", "\n", "df = gen_cf_country(df,doc4,selection)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 16, "id": "ecb9dd41-2fab-49bd-bae5-30300ce39e41", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "single_nearest = alt.selection_single(on='mouseover', nearest=True)\n", "full = alt.Chart(df).encode(\n", " alt.X('Continent:N'), # specify nominal data\n", " alt.Y('prediction:Q'), # specify quantitative data\n", " color=alt.Color('seed:N', legend=alt.Legend(title=\"Seed or Alternative\")),\n", " size='seed:N',\n", " tooltip=('Country','prediction')\n", ").mark_circle(opacity=.5).properties(width=300).add_selection(single_nearest)\n", "\n", "full" ] }, { "cell_type": "code", "execution_count": 17, "id": "56bc30d7-03a5-43ff-9dfe-878197628305", "metadata": {}, "outputs": [], "source": [ "df2 = df.nlargest(5, 'prediction')\n", "df3 = df.nsmallest(5, 'prediction')\n", "frames = [df2,df3]\n", "results = pd.concat(frames)" ] }, { "cell_type": "code", "execution_count": 18, "id": "1610bb48-c9b9-4bee-bcb5-999886acb9e3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bar = alt.Chart(results).encode( \n", " alt.X('prediction:Q'), \n", " alt.Y('Country:N', sort=\"-x\"),\n", " color=alt.Color('seed:N', legend=alt.Legend(title=\"Seed or Alternative\")),\n", " size='seed:N',\n", " tooltip=('Country','prediction')\n", ").mark_circle().properties(width=300).add_selection(single_nearest)\n", "\n", "bar" ] }, { "cell_type": "code", "execution_count": 34, "id": "96cd0798-5ac5-4ede-8373-e8ed71ab07b3", "metadata": {}, "outputs": [], "source": [ "def critical_words(document, options=False):\n", " '''This function is meant to select the critical part of a sentence. Critical, in this context means\n", " the part of the sentence that is either: A) a PROPN from the correct entity group; B) an ADJ associated with a NOUN;\n", " C) a NOUN that represents gender. It also checks this against what the model thinks is important if the user defines \"options\" as \"LIME\" or True.'''\n", " if type(document) is not spacy.tokens.doc.Doc:\n", " document = nlp(document)\n", " chunks = list(document.noun_chunks)\n", " pos_options = []\n", " lime_options = []\n", " \n", " #Identify what the model cares about.\n", " if options:\n", " exp = explainer.explain_instance(document.text, predictor, num_features=15, num_samples=2000)\n", " lime_results = exp.as_list()\n", " for feature in lime_results:\n", " lime_options.append(feature[0])\n", " lime_results = pd.DataFrame(lime_results, columns=[\"Word\",\"Weight\"])\n", " \n", " #Identify what we care about \"parts of speech\". The first section focuses on NOUNs and related ADJ.\n", " for chunk in chunks:\n", " #The use of chunk[-1] is due to testing that it appears to always match the root\n", " root = chunk[-1]\n", " #This currently matches to a list I've created. I don't know the best way to deal with this so I'm leaving it as is for the moment.\n", " if root.ent_type_:\n", " cur_values = []\n", " if (len(chunk) > 1) and (chunk[-2].dep_ == \"compound\"):\n", " #creates the compound element of the noun\n", " compound = [x.text for x in chunk if x.dep_ == \"compound\"]\n", " print(f\"This is the contents of {compound} and it is {all(elem in lime_options for elem in compound)} that all elements are present in {lime_options}.\") #for QA\n", " #checks to see all elements in the compound are important to the model or use the compound if not checking importance.\n", " if (all(elem in lime_options for elem in cur_values) and (options is True)) or ((options is False)):\n", " #creates a span for the entirety of the compound noun and adds it to the list.\n", " span = -1 * (1 + len(compound))\n", " pos_options.append(chunk[span:].text)\n", " cur_values + [token.text for token in chunk if token.pos_ == \"ADJ\"]\n", " else:\n", " print(f\"The elmenents in {compound} could not be added to the final list because they are not all relevant to the model.\")\n", " else: \n", " cur_values = [token.text for token in chunk if (token.ent_type_) or (token.pos_ == \"ADJ\")]\n", " if (all(elem in lime_options for elem in cur_values) and (options is True)) or ((options is False)):\n", " pos_options.extend(cur_values)\n", " print(f\"From {chunk.text}, {cur_values} added to pos_options due to entity recognition.\") #for QA\n", " elif len(chunk) >= 1:\n", " cur_values = [token.text for token in chunk if token.pos_ in [\"NOUN\",\"ADJ\"]]\n", " if (all(elem in lime_options for elem in cur_values) and (options is True)) or ((options is False)):\n", " pos_options.extend(cur_values)\n", " print(f\"From {chunk.text}, {cur_values} added to pos_options due to wildcard.\") #for QA\n", " else:\n", " print(f\"No options added for \\'{chunk.text}\\' \")\n", " # Here I am going to try to pick up pronouns, which are people, and Adjectival Compliments.\n", " for token in document:\n", " if (token.text not in pos_options) and ((token.text in lime_options) or (options == False)):\n", " #print(f\"executed {token.text} with {token.pos_} and {token.dep_}\") #QA\n", " if (token.pos_ == \"ADJ\") and (token.dep_ in [\"acomp\",\"conj\"]):\n", " pos_options.append(token.text) \n", " elif (token.pos_ == \"PRON\") and (token.morph.get(\"PronType\")[0] == \"Prs\"):\n", " pos_options.append(token.text)\n", " \n", " #Return the correct set of options based on user input, defaults to POS for simplicity.\n", " if options:\n", " return pos_options, lime_results\n", " else:\n", " return pos_options" ] }, { "cell_type": "code", "execution_count": 20, "id": "b04e7783-e51b-49b0-8165-afe1d5a1c576", "metadata": {}, "outputs": [], "source": [ "#Testing new code\n", "a = \"People are fat and lazy.\"\n", "b = \"I think she is beautiful.\"\n", "doca = nlp(a)\n", "docb = nlp(b)" ] }, { "cell_type": "code", "execution_count": 21, "id": "0a6bc521-9282-41ad-82c9-29e447d77635", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No options added for 'People' \n" ] }, { "data": { "text/plain": [ "['fat', 'lazy']" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "optsa, limea = critical_words(doca, True)\n", "optsa" ] }, { "cell_type": "code", "execution_count": 22, "id": "042e94d3-65a5-4a20-b69a-96ec3296d7d4", "metadata": {}, "outputs": [], "source": [ "def lime_viz(df):\n", " single_nearest = alt.selection_single(on='mouseover', nearest=True)\n", " viz = alt.Chart(df).encode(\n", " alt.X('Weight:Q', scale=alt.Scale(domain=(-1, 1))),\n", " alt.Y('Word:N', sort='x', axis=None),\n", " color=alt.Color(\"Weight\", scale=alt.Scale(scheme='blueorange', domain=[0], type=\"threshold\", range='diverging'), legend=None),\n", " tooltip = (\"Word\",\"Weight\")\n", " ).mark_bar().properties(title =\"Importance of individual words\")\n", "\n", " text = viz.mark_text(\n", " fill=\"black\",\n", " align='right',\n", " baseline='middle'\n", " ).encode(\n", " text='Word:N'\n", " )\n", " limeplot = alt.LayerChart(layer=[viz,text], width = 300).configure_axis(grid=False).configure_view(strokeWidth=0)\n", " return limeplot" ] }, { "cell_type": "code", "execution_count": 23, "id": "924eeea8-1d5d-4fe7-8308-164521919269", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No options added for 'I' \n", "From a white woman, ['white', 'woman'] added to pos_options due to wildcard.\n", "From the street, ['street'] added to pos_options due to wildcard.\n", "From an asian man, ['asian', 'man'] added to pos_options due to wildcard.\n" ] }, { "data": { "text/plain": [ "['white', 'woman', 'street', 'asian', 'man', 'I']" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test8 = \"I saw a white woman walking down the street with an asian man.\"\n", "opts8, lime8 = critical_words(test8,True)\n", "opts8" ] }, { "cell_type": "code", "execution_count": 24, "id": "734366df-ad99-4d80-87e1-51793e150681", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.LayerChart(...)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lime_viz(lime8)" ] }, { "cell_type": "code", "execution_count": 25, "id": "816e1c4b-7f02-41b1-b430-2f3750ae6c4a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No options added for 'I' \n", "From a white woman, ['white', 'woman'] added to pos_options due to wildcard.\n", "From the street, ['street'] added to pos_options due to wildcard.\n", "From an asian man, ['asian', 'man'] added to pos_options due to wildcard.\n" ] } ], "source": [ "probability, sentiment = eval_pred_test(test8, return_all=True)\n", "options, lime = critical_words(test8,options=True)" ] }, { "cell_type": "code", "execution_count": 38, "id": "a437a4eb-73b3-4b3c-a719-8dde2ad6dd3c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "From I, [] added to pos_options due to wildcard.\n", "From men, ['men'] added to pos_options due to wildcard.\n", "From women, ['women'] added to pos_options due to wildcard.\n", "From the same respect, ['same', 'respect'] added to pos_options due to wildcard.\n" ] } ], "source": [ "bug = \"I find men and women deserve the same respect.\"\n", "options = critical_words(bug)" ] }, { "cell_type": "code", "execution_count": 29, "id": "8676defd-0908-4218-a1d6-218de3fb7119", "metadata": {}, "outputs": [], "source": [ "bug_doc = nlp(bug)" ] }, { "cell_type": "code", "execution_count": 35, "id": "21b9e39b-2fcd-4c6f-8fe6-0d571cd79cca", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "I\n", "PRON\n", "a man\n", "NOUN\n", "woman\n", "NOUN\n", "the same respect\n", "NOUN\n" ] } ], "source": [ "for chunk in bug_doc.noun_chunks:\n", " print(chunk.text)\n", " print(chunk[-1].pos_)" ] }, { "cell_type": "code", "execution_count": null, "id": "38279d2d-e763-4329-a65e-1a67d6f5ebb8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }