{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "9e3afc69", "metadata": {}, "outputs": [], "source": [ "#|default_exp app" ] }, { "cell_type": "code", "execution_count": 2, "id": "ca02cd22", "metadata": {}, "outputs": [], "source": [ "#|export\n", "import numpy as np\n", "import pandas as pd\n", "import gradio as gr\n", "from datasets import Dataset\n", "from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer" ] }, { "cell_type": "code", "execution_count": 3, "id": "674fa5e5", "metadata": {}, "outputs": [], "source": [ "#|export\n", "import warnings, logging\n", "warnings.simplefilter('ignore')\n", "logging.disable(logging.WARNING)" ] }, { "cell_type": "code", "execution_count": 4, "id": "28150bb5", "metadata": {}, "outputs": [], "source": [ "#|export\n", "model = AutoModelForSequenceClassification.from_pretrained(\"./spam_model/\")\n", "tokz = AutoTokenizer.from_pretrained(\"./spam_model/\")\n", "trainer = Trainer(model, tokenizer=tokz)" ] }, { "cell_type": "code", "execution_count": 5, "id": "4f1da521", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer" ] }, { "cell_type": "code", "execution_count": 6, "id": "cb001f05", "metadata": {}, "outputs": [], "source": [ "#|export\n", "def tok_func(x):\n", " return tokz(x[\"input\"])" ] }, { "cell_type": "code", "execution_count": 7, "id": "c6cc7802", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/1 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.8317995071411133" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "document = 'Send this message to 5 more people ASAP'\n", "input_ds = Dataset.from_pandas(pd.DataFrame([document], columns=['input'])).map(tok_func, batched=True)\n", "trainer.predict(input_ds).predictions.astype(float)[0, 0]" ] }, { "cell_type": "code", "execution_count": 8, "id": "d9e18de1", "metadata": {}, "outputs": [], "source": [ "#|export\n", "def classify_message(text):\n", " input_ds = Dataset.from_pandas(pd.DataFrame([text], columns=['input'])).map(tok_func, batched=True)\n", " spam_prob = np.clip(trainer.predict(input_ds).predictions.astype(float), 0, 1)[0, 0]\n", " return f'{100*spam_prob:.1f}% probability being Spam'" ] }, { "cell_type": "code", "execution_count": 9, "id": "c70fc002", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/plain": [] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#|export\n", "intf = gr.Interface(fn=classify_message, inputs='text', outputs='text')\n", "intf.launch(inline=False)" ] }, { "cell_type": "code", "execution_count": 10, "id": "fdf43e45", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" ] } ], "source": [ "from nbdev.export import nb_export\n", "nb_export('app.ipynb', '.')" ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }