File size: 3,308 Bytes
b8c0a56
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", "        "]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly altair"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import altair\n", "\n", "import gradio as gr\n", "from math import sqrt\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", "    months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", "    m = months.index(month)\n", "    start_day = 30 * m\n", "    final_day = 30 * (m + 1)\n", "    x = np.arange(start_day, final_day + 1)\n", "    pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", "    if social_distancing:\n", "        r = sqrt(r)\n", "    df = pd.DataFrame({\"day\": x})\n", "    for country in countries:\n", "        df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", "    if plot_type == \"Matplotlib\":\n", "        fig = plt.figure()\n", "        plt.plot(df[\"day\"], df[countries].to_numpy())\n", "        plt.title(\"Outbreak in \" + month)\n", "        plt.ylabel(\"Cases\")\n", "        plt.xlabel(\"Days since Day 0\")\n", "        plt.legend(countries)\n", "        return fig\n", "    elif plot_type == \"Plotly\":\n", "        fig = px.line(df, x=\"day\", y=countries)\n", "        fig.update_layout(\n", "            title=\"Outbreak in \" + month,\n", "            xaxis_title=\"Cases\",\n", "            yaxis_title=\"Days Since Day 0\",\n", "        )\n", "        return fig\n", "    elif plot_type == \"Altair\":\n", "        df = df.melt(id_vars=\"day\").rename(columns={\"variable\": \"country\"})\n", "        fig = altair.Chart(df).mark_line().encode(x=\"day\", y='value', color='country')\n", "        return fig\n", "    else:\n", "        raise ValueError(\"A plot type must be selected\")\n", "\n", "\n", "inputs = [\n", "    gr.Dropdown([\"Matplotlib\", \"Plotly\", \"Altair\"], label=\"Plot Type\"),\n", "    gr.Slider(1, 4, 3.2, label=\"R\"),\n", "    gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", "    gr.CheckboxGroup(\n", "        [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", "    ),\n", "    gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", "    fn=outbreak,\n", "    inputs=inputs,\n", "    outputs=outputs,\n", "    examples=[\n", "        [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", "        [\"Altair\", 2, \"March\", [\"Mexico\", \"Canada\"], True],\n", "        [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", "    ],\n", "    cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", "    demo.launch()\n", "\n", "\n", "\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}