diff --git "a/cf-gen-pipeline.ipynb" "b/cf-gen-pipeline.ipynb" deleted file mode 100644--- "a/cf-gen-pipeline.ipynb" +++ /dev/null @@ -1,1279 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "7a3ff9c9-8642-43b9-b1f3-302c227d6acd", - "metadata": {}, - "outputs": [], - "source": [ - "#Import the libraries we know we'll need for the Generator.\n", - "import pandas as pd, spacy, nltk, numpy as np, re, ssl\n", - "from spacy import displacy\n", - "from spacy.matcher import Matcher\n", - "from nltk.corpus import wordnet\n", - "#!python -m spacy download en_core_web_md\n", - "nlp = spacy.load(\"en_core_web_md\")\n", - "lemmatizer = nlp.get_pipe(\"lemmatizer\")\n", - "\n", - "#Import the libraries to support the model, predictions, and LIME.\n", - "from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline\n", - "import lime\n", - "import torch\n", - "import torch.nn.functional as F\n", - "from lime.lime_text import LimeTextExplainer\n", - "\n", - "#Import the libraries for generating interactive visualizations.\n", - "import altair as alt" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "730cc6fd-f125-42be-ba42-9a7741f82ef5", - "metadata": {}, - "outputs": [], - "source": [ - "#Defining all necessary variables and instances.\n", - "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")\n", - "model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")\n", - "pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)\n", - "class_names = ['negative', 'positive']\n", - "explainer = LimeTextExplainer(class_names=class_names)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "1a40f224-25ce-4b57-88fb-404cccbc878b", - "metadata": {}, - "outputs": [], - "source": [ - "#Defining a Predictor required for LIME to function.\n", - "def predictor(texts):\n", - " outputs = model(**tokenizer(texts, return_tensors=\"pt\", padding=True))\n", - " probas = F.softmax(outputs.logits, dim=1).detach().numpy()\n", - " return probas" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "201c4792-868b-4e6a-a8d2-f6ffb2f8971a", - "metadata": {}, - "outputs": [], - "source": [ - "# A simple function to pull synonyms and antonyms using spacy's POS\n", - "def syn_ant(word,POS=False,human=True):\n", - " pos_options = ['NOUN','VERB','ADJ','ADV']\n", - " synonyms = [] \n", - " antonyms = []\n", - " #WordNet hates spaces so you have to remove them\n", - " if \" \" in word:\n", - " word = word.replace(\" \", \"_\")\n", - " \n", - " if POS in pos_options:\n", - " for syn in wordnet.synsets(word, pos=getattr(wordnet, POS)): \n", - " for l in syn.lemmas(): \n", - " current = l.name()\n", - " if human:\n", - " current = re.sub(\"_\",\" \",current)\n", - " synonyms.append(current) \n", - " if l.antonyms():\n", - " for ant in l.antonyms():\n", - " cur_ant = ant.name()\n", - " if human:\n", - " cur_ant = re.sub(\"_\",\" \",cur_ant)\n", - " antonyms.append(cur_ant)\n", - " else: \n", - " for syn in wordnet.synsets(word): \n", - " for l in syn.lemmas(): \n", - " current = l.name()\n", - " if human:\n", - " current = re.sub(\"_\",\" \",current)\n", - " synonyms.append(current) \n", - " if l.antonyms():\n", - " for ant in l.antonyms():\n", - " cur_ant = ant.name()\n", - " if human:\n", - " cur_ant = re.sub(\"_\",\" \",cur_ant)\n", - " antonyms.append(cur_ant)\n", - " synonyms = list(set(synonyms))\n", - " antonyms = list(set(antonyms))\n", - " return synonyms, antonyms" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "89aaea6d-a4b7-4995-85b5-625e6db292da", - "metadata": {}, - "outputs": [], - "source": [ - "# Builds a list dynamically from WordNet using NLTK.\n", - "def wordnet_list(word,POS=False):\n", - " word = word.lower()\n", - " pos_options = ['NOUN','VERB','ADJ','ADV']\n", - " synonyms, antonyms = syn_ant(word,POS,False)\n", - " #print(synonyms, antonyms)\n", - " base = []\n", - " final = [word]\n", - " #WordNet hates spaces so you have to remove them\n", - " m_word = word.replace(\" \", \"_\")\n", - " \n", - " if POS in pos_options:\n", - " for syn in wordnet.synsets(m_word, pos=getattr(wordnet, POS)):\n", - " base.extend(syn.hyponyms())\n", - " base.append(syn)\n", - " \n", - " if len(synonyms) > 0:\n", - " for w in synonyms:\n", - " w = w.replace(\" \",\"_\")\n", - " for syn in wordnet.synsets(w, pos=getattr(wordnet, POS)):\n", - " base.extend(syn.hyponyms())\n", - " base.append(syn)\n", - " if len(antonyms) > 0:\n", - " for a in antonyms:\n", - " a = a.replace(\" \",\"_\")\n", - " for syn in wordnet.synsets(a, pos=getattr(wordnet, POS)):\n", - " base.extend(syn.hyponyms())\n", - " base.append(syn)\n", - " else:\n", - " for syn in wordnet.synsets(m_word):\n", - " base.extend(syn.hyponyms())\n", - " base.append(syn)\n", - " \n", - " if len(synonyms) > 0:\n", - " for w in synonyms:\n", - " w = w.replace(\" \",\"_\")\n", - " for syn in wordnet.synsets(w):\n", - " base.extend(syn.hyponyms())\n", - " base.append(syn)\n", - " if len(antonyms) > 0:\n", - " for a in antonyms:\n", - " a = a.replace(\" \",\"_\")\n", - " for syn in wordnet.synsets(a):\n", - " base.extend(syn.hyponyms())\n", - " base.append(syn)\n", - " base = list(set(base))\n", - " for b in base:\n", - " cur_words = []\n", - " cur_words.extend([re.sub(\"_\",\" \",lemma.name()) for lemma in b.lemmas()])\n", - " final.extend(cur_words)\n", - "\n", - " \n", - " \n", - " final = list(set(final)) \n", - " return final" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "01c74ce5-8a31-4a28-bb12-655349fee453", - "metadata": {}, - "outputs": [], - "source": [ - "def eval_pred_test(text, return_all = False):\n", - " '''A basic function for evaluating the prediction from the model and turning it into a visualization friendly number.'''\n", - " preds = pipe(text)\n", - " neg_score = -1 * preds[0][0]['score']\n", - " sent_neg = preds[0][0]['label']\n", - " pos_score = preds[0][1]['score']\n", - " sent_pos = preds[0][1]['label']\n", - " prediction = 0\n", - " sentiment = ''\n", - " if pos_score > abs(neg_score):\n", - " prediction = pos_score\n", - " sentiment = sent_pos\n", - " elif abs(neg_score) > pos_score:\n", - " prediction = neg_score\n", - " sentiment = sent_neg\n", - " \n", - " if return_all:\n", - " return prediction, sentiment\n", - " else:\n", - " return prediction" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "7864d1db-399f-4d23-a036-30da34fe3805", - "metadata": {}, - "outputs": [], - "source": [ - "def cf_from_wordnet_list(seed,text):\n", - " seed_token = nlp(seed)\n", - " seed_POS = seed_token[0].pos_\n", - " #print(seed_POS)\n", - " words = wordnet_list(seed,seed_POS)\n", - " \n", - " df = pd.DataFrame()\n", - " df[\"Words\"] = words\n", - " df[\"Sentences\"] = df.Words.apply(lambda x: re.sub(r'\\b'+seed+r'\\b',x,text))\n", - " df[\"Similarity\"] = df.Words.apply(lambda x: seed_token[0].similarity(nlp(x)[0]))\n", - " df = df[df.Similarity > 0].reset_index()\n", - " df.drop(\"index\", axis=1, inplace=True)\n", - " df[\"Prediction\"] = df.Sentences.apply(eval_pred_test)\n", - " #added this because I think it will make the end results better if we ensure the seed is in the data we generate counterfactuals from.\n", - " df['Seed'] = df.Words.apply(lambda x: 'Seed' if x.lower() == seed else 'Alternative')\n", - " return df\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "527f037b-79a1-4636-88ba-efd06b9cc20d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'This film was filmed in Iraq.'" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "seed = \"film\"\n", - "text = f\"This {seed} was filmed in Iraq.\"\n", - "text" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "dc6b503c-0e7e-438e-8079-14479de4e246", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " I\n", - " PRON\n", - "\n", - "\n", - "\n", - " met\n", - " VERB\n", - "\n", - "\n", - "\n", - " a\n", - " DET\n", - "\n", - "\n", - "\n", - " naked\n", - " ADJ\n", - "\n", - "\n", - "\n", - " doctor.\n", - " NOUN\n", - "\n", - "\n", - "\n", - " \n", - " \n", - " nsubj\n", - " \n", - " \n", - "\n", - "\n", - "\n", - " \n", - " \n", - " det\n", - " \n", - " \n", - "\n", - "\n", - "\n", - " \n", - " \n", - " amod\n", - " \n", - " \n", - "\n", - "\n", - "\n", - " \n", - " \n", - " dobj\n", - " \n", - " \n", - "\n", - "" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "test = \"I met a naked doctor.\"\n", - "testdoc = nlp(test)\n", - "displacy.render(testdoc, style=\"dep\")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "04509f16-9098-4b67-9fea-3a256d1debe9", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'adjectival modifier'" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "spacy.explain(\"amod\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "ebe15fd4-4c1e-4a10-a886-6698620b3b72", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/lx/xt9qnk8569n7xy_d7knh3npr0000gp/T/ipykernel_6702/625873274.py:10: UserWarning: [W008] Evaluating Token.similarity based on empty vectors.\n", - " df[\"Similarity\"] = df.Words.apply(lambda x: seed_token[0].similarity(nlp(x)[0]))\n" - ] - } - ], - "source": [ - "cf_df = cf_from_wordnet_list(seed,text)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "a8d44862-d39d-45e5-a030-49d91e9e0712", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
WordsSentencesSimilarityPredictionSeed
0dioramaThis diorama was filmed in Iraq.0.1270230.793419Alternative
1longshotThis longshot was filmed in Iraq.0.050408-0.991559Alternative
2musical comedyThis musical comedy was filmed in Iraq.1.0000000.910803Alternative
3characterisationThis characterisation was filmed in Iraq.0.216481-0.987333Alternative
4PolaroidThis Polaroid was filmed in Iraq.0.171456-0.979913Alternative
\n", - "
" - ], - "text/plain": [ - " Words Sentences Similarity \\\n", - "0 diorama This diorama was filmed in Iraq. 0.127023 \n", - "1 longshot This longshot was filmed in Iraq. 0.050408 \n", - "2 musical comedy This musical comedy was filmed in Iraq. 1.000000 \n", - "3 characterisation This characterisation was filmed in Iraq. 0.216481 \n", - "4 Polaroid This Polaroid was filmed in Iraq. 0.171456 \n", - "\n", - " Prediction Seed \n", - "0 0.793419 Alternative \n", - "1 -0.991559 Alternative \n", - "2 0.910803 Alternative \n", - "3 -0.987333 Alternative \n", - "4 -0.979913 Alternative " - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cf_df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "6d133205-08ad-43c8-8140-ece8d5943a7f", - "metadata": {}, - "outputs": [], - "source": [ - "def max_min(df):\n", - " maximum = df[df.Words != \"girl\"].Similarity.max()\n", - " text3 = df.loc[df['Similarity'] == maximum, 'Words'].iloc[0]\n", - " minimum = df.Similarity.min()\n", - " text2 = df.loc[df['Similarity'] == minimum, 'Words'].iloc[0]\n", - " return text2, text3" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "8311596b-10cf-4915-82f2-5e3e50000436", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "
\n", - "" - ], - "text/plain": [ - "alt.Chart(...)" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "single_nearest = alt.selection_single(on='mouseover', nearest=True)\n", - "full = alt.Chart(cf_df).encode(\n", - " alt.X('Similarity:Q'), # specify nominal data\n", - " alt.Y('Prediction:Q'), # specify quantitative data\n", - " color=alt.Color('Seed:N', legend=alt.Legend(title=\"Seed or Alternative\")),\n", - " size='Seed:N',\n", - " tooltip=('Words','Prediction','Similarity')\n", - ").mark_circle(opacity=.5).properties(width=300).add_selection(single_nearest)\n", - "\n", - "full" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "1898eb99-22fc-4c6f-b918-1f972c4edb9b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "
\n", - "" - ], - "text/plain": [ - "alt.Chart(...)" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df2 = cf_df.nlargest(5, 'Prediction')\n", - "df3 = cf_df.nsmallest(5, 'Prediction')\n", - "df4 = cf_df[cf_df.Seed == \"Seed\"]\n", - "frames = [df2,df3,df4]\n", - "results = pd.concat(frames)\n", - "\n", - "bar = alt.Chart(results).encode( \n", - " alt.X('Prediction:Q'), \n", - " alt.Y('Words:N', sort=\"-x\"),\n", - " color=alt.Color('Seed:N', legend=alt.Legend(title=\"Seed or Alternative\")),\n", - " size='Seed:N',\n", - " tooltip=('Words','Prediction','Similarity')\n", - ").mark_circle().properties(width=300).add_selection(single_nearest)\n", - "\n", - "bar" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "52a4fcfa-80d9-456e-a474-8192ac1c02e8", - "metadata": {}, - "outputs": [], - "source": [ - "# Builds a list dynamically from WordNet using NLTK.\n", - "def wordnet_df(word,POS=False):\n", - " pos_options = ['NOUN','VERB','ADJ','ADV']\n", - " synonyms, antonyms = syn_ant(word,POS,False)\n", - " words = []\n", - " cats = []\n", - " #WordNet hates spaces so you have to remove them\n", - " m_word = word.replace(\" \", \"_\")\n", - " \n", - " if POS in pos_options:\n", - " for syn in wordnet.synsets(m_word, pos=getattr(wordnet, POS)):\n", - " cur_lemmas = syn.lemmas()\n", - " hypos = syn.hyponyms()\n", - " for hypo in hypos:\n", - " cur_lemmas.extend(hypo.lemmas())\n", - " for lemma in cur_lemmas:\n", - " ll = lemma.name()\n", - " cats.append(re.sub(\"_\",\" \", syn.name().split(\".\")[0]))\n", - " words.append(re.sub(\"_\",\" \",ll))\n", - " \n", - " if len(synonyms) > 0:\n", - " for w in synonyms:\n", - " w = w.replace(\" \",\"_\")\n", - " for syn in wordnet.synsets(w, pos=getattr(wordnet, POS)):\n", - " cur_lemmas = syn.lemmas()\n", - " hypos = syn.hyponyms()\n", - " for hypo in hypos:\n", - " cur_lemmas.extend(hypo.lemmas())\n", - " for lemma in cur_lemmas:\n", - " ll = lemma.name()\n", - " cats.append(re.sub(\"_\",\" \", syn.name().split(\".\")[0]))\n", - " words.append(re.sub(\"_\",\" \",ll))\n", - " if len(antonyms) > 0:\n", - " for a in antonyms:\n", - " a = a.replace(\" \",\"_\")\n", - " for syn in wordnet.synsets(a, pos=getattr(wordnet, POS)):\n", - " cur_lemmas = syn.lemmas()\n", - " hypos = syn.hyponyms()\n", - " for hypo in hypos:\n", - " cur_lemmas.extend(hypo.lemmas())\n", - " for lemma in cur_lemmas:\n", - " ll = lemma.name()\n", - " cats.append(re.sub(\"_\",\" \", syn.name().split(\".\")[0]))\n", - " words.append(re.sub(\"_\",\" \",ll))\n", - " else:\n", - " for syn in wordnet.synsets(m_word):\n", - " cur_lemmas = syn.lemmas()\n", - " hypos = syn.hyponyms()\n", - " for hypo in hypos:\n", - " cur_lemmas.extend(hypo.lemmas())\n", - " for lemma in cur_lemmas:\n", - " ll = lemma.name()\n", - " cats.append(re.sub(\"_\",\" \", syn.name().split(\".\")[0]))\n", - " words.append(re.sub(\"_\",\" \",ll))\n", - " \n", - " if len(synonyms) > 0:\n", - " for w in synonyms:\n", - " w = w.replace(\" \",\"_\")\n", - " for syn in wordnet.synsets(w):\n", - " cur_lemmas = syn.lemmas()\n", - " hypos = syn.hyponyms()\n", - " for hypo in hypos:\n", - " cur_lemmas.extend(hypo.lemmas())\n", - " for lemma in cur_lemmas:\n", - " ll = lemma.name()\n", - " cats.append(re.sub(\"_\",\" \", syn.name().split(\".\")[0]))\n", - " words.append(re.sub(\"_\",\" \",ll))\n", - " if len(antonyms) > 0:\n", - " for a in antonyms:\n", - " a = a.replace(\" \",\"_\")\n", - " for syn in wordnet.synsets(a):\n", - " cur_lemmas = syn.lemmas()\n", - " hypos = syn.hyponyms()\n", - " for hypo in hypos:\n", - " cur_lemmas.extend(hypo.lemmas())\n", - " for lemma in cur_lemmas:\n", - " ll = lemma.name()\n", - " cats.append(re.sub(\"_\",\" \", syn.name().split(\".\")[0]))\n", - " words.append(re.sub(\"_\",\" \",ll))\n", - "\n", - " df = {\"Categories\":cats, \"Words\":words}\n", - " df = pd.DataFrame(df) \n", - " df = df.drop_duplicates().reset_index()\n", - " df = df.drop(\"index\", axis=1)\n", - " return df" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "8a61922c-09f8-4a2b-8317-36783016fc2a", - "metadata": {}, - "outputs": [], - "source": [ - "def cf_from_wordnet_df(seed,text):\n", - " seed_token = nlp(seed)\n", - " seed_POS = seed_token[0].pos_\n", - " print(seed_POS)\n", - " df = wordnet_df(seed,seed_POS)\n", - " \n", - " df[\"Sentences\"] = df.Words.apply(lambda x: re.sub(r'\\b'+seed+r'\\b',x,text))\n", - " df[\"Word Similarity\"] = df.Words.apply(lambda x: seed_token.similarity(nlp(x)))\n", - " df = df[df[\"Word Similarity\"] > 0].reset_index()\n", - " df.drop(\"index\", axis=1, inplace=True)\n", - " df[\"Prediction\"] = df.Sentences.apply(eval_pred_test)\n", - " #added this because I think it will make the end results better if we ensure the seed is in the data we generate counterfactuals from.\n", - " df['Seed'] = df.Words.apply(lambda x: 'Seed' if x.lower() == seed else 'Alternative')\n", - " return df" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "ebeafe57-44e2-4ae2-bdab-f16050efc4f1", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "NOUN\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/lx/xt9qnk8569n7xy_d7knh3npr0000gp/T/ipykernel_6702/2903488048.py:8: UserWarning: [W008] Evaluating Doc.similarity based on empty vectors.\n", - " df[\"Word Similarity\"] = df.Words.apply(lambda x: seed_token.similarity(nlp(x)))\n" - ] - } - ], - "source": [ - "panic = cf_from_wordnet_df(seed,text)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "c534bbab-7d5d-4695-8ab0-948ff88463de", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
CategoriesWordsSentencesWord SimilarityPredictionSeed
0moviemovieThis movie was filmed in Iraq.0.519086-0.985851Alternative
1moviefilmThis film was filmed in Iraq.1.000000-0.976839Seed
2moviepictureThis picture was filmed in Iraq.0.275934-0.966598Alternative
3moviemoving pictureThis moving picture was filmed in Iraq.0.3170250.951934Alternative
4moviemoving-picture showThis moving-picture show was filmed in Iraq.0.438731-0.891211Alternative
\n", - "
" - ], - "text/plain": [ - " Categories Words \\\n", - "0 movie movie \n", - "1 movie film \n", - "2 movie picture \n", - "3 movie moving picture \n", - "4 movie moving-picture show \n", - "\n", - " Sentences Word Similarity Prediction \\\n", - "0 This movie was filmed in Iraq. 0.519086 -0.985851 \n", - "1 This film was filmed in Iraq. 1.000000 -0.976839 \n", - "2 This picture was filmed in Iraq. 0.275934 -0.966598 \n", - "3 This moving picture was filmed in Iraq. 0.317025 0.951934 \n", - "4 This moving-picture show was filmed in Iraq. 0.438731 -0.891211 \n", - "\n", - " Seed \n", - "0 Alternative \n", - "1 Seed \n", - "2 Alternative \n", - "3 Alternative \n", - "4 Alternative " - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "panic.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "ea7ad4f3-98d0-4760-9bae-e94210886f08", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "
\n", - "" - ], - "text/plain": [ - "alt.Chart(...)" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "single_nearest = alt.selection_single(on='mouseover', nearest=True)\n", - "full = alt.Chart(panic).encode(\n", - " alt.X('Word Similarity:Q'), # specify nominal data\n", - " alt.Y('Prediction:Q'), # specify quantitative data\n", - " color=alt.Color('Seed:N', legend=alt.Legend(title=\"Seed or Alternative\")),\n", - " size='Seed:N',\n", - " tooltip=('Words','Prediction','Word Similarity')\n", - ").mark_circle(opacity=.5).properties(width=300).add_selection(single_nearest)\n", - "\n", - "full" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "68dfc91f-3a54-4d96-8be9-887415299735", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "isinstance(cf_df, pd.DataFrame)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "9b646bd3-fe80-460d-8237-d181554992a0", - "metadata": {}, - "outputs": [], - "source": [ - "#https://github.com/tvst/st-annotated-text/blob/master/example.py" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "77b86e94-e55a-4c2c-a2ed-bba4e504fbc8", - "metadata": {}, - "outputs": [], - "source": [ - "def get_sampled(df, seed, fixed=False):\n", - " sub_df = df[df['Words'] != seed]\n", - " if fixed:\n", - " sample = sub_df.sample(n=2, random_state = 2052)\n", - " else:\n", - " sample = sub_df.sample(n=2)\n", - " text2 = sample.Sentences.iloc[0]\n", - " text3 = sample.Sentences.iloc[1]\n", - " return text2,text3" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "99253045-8671-47d0-8bed-17b20476c196", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "('This cheesecake was filmed in Iraq.', 'This scum was filmed in Iraq.')" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "get_sampled(panic,\"film\")" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "f438ae0e-7a93-4021-a3ad-5ce36ad37984", - "metadata": {}, - "outputs": [], - "source": [ - "text2, text3 = get_sampled(panic, \"film\")" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "06b9e22f-7562-423d-b96d-19b706869c6a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'This montage was filmed in Iraq.'" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "text2" - ] - }, - { - "cell_type": "code", - "execution_count": 107, - "id": "bc4b5467-c8f6-4a88-8fa2-e781e1eba23e", - "metadata": {}, - "outputs": [], - "source": [ - "#inspired by https://stackoverflow.com/questions/17758023/return-rows-in-a-dataframe-closest-to-a-user-defined-number/17758115#17758115\n", - "def abs_dif(df,seed):\n", - " target = df[df['Words'] == seed].Prediction.iloc[0]\n", - " sub_df = df[df['Words'] != seed].reset_index()\n", - " nearest_prediction = sub_df.Prediction[(sub_df.Prediction-target).abs().argsort()[:1]]\n", - " farthest_prediction = sub_df.Prediction[(sub_df.Prediction-target).abs().argsort()[-1:]]\n", - " nearest = sub_df.Sentences.iloc[nearest_prediction.index[0]]\n", - " farthest = sub_df.Sentences.iloc[farthest_prediction.index[0]]\n", - " return target, nearest, farthest" - ] - }, - { - "cell_type": "code", - "execution_count": 108, - "id": "ae73d9ac-0a05-4d30-8aec-0f40f05b7b16", - "metadata": {}, - "outputs": [], - "source": [ - "target, near, far = abs_dif(panic,\"film\")" - ] - }, - { - "cell_type": "code", - "execution_count": 100, - "id": "48312660-33af-4909-98f1-6134215e63bb", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'This abstraction was filmed in Iraq.'" - ] - }, - "execution_count": 100, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "near" - ] - }, - { - "cell_type": "code", - "execution_count": 101, - "id": "37f226d4-dad7-42b1-9b1c-4e60fb5f3504", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'This positive was filmed in Iraq.'" - ] - }, - "execution_count": 101, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "far" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "id": "2a631966-025c-4964-bb22-bb245f0aeb57", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "-0.9771453142166138" - ] - }, - "execution_count": 102, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval_pred_test(near)" - ] - }, - { - "cell_type": "code", - "execution_count": 103, - "id": "777afc55-f701-4452-82ad-1b5487b68d29", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.9987342953681946" - ] - }, - "execution_count": 103, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval_pred_test(far)" - ] - }, - { - "cell_type": "code", - "execution_count": 104, - "id": "c9a92eef-9649-439e-b669-15cd360a3019", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "-0.9768388867378235" - ] - }, - "execution_count": 104, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "target" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e69f1778-22bf-47cc-8ee9-dd90d7e6cebf", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}