{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "7a3ff9c9-8642-43b9-b1f3-302c227d6acd", "metadata": {}, "outputs": [], "source": [ "#Import the libraries we know we'll need for the Generator.\n", "import pandas as pd, spacy, nltk, numpy as np, re, ssl\n", "from spacy import displacy\n", "from spacy.matcher import Matcher\n", "from nltk.corpus import wordnet\n", "#!python -m spacy download en_core_web_md\n", "nlp = spacy.load(\"en_core_web_md\")\n", "lemmatizer = nlp.get_pipe(\"lemmatizer\")\n", "\n", "#Import the libraries to support the model, predictions, and LIME.\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline\n", "import lime\n", "import torch\n", "import torch.nn.functional as F\n", "from lime.lime_text import LimeTextExplainer\n", "\n", "#Import the libraries for generating interactive visualizations.\n", "import altair as alt" ] }, { "cell_type": "code", "execution_count": 2, "id": "730cc6fd-f125-42be-ba42-9a7741f82ef5", "metadata": {}, "outputs": [], "source": [ "#Defining all necessary variables and instances.\n", "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")\n", "model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")\n", "pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)\n", "class_names = ['negative', 'positive']\n", "explainer = LimeTextExplainer(class_names=class_names)" ] }, { "cell_type": "code", "execution_count": 3, "id": "1a40f224-25ce-4b57-88fb-404cccbc878b", "metadata": {}, "outputs": [], "source": [ "#Defining a Predictor required for LIME to function.\n", "def predictor(texts):\n", " outputs = model(**tokenizer(texts, return_tensors=\"pt\", padding=True))\n", " probas = F.softmax(outputs.logits, dim=1).detach().numpy()\n", " return probas" ] }, { "cell_type": "code", "execution_count": 4, "id": "201c4792-868b-4e6a-a8d2-f6ffb2f8971a", "metadata": {}, "outputs": [], "source": [ "# A simple function to pull synonyms and antonyms using spacy's POS\n", "def syn_ant(word,POS=False,human=True):\n", " pos_options = ['NOUN','VERB','ADJ','ADV']\n", " synonyms = [] \n", " antonyms = []\n", " #WordNet hates spaces so you have to remove them\n", " if \" \" in word:\n", " word = word.replace(\" \", \"_\")\n", " \n", " if POS in pos_options:\n", " for syn in wordnet.synsets(word, pos=getattr(wordnet, POS)): \n", " for l in syn.lemmas(): \n", " current = l.name()\n", " if human:\n", " current = re.sub(\"_\",\" \",current)\n", " synonyms.append(current) \n", " if l.antonyms():\n", " for ant in l.antonyms():\n", " cur_ant = ant.name()\n", " if human:\n", " cur_ant = re.sub(\"_\",\" \",cur_ant)\n", " antonyms.append(cur_ant)\n", " else: \n", " for syn in wordnet.synsets(word): \n", " for l in syn.lemmas(): \n", " current = l.name()\n", " if human:\n", " current = re.sub(\"_\",\" \",current)\n", " synonyms.append(current) \n", " if l.antonyms():\n", " for ant in l.antonyms():\n", " cur_ant = ant.name()\n", " if human:\n", " cur_ant = re.sub(\"_\",\" \",cur_ant)\n", " antonyms.append(cur_ant)\n", " synonyms = list(set(synonyms))\n", " antonyms = list(set(antonyms))\n", " return synonyms, antonyms" ] }, { "cell_type": "code", "execution_count": 5, "id": "89aaea6d-a4b7-4995-85b5-625e6db292da", "metadata": {}, "outputs": [], "source": [ "# Builds a list dynamically from WordNet using NLTK.\n", "def wordnet_list(word,POS=False):\n", " word = word.lower()\n", " pos_options = ['NOUN','VERB','ADJ','ADV']\n", " synonyms, antonyms = syn_ant(word,POS,False)\n", " #print(synonyms, antonyms)\n", " base = []\n", " final = [word]\n", " #WordNet hates spaces so you have to remove them\n", " m_word = word.replace(\" \", \"_\")\n", " \n", " if POS in pos_options:\n", " for syn in wordnet.synsets(m_word, pos=getattr(wordnet, POS)):\n", " base.extend(syn.hyponyms())\n", " base.append(syn)\n", " \n", " if len(synonyms) > 0:\n", " for w in synonyms:\n", " w = w.replace(\" \",\"_\")\n", " for syn in wordnet.synsets(w, pos=getattr(wordnet, POS)):\n", " base.extend(syn.hyponyms())\n", " base.append(syn)\n", " if len(antonyms) > 0:\n", " for a in antonyms:\n", " a = a.replace(\" \",\"_\")\n", " for syn in wordnet.synsets(a, pos=getattr(wordnet, POS)):\n", " base.extend(syn.hyponyms())\n", " base.append(syn)\n", " else:\n", " for syn in wordnet.synsets(m_word):\n", " base.extend(syn.hyponyms())\n", " base.append(syn)\n", " \n", " if len(synonyms) > 0:\n", " for w in synonyms:\n", " w = w.replace(\" \",\"_\")\n", " for syn in wordnet.synsets(w):\n", " base.extend(syn.hyponyms())\n", " base.append(syn)\n", " if len(antonyms) > 0:\n", " for a in antonyms:\n", " a = a.replace(\" \",\"_\")\n", " for syn in wordnet.synsets(a):\n", " base.extend(syn.hyponyms())\n", " base.append(syn)\n", " base = list(set(base))\n", " for b in base:\n", " cur_words = []\n", " cur_words.extend([re.sub(\"_\",\" \",lemma.name()) for lemma in b.lemmas()])\n", " final.extend(cur_words)\n", "\n", " \n", " \n", " final = list(set(final)) \n", " return final" ] }, { "cell_type": "code", "execution_count": 6, "id": "01c74ce5-8a31-4a28-bb12-655349fee453", "metadata": {}, "outputs": [], "source": [ "def eval_pred_test(text, return_all = False):\n", " '''A basic function for evaluating the prediction from the model and turning it into a visualization friendly number.'''\n", " preds = pipe(text)\n", " neg_score = -1 * preds[0][0]['score']\n", " sent_neg = preds[0][0]['label']\n", " pos_score = preds[0][1]['score']\n", " sent_pos = preds[0][1]['label']\n", " prediction = 0\n", " sentiment = ''\n", " if pos_score > abs(neg_score):\n", " prediction = pos_score\n", " sentiment = sent_pos\n", " elif abs(neg_score) > pos_score:\n", " prediction = neg_score\n", " sentiment = sent_neg\n", " \n", " if return_all:\n", " return prediction, sentiment\n", " else:\n", " return prediction" ] }, { "cell_type": "code", "execution_count": 7, "id": "7864d1db-399f-4d23-a036-30da34fe3805", "metadata": {}, "outputs": [], "source": [ "def cf_from_wordnet_list(seed,text):\n", " seed_token = nlp(seed)\n", " seed_POS = seed_token[0].pos_\n", " #print(seed_POS)\n", " words = wordnet_list(seed,seed_POS)\n", " \n", " df = pd.DataFrame()\n", " df[\"Words\"] = words\n", " df[\"Sentences\"] = df.Words.apply(lambda x: re.sub(r'\\b'+seed+r'\\b',x,text))\n", " df[\"Similarity\"] = df.Words.apply(lambda x: seed_token[0].similarity(nlp(x)[0]))\n", " df = df[df.Similarity > 0].reset_index()\n", " df.drop(\"index\", axis=1, inplace=True)\n", " df[\"Prediction\"] = df.Sentences.apply(eval_pred_test)\n", " #added this because I think it will make the end results better if we ensure the seed is in the data we generate counterfactuals from.\n", " df['Seed'] = df.Words.apply(lambda x: 'Seed' if x.lower() == seed else 'Alternative')\n", " return df\n", " " ] }, { "cell_type": "code", "execution_count": 8, "id": "527f037b-79a1-4636-88ba-efd06b9cc20d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'This film was filmed in Iraq.'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "seed = \"film\"\n", "text = f\"This {seed} was filmed in Iraq.\"\n", "text" ] }, { "cell_type": "code", "execution_count": 9, "id": "dc6b503c-0e7e-438e-8079-14479de4e246", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " I\n", " PRON\n", "\n", "\n", "\n", " met\n", " VERB\n", "\n", "\n", "\n", " a\n", " DET\n", "\n", "\n", "\n", " naked\n", " ADJ\n", "\n", "\n", "\n", " doctor.\n", " NOUN\n", "\n", "\n", "\n", " \n", " \n", " nsubj\n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " det\n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " amod\n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " dobj\n", " \n", " \n", "\n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "test = \"I met a naked doctor.\"\n", "testdoc = nlp(test)\n", "displacy.render(testdoc, style=\"dep\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "04509f16-9098-4b67-9fea-3a256d1debe9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'adjectival modifier'" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "spacy.explain(\"amod\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "ebe15fd4-4c1e-4a10-a886-6698620b3b72", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/lx/xt9qnk8569n7xy_d7knh3npr0000gp/T/ipykernel_6702/625873274.py:10: UserWarning: [W008] Evaluating Token.similarity based on empty vectors.\n", " df[\"Similarity\"] = df.Words.apply(lambda x: seed_token[0].similarity(nlp(x)[0]))\n" ] } ], "source": [ "cf_df = cf_from_wordnet_list(seed,text)" ] }, { "cell_type": "code", "execution_count": 12, "id": "a8d44862-d39d-45e5-a030-49d91e9e0712", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
WordsSentencesSimilarityPredictionSeed
0dioramaThis diorama was filmed in Iraq.0.1270230.793419Alternative
1longshotThis longshot was filmed in Iraq.0.050408-0.991559Alternative
2musical comedyThis musical comedy was filmed in Iraq.1.0000000.910803Alternative
3characterisationThis characterisation was filmed in Iraq.0.216481-0.987333Alternative
4PolaroidThis Polaroid was filmed in Iraq.0.171456-0.979913Alternative
\n", "
" ], "text/plain": [ " Words Sentences Similarity \\\n", "0 diorama This diorama was filmed in Iraq. 0.127023 \n", "1 longshot This longshot was filmed in Iraq. 0.050408 \n", "2 musical comedy This musical comedy was filmed in Iraq. 1.000000 \n", "3 characterisation This characterisation was filmed in Iraq. 0.216481 \n", "4 Polaroid This Polaroid was filmed in Iraq. 0.171456 \n", "\n", " Prediction Seed \n", "0 0.793419 Alternative \n", "1 -0.991559 Alternative \n", "2 0.910803 Alternative \n", "3 -0.987333 Alternative \n", "4 -0.979913 Alternative " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cf_df.head()" ] }, { "cell_type": "code", "execution_count": 13, "id": "6d133205-08ad-43c8-8140-ece8d5943a7f", "metadata": {}, "outputs": [], "source": [ "def max_min(df):\n", " maximum = df[df.Words != \"girl\"].Similarity.max()\n", " text3 = df.loc[df['Similarity'] == maximum, 'Words'].iloc[0]\n", " minimum = df.Similarity.min()\n", " text2 = df.loc[df['Similarity'] == minimum, 'Words'].iloc[0]\n", " return text2, text3" ] }, { "cell_type": "code", "execution_count": 14, "id": "8311596b-10cf-4915-82f2-5e3e50000436", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "single_nearest = alt.selection_single(on='mouseover', nearest=True)\n", "full = alt.Chart(cf_df).encode(\n", " alt.X('Similarity:Q'), # specify nominal data\n", " alt.Y('Prediction:Q'), # specify quantitative data\n", " color=alt.Color('Seed:N', legend=alt.Legend(title=\"Seed or Alternative\")),\n", " size='Seed:N',\n", " tooltip=('Words','Prediction','Similarity')\n", ").mark_circle(opacity=.5).properties(width=300).add_selection(single_nearest)\n", "\n", "full" ] }, { "cell_type": "code", "execution_count": 15, "id": "1898eb99-22fc-4c6f-b918-1f972c4edb9b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df2 = cf_df.nlargest(5, 'Prediction')\n", "df3 = cf_df.nsmallest(5, 'Prediction')\n", "df4 = cf_df[cf_df.Seed == \"Seed\"]\n", "frames = [df2,df3,df4]\n", "results = pd.concat(frames)\n", "\n", "bar = alt.Chart(results).encode( \n", " alt.X('Prediction:Q'), \n", " alt.Y('Words:N', sort=\"-x\"),\n", " color=alt.Color('Seed:N', legend=alt.Legend(title=\"Seed or Alternative\")),\n", " size='Seed:N',\n", " tooltip=('Words','Prediction','Similarity')\n", ").mark_circle().properties(width=300).add_selection(single_nearest)\n", "\n", "bar" ] }, { "cell_type": "code", "execution_count": 16, "id": "52a4fcfa-80d9-456e-a474-8192ac1c02e8", "metadata": {}, "outputs": [], "source": [ "# Builds a list dynamically from WordNet using NLTK.\n", "def wordnet_df(word,POS=False):\n", " pos_options = ['NOUN','VERB','ADJ','ADV']\n", " synonyms, antonyms = syn_ant(word,POS,False)\n", " words = []\n", " cats = []\n", " #WordNet hates spaces so you have to remove them\n", " m_word = word.replace(\" \", \"_\")\n", " \n", " if POS in pos_options:\n", " for syn in wordnet.synsets(m_word, pos=getattr(wordnet, POS)):\n", " cur_lemmas = syn.lemmas()\n", " hypos = syn.hyponyms()\n", " for hypo in hypos:\n", " cur_lemmas.extend(hypo.lemmas())\n", " for lemma in cur_lemmas:\n", " ll = lemma.name()\n", " cats.append(re.sub(\"_\",\" \", syn.name().split(\".\")[0]))\n", " words.append(re.sub(\"_\",\" \",ll))\n", " \n", " if len(synonyms) > 0:\n", " for w in synonyms:\n", " w = w.replace(\" \",\"_\")\n", " for syn in wordnet.synsets(w, pos=getattr(wordnet, POS)):\n", " cur_lemmas = syn.lemmas()\n", " hypos = syn.hyponyms()\n", " for hypo in hypos:\n", " cur_lemmas.extend(hypo.lemmas())\n", " for lemma in cur_lemmas:\n", " ll = lemma.name()\n", " cats.append(re.sub(\"_\",\" \", syn.name().split(\".\")[0]))\n", " words.append(re.sub(\"_\",\" \",ll))\n", " if len(antonyms) > 0:\n", " for a in antonyms:\n", " a = a.replace(\" \",\"_\")\n", " for syn in wordnet.synsets(a, pos=getattr(wordnet, POS)):\n", " cur_lemmas = syn.lemmas()\n", " hypos = syn.hyponyms()\n", " for hypo in hypos:\n", " cur_lemmas.extend(hypo.lemmas())\n", " for lemma in cur_lemmas:\n", " ll = lemma.name()\n", " cats.append(re.sub(\"_\",\" \", syn.name().split(\".\")[0]))\n", " words.append(re.sub(\"_\",\" \",ll))\n", " else:\n", " for syn in wordnet.synsets(m_word):\n", " cur_lemmas = syn.lemmas()\n", " hypos = syn.hyponyms()\n", " for hypo in hypos:\n", " cur_lemmas.extend(hypo.lemmas())\n", " for lemma in cur_lemmas:\n", " ll = lemma.name()\n", " cats.append(re.sub(\"_\",\" \", syn.name().split(\".\")[0]))\n", " words.append(re.sub(\"_\",\" \",ll))\n", " \n", " if len(synonyms) > 0:\n", " for w in synonyms:\n", " w = w.replace(\" \",\"_\")\n", " for syn in wordnet.synsets(w):\n", " cur_lemmas = syn.lemmas()\n", " hypos = syn.hyponyms()\n", " for hypo in hypos:\n", " cur_lemmas.extend(hypo.lemmas())\n", " for lemma in cur_lemmas:\n", " ll = lemma.name()\n", " cats.append(re.sub(\"_\",\" \", syn.name().split(\".\")[0]))\n", " words.append(re.sub(\"_\",\" \",ll))\n", " if len(antonyms) > 0:\n", " for a in antonyms:\n", " a = a.replace(\" \",\"_\")\n", " for syn in wordnet.synsets(a):\n", " cur_lemmas = syn.lemmas()\n", " hypos = syn.hyponyms()\n", " for hypo in hypos:\n", " cur_lemmas.extend(hypo.lemmas())\n", " for lemma in cur_lemmas:\n", " ll = lemma.name()\n", " cats.append(re.sub(\"_\",\" \", syn.name().split(\".\")[0]))\n", " words.append(re.sub(\"_\",\" \",ll))\n", "\n", " df = {\"Categories\":cats, \"Words\":words}\n", " df = pd.DataFrame(df) \n", " df = df.drop_duplicates().reset_index()\n", " df = df.drop(\"index\", axis=1)\n", " return df" ] }, { "cell_type": "code", "execution_count": 17, "id": "8a61922c-09f8-4a2b-8317-36783016fc2a", "metadata": {}, "outputs": [], "source": [ "def cf_from_wordnet_df(seed,text):\n", " seed_token = nlp(seed)\n", " seed_POS = seed_token[0].pos_\n", " print(seed_POS)\n", " df = wordnet_df(seed,seed_POS)\n", " \n", " df[\"Sentences\"] = df.Words.apply(lambda x: re.sub(r'\\b'+seed+r'\\b',x,text))\n", " df[\"Word Similarity\"] = df.Words.apply(lambda x: seed_token.similarity(nlp(x)))\n", " df = df[df[\"Word Similarity\"] > 0].reset_index()\n", " df.drop(\"index\", axis=1, inplace=True)\n", " df[\"Prediction\"] = df.Sentences.apply(eval_pred_test)\n", " #added this because I think it will make the end results better if we ensure the seed is in the data we generate counterfactuals from.\n", " df['Seed'] = df.Words.apply(lambda x: 'Seed' if x.lower() == seed else 'Alternative')\n", " return df" ] }, { "cell_type": "code", "execution_count": 18, "id": "ebeafe57-44e2-4ae2-bdab-f16050efc4f1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NOUN\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/lx/xt9qnk8569n7xy_d7knh3npr0000gp/T/ipykernel_6702/2903488048.py:8: UserWarning: [W008] Evaluating Doc.similarity based on empty vectors.\n", " df[\"Word Similarity\"] = df.Words.apply(lambda x: seed_token.similarity(nlp(x)))\n" ] } ], "source": [ "panic = cf_from_wordnet_df(seed,text)" ] }, { "cell_type": "code", "execution_count": 24, "id": "c534bbab-7d5d-4695-8ab0-948ff88463de", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CategoriesWordsSentencesWord SimilarityPredictionSeed
0moviemovieThis movie was filmed in Iraq.0.519086-0.985851Alternative
1moviefilmThis film was filmed in Iraq.1.000000-0.976839Seed
2moviepictureThis picture was filmed in Iraq.0.275934-0.966598Alternative
3moviemoving pictureThis moving picture was filmed in Iraq.0.3170250.951934Alternative
4moviemoving-picture showThis moving-picture show was filmed in Iraq.0.438731-0.891211Alternative
\n", "
" ], "text/plain": [ " Categories Words \\\n", "0 movie movie \n", "1 movie film \n", "2 movie picture \n", "3 movie moving picture \n", "4 movie moving-picture show \n", "\n", " Sentences Word Similarity Prediction \\\n", "0 This movie was filmed in Iraq. 0.519086 -0.985851 \n", "1 This film was filmed in Iraq. 1.000000 -0.976839 \n", "2 This picture was filmed in Iraq. 0.275934 -0.966598 \n", "3 This moving picture was filmed in Iraq. 0.317025 0.951934 \n", "4 This moving-picture show was filmed in Iraq. 0.438731 -0.891211 \n", "\n", " Seed \n", "0 Alternative \n", "1 Seed \n", "2 Alternative \n", "3 Alternative \n", "4 Alternative " ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "panic.head()" ] }, { "cell_type": "code", "execution_count": 19, "id": "ea7ad4f3-98d0-4760-9bae-e94210886f08", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "single_nearest = alt.selection_single(on='mouseover', nearest=True)\n", "full = alt.Chart(panic).encode(\n", " alt.X('Word Similarity:Q'), # specify nominal data\n", " alt.Y('Prediction:Q'), # specify quantitative data\n", " color=alt.Color('Seed:N', legend=alt.Legend(title=\"Seed or Alternative\")),\n", " size='Seed:N',\n", " tooltip=('Words','Prediction','Word Similarity')\n", ").mark_circle(opacity=.5).properties(width=300).add_selection(single_nearest)\n", "\n", "full" ] }, { "cell_type": "code", "execution_count": 20, "id": "68dfc91f-3a54-4d96-8be9-887415299735", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "isinstance(cf_df, pd.DataFrame)" ] }, { "cell_type": "code", "execution_count": 21, "id": "9b646bd3-fe80-460d-8237-d181554992a0", "metadata": {}, "outputs": [], "source": [ "#https://github.com/tvst/st-annotated-text/blob/master/example.py" ] }, { "cell_type": "code", "execution_count": 35, "id": "77b86e94-e55a-4c2c-a2ed-bba4e504fbc8", "metadata": {}, "outputs": [], "source": [ "def get_sampled(df, seed, fixed=False):\n", " sub_df = df[df['Words'] != seed]\n", " if fixed:\n", " sample = sub_df.sample(n=2, random_state = 2052)\n", " else:\n", " sample = sub_df.sample(n=2)\n", " text2 = sample.Sentences.iloc[0]\n", " text3 = sample.Sentences.iloc[1]\n", " return text2,text3" ] }, { "cell_type": "code", "execution_count": 26, "id": "99253045-8671-47d0-8bed-17b20476c196", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('This cheesecake was filmed in Iraq.', 'This scum was filmed in Iraq.')" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_sampled(panic,\"film\")" ] }, { "cell_type": "code", "execution_count": 27, "id": "f438ae0e-7a93-4021-a3ad-5ce36ad37984", "metadata": {}, "outputs": [], "source": [ "text2, text3 = get_sampled(panic, \"film\")" ] }, { "cell_type": "code", "execution_count": 28, "id": "06b9e22f-7562-423d-b96d-19b706869c6a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'This montage was filmed in Iraq.'" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text2" ] }, { "cell_type": "code", "execution_count": 107, "id": "bc4b5467-c8f6-4a88-8fa2-e781e1eba23e", "metadata": {}, "outputs": [], "source": [ "#inspired by https://stackoverflow.com/questions/17758023/return-rows-in-a-dataframe-closest-to-a-user-defined-number/17758115#17758115\n", "def abs_dif(df,seed):\n", " target = df[df['Words'] == seed].Prediction.iloc[0]\n", " sub_df = df[df['Words'] != seed].reset_index()\n", " nearest_prediction = sub_df.Prediction[(sub_df.Prediction-target).abs().argsort()[:1]]\n", " farthest_prediction = sub_df.Prediction[(sub_df.Prediction-target).abs().argsort()[-1:]]\n", " nearest = sub_df.Sentences.iloc[nearest_prediction.index[0]]\n", " farthest = sub_df.Sentences.iloc[farthest_prediction.index[0]]\n", " return target, nearest, farthest" ] }, { "cell_type": "code", "execution_count": 108, "id": "ae73d9ac-0a05-4d30-8aec-0f40f05b7b16", "metadata": {}, "outputs": [], "source": [ "target, near, far = abs_dif(panic,\"film\")" ] }, { "cell_type": "code", "execution_count": 100, "id": "48312660-33af-4909-98f1-6134215e63bb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'This abstraction was filmed in Iraq.'" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "near" ] }, { "cell_type": "code", "execution_count": 101, "id": "37f226d4-dad7-42b1-9b1c-4e60fb5f3504", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'This positive was filmed in Iraq.'" ] }, "execution_count": 101, "metadata": {}, "output_type": "execute_result" } ], "source": [ "far" ] }, { "cell_type": "code", "execution_count": 102, "id": "2a631966-025c-4964-bb22-bb245f0aeb57", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-0.9771453142166138" ] }, "execution_count": 102, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eval_pred_test(near)" ] }, { "cell_type": "code", "execution_count": 103, "id": "777afc55-f701-4452-82ad-1b5487b68d29", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9987342953681946" ] }, "execution_count": 103, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eval_pred_test(far)" ] }, { "cell_type": "code", "execution_count": 104, "id": "c9a92eef-9649-439e-b669-15cd360a3019", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-0.9768388867378235" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target" ] }, { "cell_type": "code", "execution_count": null, "id": "e69f1778-22bf-47cc-8ee9-dd90d7e6cebf", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }