{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "36b2fb45-9f0a-4943-834f-4e3602c9d592", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "#https://github.com/marcotcr/lime for reference\n", "import lime\n", "import torch\n", "import torch.nn.functional as F\n", "from lime.lime_text import LimeTextExplainer\n", "\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification" ] }, { "cell_type": "code", "execution_count": 2, "id": "73778913-2ad6-4e3d-b1ee-9eadbc17e823", "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")\n", "model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "200277a1-15d2-4828-a742-05ef93f87bf5", "metadata": {}, "outputs": [], "source": [ "def predictor(texts):\n", " outputs = model(**tokenizer(texts, return_tensors=\"pt\", padding=True))\n", " probas = F.softmax(outputs.logits, dim=1).detach().numpy()\n", " return probas" ] }, { "cell_type": "code", "execution_count": 4, "id": "7ed43d08-259a-4a5d-8925-42a4ccbf7cea", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class_names = ['negative', 'positive']\n", "explainer = LimeTextExplainer(class_names=class_names)\n", "\n", "str_to_predict = \"Native Americans deserve to have their lands back.\"\n", "exp = explainer.explain_instance(str_to_predict, predictor, num_features=20, num_samples=2000)\n", "exp.show_in_notebook(text=str_to_predict)" ] }, { "cell_type": "code", "execution_count": 5, "id": "15cdaeb1-4f1e-4084-80e0-7ea802449274", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "text1 = \"FC Barcelona is not a German football team.\"\n", "exp = explainer.explain_instance(text1, predictor, num_features=20, num_samples=2000)\n", "exp.show_in_notebook(text=text1)" ] }, { "cell_type": "code", "execution_count": 6, "id": "0ea62f5e-b83d-4ce9-b92b-e09aeacd4741", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "text2 = \"He had a German car that he drove.\"\n", "exp = explainer.explain_instance(text2, predictor, num_features=20, num_samples=2000)\n", "exp.show_in_notebook(text=text2)" ] }, { "cell_type": "code", "execution_count": 7, "id": "813682cc-1271-4021-a932-dfe6347e13af", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('German', -0.6919233098787057),\n", " ('he', -0.2385283355148371),\n", " ('car', 0.17295381923211997),\n", " ('had', -0.14217629753917047),\n", " ('a', 0.12866437095896172),\n", " ('drove', 0.08097486924056058),\n", " ('He', 0.049226218092743),\n", " ('that', 0.00988639644487302)]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "exp.as_list()" ] }, { "cell_type": "code", "execution_count": 8, "id": "02d0c97e-098c-4932-9d67-285af7bbb1ec", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('Iraq', -0.9081354370318596),\n", " ('was', -0.03403640727982729),\n", " ('in', -0.030957227884937304),\n", " ('This', -0.03011625239348594),\n", " ('movie', -0.029048209127173097),\n", " ('filmed', -0.015553741829700034)]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "upt1 = \"I like movies starring black actors.\"\n", "upt2 = \"I am a black trans-woman.\"\n", "upt3 = \"Native Americans deserve to have their land back.\"\n", "upt4 = \"This movie was filmed in Iraq.\"\n", "exp = explainer.explain_instance(upt4, predictor, num_features=20, num_samples=2000)\n", "exp.as_list()" ] }, { "cell_type": "code", "execution_count": 11, "id": "ede2760d-0e58-4779-83bb-6f94d053a7a0", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "iraq = \"North Korea eats halal meat.\"\n", "exp = explainer.explain_instance(iraq, predictor, num_features=20, num_samples=2000)\n", "exp.show_in_notebook(text=iraq)" ] }, { "cell_type": "code", "execution_count": 12, "id": "7dd86b0b-497e-4f0b-bd28-aec1e6548663", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "iraq = \"North Korea serves halal meat.\"\n", "exp = explainer.explain_instance(iraq, predictor, num_features=20, num_samples=2000)\n", "exp.show_in_notebook(text=iraq)" ] }, { "cell_type": "code", "execution_count": null, "id": "0e8d3c98-d6a6-4189-ad69-f102964a7da4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }