File size: 5,480 Bytes
90d4aa5
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: altair_plot"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio altair vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import altair as alt\n", "import gradio as gr\n", "import numpy as np\n", "import pandas as pd\n", "from vega_datasets import data\n", "\n", "\n", "def make_plot(plot_type):\n", "    if plot_type == \"scatter_plot\":\n", "        cars = data.cars()\n", "        return alt.Chart(cars).mark_point().encode(\n", "            x='Horsepower',\n", "            y='Miles_per_Gallon',\n", "            color='Origin',\n", "        )\n", "    elif plot_type == \"heatmap\":\n", "        # Compute x^2 + y^2 across a 2D grid\n", "        x, y = np.meshgrid(range(-5, 5), range(-5, 5))\n", "        z = x ** 2 + y ** 2\n", "\n", "        # Convert this grid to columnar data expected by Altair\n", "        source = pd.DataFrame({'x': x.ravel(),\n", "                            'y': y.ravel(),\n", "                            'z': z.ravel()})\n", "        return alt.Chart(source).mark_rect().encode(\n", "            x='x:O',\n", "            y='y:O',\n", "            color='z:Q'\n", "        )\n", "    elif plot_type == \"us_map\":\n", "        states = alt.topo_feature(data.us_10m.url, 'states')\n", "        source = data.income.url\n", "\n", "        return alt.Chart(source).mark_geoshape().encode(\n", "            shape='geo:G',\n", "            color='pct:Q',\n", "            tooltip=['name:N', 'pct:Q'],\n", "            facet=alt.Facet('group:N', columns=2),\n", "        ).transform_lookup(\n", "            lookup='id',\n", "            from_=alt.LookupData(data=states, key='id'),\n", "            as_='geo'\n", "        ).properties(\n", "            width=300,\n", "            height=175,\n", "        ).project(\n", "            type='albersUsa'\n", "        )\n", "    elif plot_type == \"interactive_barplot\":\n", "        source = data.movies.url\n", "\n", "        pts = alt.selection(type=\"single\", encodings=['x'])\n", "\n", "        rect = alt.Chart(data.movies.url).mark_rect().encode(\n", "            alt.X('IMDB_Rating:Q', bin=True),\n", "            alt.Y('Rotten_Tomatoes_Rating:Q', bin=True),\n", "            alt.Color('count()',\n", "                scale=alt.Scale(scheme='greenblue'),\n", "                legend=alt.Legend(title='Total Records')\n", "            )\n", "        )\n", "\n", "        circ = rect.mark_point().encode(\n", "            alt.ColorValue('grey'),\n", "            alt.Size('count()',\n", "                legend=alt.Legend(title='Records in Selection')\n", "            )\n", "        ).transform_filter(\n", "            pts\n", "        )\n", "\n", "        bar = alt.Chart(source).mark_bar().encode(\n", "            x='Major_Genre:N',\n", "            y='count()',\n", "            color=alt.condition(pts, alt.ColorValue(\"steelblue\"), alt.ColorValue(\"grey\"))\n", "        ).properties(\n", "            width=550,\n", "            height=200\n", "        ).add_selection(pts)\n", "\n", "        plot = alt.vconcat(\n", "            rect + circ,\n", "            bar\n", "        ).resolve_legend(\n", "            color=\"independent\",\n", "            size=\"independent\"\n", "        )\n", "        return plot\n", "    elif plot_type == \"radial\":\n", "        source = pd.DataFrame({\"values\": [12, 23, 47, 6, 52, 19]})\n", "\n", "        base = alt.Chart(source).encode(\n", "            theta=alt.Theta(\"values:Q\", stack=True),\n", "            radius=alt.Radius(\"values\", scale=alt.Scale(type=\"sqrt\", zero=True, rangeMin=20)),\n", "            color=\"values:N\",\n", "        )\n", "\n", "        c1 = base.mark_arc(innerRadius=20, stroke=\"#fff\")\n", "\n", "        c2 = base.mark_text(radiusOffset=10).encode(text=\"values:Q\")\n", "\n", "        return c1 + c2\n", "    elif plot_type == \"multiline\":\n", "        source = data.stocks()\n", "\n", "        highlight = alt.selection(type='single', on='mouseover',\n", "                                fields=['symbol'], nearest=True)\n", "\n", "        base = alt.Chart(source).encode(\n", "            x='date:T',\n", "            y='price:Q',\n", "            color='symbol:N'\n", "        )\n", "\n", "        points = base.mark_circle().encode(\n", "            opacity=alt.value(0)\n", "        ).add_selection(\n", "            highlight\n", "        ).properties(\n", "            width=600\n", "        )\n", "\n", "        lines = base.mark_line().encode(\n", "            size=alt.condition(~highlight, alt.value(1), alt.value(3))\n", "        )\n", "\n", "        return points + lines\n", "\n", "\n", "with gr.Blocks() as demo:\n", "    button = gr.Radio(label=\"Plot type\",\n", "                      choices=['scatter_plot', 'heatmap', 'us_map',\n", "                               'interactive_barplot', \"radial\", \"multiline\"], value='scatter_plot')\n", "    plot = gr.Plot(label=\"Plot\")\n", "    button.change(make_plot, inputs=button, outputs=[plot])\n", "    demo.load(make_plot, inputs=[button], outputs=[plot])\n", "\n", "\n", "if __name__ == \"__main__\":\n", "    demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}