Spaces:
Sleeping
Sleeping
File size: 2,667 Bytes
42568a5 4fea8d0 c454738 33468ea 4fea8d0 94f05f1 4fea8d0 62c0b96 f981c43 4fea8d0 42568a5 4fea8d0 62c0b96 f981c43 62c0b96 2888a96 166477b 2888a96 94f05f1 2888a96 94f05f1 166477b 62c0b96 ced70f2 166477b 62c0b96 26e5209 ced70f2 4fea8d0 44db048 4fea8d0 44db048 4fea8d0 44db048 4fea8d0 44db048 4fea8d0 33468ea 4fea8d0 33468ea 85f4ee8 33468ea 4fea8d0 fe78e91 ced70f2 fe78e91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from typing import List
from PIL import Image
from constants import (
GradioPredictModelChoice,
GradioTrainDataChoice,
GradioTrainModelChoice,
)
from data_manager import PregeneratedFinetuneData, UserFinetuneData, UserPredictData
from model import PredictionModel
from session import Session
from view import View
SESSION: Session
# Define callbacks
def save_image_cb(data_sketchpad: Image.Image, digit_label: str) -> List[List[int]]:
"""User Finetune Data -> save image"""
SESSION.finetune_data.save_data(data_sketchpad, digit_label)
# TODO: move statistics to base class so that type hint below recognizes it
return [list(SESSION.finetune_data.statistics.values())]
def init_new_session() -> Session:
return Session(
name="my_session",
finetune_data=UserFinetuneData(),
predict_data=UserPredictData(),
pregen_data=PregeneratedFinetuneData(),
model=PredictionModel(),
)
def new_session_cb() -> str:
global SESSION
SESSION = init_new_session()
return str(SESSION.id)
def load_app_cb(local_storage_id: str):
global SESSION
SESSION = init_new_session()
if local_storage_id is None:
return str(SESSION.id)
SESSION.load_local_storage(local_storage_id)
data_table_list = [list(SESSION.finetune_data.statistics.values())]
return [local_storage_id, data_table_list]
def show_batch_gallery_cb(data_choice, aug_choice):
"""
Callback to display random batch gallery
"""
selected_option = GradioTrainDataChoice(data_choice)
return SESSION.model.get_random_batch(selected_option, aug_choice)
def train_cb(model_choice, data_choice, epocs, aug_choice) -> None:
"""Callback to finetune a model based on params"""
model_choice = GradioTrainModelChoice(model_choice)
data_choice = GradioTrainDataChoice(data_choice)
SESSION.model.train(model_choice, data_choice, epocs, aug_choice)
def predict_cb(model_choice, img):
"""Model -> predict, User Predict Data -> save"""
model_choice = GradioPredictModelChoice(model_choice)
pred = SESSION.model.predict(img, model_choice)
SESSION.predict_data.save_data(img, max(pred, key=pred.get)) # type: ignore
return pred
def app_launch():
view = View()
view.load_app_handler.register_callback(load_app_cb)
view.save_img_btn_handler.register_callback(save_image_cb)
view.new_session_btn_handler.register_callback(new_session_cb)
view.show_sample_btn_handler.register_callback(show_batch_gallery_cb)
view.train_btn_handler.register_callback(train_cb)
view.predict_btn_handler.register_callback(predict_cb)
view.run_gui()
|