{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "36b2fb45-9f0a-4943-834f-4e3602c9d592", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "#https://github.com/marcotcr/lime for reference\n", "import lime\n", "import torch\n", "import torch.nn.functional as F\n", "from lime.lime_text import LimeTextExplainer\n", "\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification" ] }, { "cell_type": "code", "execution_count": 2, "id": "73778913-2ad6-4e3d-b1ee-9eadbc17e823", "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")\n", "model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "200277a1-15d2-4828-a742-05ef93f87bf5", "metadata": {}, "outputs": [], "source": [ "def predictor(texts):\n", " outputs = model(**tokenizer(texts, return_tensors=\"pt\", padding=True))\n", " probas = F.softmax(outputs.logits, dim=1).detach().numpy()\n", " return probas" ] }, { "cell_type": "code", "execution_count": 4, "id": "7ed43d08-259a-4a5d-8925-42a4ccbf7cea", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", "
\n", " \n", " \n", " \n", " " ], "text/plain": [ "