MNIST-ResNet-Demo / src /controller.py
shpr
Add user finetune data table population on app load
ced70f2
from typing import List
from PIL import Image
from constants import (
GradioPredictModelChoice,
GradioTrainDataChoice,
GradioTrainModelChoice,
)
from data_manager import PregeneratedFinetuneData, UserFinetuneData, UserPredictData
from model import PredictionModel
from session import Session
from view import View
SESSION: Session
# Define callbacks
def save_image_cb(data_sketchpad: Image.Image, digit_label: str) -> List[List[int]]:
"""User Finetune Data -> save image"""
SESSION.finetune_data.save_data(data_sketchpad, digit_label)
# TODO: move statistics to base class so that type hint below recognizes it
return [list(SESSION.finetune_data.statistics.values())]
def init_new_session() -> Session:
return Session(
name="my_session",
finetune_data=UserFinetuneData(),
predict_data=UserPredictData(),
pregen_data=PregeneratedFinetuneData(),
model=PredictionModel(),
)
def new_session_cb() -> str:
global SESSION
SESSION = init_new_session()
return str(SESSION.id)
def load_app_cb(local_storage_id: str):
global SESSION
SESSION = init_new_session()
if local_storage_id is None:
return str(SESSION.id)
SESSION.load_local_storage(local_storage_id)
data_table_list = [list(SESSION.finetune_data.statistics.values())]
return [local_storage_id, data_table_list]
def show_batch_gallery_cb(data_choice, aug_choice):
"""
Callback to display random batch gallery
"""
selected_option = GradioTrainDataChoice(data_choice)
return SESSION.model.get_random_batch(selected_option, aug_choice)
def train_cb(model_choice, data_choice, epocs, aug_choice) -> None:
"""Callback to finetune a model based on params"""
model_choice = GradioTrainModelChoice(model_choice)
data_choice = GradioTrainDataChoice(data_choice)
SESSION.model.train(model_choice, data_choice, epocs, aug_choice)
def predict_cb(model_choice, img):
"""Model -> predict, User Predict Data -> save"""
model_choice = GradioPredictModelChoice(model_choice)
pred = SESSION.model.predict(img, model_choice)
SESSION.predict_data.save_data(img, max(pred, key=pred.get)) # type: ignore
return pred
def app_launch():
view = View()
view.load_app_handler.register_callback(load_app_cb)
view.save_img_btn_handler.register_callback(save_image_cb)
view.new_session_btn_handler.register_callback(new_session_cb)
view.show_sample_btn_handler.register_callback(show_batch_gallery_cb)
view.train_btn_handler.register_callback(train_cb)
view.predict_btn_handler.register_callback(predict_cb)
view.run_gui()