File size: 2,667 Bytes
42568a5
4fea8d0
c454738
 
33468ea
 
 
 
 
4fea8d0
94f05f1
4fea8d0
 
 
62c0b96
f981c43
4fea8d0
 
42568a5
4fea8d0
62c0b96
f981c43
62c0b96
2888a96
 
166477b
 
2888a96
 
 
 
94f05f1
2888a96
94f05f1
 
 
 
166477b
62c0b96
 
 
ced70f2
166477b
 
 
62c0b96
26e5209
ced70f2
 
4fea8d0
 
44db048
4fea8d0
44db048
4fea8d0
44db048
 
4fea8d0
 
44db048
 
 
 
 
4fea8d0
 
33468ea
4fea8d0
33468ea
 
85f4ee8
33468ea
4fea8d0
 
fe78e91
 
ced70f2
fe78e91
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import List

from PIL import Image

from constants import (
    GradioPredictModelChoice,
    GradioTrainDataChoice,
    GradioTrainModelChoice,
)
from data_manager import PregeneratedFinetuneData, UserFinetuneData, UserPredictData
from model import PredictionModel
from session import Session
from view import View

SESSION: Session


# Define callbacks
def save_image_cb(data_sketchpad: Image.Image, digit_label: str) -> List[List[int]]:
    """User Finetune Data -> save image"""
    SESSION.finetune_data.save_data(data_sketchpad, digit_label)
    # TODO: move statistics to base class so that type hint below recognizes it
    return [list(SESSION.finetune_data.statistics.values())]


def init_new_session() -> Session:
    return Session(
        name="my_session",
        finetune_data=UserFinetuneData(),
        predict_data=UserPredictData(),
        pregen_data=PregeneratedFinetuneData(),
        model=PredictionModel(),
    )


def new_session_cb() -> str:
    global SESSION
    SESSION = init_new_session()
    return str(SESSION.id)


def load_app_cb(local_storage_id: str):
    global SESSION
    SESSION = init_new_session()
    if local_storage_id is None:
        return str(SESSION.id)
    SESSION.load_local_storage(local_storage_id)
    data_table_list = [list(SESSION.finetune_data.statistics.values())]
    return [local_storage_id, data_table_list]


def show_batch_gallery_cb(data_choice, aug_choice):
    """
    Callback to display random batch gallery
    """
    selected_option = GradioTrainDataChoice(data_choice)
    return SESSION.model.get_random_batch(selected_option, aug_choice)


def train_cb(model_choice, data_choice, epocs, aug_choice) -> None:
    """Callback to finetune a model based on params"""
    model_choice = GradioTrainModelChoice(model_choice)
    data_choice = GradioTrainDataChoice(data_choice)
    SESSION.model.train(model_choice, data_choice, epocs, aug_choice)


def predict_cb(model_choice, img):
    """Model -> predict, User Predict Data -> save"""
    model_choice = GradioPredictModelChoice(model_choice)
    pred = SESSION.model.predict(img, model_choice)
    SESSION.predict_data.save_data(img, max(pred, key=pred.get))  # type: ignore
    return pred


def app_launch():
    view = View()
    view.load_app_handler.register_callback(load_app_cb)
    view.save_img_btn_handler.register_callback(save_image_cb)
    view.new_session_btn_handler.register_callback(new_session_cb)
    view.show_sample_btn_handler.register_callback(show_batch_gallery_cb)
    view.train_btn_handler.register_callback(train_cb)
    view.predict_btn_handler.register_callback(predict_cb)
    view.run_gui()