import streamlit as st from utils.load_dataset import load_datasets from utils.load_tasks import load_tasks from utils.load_models import load_models from trainer import train_estimtator from datetime import datetime import logging logger = logging.getLogger(__name__) def main(): parameter = st.experimental_get_query_params() parameter["model_name_or_path"] = parameter.get("model_name_or_path", ["none"]) parameter["dataset"] = parameter.get("dataset", ["none"]) parameter["task"] = parameter.get("task", ["none"]) ### hyperparameter parameter["epochs"] = parameter.get("epochs", [3]) parameter["learning_rate"] = parameter.get("learning_rate", [5e-5]) parameter["per_device_train_batch_size"] = parameter.get("per_device_train_batch_size", [8]) parameter["per_device_eval_batch_size"] = parameter.get("per_device_eval_batch_size", [8]) st.experimental_set_query_params(**parameter) dataset_list = load_datasets() task_list = load_tasks() model_list = load_models() st.header("Hugging Face model & dataset") col1, col2 = st.beta_columns(2) parameter["model_name_or_path"] = col1.selectbox("Model ID:", parameter["model_name_or_path"] + model_list) st.experimental_set_query_params(**parameter) parameter["dataset"] = col2.selectbox("Dataset:", parameter["dataset"] + dataset_list) st.experimental_set_query_params(**parameter) parameter["task"] = col1.selectbox("Task:", parameter["task"] + task_list) st.experimental_set_query_params(**parameter) use_auth_token = col2.text_input("HF auth token to upload your model:", help="api_xxxxx") my_expander = st.beta_expander("Hyperparameters") col1, col2 = my_expander.beta_columns(2) parameter["epochs"] = col1.number_input("Epoch", 3) st.experimental_set_query_params(**parameter) parameter["learning_rate"] = col2.text_input("Learning Rate", 5e-5) st.experimental_set_query_params(**parameter) parameter["per_device_train_batch_size"] = col1.number_input("Training Batch Size", 8) st.experimental_set_query_params(**parameter) parameter["per_device_eval_batch_size"] = col2.number_input("Eval Batch Size", 8) st.experimental_set_query_params(**parameter) st.markdown("---") st.header("Amazon Sagemaker configuration") config = {} config["job_name"] = st.text_input( "model name", f"{parameter['model_name_or_path'][0] if isinstance(parameter['model_name_or_path'],list)else parameter['model_name_or_path']}-job-{str(datetime.today()).split()[0]}", ) col1, col2 = st.beta_columns(2) config["aws_sagemaker_role"] = col1.text_input("AWS IAM role for sagemaker job") config["instance_type"] = col2.selectbox( "Instance type", [ "single-gpu | ml.p3.2xlarge", "multi-gpu | ml.p3.16xlarge", ], ) config["region"] = col1.selectbox( "AWS Region", ["eu-central-1", "eu-west-1", "us-east-1", "us-east-1", "us-west-1", "us-west-2"], ) config["instance_count"] = col2.number_input("Instance count", 1) config["use_spot"] = col1.selectbox("use spot instances", [False, True]) st.markdown("---") st.header("Credentials") # sagemaker config col1, col2 = st.beta_columns(2) config["aws_access_key_id"] = col1.text_input("Aws Secret Key ID") config["aws_secret_accesskey"] = col2.text_input("Aws Secret Access Key") if use_auth_token: parameter["use_auth_token"] = use_auth_token if st.button("Start training on SageMaker"): train_estimtator(parameter, config) if __name__ == "__main__": main()