Spaces:
Runtime error
Runtime error
File size: 2,579 Bytes
c409d3d |
1 |
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: pictionary"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio torch gdown numpy"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/pictionary/class_names.txt"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["from pathlib import Path\n", "\n", "import numpy as np\n", "import torch\n", "import gradio as gr\n", "from torch import nn\n", "import gdown \n", "\n", "url = 'https://drive.google.com/uc?id=1dsk2JNZLRDjC-0J4wIQX_FcVurPaXaAZ'\n", "output = 'pytorch_model.bin'\n", "gdown.download(url, output, quiet=False)\n", "\n", "LABELS = Path('class_names.txt').read_text().splitlines()\n", "\n", "model = nn.Sequential(\n", " nn.Conv2d(1, 32, 3, padding='same'),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(32, 64, 3, padding='same'),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(64, 128, 3, padding='same'),\n", " nn.ReLU(),\n", " nn.MaxPool2d(2),\n", " nn.Flatten(),\n", " nn.Linear(1152, 256),\n", " nn.ReLU(),\n", " nn.Linear(256, len(LABELS)),\n", ")\n", "state_dict = torch.load('pytorch_model.bin', map_location='cpu')\n", "model.load_state_dict(state_dict, strict=False)\n", "model.eval()\n", "\n", "def predict(im):\n", " if im is None:\n", " return None\n", " im = np.asarray(im.resize((28, 28)))\n", " \n", " x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.\n", "\n", " with torch.no_grad():\n", " out = model(x)\n", "\n", " probabilities = torch.nn.functional.softmax(out[0], dim=0)\n", "\n", " values, indices = torch.topk(probabilities, 5)\n", "\n", " return {LABELS[i]: v.item() for i, v in zip(indices, values)}\n", "\n", "\n", "interface = gr.Interface(predict, \n", " inputs=gr.Sketchpad(label=\"Draw Here\", brush_radius=5, type=\"pil\", shape=(120, 120)), \n", " outputs=gr.Label(label=\"Guess\"), \n", " live=True)\n", "\n", "interface.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} |