File size: 7,263 Bytes
04b83c2
cc49484
a13fd2a
85f4ee8
166477b
 
 
33468ea
166477b
 
 
a13fd2a
361a4ae
04b83c2
44db048
f981c43
 
 
 
 
 
 
 
 
 
 
 
04b83c2
ced70f2
04b83c2
f981c43
 
ced70f2
f981c43
 
04b83c2
a4faf98
 
ced70f2
04b83c2
a4faf98
361a4ae
04b83c2
ced70f2
 
 
 
 
 
 
 
 
a4faf98
 
cc49484
04b83c2
361a4ae
cc49484
 
 
 
 
a13fd2a
cc49484
5a007fb
1971a6c
 
 
 
 
 
 
 
 
 
 
 
 
f981c43
62c0b96
 
 
 
 
f981c43
 
 
cc49484
 
 
 
 
a13fd2a
cc49484
 
 
166477b
cc49484
 
 
a4faf98
166477b
cc49484
1971a6c
361a4ae
 
166477b
361a4ae
 
 
 
cc49484
 
 
 
 
 
2a9dcfb
cc49484
 
44db048
164a1a8
cc49484
164a1a8
1971a6c
cc49484
44db048
164a1a8
cc49484
164a1a8
1971a6c
 
 
 
cc49484
44db048
290c9ec
 
 
1971a6c
a13fd2a
cc49484
 
164a1a8
 
 
 
cc49484
33468ea
 
 
 
 
 
 
 
cc49484
166477b
cc49484
e7395df
1971a6c
33468ea
04b83c2
ced70f2
361a4ae
04b83c2
ced70f2
44db048
04b83c2
ced70f2
290c9ec
 
44db048
290c9ec
04b83c2
ced70f2
33468ea
 
 
 
 
ced70f2
62c0b96
5a007fb
ced70f2
 
 
 
 
 
f981c43
62c0b96
1400d53
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from typing import Any, Callable

import gradio as gr

from constants import (
    NUMBERS_LIST,
    SKETCHPAD_SIZE,
    GradioPredictModelChoice,
    GradioTrainDataChoice,
    GradioTrainModelChoice,
)

gr.close_all()

GET_LOCALSTORAGE_JS = """
    function() {
      globalThis.setStorage = (key, value)=>{
        localStorage.setItem(key, JSON.stringify(value))
      }
       globalThis.getStorage = (key, value)=>{
        return JSON.parse(localStorage.getItem(key))
      }
       const sessionId =  getStorage('session_id')
       return [sessionId];
      }
    """


class CallbackHandler:
    def __init__(
        self,
        block: gr.Blocks,
        element: gr.Button | gr.Blocks,
        inputs: list[Any],
        outputs: list[Any],
    ) -> None:
        self.inputs = inputs
        self.outputs = outputs
        self.element = element
        self.block = block

    def register_callback(self, fn: Callable[..., Any]):
        with self.block:
            if isinstance(self.element, gr.Button):
                self.element.click(fn=fn, inputs=self.inputs, outputs=self.outputs)
            if isinstance(self.element, gr.Blocks):
                self.element.load(
                    fn=fn,
                    inputs=self.inputs,
                    outputs=self.outputs,
                    _js=GET_LOCALSTORAGE_JS,
                )


class View:
    def __init__(self) -> None:

        with gr.Blocks() as self.app:
            gr.Markdown(
                """
            # Mouse written digits recognition app
            Follow instructions in tabs below:
            """
            )
            with gr.Tab("Session"):
                with gr.Row():
                    with gr.Column():
                        with gr.Box():
                            gr.Markdown("## Current session:")
                            self.session_id_text = gr.Markdown("")
                    with gr.Column():
                        with gr.Box():
                            gr.Markdown(
                                "New session will reset generated data and model"
                            )
                            new_session_btn = gr.Button(
                                "New Session", variant="primary"
                            )

                # TODO: Currently new session id is stored on text field change,
                # should be stored on button press
                self.session_id_text.change(
                    fn=lambda *args: None,
                    inputs=[self.session_id_text],
                    outputs=None,
                    _js="(v)=>{ setStorage('session_id',v) }",
                )
            with gr.Tab("Generate Data"):
                gr.Markdown(
                    """
                Draw a digit with mouse, select whe drawn digit and hit save
                """
                )
                with gr.Row():
                    with gr.Column():
                        data_sketchpad = gr.Sketchpad(
                            shape=(SKETCHPAD_SIZE, SKETCHPAD_SIZE),
                            brush_radius=2,
                        )
                    with gr.Column():
                        digit_label = gr.Dropdown(
                            choices=NUMBERS_LIST, label="Select label"
                        )
                        save_btn = gr.Button("Save", variant="primary")
                        data_count_table_init = [["0"] * 10]
                        data_count_table = gr.DataFrame(
                            headers=NUMBERS_LIST,
                            value=data_count_table_init,
                            label="Image Count",
                            interactive=False,
                        )
            with gr.Tab("Finetune model"):
                gr.Markdown(
                    """
                Fine-tune a new model with existing/new data.
                Data split for validation is random at 20%
                """
                )
                with gr.Row():
                    with gr.Column():
                        model_choice = gr.Dropdown(
                            choices=[choice.value for choice in GradioTrainModelChoice],
                            label="Select model",
                            value=GradioTrainModelChoice.MNIST_FINETUNED.value,
                            interactive=True,
                        )
                        data_choice = gr.Dropdown(
                            choices=[choice.value for choice in GradioTrainDataChoice],
                            label="Select DataSet",
                            value=GradioTrainDataChoice.PREDEFINED.value,
                            interactive=True,
                        )
                        epocs = gr.Number(
                            value=3, label="Epocs", precision=0, interactive=True
                        )
                        aug_choice = gr.Checkbox(
                            label="Augment Data",
                            info="With probability of 0.75, applies random affine transforms to selected dataset at batch level",
                        )
                        train_btn = gr.Button(value="Train", variant="primary")

                    with gr.Column():
                        show_sample_btn = gr.Button(value="Show random sample")
                        data_sample = gr.Gallery(visible=True, show_label=True)
                        data_sample.style(
                            columns=3, rows=3, object_fit="contain", height="auto"
                        )
            with gr.Tab("Predict"):

                model_choice_pred = gr.Dropdown(
                    choices=[choice.value for choice in GradioPredictModelChoice],
                    label="Select model",
                    value=GradioPredictModelChoice.PRETRAINED.value,
                    interactive=True,
                )

                predict_sketchpad = gr.Sketchpad(
                    shape=(SKETCHPAD_SIZE, SKETCHPAD_SIZE),
                    brush_radius=2,
                )
                predict_btn = gr.Button("Classify", variant="primary")
                label = gr.Label(num_top_classes=5)

        self.save_img_btn_handler = CallbackHandler(
            self.app, save_btn, [data_sketchpad, digit_label], [data_count_table]
        )
        self.train_btn_handler = CallbackHandler(
            self.app, train_btn, [model_choice, data_choice, epocs, aug_choice], []
        )
        self.show_sample_btn_handler = CallbackHandler(
            self.app,
            show_sample_btn,
            inputs=[data_choice, aug_choice],
            outputs=[data_sample],
        )
        self.predict_btn_handler = CallbackHandler(
            self.app,
            predict_btn,
            [model_choice_pred, predict_sketchpad],
            [label],
        )
        self.new_session_btn_handler = CallbackHandler(
            self.app, new_session_btn, [], [self.session_id_text]
        )
        self.load_app_handler = CallbackHandler(
            self.app,
            self.app,
            [self.session_id_text],
            [self.session_id_text, data_count_table],
        )

    def run_gui(self) -> None:
        self.app.launch()