File size: 1,326 Bytes
f6bf0dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr

from vega_datasets import data

cars = data.cars()
iris = data.iris()


def scatter_plot_fn(dataset):
    if dataset == "iris":
        return gr.ScatterPlot(
            value=iris,
            x="petalWidth",
            y="petalLength",
            color="species",
            title="Iris Dataset",
            color_legend_title="Species",
            x_title="Petal Width",
            y_title="Petal Length",
            tooltip=["petalWidth", "petalLength", "species"],
            caption="",
        )
    else:
        return gr.ScatterPlot(
            value=cars,
            x="Horsepower",
            y="Miles_per_Gallon",
            color="Origin",
            tooltip="Name",
            title="Car Data",
            y_title="Miles per Gallon",
            color_legend_title="Origin of Car",
            caption="MPG vs Horsepower of various cars",
        )


with gr.Blocks() as scatter_plot:
    with gr.Row():
        with gr.Column():
            dataset = gr.Dropdown(choices=["cars", "iris"], value="cars")
        with gr.Column():
            plot = gr.ScatterPlot(show_label=False)
    dataset.change(scatter_plot_fn, inputs=dataset, outputs=plot)
    scatter_plot.load(fn=scatter_plot_fn, inputs=dataset, outputs=plot)

if __name__ == "__main__":
    scatter_plot.launch()