shpr
Add user finetune data table population on app load
ced70f2
from typing import Any, Callable
import gradio as gr
from constants import (
NUMBERS_LIST,
SKETCHPAD_SIZE,
GradioPredictModelChoice,
GradioTrainDataChoice,
GradioTrainModelChoice,
)
gr.close_all()
GET_LOCALSTORAGE_JS = """
function() {
globalThis.setStorage = (key, value)=>{
localStorage.setItem(key, JSON.stringify(value))
}
globalThis.getStorage = (key, value)=>{
return JSON.parse(localStorage.getItem(key))
}
const sessionId = getStorage('session_id')
return [sessionId];
}
"""
class CallbackHandler:
def __init__(
self,
block: gr.Blocks,
element: gr.Button | gr.Blocks,
inputs: list[Any],
outputs: list[Any],
) -> None:
self.inputs = inputs
self.outputs = outputs
self.element = element
self.block = block
def register_callback(self, fn: Callable[..., Any]):
with self.block:
if isinstance(self.element, gr.Button):
self.element.click(fn=fn, inputs=self.inputs, outputs=self.outputs)
if isinstance(self.element, gr.Blocks):
self.element.load(
fn=fn,
inputs=self.inputs,
outputs=self.outputs,
_js=GET_LOCALSTORAGE_JS,
)
class View:
def __init__(self) -> None:
with gr.Blocks() as self.app:
gr.Markdown(
"""
# Mouse written digits recognition app
Follow instructions in tabs below:
"""
)
with gr.Tab("Session"):
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown("## Current session:")
self.session_id_text = gr.Markdown("")
with gr.Column():
with gr.Box():
gr.Markdown(
"New session will reset generated data and model"
)
new_session_btn = gr.Button(
"New Session", variant="primary"
)
# TODO: Currently new session id is stored on text field change,
# should be stored on button press
self.session_id_text.change(
fn=lambda *args: None,
inputs=[self.session_id_text],
outputs=None,
_js="(v)=>{ setStorage('session_id',v) }",
)
with gr.Tab("Generate Data"):
gr.Markdown(
"""
Draw a digit with mouse, select whe drawn digit and hit save
"""
)
with gr.Row():
with gr.Column():
data_sketchpad = gr.Sketchpad(
shape=(SKETCHPAD_SIZE, SKETCHPAD_SIZE),
brush_radius=2,
)
with gr.Column():
digit_label = gr.Dropdown(
choices=NUMBERS_LIST, label="Select label"
)
save_btn = gr.Button("Save", variant="primary")
data_count_table_init = [["0"] * 10]
data_count_table = gr.DataFrame(
headers=NUMBERS_LIST,
value=data_count_table_init,
label="Image Count",
interactive=False,
)
with gr.Tab("Finetune model"):
gr.Markdown(
"""
Fine-tune a new model with existing/new data.
Data split for validation is random at 20%
"""
)
with gr.Row():
with gr.Column():
model_choice = gr.Dropdown(
choices=[choice.value for choice in GradioTrainModelChoice],
label="Select model",
value=GradioTrainModelChoice.MNIST_FINETUNED.value,
interactive=True,
)
data_choice = gr.Dropdown(
choices=[choice.value for choice in GradioTrainDataChoice],
label="Select DataSet",
value=GradioTrainDataChoice.PREDEFINED.value,
interactive=True,
)
epocs = gr.Number(
value=3, label="Epocs", precision=0, interactive=True
)
aug_choice = gr.Checkbox(
label="Augment Data",
info="With probability of 0.75, applies random affine transforms to selected dataset at batch level",
)
train_btn = gr.Button(value="Train", variant="primary")
with gr.Column():
show_sample_btn = gr.Button(value="Show random sample")
data_sample = gr.Gallery(visible=True, show_label=True)
data_sample.style(
columns=3, rows=3, object_fit="contain", height="auto"
)
with gr.Tab("Predict"):
model_choice_pred = gr.Dropdown(
choices=[choice.value for choice in GradioPredictModelChoice],
label="Select model",
value=GradioPredictModelChoice.PRETRAINED.value,
interactive=True,
)
predict_sketchpad = gr.Sketchpad(
shape=(SKETCHPAD_SIZE, SKETCHPAD_SIZE),
brush_radius=2,
)
predict_btn = gr.Button("Classify", variant="primary")
label = gr.Label(num_top_classes=5)
self.save_img_btn_handler = CallbackHandler(
self.app, save_btn, [data_sketchpad, digit_label], [data_count_table]
)
self.train_btn_handler = CallbackHandler(
self.app, train_btn, [model_choice, data_choice, epocs, aug_choice], []
)
self.show_sample_btn_handler = CallbackHandler(
self.app,
show_sample_btn,
inputs=[data_choice, aug_choice],
outputs=[data_sample],
)
self.predict_btn_handler = CallbackHandler(
self.app,
predict_btn,
[model_choice_pred, predict_sketchpad],
[label],
)
self.new_session_btn_handler = CallbackHandler(
self.app, new_session_btn, [], [self.session_id_text]
)
self.load_app_handler = CallbackHandler(
self.app,
self.app,
[self.session_id_text],
[self.session_id_text, data_count_table],
)
def run_gui(self) -> None:
self.app.launch()