Spaces:
Sleeping
Sleeping
from typing import Any, Callable | |
import gradio as gr | |
from constants import ( | |
NUMBERS_LIST, | |
SKETCHPAD_SIZE, | |
GradioPredictModelChoice, | |
GradioTrainDataChoice, | |
GradioTrainModelChoice, | |
) | |
gr.close_all() | |
GET_LOCALSTORAGE_JS = """ | |
function() { | |
globalThis.setStorage = (key, value)=>{ | |
localStorage.setItem(key, JSON.stringify(value)) | |
} | |
globalThis.getStorage = (key, value)=>{ | |
return JSON.parse(localStorage.getItem(key)) | |
} | |
const sessionId = getStorage('session_id') | |
return [sessionId]; | |
} | |
""" | |
class CallbackHandler: | |
def __init__( | |
self, | |
block: gr.Blocks, | |
element: gr.Button | gr.Blocks, | |
inputs: list[Any], | |
outputs: list[Any], | |
) -> None: | |
self.inputs = inputs | |
self.outputs = outputs | |
self.element = element | |
self.block = block | |
def register_callback(self, fn: Callable[..., Any]): | |
with self.block: | |
if isinstance(self.element, gr.Button): | |
self.element.click(fn=fn, inputs=self.inputs, outputs=self.outputs) | |
if isinstance(self.element, gr.Blocks): | |
self.element.load( | |
fn=fn, | |
inputs=self.inputs, | |
outputs=self.outputs, | |
_js=GET_LOCALSTORAGE_JS, | |
) | |
class View: | |
def __init__(self) -> None: | |
with gr.Blocks() as self.app: | |
gr.Markdown( | |
""" | |
# Mouse written digits recognition app | |
Follow instructions in tabs below: | |
""" | |
) | |
with gr.Tab("Session"): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Box(): | |
gr.Markdown("## Current session:") | |
self.session_id_text = gr.Markdown("") | |
with gr.Column(): | |
with gr.Box(): | |
gr.Markdown( | |
"New session will reset generated data and model" | |
) | |
new_session_btn = gr.Button( | |
"New Session", variant="primary" | |
) | |
# TODO: Currently new session id is stored on text field change, | |
# should be stored on button press | |
self.session_id_text.change( | |
fn=lambda *args: None, | |
inputs=[self.session_id_text], | |
outputs=None, | |
_js="(v)=>{ setStorage('session_id',v) }", | |
) | |
with gr.Tab("Generate Data"): | |
gr.Markdown( | |
""" | |
Draw a digit with mouse, select whe drawn digit and hit save | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
data_sketchpad = gr.Sketchpad( | |
shape=(SKETCHPAD_SIZE, SKETCHPAD_SIZE), | |
brush_radius=2, | |
) | |
with gr.Column(): | |
digit_label = gr.Dropdown( | |
choices=NUMBERS_LIST, label="Select label" | |
) | |
save_btn = gr.Button("Save", variant="primary") | |
data_count_table_init = [["0"] * 10] | |
data_count_table = gr.DataFrame( | |
headers=NUMBERS_LIST, | |
value=data_count_table_init, | |
label="Image Count", | |
interactive=False, | |
) | |
with gr.Tab("Finetune model"): | |
gr.Markdown( | |
""" | |
Fine-tune a new model with existing/new data. | |
Data split for validation is random at 20% | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
model_choice = gr.Dropdown( | |
choices=[choice.value for choice in GradioTrainModelChoice], | |
label="Select model", | |
value=GradioTrainModelChoice.MNIST_FINETUNED.value, | |
interactive=True, | |
) | |
data_choice = gr.Dropdown( | |
choices=[choice.value for choice in GradioTrainDataChoice], | |
label="Select DataSet", | |
value=GradioTrainDataChoice.PREDEFINED.value, | |
interactive=True, | |
) | |
epocs = gr.Number( | |
value=3, label="Epocs", precision=0, interactive=True | |
) | |
aug_choice = gr.Checkbox( | |
label="Augment Data", | |
info="With probability of 0.75, applies random affine transforms to selected dataset at batch level", | |
) | |
train_btn = gr.Button(value="Train", variant="primary") | |
with gr.Column(): | |
show_sample_btn = gr.Button(value="Show random sample") | |
data_sample = gr.Gallery(visible=True, show_label=True) | |
data_sample.style( | |
columns=3, rows=3, object_fit="contain", height="auto" | |
) | |
with gr.Tab("Predict"): | |
model_choice_pred = gr.Dropdown( | |
choices=[choice.value for choice in GradioPredictModelChoice], | |
label="Select model", | |
value=GradioPredictModelChoice.PRETRAINED.value, | |
interactive=True, | |
) | |
predict_sketchpad = gr.Sketchpad( | |
shape=(SKETCHPAD_SIZE, SKETCHPAD_SIZE), | |
brush_radius=2, | |
) | |
predict_btn = gr.Button("Classify", variant="primary") | |
label = gr.Label(num_top_classes=5) | |
self.save_img_btn_handler = CallbackHandler( | |
self.app, save_btn, [data_sketchpad, digit_label], [data_count_table] | |
) | |
self.train_btn_handler = CallbackHandler( | |
self.app, train_btn, [model_choice, data_choice, epocs, aug_choice], [] | |
) | |
self.show_sample_btn_handler = CallbackHandler( | |
self.app, | |
show_sample_btn, | |
inputs=[data_choice, aug_choice], | |
outputs=[data_sample], | |
) | |
self.predict_btn_handler = CallbackHandler( | |
self.app, | |
predict_btn, | |
[model_choice_pred, predict_sketchpad], | |
[label], | |
) | |
self.new_session_btn_handler = CallbackHandler( | |
self.app, new_session_btn, [], [self.session_id_text] | |
) | |
self.load_app_handler = CallbackHandler( | |
self.app, | |
self.app, | |
[self.session_id_text], | |
[self.session_id_text, data_count_table], | |
) | |
def run_gui(self) -> None: | |
self.app.launch() | |