File size: 3,561 Bytes
8b9ef60
 
 
 
 
 
 
 
 
 
 
 
 
 
95bab90
8b9ef60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba49958
 
 
 
8b9ef60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4492fe2
8b9ef60
 
 
 
 
4492fe2
8b9ef60
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from functools import partial
from typing import Dict

import gradio as gr
import numpy as np
import plotly.graph_objects as go
from huggingface_hub import from_pretrained_keras

ROOT_DATA_URL = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA"
TRAIN_DATA_URL = f"{ROOT_DATA_URL}/FordA_TRAIN.tsv"
TEST_DATA_URL = f"{ROOT_DATA_URL}/FordA_TEST.tsv"
TIMESERIES_LEN = 500
CLASSES = {"Symptom does NOT exist", "Symptom exists"}

model = from_pretrained_keras("keras-io/timeseries-classification-from-scratch")

# Read data
def read_data(file_url: str):
    data = np.loadtxt(file_url, delimiter="\t")
    y = data[:, 0]
    x = data[:, 1:]
    return x, y.astype(int)


x_train, y_train = read_data(file_url=TRAIN_DATA_URL)
x_test, y_test = read_data(file_url=TEST_DATA_URL)

# Helper functions
def get_prediction(row_index: int, data: np.ndarray) -> Dict[str, float]:
    x = data[row_index].reshape((1, TIMESERIES_LEN, 1))
    predictions = model.predict(x).flatten()
    return {k: float(v) for k, v in zip(CLASSES, predictions)}


def create_plot(row_index: int, dataset_name: str) -> go.Figure:
    x = x_train
    row = x[row_index]
    scatter = go.Scatter(
        x=list(range(TIMESERIES_LEN)),
        y=row.flatten(),
        mode="lines+markers",
    )
    fig = go.Figure(data=scatter)
    fig.update_layout(title=f"Timeseries in row {row_index} of {dataset_name} set ")
    return fig


def show_tab_section(data: np.ndarray, dataset_name: str):
    num_indexes = data.shape[0]
    index = gr.Slider(
        maximum=num_indexes - 1,
        label="Select the index of the row you want to classify:",
    )
    button = gr.Button("Predict")
    plot = gr.Plot()
    create_plot_data = partial(create_plot, dataset_name=dataset_name)
    button.click(create_plot_data, inputs=[index], outputs=[plot])
    get_prediction_data = partial(get_prediction, data=data)
    label = gr.Label()
    button.click(get_prediction_data, inputs=[index], outputs=[label])


# Gradio Demo
title = "# Timeseries classification from scratch"
description = """
Select a time series in the Training or Test dataset and ask the model to classify it!
<br />
<br />
The model was trained on the <a href="http://www.j-wichard.de/publications/FordPaper.pdf" target="_blank">FordA dataset</a>. Each row is a diagnostic session run on an automotive subsystem. In each session 500 samples were collected. Given a time series, the model was trained to identify if a specific symptom exists or it does not exist.
<br />
<br />
<p>
    <b>Model:</b> <a href="https://huggingface.co/keras-io/timeseries-classification-from-scratch" target="_blank">https://huggingface.co/keras-io/timeseries-classification-from-scratch</a>
    <br />
    <b>Keras Example:</b> <a href="https://keras.io/examples/timeseries/timeseries_classification_from_scratch/" target="_blank">https://keras.io/examples/timeseries/timeseries_classification_from_scratch/</a>
</p>
<br />
"""
article = """
<div style="text-align: center;">
    Space by <a href="https://github.com/EdAbati" target="_blank">Edoardo Abati</a>
    <br />
    Keras example by <a href="https://github.com/hfawaz/" target="_blank">hfawaz</a>
</div>
"""

demo = gr.Blocks()

with demo:
    gr.Markdown(title)
    gr.Markdown(description)
    with gr.Tabs():
        with gr.TabItem("Training set"):
            show_tab_section(data=x_train, dataset_name="Training")
        with gr.TabItem("Test set"):
            show_tab_section(data=x_test, dataset_name="Test")
    gr.Markdown(article)

demo.launch(enable_queue=True)