Spaces:
Sleeping
Sleeping
File size: 7,263 Bytes
04b83c2 cc49484 a13fd2a 85f4ee8 166477b 33468ea 166477b a13fd2a 361a4ae 04b83c2 44db048 f981c43 04b83c2 ced70f2 04b83c2 f981c43 ced70f2 f981c43 04b83c2 a4faf98 ced70f2 04b83c2 a4faf98 361a4ae 04b83c2 ced70f2 a4faf98 cc49484 04b83c2 361a4ae cc49484 a13fd2a cc49484 5a007fb 1971a6c f981c43 62c0b96 f981c43 cc49484 a13fd2a cc49484 166477b cc49484 a4faf98 166477b cc49484 1971a6c 361a4ae 166477b 361a4ae cc49484 2a9dcfb cc49484 44db048 164a1a8 cc49484 164a1a8 1971a6c cc49484 44db048 164a1a8 cc49484 164a1a8 1971a6c cc49484 44db048 290c9ec 1971a6c a13fd2a cc49484 164a1a8 cc49484 33468ea cc49484 166477b cc49484 e7395df 1971a6c 33468ea 04b83c2 ced70f2 361a4ae 04b83c2 ced70f2 44db048 04b83c2 ced70f2 290c9ec 44db048 290c9ec 04b83c2 ced70f2 33468ea ced70f2 62c0b96 5a007fb ced70f2 f981c43 62c0b96 1400d53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
from typing import Any, Callable
import gradio as gr
from constants import (
NUMBERS_LIST,
SKETCHPAD_SIZE,
GradioPredictModelChoice,
GradioTrainDataChoice,
GradioTrainModelChoice,
)
gr.close_all()
GET_LOCALSTORAGE_JS = """
function() {
globalThis.setStorage = (key, value)=>{
localStorage.setItem(key, JSON.stringify(value))
}
globalThis.getStorage = (key, value)=>{
return JSON.parse(localStorage.getItem(key))
}
const sessionId = getStorage('session_id')
return [sessionId];
}
"""
class CallbackHandler:
def __init__(
self,
block: gr.Blocks,
element: gr.Button | gr.Blocks,
inputs: list[Any],
outputs: list[Any],
) -> None:
self.inputs = inputs
self.outputs = outputs
self.element = element
self.block = block
def register_callback(self, fn: Callable[..., Any]):
with self.block:
if isinstance(self.element, gr.Button):
self.element.click(fn=fn, inputs=self.inputs, outputs=self.outputs)
if isinstance(self.element, gr.Blocks):
self.element.load(
fn=fn,
inputs=self.inputs,
outputs=self.outputs,
_js=GET_LOCALSTORAGE_JS,
)
class View:
def __init__(self) -> None:
with gr.Blocks() as self.app:
gr.Markdown(
"""
# Mouse written digits recognition app
Follow instructions in tabs below:
"""
)
with gr.Tab("Session"):
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown("## Current session:")
self.session_id_text = gr.Markdown("")
with gr.Column():
with gr.Box():
gr.Markdown(
"New session will reset generated data and model"
)
new_session_btn = gr.Button(
"New Session", variant="primary"
)
# TODO: Currently new session id is stored on text field change,
# should be stored on button press
self.session_id_text.change(
fn=lambda *args: None,
inputs=[self.session_id_text],
outputs=None,
_js="(v)=>{ setStorage('session_id',v) }",
)
with gr.Tab("Generate Data"):
gr.Markdown(
"""
Draw a digit with mouse, select whe drawn digit and hit save
"""
)
with gr.Row():
with gr.Column():
data_sketchpad = gr.Sketchpad(
shape=(SKETCHPAD_SIZE, SKETCHPAD_SIZE),
brush_radius=2,
)
with gr.Column():
digit_label = gr.Dropdown(
choices=NUMBERS_LIST, label="Select label"
)
save_btn = gr.Button("Save", variant="primary")
data_count_table_init = [["0"] * 10]
data_count_table = gr.DataFrame(
headers=NUMBERS_LIST,
value=data_count_table_init,
label="Image Count",
interactive=False,
)
with gr.Tab("Finetune model"):
gr.Markdown(
"""
Fine-tune a new model with existing/new data.
Data split for validation is random at 20%
"""
)
with gr.Row():
with gr.Column():
model_choice = gr.Dropdown(
choices=[choice.value for choice in GradioTrainModelChoice],
label="Select model",
value=GradioTrainModelChoice.MNIST_FINETUNED.value,
interactive=True,
)
data_choice = gr.Dropdown(
choices=[choice.value for choice in GradioTrainDataChoice],
label="Select DataSet",
value=GradioTrainDataChoice.PREDEFINED.value,
interactive=True,
)
epocs = gr.Number(
value=3, label="Epocs", precision=0, interactive=True
)
aug_choice = gr.Checkbox(
label="Augment Data",
info="With probability of 0.75, applies random affine transforms to selected dataset at batch level",
)
train_btn = gr.Button(value="Train", variant="primary")
with gr.Column():
show_sample_btn = gr.Button(value="Show random sample")
data_sample = gr.Gallery(visible=True, show_label=True)
data_sample.style(
columns=3, rows=3, object_fit="contain", height="auto"
)
with gr.Tab("Predict"):
model_choice_pred = gr.Dropdown(
choices=[choice.value for choice in GradioPredictModelChoice],
label="Select model",
value=GradioPredictModelChoice.PRETRAINED.value,
interactive=True,
)
predict_sketchpad = gr.Sketchpad(
shape=(SKETCHPAD_SIZE, SKETCHPAD_SIZE),
brush_radius=2,
)
predict_btn = gr.Button("Classify", variant="primary")
label = gr.Label(num_top_classes=5)
self.save_img_btn_handler = CallbackHandler(
self.app, save_btn, [data_sketchpad, digit_label], [data_count_table]
)
self.train_btn_handler = CallbackHandler(
self.app, train_btn, [model_choice, data_choice, epocs, aug_choice], []
)
self.show_sample_btn_handler = CallbackHandler(
self.app,
show_sample_btn,
inputs=[data_choice, aug_choice],
outputs=[data_sample],
)
self.predict_btn_handler = CallbackHandler(
self.app,
predict_btn,
[model_choice_pred, predict_sketchpad],
[label],
)
self.new_session_btn_handler = CallbackHandler(
self.app, new_session_btn, [], [self.session_id_text]
)
self.load_app_handler = CallbackHandler(
self.app,
self.app,
[self.session_id_text],
[self.session_id_text, data_count_table],
)
def run_gui(self) -> None:
self.app.launch()
|