Bram Vanroy commited on
Commit
68ddcf0
1 Parent(s): aa6a76b
Files changed (4) hide show
  1. .gitignore +229 -0
  2. README.md +2 -2
  3. app.py +110 -0
  4. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run-backend.ps
2
+ .eslintrc.js
3
+ .venv
4
+ .*credentials.json
5
+ .credentials.json
6
+ /*.ipynb
7
+ logbook.md
8
+
9
+ .transl_sysprompt_en-nl
10
+
11
+ # ignore compiled styles
12
+ *.css
13
+
14
+ # dependencies
15
+ **/node_modules/
16
+ **/.pnp
17
+ *.pnp.js
18
+
19
+ # testing
20
+ /coverage
21
+
22
+ # VSCode
23
+ **/.vscode/
24
+
25
+ # production
26
+ **/build/
27
+
28
+ # misc
29
+ .DS_Store
30
+ .env.local
31
+ .env.development.local
32
+ .env.test.local
33
+ .env.production.local
34
+
35
+ npm-debug.log*
36
+ yarn-debug.log*
37
+ yarn-error.log*
38
+
39
+
40
+ # python
41
+ data/
42
+ Pipfile*
43
+
44
+ # .idea (JetBrains)
45
+ **/.idea/
46
+
47
+ # Byte-compiled / optimized / DLL files
48
+ __pycache__/
49
+ *.py[cod]
50
+ *$py.class
51
+
52
+ # C extensions
53
+ *.so
54
+
55
+ # Distribution / packaging
56
+ .Python
57
+ build/
58
+ develop-eggs/
59
+ dist/
60
+ downloads/
61
+ eggs/
62
+ .eggs/
63
+ lib/
64
+ lib64/
65
+ parts/
66
+ sdist/
67
+ var/
68
+ wheels/
69
+ *.egg-info/
70
+ .installed.cfg
71
+ *.egg
72
+ MANIFEST
73
+
74
+ # PyInstaller
75
+ # Usually these files are written by a python script from a template
76
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
77
+ *.manifest
78
+ *.spec
79
+
80
+ # Installer logs
81
+ pip-log.txt
82
+ pip-delete-this-directory.txt
83
+
84
+ # Unit test / coverage reports
85
+ htmlcov/
86
+ .tox/
87
+ .coverage
88
+ .coverage.*
89
+ .cache
90
+ nosetests.xml
91
+ coverage.xml
92
+ *.cover
93
+ .hypothesis/
94
+ .pytest_cache/
95
+
96
+ # Translations
97
+ *.mo
98
+ *.pot
99
+
100
+ # Django stuff:
101
+ *.log
102
+ local_settings.py
103
+ db.sqlite3
104
+
105
+ # Flask stuff:
106
+ instance/
107
+ .webassets-cache
108
+
109
+ # Scrapy stuff:
110
+ .scrapy
111
+
112
+ # Sphinx documentation
113
+ docs/_build/
114
+
115
+ # PyBuilder
116
+ target/
117
+
118
+ # Jupyter Notebook
119
+ .ipynb_checkpoints
120
+
121
+ # pyenv
122
+ .python-version
123
+
124
+ # celery beat schedule file
125
+ celerybeat-schedule
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ test.py
152
+
153
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
154
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
155
+
156
+ # User-specific stuff
157
+ .idea/**/workspace.xml
158
+ .idea/**/tasks.xml
159
+ .idea/**/usage.statistics.xml
160
+ .idea/**/dictionaries
161
+ .idea/**/shelf
162
+
163
+ # AWS User-specific
164
+ .idea/**/aws.xml
165
+
166
+ # Generated files
167
+ .idea/**/contentModel.xml
168
+
169
+ # Sensitive or high-churn files
170
+ .idea/**/dataSources/
171
+ .idea/**/dataSources.ids
172
+ .idea/**/dataSources.local.xml
173
+ .idea/**/sqlDataSources.xml
174
+ .idea/**/dynamic.xml
175
+ .idea/**/uiDesigner.xml
176
+ .idea/**/dbnavigator.xml
177
+
178
+ # Gradle
179
+ .idea/**/gradle.xml
180
+ .idea/**/libraries
181
+
182
+ # Gradle and Maven with auto-import
183
+ # When using Gradle or Maven with auto-import, you should exclude module files,
184
+ # since they will be recreated, and may cause churn. Uncomment if using
185
+ # auto-import.
186
+ # .idea/artifacts
187
+ # .idea/compiler.xml
188
+ # .idea/jarRepositories.xml
189
+ # .idea/modules.xml
190
+ # .idea/*.iml
191
+ # .idea/modules
192
+ # *.iml
193
+ # *.ipr
194
+
195
+ # CMake
196
+ cmake-build-*/
197
+
198
+ # Mongo Explorer plugin
199
+ .idea/**/mongoSettings.xml
200
+
201
+ # File-based project format
202
+ *.iws
203
+
204
+ # IntelliJ
205
+ out/
206
+
207
+ # mpeltonen/sbt-idea plugin
208
+ .idea_modules/
209
+
210
+ # JIRA plugin
211
+ atlassian-ide-plugin.xml
212
+
213
+ # Cursive Clojure plugin
214
+ .idea/replstate.xml
215
+
216
+ # SonarLint plugin
217
+ .idea/sonarlint/
218
+
219
+ # Crashlytics plugin (for Android Studio and IntelliJ)
220
+ com_crashlytics_export_strings.xml
221
+ crashlytics.properties
222
+ crashlytics-build.properties``
223
+ fabric.properties
224
+
225
+ # Editor-based Rest Client
226
+ .idea/httpRequests
227
+
228
+ # Android studio 3.1+ serialized cache file
229
+ .idea/caches/build_file_checksums.ser
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: Steps Calculator
3
- emoji: 🐨
4
- colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.22.0
 
1
  ---
2
  title: Steps Calculator
3
+ emoji: 🦶
4
+ colorFrom: orange
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.22.0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import ceil
2
+
3
+ import gradio as gr
4
+ from datasets import load_dataset, IterableDataset
5
+ from transformers import AutoTokenizer, PreTrainedTokenizer
6
+
7
+
8
+ def count_tokens(batch, tokenizer, text_column):
9
+ encoded = tokenizer(batch[text_column])
10
+ return {"num_tokens": [len(input_ids) for input_ids in encoded["input_ids"]]}
11
+
12
+
13
+ def get_dataset_num_tokens(
14
+ dataset: IterableDataset, tokenizer: PreTrainedTokenizer, text_column: str, progress=gr.Progress()
15
+ ) -> int:
16
+ progress((0, None), desc="Counting tokens", unit="tokens")
17
+ ds = dataset.map(
18
+ count_tokens, batched=True, batch_size=1000, fn_kwargs={"tokenizer": tokenizer, "text_column": text_column}
19
+ )
20
+
21
+ total_num_tokens = 0
22
+ for sample in ds:
23
+ total_num_tokens += sample["num_tokens"]
24
+ progress((total_num_tokens, None), desc="Counting tokens", unit="tokens")
25
+
26
+ return total_num_tokens
27
+
28
+
29
+ def calculate_steps(
30
+ dataset_name: str,
31
+ dataset_split: str,
32
+ dataset_config: str | None,
33
+ tokenizer_name: str,
34
+ num_gpus_per_node: int,
35
+ num_nodes: int,
36
+ batch_size: int,
37
+ grad_accum: int,
38
+ block_size: int,
39
+ text_column: str = "text",
40
+ token: str | None = None,
41
+ ):
42
+ dataset_config = None if not dataset_config.strip() else dataset_config
43
+ text_column = "text" if not text_column.strip() else text_column
44
+ token = None if not token.strip() else token
45
+ try:
46
+ dataset = load_dataset(dataset_name, dataset_config, streaming=True, token=token, split=dataset_split)
47
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=token)
48
+ total_num_tokens = get_dataset_num_tokens(dataset, tokenizer, text_column)
49
+ except Exception as exc:
50
+ raise gr.Error(str(exc))
51
+ else:
52
+ dataset_size = ceil(total_num_tokens / block_size)
53
+ world_size = num_gpus_per_node * num_nodes
54
+ num_steps = ceil(dataset_size / (world_size * batch_size * grad_accum))
55
+ return dataset_size, num_steps
56
+
57
+
58
+ with gr.Blocks() as demo:
59
+ gr.Markdown(
60
+ """# Steps Calculator
61
+
62
+ Calculate the number of steps required to run through your whole dataset with a given sequence length. This is \
63
+ especially useful when training with a streaming dataset and you're not sure how many steps you need to run through \
64
+ the dataset with a given tokenizer and block size."""
65
+ )
66
+
67
+ with gr.Row():
68
+ dataset_name = gr.Text(label="Dataset name")
69
+ dataset_split = gr.Text(label="Dataset split", value="train")
70
+ dataset_config = gr.Text(label="Dataset config (optional)")
71
+ tokenizer_name = gr.Text(label="Tokenizer name")
72
+
73
+ with gr.Row():
74
+ num_gpus_per_node = gr.Number(value=1, minimum=1, label="Number of GPUs per node")
75
+ num_nodes = gr.Number(value=1, minimum=1, label="Number of nodes")
76
+ batch_size = gr.Number(value=8, minimum=1, label="Batch size")
77
+ grad_accum = gr.Number(value=1, minimum=1, label="Gradient accumulation steps")
78
+ block_size = gr.Number(value=2048, minimum=1, label="Block size")
79
+ text_column = gr.Text(value="text", label="Text column")
80
+ token = gr.Text(label="HF acces token (optional)")
81
+
82
+ with gr.Row():
83
+ with gr.Column():
84
+ calculate_btn = gr.Button(value="Calculate")
85
+ with gr.Column():
86
+ samples = gr.Number(value=None, minimum=1, label="Total block-sized samples", interactive=False)
87
+ steps = gr.Number(value=None, minimum=1, label="Total steps needed", interactive=False)
88
+
89
+ calculate_btn.click(
90
+ calculate_steps,
91
+ inputs=[
92
+ dataset_name,
93
+ dataset_split,
94
+ dataset_config,
95
+ tokenizer_name,
96
+ num_gpus_per_node,
97
+ num_nodes,
98
+ batch_size,
99
+ grad_accum,
100
+ block_size,
101
+ text_column,
102
+ token,
103
+ ],
104
+ outputs=[samples, steps],
105
+ api_name="calculate-training-steps",
106
+ )
107
+
108
+
109
+ if __name__ == "__main__":
110
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ datasets
2
+ gradio
3
+ transformers