diff --git "a/Lime Explorations.ipynb" "b/Lime Explorations.ipynb" new file mode 100644--- /dev/null +++ "b/Lime Explorations.ipynb" @@ -0,0 +1,185821 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "36b2fb45-9f0a-4943-834f-4e3602c9d592", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "#https://github.com/marcotcr/lime for reference\n", + "import lime\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from lime.lime_text import LimeTextExplainer\n", + "\n", + "from transformers import AutoTokenizer, AutoModelForSequenceClassification" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "73778913-2ad6-4e3d-b1ee-9eadbc17e823", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")\n", + "model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "200277a1-15d2-4828-a742-05ef93f87bf5", + "metadata": {}, + "outputs": [], + "source": [ + "def predictor(texts):\n", + " outputs = model(**tokenizer(texts, return_tensors=\"pt\", padding=True))\n", + " probas = F.softmax(outputs.logits, dim=1).detach().numpy()\n", + " return probas" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7ed43d08-259a-4a5d-8925-42a4ccbf7cea", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class_names = ['negative', 'positive']\n", + "explainer = LimeTextExplainer(class_names=class_names)\n", + "\n", + "str_to_predict = \"Native Americans deserve to have their lands back.\"\n", + "exp = explainer.explain_instance(str_to_predict, predictor, num_features=20, num_samples=2000)\n", + "exp.show_in_notebook(text=str_to_predict)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "15cdaeb1-4f1e-4084-80e0-7ea802449274", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "text1 = \"FC Barcelona is not a German football team.\"\n", + "exp = explainer.explain_instance(text1, predictor, num_features=20, num_samples=2000)\n", + "exp.show_in_notebook(text=text1)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0ea62f5e-b83d-4ce9-b92b-e09aeacd4741", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "text2 = \"He had a German car that he drove.\"\n", + "exp = explainer.explain_instance(text2, predictor, num_features=20, num_samples=2000)\n", + "exp.show_in_notebook(text=text2)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "813682cc-1271-4021-a932-dfe6347e13af", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('German', -0.6919233098787057),\n", + " ('he', -0.2385283355148371),\n", + " ('car', 0.17295381923211997),\n", + " ('had', -0.14217629753917047),\n", + " ('a', 0.12866437095896172),\n", + " ('drove', 0.08097486924056058),\n", + " ('He', 0.049226218092743),\n", + " ('that', 0.00988639644487302)]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "exp.as_list()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "02d0c97e-098c-4932-9d67-285af7bbb1ec", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('Iraq', -0.9081354370318596),\n", + " ('was', -0.03403640727982729),\n", + " ('in', -0.030957227884937304),\n", + " ('This', -0.03011625239348594),\n", + " ('movie', -0.029048209127173097),\n", + " ('filmed', -0.015553741829700034)]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "upt1 = \"I like movies starring black actors.\"\n", + "upt2 = \"I am a black trans-woman.\"\n", + "upt3 = \"Native Americans deserve to have their land back.\"\n", + "upt4 = \"This movie was filmed in Iraq.\"\n", + "exp = explainer.explain_instance(upt4, predictor, num_features=20, num_samples=2000)\n", + "exp.as_list()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ede2760d-0e58-4779-83bb-6f94d053a7a0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "iraq = \"North Korea eats halal meat.\"\n", + "exp = explainer.explain_instance(iraq, predictor, num_features=20, num_samples=2000)\n", + "exp.show_in_notebook(text=iraq)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "7dd86b0b-497e-4f0b-bd28-aec1e6548663", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "iraq = \"North Korea serves halal meat.\"\n", + "exp = explainer.explain_instance(iraq, predictor, num_features=20, num_samples=2000)\n", + "exp.show_in_notebook(text=iraq)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e8d3c98-d6a6-4189-ad69-f102964a7da4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}