Spaces:
Sleeping
Sleeping
from typing import List | |
from PIL import Image | |
from constants import ( | |
GradioPredictModelChoice, | |
GradioTrainDataChoice, | |
GradioTrainModelChoice, | |
) | |
from data_manager import PregeneratedFinetuneData, UserFinetuneData, UserPredictData | |
from model import PredictionModel | |
from session import Session | |
from view import View | |
SESSION: Session | |
# Define callbacks | |
def save_image_cb(data_sketchpad: Image.Image, digit_label: str) -> List[List[int]]: | |
"""User Finetune Data -> save image""" | |
SESSION.finetune_data.save_data(data_sketchpad, digit_label) | |
# TODO: move statistics to base class so that type hint below recognizes it | |
return [list(SESSION.finetune_data.statistics.values())] | |
def init_new_session() -> Session: | |
return Session( | |
name="my_session", | |
finetune_data=UserFinetuneData(), | |
predict_data=UserPredictData(), | |
pregen_data=PregeneratedFinetuneData(), | |
model=PredictionModel(), | |
) | |
def new_session_cb() -> str: | |
global SESSION | |
SESSION = init_new_session() | |
return str(SESSION.id) | |
def load_app_cb(local_storage_id: str): | |
global SESSION | |
SESSION = init_new_session() | |
if local_storage_id is None: | |
return str(SESSION.id) | |
SESSION.load_local_storage(local_storage_id) | |
data_table_list = [list(SESSION.finetune_data.statistics.values())] | |
return [local_storage_id, data_table_list] | |
def show_batch_gallery_cb(data_choice, aug_choice): | |
""" | |
Callback to display random batch gallery | |
""" | |
selected_option = GradioTrainDataChoice(data_choice) | |
return SESSION.model.get_random_batch(selected_option, aug_choice) | |
def train_cb(model_choice, data_choice, epocs, aug_choice) -> None: | |
"""Callback to finetune a model based on params""" | |
model_choice = GradioTrainModelChoice(model_choice) | |
data_choice = GradioTrainDataChoice(data_choice) | |
SESSION.model.train(model_choice, data_choice, epocs, aug_choice) | |
def predict_cb(model_choice, img): | |
"""Model -> predict, User Predict Data -> save""" | |
model_choice = GradioPredictModelChoice(model_choice) | |
pred = SESSION.model.predict(img, model_choice) | |
SESSION.predict_data.save_data(img, max(pred, key=pred.get)) # type: ignore | |
return pred | |
def app_launch(): | |
view = View() | |
view.load_app_handler.register_callback(load_app_cb) | |
view.save_img_btn_handler.register_callback(save_image_cb) | |
view.new_session_btn_handler.register_callback(new_session_cb) | |
view.show_sample_btn_handler.register_callback(show_batch_gallery_cb) | |
view.train_btn_handler.register_callback(train_cb) | |
view.predict_btn_handler.register_callback(predict_cb) | |
view.run_gui() | |