philschmid HF staff commited on
Commit
25f0c96
1 Parent(s): 3f3337a

online trainer

Browse files
.gitignore ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+ .vscode
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
133
+
134
+ # pytype static type analyzer
135
+ .pytype/
136
+
137
+ # Cython debug symbols
138
+ cython_debug/
app.py CHANGED
@@ -1,7 +1,96 @@
1
  import streamlit as st
2
- # To make things easier later, we're also importing numpy and pandas for
3
- # working with sample data.
4
- import numpy as np
5
- import pandas as pd
 
 
6
 
7
- st.title('My first app')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from utils.load_dataset import load_datasets
3
+ from utils.load_tasks import load_tasks
4
+ from utils.load_models import load_models
5
+ from trainer import train_estimtator
6
+ from datetime import datetime
7
+ import logging
8
 
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def main():
13
+ parameter = st.experimental_get_query_params()
14
+ parameter["model_name_or_path"] = parameter.get("model_name_or_path", ["none"])
15
+ parameter["dataset"] = parameter.get("dataset", ["none"])
16
+ parameter["task"] = parameter.get("task", ["none"])
17
+ ### hyperparameter
18
+ parameter["epochs"] = parameter.get("epochs", [3])
19
+ parameter["learning_rate"] = parameter.get("learning_rate", [5e-5])
20
+ parameter["per_device_train_batch_size"] = parameter.get("per_device_train_batch_size", [8])
21
+ parameter["per_device_eval_batch_size"] = parameter.get("per_device_eval_batch_size", [8])
22
+ st.experimental_set_query_params(**parameter)
23
+
24
+ dataset_list = load_datasets()
25
+ task_list = load_tasks()
26
+ model_list = load_models()
27
+
28
+ st.header("Hugging Face model & dataset")
29
+ col1, col2 = st.beta_columns(2)
30
+ parameter["model_name_or_path"] = col1.selectbox("Model ID:", parameter["model_name_or_path"] + model_list)
31
+ st.experimental_set_query_params(**parameter)
32
+
33
+ parameter["dataset"] = col2.selectbox("Dataset:", parameter["dataset"] + dataset_list)
34
+ st.experimental_set_query_params(**parameter)
35
+
36
+ parameter["task"] = col1.selectbox("Task:", parameter["task"] + task_list)
37
+ st.experimental_set_query_params(**parameter)
38
+
39
+ use_auth_token = col2.text_input("HF auth token to upload your model:", help="api_xxxxx")
40
+
41
+ my_expander = st.beta_expander("Hyperparameters")
42
+ col1, col2 = my_expander.beta_columns(2)
43
+ parameter["epochs"] = col1.number_input("Epoch", 3)
44
+ st.experimental_set_query_params(**parameter)
45
+
46
+ parameter["learning_rate"] = col2.text_input("Learning Rate", 5e-5)
47
+ st.experimental_set_query_params(**parameter)
48
+
49
+ parameter["per_device_train_batch_size"] = col1.number_input("Training Batch Size", 8)
50
+ st.experimental_set_query_params(**parameter)
51
+
52
+ parameter["per_device_eval_batch_size"] = col2.number_input("Eval Batch Size", 8)
53
+ st.experimental_set_query_params(**parameter)
54
+ st.markdown("---")
55
+
56
+ st.header("Amazon Sagemaker configuration")
57
+
58
+ config = {}
59
+
60
+ config["job_name"] = st.text_input(
61
+ "model name",
62
+ f"{parameter['model_name_or_path'][0] if isinstance(parameter['model_name_or_path'],list)else parameter['model_name_or_path']}-job-{str(datetime.today()).split()[0]}",
63
+ )
64
+ col1, col2 = st.beta_columns(2)
65
+
66
+ config["aws_sagemaker_role"] = col1.text_input("AWS IAM role for sagemaker job")
67
+ config["instance_type"] = col2.selectbox(
68
+ "Instance type",
69
+ [
70
+ "single-gpu | ml.p3.2xlarge",
71
+ "multi-gpu | ml.p3.16xlarge",
72
+ ],
73
+ )
74
+ config["region"] = col1.selectbox(
75
+ "AWS Region",
76
+ ["eu-central-1", "eu-west-1", "us-east-1", "us-east-1", "us-west-1", "us-west-2"],
77
+ )
78
+ config["instance_count"] = col2.number_input("Instance count", 1)
79
+ config["use_spot"] = col1.selectbox("use spot instances", [False, True])
80
+ st.markdown("---")
81
+
82
+ st.header("Credentials")
83
+ # sagemaker config
84
+ col1, col2 = st.beta_columns(2)
85
+ config["aws_access_key_id"] = col1.text_input("Aws Secret Key ID")
86
+ config["aws_secret_accesskey"] = col2.text_input("Aws Secret Access Key")
87
+
88
+ if use_auth_token:
89
+ parameter["use_auth_token"] = use_auth_token
90
+
91
+ if st.button("Start training on SageMaker"):
92
+ train_estimtator(parameter, config)
93
+
94
+
95
+ if __name__ == "__main__":
96
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ sagemaker
2
+ transformers
3
+ datasets
trainer.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sagemaker.huggingface import HuggingFace
2
+ import logging
3
+ import sys
4
+ from contextlib import contextmanager
5
+ from io import StringIO
6
+ from streamlit.report_thread import REPORT_CONTEXT_ATTR_NAME
7
+ from threading import current_thread
8
+ import streamlit as st
9
+ import sys
10
+ import sagemaker
11
+ import boto3
12
+
13
+
14
+ @contextmanager
15
+ def st_redirect(src, dst):
16
+ placeholder = st.empty()
17
+ output_func = getattr(placeholder, dst)
18
+
19
+ with StringIO() as buffer:
20
+ old_write = src.write
21
+
22
+ def new_write(b):
23
+ if getattr(current_thread(), REPORT_CONTEXT_ATTR_NAME, None):
24
+ buffer.write(b)
25
+ output_func(buffer.getvalue())
26
+ else:
27
+ old_write(b)
28
+
29
+ try:
30
+ src.write = new_write
31
+ yield
32
+ finally:
33
+ src.write = old_write
34
+
35
+
36
+ @contextmanager
37
+ def st_stdout(dst):
38
+ with st_redirect(sys.stdout, dst):
39
+ yield
40
+
41
+
42
+ @contextmanager
43
+ def st_stderr(dst):
44
+ with st_redirect(sys.stderr, dst):
45
+ yield
46
+
47
+
48
+ task2script = {
49
+ "text-classification": {
50
+ "entry_point": "run_glue.py",
51
+ "source_dir": "examples/text-classification",
52
+ },
53
+ "token-classification": {
54
+ "entry_point": "run_ner.py",
55
+ "source_dir": "examples/token-classification",
56
+ },
57
+ "question-answering": {
58
+ "entry_point": "run_qa.py",
59
+ "source_dir": "examples/question-answering",
60
+ },
61
+ "summarization": {
62
+ "entry_point": "run_summarization.py",
63
+ "source_dir": "examples/seq2seq",
64
+ },
65
+ "translation": {
66
+ "entry_point": "run_translation.py",
67
+ "source_dir": "examples/seq2seq",
68
+ },
69
+ "causal-language-modeling": {
70
+ "entry_point": "run_clm.py",
71
+ "source_dir": "examples/language-modeling",
72
+ },
73
+ "masked-language-modeling": {
74
+ "entry_point": "run_mlm.py",
75
+ "source_dir": "examples/language-modeling",
76
+ },
77
+ }
78
+
79
+
80
+ def train_estimtator(parameter, config):
81
+ with st_stdout("code"):
82
+ logger = logging.getLogger(__name__)
83
+
84
+ logging.basicConfig(
85
+ level=logging.getLevelName("INFO"),
86
+ handlers=[logging.StreamHandler(sys.stdout)],
87
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
88
+ )
89
+ logger.info = print
90
+
91
+ # git configuration to download our fine-tuning script
92
+ git_config = {"repo": "https://github.com/huggingface/transformers.git", "branch": "v4.4.2"}
93
+
94
+ # creating fine-tuning script
95
+ entry_point = task2script[parameter["task"]]["entry_point"]
96
+ source_dir = task2script[parameter["task"]]["source_dir"]
97
+ # create train file
98
+ # iam configuration
99
+ session = boto3.session.Session(
100
+ aws_access_key_id=config["aws_access_key_id"],
101
+ aws_secret_access_key=config["aws_secret_accesskey"],
102
+ region_name=config["region"],
103
+ )
104
+ sess = sagemaker.Session(boto_session=session)
105
+
106
+ iam = session.client(
107
+ "iam", aws_access_key_id=config["aws_access_key_id"], aws_secret_access_key=config["aws_secret_accesskey"]
108
+ )
109
+ role = iam.get_role(RoleName=config["aws_sagemaker_role"])["Role"]["Arn"]
110
+
111
+ logger.info(f"role: {role}")
112
+ instance_type = config["instance_type"].split("|")[1].split("|")[0].strip()
113
+ logger.info(f"instance_type: {instance_type}")
114
+
115
+ hyperparameters = {
116
+ "output_dir": "/opt/ml/model",
117
+ "do_train": True,
118
+ "do_eval": True,
119
+ "do_predict": True,
120
+ **parameter,
121
+ }
122
+ del hyperparameters["task"]
123
+ # create estimator
124
+ huggingface_estimator = HuggingFace(
125
+ entry_point=entry_point,
126
+ source_dir=source_dir,
127
+ git_config=git_config,
128
+ base_job_name=config["job_name"],
129
+ instance_type=instance_type,
130
+ sagemaker_session=sess,
131
+ instance_count=config["instance_count"],
132
+ role=role,
133
+ transformers_version="4.4",
134
+ pytorch_version="1.6",
135
+ py_version="py36",
136
+ hyperparameters=hyperparameters,
137
+ )
138
+ # train
139
+ huggingface_estimator.fit()
utils/__init__.py ADDED
File without changes
utils/load_dataset.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import datasets as ds
3
+
4
+
5
+ @st.cache
6
+ def load_datasets():
7
+ return ds.list_datasets(with_community_datasets=False)
utils/load_models.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import streamlit as st
3
+
4
+
5
+ @st.cache
6
+ def load_models():
7
+ res = requests.get("https://huggingface.co/api/models").json()
8
+ return [model["modelId"] for model in res]
utils/load_tasks.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import datasets as ds
3
+
4
+
5
+ @st.cache
6
+ def load_tasks():
7
+ return [
8
+ 'causal-language-modeling',
9
+ 'masked-language-modeling',
10
+ 'question-answering',
11
+ 'summarization',
12
+ 'text-classification',
13
+ 'token-classification',
14
+ 'translation',
15
+ ]