File size: 2,764 Bytes
b95b4ab
84b07f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6b1259
84b07f8
 
6bda1b3
84b07f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f238b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# type: ignore
import gradio as gr
import xyzservices.providers as xyz
from bokeh.models import ColumnDataSource, Whisker
from bokeh.plotting import figure
from bokeh.sampledata.autompg2 import autompg2 as df
from bokeh.sampledata.penguins import data
from bokeh.transform import factor_cmap, jitter, factor_mark

def get_plot(plot_type):
    if plot_type == "map":
        plot = figure(
            x_range=(-2000000, 6000000),
            y_range=(-1000000, 7000000),
            x_axis_type="mercator",
            y_axis_type="mercator",
        )
        plot.add_tile(xyz.OpenStreetMap.Mapnik)  # type: ignore
        return plot
    elif plot_type == "whisker":
        classes = sorted(df["class"].unique())

        p = figure(
            height=400,
            x_range=classes,
            background_fill_color="#efefef",
            title="Car class vs HWY mpg with quintile ranges",
        )
        p.xgrid.grid_line_color = None

        g = df.groupby("class")
        upper = g.hwy.quantile(0.80)
        lower = g.hwy.quantile(0.20)
        source = ColumnDataSource(data=dict(base=classes, upper=upper, lower=lower))

        error = Whisker(
            base="base",
            upper="upper",
            lower="lower",
            source=source,
            level="annotation",
            line_width=2,
        )
        error.upper_head.size = 20
        error.lower_head.size = 20
        p.add_layout(error)

        p.circle(
            jitter("class", 0.3, range=p.x_range),
            "hwy",
            source=df,
            alpha=0.5,
            size=13,
            line_color="white",
            color=factor_cmap("class", "Light6", classes),
        )
        return p
    elif plot_type == "scatter":

        SPECIES = sorted(data.species.unique())
        MARKERS = ["hex", "circle_x", "triangle"]

        p = figure(title="Penguin size", background_fill_color="#fafafa")
        p.xaxis.axis_label = "Flipper Length (mm)"
        p.yaxis.axis_label = "Body Mass (g)"

        p.scatter(
            "flipper_length_mm",
            "body_mass_g",
            source=data,
            legend_group="species",
            fill_alpha=0.4,
            size=12,
            marker=factor_mark("species", MARKERS, SPECIES),
            color=factor_cmap("species", "Category10_3", SPECIES),
        )

        p.legend.location = "top_left"
        p.legend.title = "Species"
        return p

with gr.Blocks() as demo:
    with gr.Row():
        plot_type = gr.Radio(value="scatter", choices=["scatter", "whisker", "map"])
        plot = gr.Plot()
    plot_type.change(get_plot, inputs=[plot_type], outputs=[plot])
    demo.load(get_plot, inputs=[plot_type], outputs=[plot])

if __name__ == "__main__":
    demo.launch()