from typing import List from PIL import Image from constants import ( GradioPredictModelChoice, GradioTrainDataChoice, GradioTrainModelChoice, ) from data_manager import PregeneratedFinetuneData, UserFinetuneData, UserPredictData from model import PredictionModel from session import Session from view import View SESSION: Session # Define callbacks def save_image_cb(data_sketchpad: Image.Image, digit_label: str) -> List[List[int]]: """User Finetune Data -> save image""" SESSION.finetune_data.save_data(data_sketchpad, digit_label) # TODO: move statistics to base class so that type hint below recognizes it return [list(SESSION.finetune_data.statistics.values())] def init_new_session() -> Session: return Session( name="my_session", finetune_data=UserFinetuneData(), predict_data=UserPredictData(), pregen_data=PregeneratedFinetuneData(), model=PredictionModel(), ) def new_session_cb() -> str: global SESSION SESSION = init_new_session() return str(SESSION.id) def load_app_cb(local_storage_id: str): global SESSION SESSION = init_new_session() if local_storage_id is None: return str(SESSION.id) SESSION.load_local_storage(local_storage_id) data_table_list = [list(SESSION.finetune_data.statistics.values())] return [local_storage_id, data_table_list] def show_batch_gallery_cb(data_choice, aug_choice): """ Callback to display random batch gallery """ selected_option = GradioTrainDataChoice(data_choice) return SESSION.model.get_random_batch(selected_option, aug_choice) def train_cb(model_choice, data_choice, epocs, aug_choice) -> None: """Callback to finetune a model based on params""" model_choice = GradioTrainModelChoice(model_choice) data_choice = GradioTrainDataChoice(data_choice) SESSION.model.train(model_choice, data_choice, epocs, aug_choice) def predict_cb(model_choice, img): """Model -> predict, User Predict Data -> save""" model_choice = GradioPredictModelChoice(model_choice) pred = SESSION.model.predict(img, model_choice) SESSION.predict_data.save_data(img, max(pred, key=pred.get)) # type: ignore return pred def app_launch(): view = View() view.load_app_handler.register_callback(load_app_cb) view.save_img_btn_handler.register_callback(save_image_cb) view.new_session_btn_handler.register_callback(new_session_cb) view.show_sample_btn_handler.register_callback(show_batch_gallery_cb) view.train_btn_handler.register_callback(train_cb) view.predict_btn_handler.register_callback(predict_cb) view.run_gui()