File size: 3,742 Bytes
3f3337a
25f0c96
 
 
 
 
 
3f3337a
25f0c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bfe93a
25f0c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
from utils.load_dataset import load_datasets
from utils.load_tasks import load_tasks
from utils.load_models import load_models
from trainer import train_estimtator
from datetime import datetime
import logging

logger = logging.getLogger(__name__)


def main():
    parameter = st.experimental_get_query_params()
    parameter["model_name_or_path"] = parameter.get("model_name_or_path", ["none"])
    parameter["dataset"] = parameter.get("dataset", ["none"])
    parameter["task"] = parameter.get("task", ["none"])
    ### hyperparameter
    parameter["epochs"] = parameter.get("epochs", [3])
    parameter["learning_rate"] = parameter.get("learning_rate", [5e-5])
    parameter["per_device_train_batch_size"] = parameter.get("per_device_train_batch_size", [8])
    parameter["per_device_eval_batch_size"] = parameter.get("per_device_eval_batch_size", [8])
    st.experimental_set_query_params(**parameter)

    dataset_list = load_datasets()
    task_list = load_tasks()
    model_list = load_models()

    st.header("Hugging Face model & dataset")
    col1, col2 = st.beta_columns(2)
    parameter["model_name_or_path"] = col1.selectbox("Model ID:", parameter["model_name_or_path"] + model_list)
    st.experimental_set_query_params(**parameter)

    parameter["dataset"] = col2.selectbox("Dataset:", parameter["dataset"] + dataset_list)
    st.experimental_set_query_params(**parameter)

    parameter["task"] = col1.selectbox("Task:", parameter["task"] + task_list)
    st.experimental_set_query_params(**parameter)

    use_auth_token = col2.text_input("HF auth token to upload your model:", help="api_xxxxx")

    my_expander = st.beta_expander("Hyperparameters")
    col1, col2 = my_expander.beta_columns(2)
    parameter["epochs"] = col1.number_input("Epoch", 3)
    st.experimental_set_query_params(**parameter)

    parameter["learning_rate"] = col2.text_input("Learning Rate", 5e-5)
    st.experimental_set_query_params(**parameter)

    parameter["per_device_train_batch_size"] = col1.number_input("Training Batch Size", 8)
    st.experimental_set_query_params(**parameter)

    parameter["per_device_eval_batch_size"] = col2.number_input("Eval Batch Size", 8)
    st.experimental_set_query_params(**parameter)
    st.markdown("---")

    st.header("Amazon Sagemaker configuration")

    config = {}

    config["job_name"] = st.text_input(
        "model name",
        f"{parameter['model_name_or_path'][0] if isinstance(parameter['model_name_or_path'],list)else parameter['model_name_or_path']}-job-{str(datetime.today()).split()[0]}",
    )
    col1, col2 = st.beta_columns(2)

    config["aws_sagemaker_role"] = col1.text_input("AWS IAM role for sagemaker job")
    config["instance_type"] = col2.selectbox(
        "Instance type",
        [
            "single-gpu | ml.p3.2xlarge",
            "multi-gpu | ml.p3.16xlarge",
        ],
    )
    config["region"] = col1.selectbox(
        "AWS Region",
        ["eu-central-1", "eu-west-1", "us-east-1", "us-east-1", "us-west-1", "us-west-2"],
    )
    config["instance_count"] = col2.number_input("Instance count", 1)
    config["use_spot"] = col1.selectbox("use spot instances", [False, True])
    config["distributed"] = col2.selectbox("distributed training", [False, True])
    st.markdown("---")

    st.header("Credentials")
    # sagemaker config
    col1, col2 = st.beta_columns(2)
    config["aws_access_key_id"] = col1.text_input("Aws Secret Key ID")
    config["aws_secret_accesskey"] = col2.text_input("Aws Secret Access Key")

    if use_auth_token:
        parameter["use_auth_token"] = use_auth_token

    if st.button("Start training on SageMaker"):
        train_estimtator(parameter, config)


if __name__ == "__main__":
    main()