Upload 4315 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +22 -35
- Tensorflow/models/.github/ISSUE_TEMPLATE/00-official-bug-report-issue.md +59 -0
- Tensorflow/models/.github/ISSUE_TEMPLATE/10-official-documentation-issue.md +20 -0
- Tensorflow/models/.github/ISSUE_TEMPLATE/20-official-feature-request-issue.md +26 -0
- Tensorflow/models/.github/ISSUE_TEMPLATE/30-research-bug-report-issue.md +58 -0
- Tensorflow/models/.github/ISSUE_TEMPLATE/40-research-documentation-issue.md +20 -0
- Tensorflow/models/.github/ISSUE_TEMPLATE/50-research-feature-request-issue.md +26 -0
- Tensorflow/models/.github/ISSUE_TEMPLATE/60-questions-help-issue.md +14 -0
- Tensorflow/models/.github/ISSUE_TEMPLATE/config.yml +1 -0
- Tensorflow/models/.github/PULL_REQUEST_TEMPLATE.md +41 -0
- Tensorflow/models/.github/README_TEMPLATE.md +124 -0
- Tensorflow/models/.github/bot_config.yml +24 -0
- Tensorflow/models/.github/scripts/pylint.sh +178 -0
- Tensorflow/models/.github/workflows/ci.yml +32 -0
- Tensorflow/models/.github/workflows/stale.yaml +67 -0
- Tensorflow/models/.gitignore +98 -0
- Tensorflow/models/AUTHORS +10 -0
- Tensorflow/models/CODEOWNERS +29 -0
- Tensorflow/models/CODE_OF_CONDUCT.md +79 -0
- Tensorflow/models/CONTRIBUTING.md +10 -0
- Tensorflow/models/ISSUES.md +24 -0
- Tensorflow/models/LICENSE +212 -0
- Tensorflow/models/README.md +130 -0
- Tensorflow/models/SECURITY.md +251 -0
- Tensorflow/models/This PC - Shortcut.lnk +0 -0
- Tensorflow/models/community/README.md +60 -0
- Tensorflow/models/docs/README.md +17 -0
- Tensorflow/models/docs/index.md +140 -0
- Tensorflow/models/docs/nlp/_guide_toc.yaml +9 -0
- Tensorflow/models/docs/nlp/customize_encoder.ipynb +596 -0
- Tensorflow/models/docs/nlp/decoding_api.ipynb +482 -0
- Tensorflow/models/docs/nlp/fine_tune_bert.ipynb +1550 -0
- Tensorflow/models/docs/nlp/index.ipynb +545 -0
- Tensorflow/models/docs/nlp/load_lm_ckpts.ipynb +692 -0
- Tensorflow/models/docs/orbit/index.ipynb +898 -0
- Tensorflow/models/docs/vision/_toc.yaml +9 -0
- Tensorflow/models/docs/vision/image_classification.ipynb +692 -0
- Tensorflow/models/docs/vision/instance_segmentation.ipynb +1138 -0
- Tensorflow/models/docs/vision/object_detection.ipynb +902 -0
- Tensorflow/models/docs/vision/semantic_segmentation.ipynb +790 -0
- Tensorflow/models/official/README-TPU.md +32 -0
- Tensorflow/models/official/README.md +177 -0
- Tensorflow/models/official/__init__.py +14 -0
- Tensorflow/models/official/common/__init__.py +15 -0
- Tensorflow/models/official/common/dataset_fn.py +44 -0
- Tensorflow/models/official/common/distribute_utils.py +233 -0
- Tensorflow/models/official/common/distribute_utils_test.py +124 -0
- Tensorflow/models/official/common/flags.py +114 -0
- Tensorflow/models/official/common/registry_imports.py +20 -0
- Tensorflow/models/official/common/streamz_counters.py +27 -0
.gitattributes
CHANGED
@@ -1,35 +1,22 @@
|
|
1 |
-
*.
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.extension filter=lfs diff=lfs merge=lfs -text
|
2 |
+
Tensorflow/models/official/projects/waste_identification_ml/pre_processing/config/sample_images/ffdeb4cd-43ba-4ca0-a1e6-aa5824005f44.jpg filter=lfs diff=lfs merge=lfs -text
|
3 |
+
Tensorflow/models/official/projects/waste_identification_ml/pre_processing/config/sample_images/image_2.png filter=lfs diff=lfs merge=lfs -text
|
4 |
+
Tensorflow/models/official/projects/waste_identification_ml/pre_processing/config/sample_images/image_3.jpg filter=lfs diff=lfs merge=lfs -text
|
5 |
+
Tensorflow/models/official/projects/waste_identification_ml/pre_processing/config/sample_images/image_4.png filter=lfs diff=lfs merge=lfs -text
|
6 |
+
Tensorflow/models/research/deeplab/testing/pascal_voc_seg/val-00000-of-00001.tfrecord filter=lfs diff=lfs merge=lfs -text
|
7 |
+
Tensorflow/models/research/dist/object_detection-0.1-py3.8.egg filter=lfs diff=lfs merge=lfs -text
|
8 |
+
Tensorflow/models/research/lfads/synth_data/trained_itb/model-65000.meta filter=lfs diff=lfs merge=lfs -text
|
9 |
+
Tensorflow/models/research/object_detection/dataset_tools/densepose/UV_symmetry_transforms.mat filter=lfs diff=lfs merge=lfs -text
|
10 |
+
Tensorflow/models/research/object_detection/g3doc/img/kites_with_segment_overlay.png filter=lfs diff=lfs merge=lfs -text
|
11 |
+
Tensorflow/models/research/object_detection/test_images/image2.jpg filter=lfs diff=lfs merge=lfs -text
|
12 |
+
Tensorflow/protoc/bin/protoc.exe filter=lfs diff=lfs merge=lfs -text
|
13 |
+
Tensorflow/protoc/protoc-3.15.6-win64.zip filter=lfs diff=lfs merge=lfs -text
|
14 |
+
Tensorflow/workspace/annotations/train.record filter=lfs diff=lfs merge=lfs -text
|
15 |
+
Tensorflow/workspace/models/my_ssd_mobnet/ckpt-2.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
16 |
+
Tensorflow/workspace/models/my_ssd_mobnet/ckpt-3.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
17 |
+
Tensorflow/workspace/models/my_ssd_mobnet/ckpt-4.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
18 |
+
Tensorflow/workspace/models/my_ssd_mobnet/ckpt-5.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
19 |
+
Tensorflow/workspace/models/my_ssd_mobnet/ckpt-6.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
20 |
+
Tensorflow/workspace/models/my_ssd_mobnet/ckpt-7.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
21 |
+
Tensorflow/workspace/models/my_ssd_mobnet/ckpt-8.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
22 |
+
Tensorflow/workspace/models/my_ssd_mobnet/train/events.out.tfevents.1712697690.DESKTOP-GD8VO36.10056.0.v2 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Tensorflow/models/.github/ISSUE_TEMPLATE/00-official-bug-report-issue.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "[Official Model] Bug Report"
|
3 |
+
about: Use this template for reporting a bug for the “official” directory
|
4 |
+
labels: type:bug,models:official
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
# Prerequisites
|
9 |
+
|
10 |
+
Please answer the following questions for yourself before submitting an issue.
|
11 |
+
|
12 |
+
- [ ] I am using the latest TensorFlow Model Garden release and TensorFlow 2.
|
13 |
+
- [ ] I am reporting the issue to the correct repository. (Model Garden official or research directory)
|
14 |
+
- [ ] I checked to make sure that this issue has not been filed already.
|
15 |
+
|
16 |
+
## 1. The entire URL of the file you are using
|
17 |
+
|
18 |
+
https://github.com/tensorflow/models/tree/master/official/...
|
19 |
+
|
20 |
+
## 2. Describe the bug
|
21 |
+
|
22 |
+
A clear and concise description of what the bug is.
|
23 |
+
|
24 |
+
## 3. Steps to reproduce
|
25 |
+
|
26 |
+
Steps to reproduce the behavior.
|
27 |
+
|
28 |
+
## 4. Expected behavior
|
29 |
+
|
30 |
+
A clear and concise description of what you expected to happen.
|
31 |
+
|
32 |
+
## 5. Additional context
|
33 |
+
|
34 |
+
Include any logs that would be helpful to diagnose the problem.
|
35 |
+
|
36 |
+
## 6. System information
|
37 |
+
|
38 |
+
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
39 |
+
- Mobile device name if the issue happens on a mobile device:
|
40 |
+
- TensorFlow installed from (source or binary):
|
41 |
+
- TensorFlow version (use command below):
|
42 |
+
- Python version:
|
43 |
+
- Bazel version (if compiling from source):
|
44 |
+
- GCC/Compiler version (if compiling from source):
|
45 |
+
- CUDA/cuDNN version:
|
46 |
+
- GPU model and memory:
|
47 |
+
|
48 |
+
<!--
|
49 |
+
Collect system information using our environment capture script.
|
50 |
+
https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
|
51 |
+
|
52 |
+
You can also obtain the TensorFlow version with:
|
53 |
+
|
54 |
+
1. TensorFlow 1.0
|
55 |
+
`python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
|
56 |
+
|
57 |
+
2. TensorFlow 2.0
|
58 |
+
`python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
59 |
+
-->
|
Tensorflow/models/.github/ISSUE_TEMPLATE/10-official-documentation-issue.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "[Official Model] Documentation Issue"
|
3 |
+
about: Use this template for reporting a documentation issue for the “official” directory
|
4 |
+
labels: type:docs,models:official
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
# Prerequisites
|
9 |
+
|
10 |
+
Please answer the following question for yourself before submitting an issue.
|
11 |
+
|
12 |
+
- [ ] I checked to make sure that this issue has not been filed already.
|
13 |
+
|
14 |
+
## 1. The entire URL of the documentation with the issue
|
15 |
+
|
16 |
+
https://github.com/tensorflow/models/tree/master/official/...
|
17 |
+
|
18 |
+
## 2. Describe the issue
|
19 |
+
|
20 |
+
A clear and concise description of what needs to be changed.
|
Tensorflow/models/.github/ISSUE_TEMPLATE/20-official-feature-request-issue.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "[Official Model] Feature request"
|
3 |
+
about: Use this template for raising a feature request for the “official” directory
|
4 |
+
labels: type:feature,models:official
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
# Prerequisites
|
9 |
+
|
10 |
+
Please answer the following question for yourself before submitting an issue.
|
11 |
+
|
12 |
+
- [ ] I checked to make sure that this feature has not been requested already.
|
13 |
+
|
14 |
+
## 1. The entire URL of the file you are using
|
15 |
+
|
16 |
+
https://github.com/tensorflow/models/tree/master/official/...
|
17 |
+
|
18 |
+
## 2. Describe the feature you request
|
19 |
+
|
20 |
+
A clear and concise description of what you want to happen.
|
21 |
+
|
22 |
+
## 3. Additional context
|
23 |
+
|
24 |
+
Add any other context about the feature request here.
|
25 |
+
|
26 |
+
## 4. Are you willing to contribute it? (Yes or No)
|
Tensorflow/models/.github/ISSUE_TEMPLATE/30-research-bug-report-issue.md
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "[Research Model] Bug Report"
|
3 |
+
about: Use this template for reporting a bug for the “research” directory
|
4 |
+
labels: type:bug,models:research
|
5 |
+
|
6 |
+
---
|
7 |
+
# Prerequisites
|
8 |
+
|
9 |
+
Please answer the following questions for yourself before submitting an issue.
|
10 |
+
|
11 |
+
- [ ] I am using the latest TensorFlow Model Garden release and TensorFlow 2.
|
12 |
+
- [ ] I am reporting the issue to the correct repository. (Model Garden official or research directory)
|
13 |
+
- [ ] I checked to make sure that this issue has not already been filed.
|
14 |
+
|
15 |
+
## 1. The entire URL of the file you are using
|
16 |
+
|
17 |
+
https://github.com/tensorflow/models/tree/master/research/...
|
18 |
+
|
19 |
+
## 2. Describe the bug
|
20 |
+
|
21 |
+
A clear and concise description of what the bug is.
|
22 |
+
|
23 |
+
## 3. Steps to reproduce
|
24 |
+
|
25 |
+
Steps to reproduce the behavior.
|
26 |
+
|
27 |
+
## 4. Expected behavior
|
28 |
+
|
29 |
+
A clear and concise description of what you expected to happen.
|
30 |
+
|
31 |
+
## 5. Additional context
|
32 |
+
|
33 |
+
Include any logs that would be helpful to diagnose the problem.
|
34 |
+
|
35 |
+
## 6. System information
|
36 |
+
|
37 |
+
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
38 |
+
- Mobile device name if the issue happens on a mobile device:
|
39 |
+
- TensorFlow installed from (source or binary):
|
40 |
+
- TensorFlow version (use command below):
|
41 |
+
- Python version:
|
42 |
+
- Bazel version (if compiling from source):
|
43 |
+
- GCC/Compiler version (if compiling from source):
|
44 |
+
- CUDA/cuDNN version:
|
45 |
+
- GPU model and memory:
|
46 |
+
|
47 |
+
<!--
|
48 |
+
Collect system information using our environment capture script.
|
49 |
+
https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
|
50 |
+
|
51 |
+
You can also obtain the TensorFlow version with:
|
52 |
+
|
53 |
+
1. TensorFlow 1.0
|
54 |
+
`python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
|
55 |
+
|
56 |
+
2. TensorFlow 2.0
|
57 |
+
`python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
58 |
+
-->
|
Tensorflow/models/.github/ISSUE_TEMPLATE/40-research-documentation-issue.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "[Research Model] Documentation Issue"
|
3 |
+
about: Use this template for reporting a documentation issue for the “research” directory
|
4 |
+
labels: type:docs,models:research
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
# Prerequisites
|
9 |
+
|
10 |
+
Please answer the following question for yourself before submitting an issue.
|
11 |
+
|
12 |
+
- [ ] I checked to make sure that this issue has not been filed already.
|
13 |
+
|
14 |
+
## 1. The entire URL of the documentation with the issue
|
15 |
+
|
16 |
+
https://github.com/tensorflow/models/tree/master/research/...
|
17 |
+
|
18 |
+
## 2. Describe the issue
|
19 |
+
|
20 |
+
A clear and concise description of what needs to be changed.
|
Tensorflow/models/.github/ISSUE_TEMPLATE/50-research-feature-request-issue.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "[Research Model] Feature Request"
|
3 |
+
about: Use this template for raising a feature request for the “research” directory
|
4 |
+
labels: type:feature,models:research
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
# Prerequisites
|
9 |
+
|
10 |
+
Please answer the following question for yourself before submitting an issue.
|
11 |
+
|
12 |
+
- [ ] I checked to make sure that this feature has not been requested already.
|
13 |
+
|
14 |
+
## 1. The entire URL of the file you are using
|
15 |
+
|
16 |
+
https://github.com/tensorflow/models/tree/master/research/...
|
17 |
+
|
18 |
+
## 2. Describe the feature you request
|
19 |
+
|
20 |
+
A clear and concise description of what you want to happen.
|
21 |
+
|
22 |
+
## 3. Additional context
|
23 |
+
|
24 |
+
Add any other context about the feature request here.
|
25 |
+
|
26 |
+
## 4. Are you willing to contribute it? (Yes or No)
|
Tensorflow/models/.github/ISSUE_TEMPLATE/60-questions-help-issue.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Questions and Help
|
3 |
+
about: Use this template for Questions and Help.
|
4 |
+
labels: type:support
|
5 |
+
|
6 |
+
---
|
7 |
+
<!--
|
8 |
+
As per our GitHub Policy (https://github.com/tensorflow/models/blob/master/ISSUES.md), we only address code bugs, documentation issues, and feature requests on GitHub.
|
9 |
+
|
10 |
+
We will automatically close questions and help related issues.
|
11 |
+
|
12 |
+
Please go to Stack Overflow (http://stackoverflow.com/questions/tagged/tensorflow-model-garden) for questions and help.
|
13 |
+
|
14 |
+
-->
|
Tensorflow/models/.github/ISSUE_TEMPLATE/config.yml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
blank_issues_enabled: false
|
Tensorflow/models/.github/PULL_REQUEST_TEMPLATE.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Description
|
2 |
+
|
3 |
+
> :memo: Please include a summary of the change.
|
4 |
+
>
|
5 |
+
> * Please also include relevant motivation and context.
|
6 |
+
> * List any dependencies that are required for this change.
|
7 |
+
|
8 |
+
## Type of change
|
9 |
+
|
10 |
+
For a new feature or function, please create an issue first to discuss it
|
11 |
+
with us before submitting a pull request.
|
12 |
+
|
13 |
+
Note: Please delete options that are not relevant.
|
14 |
+
|
15 |
+
- [ ] Bug fix (non-breaking change which fixes an issue)
|
16 |
+
- [ ] Documentation update
|
17 |
+
- [ ] TensorFlow 2 migration
|
18 |
+
- [ ] New feature (non-breaking change which adds functionality)
|
19 |
+
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
20 |
+
- [ ] A new research paper code implementation
|
21 |
+
- [ ] Other (Specify)
|
22 |
+
|
23 |
+
## Tests
|
24 |
+
|
25 |
+
> :memo: Please describe the tests that you ran to verify your changes.
|
26 |
+
>
|
27 |
+
> * Provide instructions so we can reproduce.
|
28 |
+
> * Please also list any relevant details for your test configuration.
|
29 |
+
|
30 |
+
**Test Configuration**:
|
31 |
+
|
32 |
+
## Checklist
|
33 |
+
|
34 |
+
- [ ] I have signed the [Contributor License Agreement](https://github.com/tensorflow/models/wiki/Contributor-License-Agreements).
|
35 |
+
- [ ] I have read [guidelines for pull request](https://github.com/tensorflow/models/wiki/Submitting-a-pull-request).
|
36 |
+
- [ ] My code follows the [coding guidelines](https://github.com/tensorflow/models/wiki/Coding-guidelines).
|
37 |
+
- [ ] I have performed a self [code review](https://github.com/tensorflow/models/wiki/Code-review) of my own code.
|
38 |
+
- [ ] I have commented my code, particularly in hard-to-understand areas.
|
39 |
+
- [ ] I have made corresponding changes to the documentation.
|
40 |
+
- [ ] My changes generate no new warnings.
|
41 |
+
- [ ] I have added tests that prove my fix is effective or that my feature works.
|
Tensorflow/models/.github/README_TEMPLATE.md
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
> :memo: A README.md template for releasing a paper code implementation to a GitHub repository.
|
2 |
+
>
|
3 |
+
> * Template version: 1.0.2020.170
|
4 |
+
> * Please modify sections depending on needs.
|
5 |
+
|
6 |
+
# Model name, Paper title, or Project Name
|
7 |
+
|
8 |
+
> :memo: Add a badge for the ArXiv identifier of your paper (arXiv:YYMM.NNNNN)
|
9 |
+
|
10 |
+
[![Paper](http://img.shields.io/badge/Paper-arXiv.YYMM.NNNNN-B3181B?logo=arXiv)](https://arxiv.org/abs/...)
|
11 |
+
|
12 |
+
This repository is the official or unofficial implementation of the following paper.
|
13 |
+
|
14 |
+
* Paper title: [Paper Title](https://arxiv.org/abs/YYMM.NNNNN)
|
15 |
+
|
16 |
+
## Description
|
17 |
+
|
18 |
+
> :memo: Provide description of the model.
|
19 |
+
>
|
20 |
+
> * Provide brief information of the algorithms used.
|
21 |
+
> * Provide links for demos, blog posts, etc.
|
22 |
+
|
23 |
+
## History
|
24 |
+
|
25 |
+
> :memo: Provide a changelog.
|
26 |
+
|
27 |
+
## Authors or Maintainers
|
28 |
+
|
29 |
+
> :memo: Provide maintainer information.
|
30 |
+
|
31 |
+
* Full name ([@GitHub username](https://github.com/username))
|
32 |
+
* Full name ([@GitHub username](https://github.com/username))
|
33 |
+
|
34 |
+
## Table of Contents
|
35 |
+
|
36 |
+
> :memo: Provide a table of contents to help readers navigate a lengthy README document.
|
37 |
+
|
38 |
+
## Requirements
|
39 |
+
|
40 |
+
[![TensorFlow 2.1](https://img.shields.io/badge/TensorFlow-2.1-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0)
|
41 |
+
[![Python 3.6](https://img.shields.io/badge/Python-3.6-3776AB)](https://www.python.org/downloads/release/python-360/)
|
42 |
+
|
43 |
+
> :memo: Provide details of the software required.
|
44 |
+
>
|
45 |
+
> * Add a `requirements.txt` file to the root directory for installing the necessary dependencies.
|
46 |
+
> * Describe how to install requirements using pip.
|
47 |
+
> * Alternatively, create INSTALL.md.
|
48 |
+
|
49 |
+
To install requirements:
|
50 |
+
|
51 |
+
```setup
|
52 |
+
pip install -r requirements.txt
|
53 |
+
```
|
54 |
+
|
55 |
+
## Results
|
56 |
+
|
57 |
+
[![TensorFlow Hub](https://img.shields.io/badge/TF%20Hub-Models-FF6F00?logo=tensorflow)](https://tfhub.dev/...)
|
58 |
+
|
59 |
+
> :memo: Provide a table with results. (e.g., accuracy, latency)
|
60 |
+
>
|
61 |
+
> * Provide links to the pre-trained models (checkpoint, SavedModel files).
|
62 |
+
> * Publish TensorFlow SavedModel files on TensorFlow Hub (tfhub.dev) if possible.
|
63 |
+
> * Add links to [TensorBoard.dev](https://tensorboard.dev/) for visualizing metrics.
|
64 |
+
>
|
65 |
+
> An example table for image classification results
|
66 |
+
>
|
67 |
+
> ### Image Classification
|
68 |
+
>
|
69 |
+
> | Model name | Download | Top 1 Accuracy | Top 5 Accuracy |
|
70 |
+
> |------------|----------|----------------|----------------|
|
71 |
+
> | Model name | [Checkpoint](https://drive.google.com/...), [SavedModel](https://tfhub.dev/...) | xx% | xx% |
|
72 |
+
|
73 |
+
## Dataset
|
74 |
+
|
75 |
+
> :memo: Provide information of the dataset used.
|
76 |
+
|
77 |
+
## Training
|
78 |
+
|
79 |
+
> :memo: Provide training information.
|
80 |
+
>
|
81 |
+
> * Provide details for preprocessing, hyperparameters, random seeds, and environment.
|
82 |
+
> * Provide a command line example for training.
|
83 |
+
|
84 |
+
Please run this command line for training.
|
85 |
+
|
86 |
+
```shell
|
87 |
+
python3 ...
|
88 |
+
```
|
89 |
+
|
90 |
+
## Evaluation
|
91 |
+
|
92 |
+
> :memo: Provide an evaluation script with details of how to reproduce results.
|
93 |
+
>
|
94 |
+
> * Describe data preprocessing / postprocessing steps.
|
95 |
+
> * Provide a command line example for evaluation.
|
96 |
+
|
97 |
+
Please run this command line for evaluation.
|
98 |
+
|
99 |
+
```shell
|
100 |
+
python3 ...
|
101 |
+
```
|
102 |
+
|
103 |
+
## References
|
104 |
+
|
105 |
+
> :memo: Provide links to references.
|
106 |
+
|
107 |
+
## License
|
108 |
+
|
109 |
+
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
|
110 |
+
|
111 |
+
> :memo: Place your license text in a file named LICENSE in the root of the repository.
|
112 |
+
>
|
113 |
+
> * Include information about your license.
|
114 |
+
> * Reference: [Adding a license to a repository](https://help.github.com/en/github/building-a-strong-community/adding-a-license-to-a-repository)
|
115 |
+
|
116 |
+
This project is licensed under the terms of the **Apache License 2.0**.
|
117 |
+
|
118 |
+
## Citation
|
119 |
+
|
120 |
+
> :memo: Make your repository citable.
|
121 |
+
>
|
122 |
+
> * Reference: [Making Your Code Citable](https://guides.github.com/activities/citable-code/)
|
123 |
+
|
124 |
+
If you want to cite this repository in your research paper, please use the following information.
|
Tensorflow/models/.github/bot_config.yml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ============================================================================
|
15 |
+
#
|
16 |
+
# THIS IS A GENERATED DOCKERFILE.
|
17 |
+
#
|
18 |
+
# This file was assembled from multiple pieces, whose use is documented
|
19 |
+
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
20 |
+
# for more information.
|
21 |
+
|
22 |
+
# A list of assignees
|
23 |
+
assignees:
|
24 |
+
- laxmareddyp
|
Tensorflow/models/.github/scripts/pylint.sh
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ==============================================================================
|
16 |
+
#
|
17 |
+
# Pylint wrapper extracted from main TensorFlow, sharing same exceptions.
|
18 |
+
# Specify --incremental to only check files touched since last commit on master,
|
19 |
+
# otherwise will recursively check current directory (full repo takes long!).
|
20 |
+
|
21 |
+
set -euo pipefail
|
22 |
+
|
23 |
+
# Download latest configs from main TensorFlow repo.
|
24 |
+
wget -q -O /tmp/pylintrc https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/tools/ci_build/pylintrc
|
25 |
+
|
26 |
+
SCRIPT_DIR=/tmp
|
27 |
+
|
28 |
+
num_cpus() {
|
29 |
+
# Get the number of CPUs
|
30 |
+
if [[ -f /proc/cpuinfo ]]; then
|
31 |
+
N_CPUS=$(grep -c ^processor /proc/cpuinfo)
|
32 |
+
else
|
33 |
+
# Fallback method
|
34 |
+
N_CPUS=`getconf _NPROCESSORS_ONLN`
|
35 |
+
fi
|
36 |
+
if [[ -z ${N_CPUS} ]]; then
|
37 |
+
die "ERROR: Unable to determine the number of CPUs"
|
38 |
+
fi
|
39 |
+
|
40 |
+
echo ${N_CPUS}
|
41 |
+
}
|
42 |
+
|
43 |
+
get_changed_files_in_last_non_merge_git_commit() {
|
44 |
+
git diff --name-only $(git merge-base master $(git branch --show-current))
|
45 |
+
}
|
46 |
+
|
47 |
+
# List Python files changed in the last non-merge git commit that still exist,
|
48 |
+
# i.e., not removed.
|
49 |
+
# Usage: get_py_files_to_check [--incremental]
|
50 |
+
get_py_files_to_check() {
|
51 |
+
if [[ "$1" == "--incremental" ]]; then
|
52 |
+
CHANGED_PY_FILES=$(get_changed_files_in_last_non_merge_git_commit | \
|
53 |
+
grep '.*\.py$')
|
54 |
+
|
55 |
+
# Do not include files removed in the last non-merge commit.
|
56 |
+
PY_FILES=""
|
57 |
+
for PY_FILE in ${CHANGED_PY_FILES}; do
|
58 |
+
if [[ -f "${PY_FILE}" ]]; then
|
59 |
+
PY_FILES="${PY_FILES} ${PY_FILE}"
|
60 |
+
fi
|
61 |
+
done
|
62 |
+
|
63 |
+
echo "${PY_FILES}"
|
64 |
+
else
|
65 |
+
find . -name '*.py'
|
66 |
+
fi
|
67 |
+
}
|
68 |
+
|
69 |
+
do_pylint() {
|
70 |
+
if [[ $# == 1 ]] && [[ "$1" == "--incremental" ]]; then
|
71 |
+
PYTHON_SRC_FILES=$(get_py_files_to_check --incremental)
|
72 |
+
|
73 |
+
if [[ -z "${PYTHON_SRC_FILES}" ]]; then
|
74 |
+
echo "do_pylint will NOT run due to --incremental flag and due to the "\
|
75 |
+
"absence of Python code changes in the last commit."
|
76 |
+
return 0
|
77 |
+
fi
|
78 |
+
elif [[ $# != 0 ]]; then
|
79 |
+
echo "Invalid syntax for invoking do_pylint"
|
80 |
+
echo "Usage: do_pylint [--incremental]"
|
81 |
+
return 1
|
82 |
+
else
|
83 |
+
PYTHON_SRC_FILES=$(get_py_files_to_check)
|
84 |
+
fi
|
85 |
+
|
86 |
+
# Something happened. TF no longer has Python code if this branch is taken
|
87 |
+
if [[ -z ${PYTHON_SRC_FILES} ]]; then
|
88 |
+
echo "do_pylint found no Python files to check. Returning."
|
89 |
+
return 0
|
90 |
+
fi
|
91 |
+
|
92 |
+
# Now that we know we have to do work, check if `pylint` is installed
|
93 |
+
PYLINT_BIN="python3.8 -m pylint"
|
94 |
+
|
95 |
+
echo ""
|
96 |
+
echo "check whether pylint is available or not."
|
97 |
+
echo ""
|
98 |
+
${PYLINT_BIN} --version
|
99 |
+
if [[ $? -eq 0 ]]
|
100 |
+
then
|
101 |
+
echo ""
|
102 |
+
echo "pylint available, proceeding with pylint sanity check."
|
103 |
+
echo ""
|
104 |
+
else
|
105 |
+
echo ""
|
106 |
+
echo "pylint not available."
|
107 |
+
echo ""
|
108 |
+
return 1
|
109 |
+
fi
|
110 |
+
|
111 |
+
# Configure pylint using the following file
|
112 |
+
PYLINTRC_FILE="${SCRIPT_DIR}/pylintrc"
|
113 |
+
|
114 |
+
if [[ ! -f "${PYLINTRC_FILE}" ]]; then
|
115 |
+
die "ERROR: Cannot find pylint rc file at ${PYLINTRC_FILE}"
|
116 |
+
fi
|
117 |
+
|
118 |
+
# Run pylint in parallel, after some disk setup
|
119 |
+
NUM_SRC_FILES=$(echo ${PYTHON_SRC_FILES} | wc -w)
|
120 |
+
NUM_CPUS=$(num_cpus)
|
121 |
+
|
122 |
+
echo "Running pylint on ${NUM_SRC_FILES} files with ${NUM_CPUS} "\
|
123 |
+
"parallel jobs..."
|
124 |
+
echo ""
|
125 |
+
|
126 |
+
PYLINT_START_TIME=$(date +'%s')
|
127 |
+
OUTPUT_FILE="$(mktemp)_pylint_output.log"
|
128 |
+
ERRORS_FILE="$(mktemp)_pylint_errors.log"
|
129 |
+
|
130 |
+
rm -rf ${OUTPUT_FILE}
|
131 |
+
rm -rf ${ERRORS_FILE}
|
132 |
+
|
133 |
+
set +e
|
134 |
+
# When running, filter to only contain the error code lines. Removes module
|
135 |
+
# header, removes lines of context that show up from some lines.
|
136 |
+
# Also, don't redirect stderr as this would hide pylint fatal errors.
|
137 |
+
${PYLINT_BIN} --rcfile="${PYLINTRC_FILE}" --output-format=parseable \
|
138 |
+
--jobs=${NUM_CPUS} ${PYTHON_SRC_FILES} | grep '\[[CEFW]' > ${OUTPUT_FILE}
|
139 |
+
PYLINT_END_TIME=$(date +'%s')
|
140 |
+
|
141 |
+
echo ""
|
142 |
+
echo "pylint took $((PYLINT_END_TIME - PYLINT_START_TIME)) s"
|
143 |
+
echo ""
|
144 |
+
|
145 |
+
# Report only what we care about
|
146 |
+
# Ref https://pylint.readthedocs.io/en/latest/technical_reference/features.html
|
147 |
+
# E: all errors
|
148 |
+
# W0311 bad-indentation
|
149 |
+
# W0312 mixed-indentation
|
150 |
+
# C0330 bad-continuation
|
151 |
+
# C0301 line-too-long
|
152 |
+
# C0326 bad-whitespace
|
153 |
+
# W0611 unused-import
|
154 |
+
# W0622 redefined-builtin
|
155 |
+
grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301|\[C0326|\[W0611|\[W0622)' ${OUTPUT_FILE} > ${ERRORS_FILE}
|
156 |
+
|
157 |
+
# Determine counts of errors
|
158 |
+
N_FORBID_ERRORS=$(wc -l ${ERRORS_FILE} | cut -d' ' -f1)
|
159 |
+
set -e
|
160 |
+
|
161 |
+
# Now, print the errors we should fix
|
162 |
+
echo ""
|
163 |
+
if [[ ${N_FORBID_ERRORS} != 0 ]]; then
|
164 |
+
echo "Found ${N_FORBID_ERRORS} pylint errors:"
|
165 |
+
cat ${ERRORS_FILE}
|
166 |
+
fi
|
167 |
+
|
168 |
+
echo ""
|
169 |
+
if [[ ${N_FORBID_ERRORS} != 0 ]]; then
|
170 |
+
echo "FAIL: Found ${N_FORBID_ERRORS} errors"
|
171 |
+
return 1
|
172 |
+
else
|
173 |
+
echo "PASS: Found no errors"
|
174 |
+
fi
|
175 |
+
}
|
176 |
+
|
177 |
+
do_pylint "$@"
|
178 |
+
|
Tensorflow/models/.github/workflows/ci.yml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: CI
|
2 |
+
on: pull_request
|
3 |
+
|
4 |
+
permissions:
|
5 |
+
contents: read
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
pylint:
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
|
11 |
+
steps:
|
12 |
+
- name: Set up Python 3.8
|
13 |
+
uses: actions/setup-python@v2
|
14 |
+
with:
|
15 |
+
python-version: 3.8
|
16 |
+
|
17 |
+
- name: Install pylint 2.4.4
|
18 |
+
run: |
|
19 |
+
python -m pip install --upgrade pip
|
20 |
+
pip install pylint==2.4.4
|
21 |
+
|
22 |
+
- name: Checkout code
|
23 |
+
uses: actions/checkout@v2
|
24 |
+
with:
|
25 |
+
ref: ${{ github.event.pull_request.head.sha }}
|
26 |
+
fetch-depth: 0
|
27 |
+
|
28 |
+
- name: Fetch master for diff
|
29 |
+
run: git fetch origin master:master
|
30 |
+
|
31 |
+
- name: Run pylint script
|
32 |
+
run: bash ./.github/scripts/pylint.sh --incremental
|
Tensorflow/models/.github/workflows/stale.yaml
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
# This workflow alerts and then closes the stale issues/PRs after specific time
|
17 |
+
# You can adjust the behavior by modifying this file.
|
18 |
+
# For more information, see:
|
19 |
+
# https://github.com/actions/stale
|
20 |
+
|
21 |
+
name: 'Close stale issues and PRs'
|
22 |
+
"on":
|
23 |
+
schedule:
|
24 |
+
- cron: "30 1 * * *"
|
25 |
+
permissions:
|
26 |
+
contents: read
|
27 |
+
issues: write
|
28 |
+
pull-requests: write
|
29 |
+
|
30 |
+
jobs:
|
31 |
+
stale:
|
32 |
+
runs-on: ubuntu-latest
|
33 |
+
steps:
|
34 |
+
- uses: 'actions/stale@v7'
|
35 |
+
with:
|
36 |
+
#Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale
|
37 |
+
exempt-issue-labels: 'override-stale'
|
38 |
+
#Comma separated list of labels that can be assigned to PRs to exclude them from being marked as stale
|
39 |
+
exempt-pr-labels: "override-stale"
|
40 |
+
#Limit the No. of API calls in one run default value is 30.
|
41 |
+
operations-per-run: 1000
|
42 |
+
#Prevent to remove stale label when PRs or issues are updated.
|
43 |
+
remove-stale-when-updated: false
|
44 |
+
# comment on issue if not active for more then 7 days.
|
45 |
+
stale-issue-message: 'This issue has been marked stale because it has no recent activity since 7 days. It will be closed if no further activity occurs. Thank you.'
|
46 |
+
# comment on PR if not active for more then 14 days.
|
47 |
+
stale-pr-message: 'This PR has been marked stale because it has no recent activity since 14 days. It will be closed if no further activity occurs. Thank you.'
|
48 |
+
# comment on issue if stale for more then 7 days.
|
49 |
+
close-issue-message: This issue was closed due to lack of activity after being marked stale for past 7 days.
|
50 |
+
# comment on PR if stale for more then 14 days.
|
51 |
+
close-pr-message: This PR was closed due to lack of activity after being marked stale for past 14 days.
|
52 |
+
# Number of days of inactivity before an Issue Request becomes stale
|
53 |
+
days-before-issue-stale: 7
|
54 |
+
# Number of days of inactivity before a stale Issue is closed
|
55 |
+
days-before-issue-close: 7
|
56 |
+
# reason for closed the issue default value is not_planned
|
57 |
+
close-issue-reason: completed
|
58 |
+
# Number of days of inactivity before a stale PR is closed
|
59 |
+
days-before-pr-close: 14
|
60 |
+
# Number of days of inactivity before an PR Request becomes stale
|
61 |
+
days-before-pr-stale: 14
|
62 |
+
# Check for label to stale or close the issue/PR
|
63 |
+
any-of-labels: 'stat:awaiting response'
|
64 |
+
# override stale to stalled for PR
|
65 |
+
stale-pr-label: 'stale'
|
66 |
+
# override stale to stalled for Issue
|
67 |
+
stale-issue-label: "stale"
|
Tensorflow/models/.gitignore
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
env/
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
|
27 |
+
# PyInstaller
|
28 |
+
# Usually these files are written by a python script from a template
|
29 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
30 |
+
*.manifest
|
31 |
+
*.spec
|
32 |
+
|
33 |
+
# Installer logs
|
34 |
+
pip-log.txt
|
35 |
+
pip-delete-this-directory.txt
|
36 |
+
|
37 |
+
# Unit test / coverage reports
|
38 |
+
htmlcov/
|
39 |
+
.tox/
|
40 |
+
.coverage
|
41 |
+
.coverage.*
|
42 |
+
.cache
|
43 |
+
nosetests.xml
|
44 |
+
coverage.xml
|
45 |
+
*,cover
|
46 |
+
.hypothesis/
|
47 |
+
|
48 |
+
# Translations
|
49 |
+
*.mo
|
50 |
+
*.pot
|
51 |
+
|
52 |
+
# Django stuff:
|
53 |
+
*.log
|
54 |
+
local_settings.py
|
55 |
+
|
56 |
+
# Flask stuff:
|
57 |
+
instance/
|
58 |
+
.webassets-cache
|
59 |
+
|
60 |
+
# Scrapy stuff:
|
61 |
+
.scrapy
|
62 |
+
|
63 |
+
# Sphinx documentation
|
64 |
+
docs/_build/
|
65 |
+
|
66 |
+
# PyBuilder
|
67 |
+
target/
|
68 |
+
|
69 |
+
# IPython Notebook
|
70 |
+
.ipynb_checkpoints
|
71 |
+
|
72 |
+
# pyenv
|
73 |
+
.python-version
|
74 |
+
|
75 |
+
# mypy
|
76 |
+
.mypy_cache
|
77 |
+
|
78 |
+
# celery beat schedule file
|
79 |
+
celerybeat-schedule
|
80 |
+
|
81 |
+
# dotenv
|
82 |
+
.env
|
83 |
+
|
84 |
+
# virtualenv
|
85 |
+
venv/
|
86 |
+
ENV/
|
87 |
+
|
88 |
+
# Spyder project settings
|
89 |
+
.spyderproject
|
90 |
+
|
91 |
+
# Rope project settings
|
92 |
+
.ropeproject
|
93 |
+
|
94 |
+
# PyCharm
|
95 |
+
.idea/
|
96 |
+
|
97 |
+
# For mac
|
98 |
+
.DS_Store
|
Tensorflow/models/AUTHORS
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is the official list of authors for copyright purposes.
|
2 |
+
# This file is distinct from the CONTRIBUTORS files.
|
3 |
+
# See the latter for an explanation.
|
4 |
+
|
5 |
+
# Names should be added to this file as:
|
6 |
+
# Name or Organization <email address>
|
7 |
+
# The email address is not required for organizations.
|
8 |
+
|
9 |
+
Google Inc.
|
10 |
+
David Dao <daviddao@broad.mit.edu>
|
Tensorflow/models/CODEOWNERS
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
* @tensorflow/tf-model-garden-team
|
2 |
+
/official/ @rachellj218 @saberkun
|
3 |
+
/official/nlp/ @saberkun @lehougoogle @rachellj218
|
4 |
+
/official/recommendation/ranking/ @gagika
|
5 |
+
/official/vision/ @yeqingli @arashwan @saberkun @rachellj218
|
6 |
+
/official/vision/projects/assemblenet/ @mryoo @yeqingli
|
7 |
+
/official/vision/projects/deepmac_maskrcnn/ @vighneshbirodkar
|
8 |
+
/official/vision/projects/movinet/ @hyperparticle @yuanliangzhe @yeqingli
|
9 |
+
/official/vision/projects/simclr/ @luotigerlsx @chentingpc @saxenasaurabh
|
10 |
+
/official/vision/projects/video_ssl/ @richardaecn @yeqingli
|
11 |
+
/research/adversarial_text/ @rsepassi @a-dai
|
12 |
+
/research/attention_ocr/ @xavigibert
|
13 |
+
/research/audioset/ @plakal @dpwe
|
14 |
+
/research/autoaugment/ @barretzoph
|
15 |
+
/research/cognitive_planning/ @s-gupta
|
16 |
+
/research/cvt_text/ @clarkkev @lmthang
|
17 |
+
/research/deep_speech/ @yhliang2018
|
18 |
+
/research/deeplab/ @aquariusjay @yknzhu
|
19 |
+
/research/delf/ @andrefaraujo
|
20 |
+
/research/efficient-hrl/ @ofirnachum
|
21 |
+
/research/lfads/ @jazcollins @sussillo
|
22 |
+
/research/lstm_object_detection/ @yinxiaoli @yongzhe2160
|
23 |
+
/research/marco/ @vincentvanhoucke
|
24 |
+
/research/object_detection/ @jch1 @tombstone @pkulzc
|
25 |
+
/research/pcl_rl/ @ofirnachum
|
26 |
+
/research/rebar/ @gjtucker
|
27 |
+
/research/seq_flow_lite/ @thunderfyc @karunreddy30
|
28 |
+
/research/slim/ @sguada @marksandler2
|
29 |
+
/research/vid2depth/ @rezama
|
Tensorflow/models/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TensorFlow-models Code of Conduct
|
2 |
+
|
3 |
+
In the interest of fostering an open and welcoming environment, we as
|
4 |
+
contributors and maintainers pledge to make participation in our project and our
|
5 |
+
community a harassment-free experience for everyone, regardless of age, body
|
6 |
+
size, disability, ethnicity, gender identity and expression, level of
|
7 |
+
experience, nationality, personal appearance, race, religion, or sexual identity
|
8 |
+
and orientation.
|
9 |
+
|
10 |
+
## Our Standards
|
11 |
+
|
12 |
+
Examples of behavior that contributes to creating a positive environment include:
|
13 |
+
|
14 |
+
* Using welcoming and inclusive language.
|
15 |
+
* Being respectful of differing viewpoints and experiences.
|
16 |
+
* Gracefully accepting constructive criticism.
|
17 |
+
* Focusing on what is best for the community.
|
18 |
+
* Showing empathy towards other community members.
|
19 |
+
|
20 |
+
Examples of unacceptable behavior by participants include:
|
21 |
+
|
22 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
23 |
+
advances.
|
24 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks.
|
25 |
+
* Public or private harassment.
|
26 |
+
* Publishing others' private information, such as a physical or electronic
|
27 |
+
address, without explicit permission.
|
28 |
+
* Conduct which could reasonably be considered inappropriate for the forum in
|
29 |
+
which it occurs.
|
30 |
+
|
31 |
+
All TensorFlow-models forums and spaces are meant for professional interactions, and any behavior which could reasonably be considered inappropriate in a professional setting is unacceptable.
|
32 |
+
|
33 |
+
å
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
|
37 |
+
|
38 |
+
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
|
39 |
+
|
40 |
+
|
41 |
+
## Scope
|
42 |
+
|
43 |
+
This Code of Conduct applies to all content on tensorflow.org, TensorFlow-models GitHub organization, or any other official TensorFlow-models web presence allowing for community interactions, as well as at all official TensorFlow-models events, whether offline or online.
|
44 |
+
|
45 |
+
The Code of Conduct also applies within project spaces and in public spaces whenever an individual is representing TensorFlow-models or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed or de facto representative at an online or offline event.
|
46 |
+
|
47 |
+
|
48 |
+
## Conflict Resolution
|
49 |
+
|
50 |
+
Conflicts in an open source project can take many forms, from someone having a bad day and using harsh and hurtful language in the issue queue, to more serious instances such as sexist/racist statements or threats of violence, and everything in between.
|
51 |
+
|
52 |
+
If the behavior is threatening or harassing, or for other reasons requires immediate escalation, please see below.
|
53 |
+
|
54 |
+
However, for the vast majority of issues, we aim to empower individuals to first resolve conflicts themselves, asking for help when needed, and only after that fails to escalate further. This approach gives people more control over the outcome of their dispute.
|
55 |
+
|
56 |
+
If you are experiencing or witnessing conflict, we ask you to use the following escalation strategy to address the conflict:
|
57 |
+
|
58 |
+
1. Address the perceived conflict directly with those involved, preferably in a
|
59 |
+
real-time medium.
|
60 |
+
2. If this fails, get a third party (e.g. a mutual friend, and/or someone with
|
61 |
+
background on the issue, but not involved in the conflict) to intercede.
|
62 |
+
3. If you are still unable to resolve the conflict, and you believe it rises to
|
63 |
+
harassment or another code of conduct violation, report it.
|
64 |
+
|
65 |
+
## Reporting Violations
|
66 |
+
|
67 |
+
Violations of the Code of Conduct can be reported to TensorFlow’s Project Stewards, Thea Lamkin (thealamkin@google.com) and Joana Carrasqueira (joanafilipa@google.com). The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report.
|
68 |
+
|
69 |
+
Violations of the Code of Conduct can occur in any setting, even those unrelated to the project. We will only consider complaints about conduct that has occurred within one year of the report.
|
70 |
+
|
71 |
+
|
72 |
+
## Enforcement
|
73 |
+
|
74 |
+
If the Project Stewards receive a report alleging a violation of the Code of Conduct, the Project Stewards will notify the accused of the report, and provide them an opportunity to discuss the report before a sanction is issued. The Project Stewards will do their utmost to keep the reporter anonymous. If the act is ongoing (such as someone engaging in harassment), or involves a threat to anyone's safety (e.g. threats of violence), the Project Stewards may issue sanctions without notice.
|
75 |
+
|
76 |
+
|
77 |
+
## Attribution
|
78 |
+
|
79 |
+
This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at https://contributor-covenant.org/version/1/4, and includes some aspects of the Geek Feminism Code of Conduct and the Drupal Code of Conduct.
|
Tensorflow/models/CONTRIBUTING.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How to contribute
|
2 |
+
|
3 |
+
![Contributors](https://img.shields.io/github/contributors/tensorflow/models)
|
4 |
+
|
5 |
+
We encourage you to contribute to the TensorFlow Model Garden.
|
6 |
+
|
7 |
+
Please read our [guidelines](../../wiki/How-to-contribute) for details.
|
8 |
+
|
9 |
+
**NOTE**: Only [code owners](./CODEOWNERS) are allowed to merge a pull request.
|
10 |
+
Please contact the code owners of each model to merge your pull request.
|
Tensorflow/models/ISSUES.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# If you open a GitHub issue, here is our policy.
|
2 |
+
|
3 |
+
* It must be a **bug**, a **feature request**, or a significant problem
|
4 |
+
with **documentation**.
|
5 |
+
* Please send a pull request instead for small documentation fixes.
|
6 |
+
* The required form must be filled out.
|
7 |
+
* The issue should be related to the repository it is created in.
|
8 |
+
|
9 |
+
General help and support should be sought on [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow-model-garden) or other non-GitHub channels.
|
10 |
+
|
11 |
+
[![](https://img.shields.io/stackexchange/stackoverflow/t/tensorflow-model-garden)](https://stackoverflow.com/questions/tagged/tensorflow-model-garden)
|
12 |
+
|
13 |
+
TensorFlow developers respond to issues.
|
14 |
+
We want to focus on work that benefits the whole community such as fixing bugs
|
15 |
+
and adding new features.
|
16 |
+
It helps us to address bugs and feature requests in a timely manner.
|
17 |
+
|
18 |
+
---
|
19 |
+
|
20 |
+
Please understand that research models in the [research directory](https://github.com/tensorflow/models/tree/master/research)
|
21 |
+
included in this repository are experimental and research-style code.
|
22 |
+
They are not officially supported by the TensorFlow team.
|
23 |
+
|
24 |
+
|
Tensorflow/models/LICENSE
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2022 Google LLC. All rights reserved.
|
2 |
+
|
3 |
+
All files in the following folders:
|
4 |
+
/community
|
5 |
+
/official
|
6 |
+
/orbit
|
7 |
+
/research
|
8 |
+
/tensorflow_models
|
9 |
+
|
10 |
+
Are licensed as follows:
|
11 |
+
|
12 |
+
Apache License
|
13 |
+
Version 2.0, January 2004
|
14 |
+
http://www.apache.org/licenses/
|
15 |
+
|
16 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
17 |
+
|
18 |
+
1. Definitions.
|
19 |
+
|
20 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
21 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
22 |
+
|
23 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
24 |
+
the copyright owner that is granting the License.
|
25 |
+
|
26 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
27 |
+
other entities that control, are controlled by, or are under common
|
28 |
+
control with that entity. For the purposes of this definition,
|
29 |
+
"control" means (i) the power, direct or indirect, to cause the
|
30 |
+
direction or management of such entity, whether by contract or
|
31 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
32 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
33 |
+
|
34 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
35 |
+
exercising permissions granted by this License.
|
36 |
+
|
37 |
+
"Source" form shall mean the preferred form for making modifications,
|
38 |
+
including but not limited to software source code, documentation
|
39 |
+
source, and configuration files.
|
40 |
+
|
41 |
+
"Object" form shall mean any form resulting from mechanical
|
42 |
+
transformation or translation of a Source form, including but
|
43 |
+
not limited to compiled object code, generated documentation,
|
44 |
+
and conversions to other media types.
|
45 |
+
|
46 |
+
"Work" shall mean the work of authorship, whether in Source or
|
47 |
+
Object form, made available under the License, as indicated by a
|
48 |
+
copyright notice that is included in or attached to the work
|
49 |
+
(an example is provided in the Appendix below).
|
50 |
+
|
51 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
52 |
+
form, that is based on (or derived from) the Work and for which the
|
53 |
+
editorial revisions, annotations, elaborations, or other modifications
|
54 |
+
represent, as a whole, an original work of authorship. For the purposes
|
55 |
+
of this License, Derivative Works shall not include works that remain
|
56 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
57 |
+
the Work and Derivative Works thereof.
|
58 |
+
|
59 |
+
"Contribution" shall mean any work of authorship, including
|
60 |
+
the original version of the Work and any modifications or additions
|
61 |
+
to that Work or Derivative Works thereof, that is intentionally
|
62 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
63 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
64 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
65 |
+
means any form of electronic, verbal, or written communication sent
|
66 |
+
to the Licensor or its representatives, including but not limited to
|
67 |
+
communication on electronic mailing lists, source code control systems,
|
68 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
69 |
+
Licensor for the purpose of discussing and improving the Work, but
|
70 |
+
excluding communication that is conspicuously marked or otherwise
|
71 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
72 |
+
|
73 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
74 |
+
on behalf of whom a Contribution has been received by Licensor and
|
75 |
+
subsequently incorporated within the Work.
|
76 |
+
|
77 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
78 |
+
this License, each Contributor hereby grants to You a perpetual,
|
79 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
80 |
+
copyright license to reproduce, prepare Derivative Works of,
|
81 |
+
publicly display, publicly perform, sublicense, and distribute the
|
82 |
+
Work and such Derivative Works in Source or Object form.
|
83 |
+
|
84 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
85 |
+
this License, each Contributor hereby grants to You a perpetual,
|
86 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
87 |
+
(except as stated in this section) patent license to make, have made,
|
88 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
89 |
+
where such license applies only to those patent claims licensable
|
90 |
+
by such Contributor that are necessarily infringed by their
|
91 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
92 |
+
with the Work to which such Contribution(s) was submitted. If You
|
93 |
+
institute patent litigation against any entity (including a
|
94 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
95 |
+
or a Contribution incorporated within the Work constitutes direct
|
96 |
+
or contributory patent infringement, then any patent licenses
|
97 |
+
granted to You under this License for that Work shall terminate
|
98 |
+
as of the date such litigation is filed.
|
99 |
+
|
100 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
101 |
+
Work or Derivative Works thereof in any medium, with or without
|
102 |
+
modifications, and in Source or Object form, provided that You
|
103 |
+
meet the following conditions:
|
104 |
+
|
105 |
+
(a) You must give any other recipients of the Work or
|
106 |
+
Derivative Works a copy of this License; and
|
107 |
+
|
108 |
+
(b) You must cause any modified files to carry prominent notices
|
109 |
+
stating that You changed the files; and
|
110 |
+
|
111 |
+
(c) You must retain, in the Source form of any Derivative Works
|
112 |
+
that You distribute, all copyright, patent, trademark, and
|
113 |
+
attribution notices from the Source form of the Work,
|
114 |
+
excluding those notices that do not pertain to any part of
|
115 |
+
the Derivative Works; and
|
116 |
+
|
117 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
118 |
+
distribution, then any Derivative Works that You distribute must
|
119 |
+
include a readable copy of the attribution notices contained
|
120 |
+
within such NOTICE file, excluding those notices that do not
|
121 |
+
pertain to any part of the Derivative Works, in at least one
|
122 |
+
of the following places: within a NOTICE text file distributed
|
123 |
+
as part of the Derivative Works; within the Source form or
|
124 |
+
documentation, if provided along with the Derivative Works; or,
|
125 |
+
within a display generated by the Derivative Works, if and
|
126 |
+
wherever such third-party notices normally appear. The contents
|
127 |
+
of the NOTICE file are for informational purposes only and
|
128 |
+
do not modify the License. You may add Your own attribution
|
129 |
+
notices within Derivative Works that You distribute, alongside
|
130 |
+
or as an addendum to the NOTICE text from the Work, provided
|
131 |
+
that such additional attribution notices cannot be construed
|
132 |
+
as modifying the License.
|
133 |
+
|
134 |
+
You may add Your own copyright statement to Your modifications and
|
135 |
+
may provide additional or different license terms and conditions
|
136 |
+
for use, reproduction, or distribution of Your modifications, or
|
137 |
+
for any such Derivative Works as a whole, provided Your use,
|
138 |
+
reproduction, and distribution of the Work otherwise complies with
|
139 |
+
the conditions stated in this License.
|
140 |
+
|
141 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
142 |
+
any Contribution intentionally submitted for inclusion in the Work
|
143 |
+
by You to the Licensor shall be under the terms and conditions of
|
144 |
+
this License, without any additional terms or conditions.
|
145 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
146 |
+
the terms of any separate license agreement you may have executed
|
147 |
+
with Licensor regarding such Contributions.
|
148 |
+
|
149 |
+
6. Trademarks. This License does not grant permission to use the trade
|
150 |
+
names, trademarks, service marks, or product names of the Licensor,
|
151 |
+
except as required for reasonable and customary use in describing the
|
152 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
153 |
+
|
154 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
155 |
+
agreed to in writing, Licensor provides the Work (and each
|
156 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
157 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
158 |
+
implied, including, without limitation, any warranties or conditions
|
159 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
160 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
161 |
+
appropriateness of using or redistributing the Work and assume any
|
162 |
+
risks associated with Your exercise of permissions under this License.
|
163 |
+
|
164 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
165 |
+
whether in tort (including negligence), contract, or otherwise,
|
166 |
+
unless required by applicable law (such as deliberate and grossly
|
167 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
168 |
+
liable to You for damages, including any direct, indirect, special,
|
169 |
+
incidental, or consequential damages of any character arising as a
|
170 |
+
result of this License or out of the use or inability to use the
|
171 |
+
Work (including but not limited to damages for loss of goodwill,
|
172 |
+
work stoppage, computer failure or malfunction, or any and all
|
173 |
+
other commercial damages or losses), even if such Contributor
|
174 |
+
has been advised of the possibility of such damages.
|
175 |
+
|
176 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
177 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
178 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
179 |
+
or other liability obligations and/or rights consistent with this
|
180 |
+
License. However, in accepting such obligations, You may act only
|
181 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
182 |
+
of any other Contributor, and only if You agree to indemnify,
|
183 |
+
defend, and hold each Contributor harmless for any liability
|
184 |
+
incurred by, or claims asserted against, such Contributor by reason
|
185 |
+
of your accepting any such warranty or additional liability.
|
186 |
+
|
187 |
+
END OF TERMS AND CONDITIONS
|
188 |
+
|
189 |
+
APPENDIX: How to apply the Apache License to your work.
|
190 |
+
|
191 |
+
To apply the Apache License to your work, attach the following
|
192 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
193 |
+
replaced with your own identifying information. (Don't include
|
194 |
+
the brackets!) The text should be enclosed in the appropriate
|
195 |
+
comment syntax for the file format. We also recommend that a
|
196 |
+
file or class name and description of purpose be included on the
|
197 |
+
same "printed page" as the copyright notice for easier
|
198 |
+
identification within third-party archives.
|
199 |
+
|
200 |
+
Copyright 2016, The Authors.
|
201 |
+
|
202 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
203 |
+
you may not use this file except in compliance with the License.
|
204 |
+
You may obtain a copy of the License at
|
205 |
+
|
206 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
207 |
+
|
208 |
+
Unless required by applicable law or agreed to in writing, software
|
209 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
210 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
211 |
+
See the License for the specific language governing permissions and
|
212 |
+
limitations under the License.
|
Tensorflow/models/README.md
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<img src="https://storage.googleapis.com/tf_model_garden/tf_model_garden_logo.png">
|
3 |
+
</div>
|
4 |
+
|
5 |
+
[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow)
|
6 |
+
[![tf-models-official PyPI](https://badge.fury.io/py/tf-models-official.svg)](https://badge.fury.io/py/tf-models-official)
|
7 |
+
|
8 |
+
|
9 |
+
# Welcome to the Model Garden for TensorFlow
|
10 |
+
|
11 |
+
The TensorFlow Model Garden is a repository with a number of different
|
12 |
+
implementations of state-of-the-art (SOTA) models and modeling solutions for
|
13 |
+
TensorFlow users. We aim to demonstrate the best practices for modeling so that
|
14 |
+
TensorFlow users can take full advantage of TensorFlow for their research and
|
15 |
+
product development.
|
16 |
+
|
17 |
+
To improve the transparency and reproducibility of our models, training logs on
|
18 |
+
[TensorBoard.dev](https://tensorboard.dev) are also provided for models to the
|
19 |
+
extent possible though not all models are suitable.
|
20 |
+
|
21 |
+
| Directory | Description |
|
22 |
+
|-----------|-------------|
|
23 |
+
| [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read<br /> For more details on the capabilities, check the guide on the [Model-garden](https://www.tensorflow.org/tfmodels)|
|
24 |
+
| [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers |
|
25 |
+
| [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
|
26 |
+
| [orbit](orbit) | • A flexible and lightweight library that users can easily use or fork when writing customized training loop code in TensorFlow 2.x. It seamlessly integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU). |
|
27 |
+
|
28 |
+
## Installation
|
29 |
+
|
30 |
+
To install the current release of tensorflow-models, please follow any one of the methods described below.
|
31 |
+
|
32 |
+
#### Method 1: Install the TensorFlow Model Garden pip package
|
33 |
+
|
34 |
+
<details>
|
35 |
+
|
36 |
+
**tf-models-official** is the stable Model Garden package. Please check out the [releases](https://github.com/tensorflow/models/releases) to see what are available modules.
|
37 |
+
|
38 |
+
pip3 will install all models and dependencies automatically.
|
39 |
+
|
40 |
+
```shell
|
41 |
+
pip3 install tf-models-official
|
42 |
+
```
|
43 |
+
|
44 |
+
Please check out our examples:
|
45 |
+
- [basic library import](https://github.com/tensorflow/models/blob/master/tensorflow_models/tensorflow_models_pypi.ipynb)
|
46 |
+
- [nlp model building](https://github.com/tensorflow/models/blob/master/docs/nlp/index.ipynb)
|
47 |
+
to learn how to use a PIP package.
|
48 |
+
|
49 |
+
Note that **tf-models-official** may not include the latest changes in the master branch of this
|
50 |
+
github repo. To include latest changes, you may install **tf-models-nightly**,
|
51 |
+
which is the nightly Model Garden package created daily automatically.
|
52 |
+
|
53 |
+
```shell
|
54 |
+
pip3 install tf-models-nightly
|
55 |
+
```
|
56 |
+
|
57 |
+
</details>
|
58 |
+
|
59 |
+
|
60 |
+
#### Method 2: Clone the source
|
61 |
+
|
62 |
+
<details>
|
63 |
+
|
64 |
+
1. Clone the GitHub repository:
|
65 |
+
|
66 |
+
```shell
|
67 |
+
git clone https://github.com/tensorflow/models.git
|
68 |
+
```
|
69 |
+
|
70 |
+
2. Add the top-level ***/models*** folder to the Python path.
|
71 |
+
|
72 |
+
```shell
|
73 |
+
export PYTHONPATH=$PYTHONPATH:/path/to/models
|
74 |
+
```
|
75 |
+
|
76 |
+
If you are using in a Windows environment, you may need to use the following command with PowerShell:
|
77 |
+
```shell
|
78 |
+
$env:PYTHONPATH += ":\path\to\models"
|
79 |
+
```
|
80 |
+
|
81 |
+
If you are using a Colab notebook, please set the Python path with os.environ.
|
82 |
+
|
83 |
+
```python
|
84 |
+
import os
|
85 |
+
os.environ['PYTHONPATH'] += ":/path/to/models"
|
86 |
+
```
|
87 |
+
|
88 |
+
3. Install other dependencies
|
89 |
+
|
90 |
+
```shell
|
91 |
+
pip3 install --user -r models/official/requirements.txt
|
92 |
+
```
|
93 |
+
|
94 |
+
Finally, if you are using nlp packages, please also install
|
95 |
+
**tensorflow-text-nightly**:
|
96 |
+
|
97 |
+
```shell
|
98 |
+
pip3 install tensorflow-text-nightly
|
99 |
+
```
|
100 |
+
|
101 |
+
</details>
|
102 |
+
|
103 |
+
|
104 |
+
## Announcements
|
105 |
+
|
106 |
+
Please check [this page](https://github.com/tensorflow/models/wiki/Announcements) for recent announcements.
|
107 |
+
|
108 |
+
## Contributions
|
109 |
+
|
110 |
+
[![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
|
111 |
+
|
112 |
+
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
|
113 |
+
|
114 |
+
## License
|
115 |
+
|
116 |
+
[Apache License 2.0](LICENSE)
|
117 |
+
|
118 |
+
## Citing TensorFlow Model Garden
|
119 |
+
|
120 |
+
If you use TensorFlow Model Garden in your research, please cite this repository.
|
121 |
+
|
122 |
+
```
|
123 |
+
@misc{tensorflowmodelgarden2020,
|
124 |
+
author = {Hongkun Yu, Chen Chen, Xianzhi Du, Yeqing Li, Abdullah Rashwan, Le Hou, Pengchong Jin, Fan Yang,
|
125 |
+
Frederick Liu, Jaeyoun Kim, and Jing Li},
|
126 |
+
title = {{TensorFlow Model Garden}},
|
127 |
+
howpublished = {\url{https://github.com/tensorflow/models}},
|
128 |
+
year = {2020}
|
129 |
+
}
|
130 |
+
```
|
Tensorflow/models/SECURITY.md
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Using TensorFlow Securely
|
2 |
+
|
3 |
+
This document discusses how to safely deal with untrusted programs (models or
|
4 |
+
model parameters), and input data. Below, we also provide guidelines on how to
|
5 |
+
report vulnerabilities in TensorFlow.
|
6 |
+
|
7 |
+
## TensorFlow models are programs
|
8 |
+
|
9 |
+
TensorFlow's runtime system interprets and executes programs. What machine
|
10 |
+
learning practitioners term
|
11 |
+
[**models**](https://developers.google.com/machine-learning/glossary/#model) are
|
12 |
+
expressed as programs that TensorFlow executes. TensorFlow programs are encoded
|
13 |
+
as computation
|
14 |
+
[**graphs**](https://developers.google.com/machine-learning/glossary/#graph).
|
15 |
+
The model's parameters are often stored separately in **checkpoints**.
|
16 |
+
|
17 |
+
At runtime, TensorFlow executes the computation graph using the parameters
|
18 |
+
provided. Note that the behavior of the computation graph may change
|
19 |
+
depending on the parameters provided. TensorFlow itself is not a sandbox. When
|
20 |
+
executing the computation graph, TensorFlow may read and write files, send and
|
21 |
+
receive data over the network, and even spawn additional processes. All these
|
22 |
+
tasks are performed with the permissions of the TensorFlow process. Allowing
|
23 |
+
for this flexibility makes for a powerful machine learning platform,
|
24 |
+
but it has implications for security.
|
25 |
+
|
26 |
+
The computation graph may also accept **inputs**. Those inputs are the
|
27 |
+
data you supply to TensorFlow to train a model, or to use a model to run
|
28 |
+
inference on the data.
|
29 |
+
|
30 |
+
**TensorFlow models are programs, and need to be treated as such from a security
|
31 |
+
perspective.**
|
32 |
+
|
33 |
+
## Running untrusted models
|
34 |
+
|
35 |
+
As a general rule: **Always** execute untrusted models inside a sandbox (e.g.,
|
36 |
+
[nsjail](https://github.com/google/nsjail)).
|
37 |
+
|
38 |
+
There are several ways in which a model could become untrusted. Obviously, if an
|
39 |
+
untrusted party supplies TensorFlow kernels, arbitrary code may be executed.
|
40 |
+
The same is true if the untrusted party provides Python code, such as the
|
41 |
+
Python code that generates TensorFlow graphs.
|
42 |
+
|
43 |
+
Even if the untrusted party only supplies the serialized computation
|
44 |
+
graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the
|
45 |
+
set of computation primitives available to TensorFlow is powerful enough that
|
46 |
+
you should assume that the TensorFlow process effectively executes arbitrary
|
47 |
+
code. One common solution is to allow only a few safe Ops. While this is
|
48 |
+
possible in theory, we still recommend you sandbox the execution.
|
49 |
+
|
50 |
+
It depends on the computation graph whether a user provided checkpoint is safe.
|
51 |
+
It is easily possible to create computation graphs in which malicious
|
52 |
+
checkpoints can trigger unsafe behavior. For example, consider a graph that
|
53 |
+
contains a `tf.cond` depending on the value of a `tf.Variable`. One branch of
|
54 |
+
the `tf.cond` is harmless, but the other is unsafe. Since the `tf.Variable` is
|
55 |
+
stored in the checkpoint, whoever provides the checkpoint now has the ability to
|
56 |
+
trigger unsafe behavior, even though the graph is not under their control.
|
57 |
+
|
58 |
+
In other words, graphs can contain vulnerabilities of their own. To allow users
|
59 |
+
to provide checkpoints to a model you run on their behalf (e.g., in order to
|
60 |
+
compare model quality for a fixed model architecture), you must carefully audit
|
61 |
+
your model, and we recommend you run the TensorFlow process in a sandbox.
|
62 |
+
|
63 |
+
## Accepting untrusted Inputs
|
64 |
+
|
65 |
+
It is possible to write models that are secure in a sense that they can safely
|
66 |
+
process untrusted inputs assuming there are no bugs. There are two main reasons
|
67 |
+
to not rely on this: First, it is easy to write models which must not be exposed
|
68 |
+
to untrusted inputs, and second, there are bugs in any software system of
|
69 |
+
sufficient complexity. Letting users control inputs could allow them to trigger
|
70 |
+
bugs either in TensorFlow or in dependent libraries.
|
71 |
+
|
72 |
+
In general, it is good practice to isolate parts of any system which is exposed
|
73 |
+
to untrusted (e.g., user-provided) inputs in a sandbox.
|
74 |
+
|
75 |
+
A useful analogy to how any TensorFlow graph is executed is any interpreted
|
76 |
+
programming language, such as Python. While it is possible to write secure
|
77 |
+
Python code which can be exposed to user supplied inputs (by, e.g., carefully
|
78 |
+
quoting and sanitizing input strings, size-checking input blobs, etc.), it is
|
79 |
+
very easy to write Python programs which are insecure. Even secure Python code
|
80 |
+
could be rendered insecure by a bug in the Python interpreter, or in a bug in a
|
81 |
+
Python library used (e.g.,
|
82 |
+
[this one](https://www.cvedetails.com/cve/CVE-2017-12852/)).
|
83 |
+
|
84 |
+
## Running a TensorFlow server
|
85 |
+
|
86 |
+
TensorFlow is a platform for distributed computing, and as such there is a
|
87 |
+
TensorFlow server (`tf.train.Server`). **The TensorFlow server is meant for
|
88 |
+
internal communication only. It is not built for use in an untrusted network.**
|
89 |
+
|
90 |
+
For performance reasons, the default TensorFlow server does not include any
|
91 |
+
authorization protocol and sends messages unencrypted. It accepts connections
|
92 |
+
from anywhere, and executes the graphs it is sent without performing any checks.
|
93 |
+
Therefore, if you run a `tf.train.Server` in your network, anybody with
|
94 |
+
access to the network can execute what you should consider arbitrary code with
|
95 |
+
the privileges of the process running the `tf.train.Server`.
|
96 |
+
|
97 |
+
When running distributed TensorFlow, you must isolate the network in which the
|
98 |
+
cluster lives. Cloud providers provide instructions for setting up isolated
|
99 |
+
networks, which are sometimes branded as "virtual private cloud." Refer to the
|
100 |
+
instructions for
|
101 |
+
[GCP](https://cloud.google.com/compute/docs/networks-and-firewalls) and
|
102 |
+
[AWS](https://aws.amazon.com/vpc/)) for details.
|
103 |
+
|
104 |
+
Note that `tf.train.Server` is different from the server created by
|
105 |
+
`tensorflow/serving` (the default binary for which is called `ModelServer`).
|
106 |
+
By default, `ModelServer` also has no built-in mechanism for authentication.
|
107 |
+
Connecting it to an untrusted network allows anyone on this network to run the
|
108 |
+
graphs known to the `ModelServer`. This means that an attacker may run
|
109 |
+
graphs using untrusted inputs as described above, but they would not be able to
|
110 |
+
execute arbitrary graphs. It is possible to safely expose a `ModelServer`
|
111 |
+
directly to an untrusted network, **but only if the graphs it is configured to
|
112 |
+
use have been carefully audited to be safe**.
|
113 |
+
|
114 |
+
Similar to best practices for other servers, we recommend running any
|
115 |
+
`ModelServer` with appropriate privileges (i.e., using a separate user with
|
116 |
+
reduced permissions). In the spirit of defense in depth, we recommend
|
117 |
+
authenticating requests to any TensorFlow server connected to an untrusted
|
118 |
+
network, as well as sandboxing the server to minimize the adverse effects of
|
119 |
+
any breach.
|
120 |
+
|
121 |
+
## Vulnerabilities in TensorFlow
|
122 |
+
|
123 |
+
TensorFlow is a large and complex system. It also depends on a large set of
|
124 |
+
third party libraries (e.g., `numpy`, `libjpeg-turbo`, PNG parsers, `protobuf`).
|
125 |
+
It is possible that TensorFlow or its dependent libraries contain
|
126 |
+
vulnerabilities that would allow triggering unexpected or dangerous behavior
|
127 |
+
with specially crafted inputs.
|
128 |
+
|
129 |
+
### What is a vulnerability?
|
130 |
+
|
131 |
+
Given TensorFlow's flexibility, it is possible to specify computation graphs
|
132 |
+
which exhibit unexpected or unwanted behavior. The fact that TensorFlow models
|
133 |
+
can perform arbitrary computations means that they may read and write files,
|
134 |
+
communicate via the network, produce deadlocks and infinite loops, or run out
|
135 |
+
of memory. It is only when these behaviors are outside the specifications of the
|
136 |
+
operations involved that such behavior is a vulnerability.
|
137 |
+
|
138 |
+
A `FileWriter` writing a file is not unexpected behavior and therefore is not a
|
139 |
+
vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution
|
140 |
+
**is** a vulnerability.
|
141 |
+
|
142 |
+
This is more subtle from a system perspective. For example, it is easy to cause
|
143 |
+
a TensorFlow process to try to allocate more memory than available by specifying
|
144 |
+
a computation graph containing an ill-considered `tf.tile` operation. TensorFlow
|
145 |
+
should exit cleanly in this case (it would raise an exception in Python, or
|
146 |
+
return an error `Status` in C++). However, if the surrounding system is not
|
147 |
+
expecting the possibility, such behavior could be used in a denial of service
|
148 |
+
attack (or worse). Because TensorFlow behaves correctly, this is not a
|
149 |
+
vulnerability in TensorFlow (although it would be a vulnerability of this
|
150 |
+
hypothetical system).
|
151 |
+
|
152 |
+
As a general rule, it is incorrect behavior for TensorFlow to access memory it
|
153 |
+
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
|
154 |
+
such behaviors constitute a vulnerability.
|
155 |
+
|
156 |
+
One of the most critical parts of any system is input handling. If malicious
|
157 |
+
input can trigger side effects or incorrect behavior, this is a bug, and likely
|
158 |
+
a vulnerability.
|
159 |
+
|
160 |
+
### Reporting vulnerabilities
|
161 |
+
|
162 |
+
Please email reports about any security related issues you find to
|
163 |
+
`security@tensorflow.org`. This mail is delivered to a small security team. Your
|
164 |
+
email will be acknowledged within one business day, and you'll receive a more
|
165 |
+
detailed response to your email within 7 days indicating the next steps in
|
166 |
+
handling your report. For critical problems, you may encrypt your report (see
|
167 |
+
below).
|
168 |
+
|
169 |
+
Please use a descriptive subject line for your report email. After the initial
|
170 |
+
reply to your report, the security team will endeavor to keep you informed of
|
171 |
+
the progress being made towards a fix and announcement.
|
172 |
+
|
173 |
+
In addition, please include the following information along with your report:
|
174 |
+
|
175 |
+
* Your name and affiliation (if any).
|
176 |
+
* A description of the technical details of the vulnerabilities. It is very
|
177 |
+
important to let us know how we can reproduce your findings.
|
178 |
+
* An explanation who can exploit this vulnerability, and what they gain when
|
179 |
+
doing so -- write an attack scenario. This will help us evaluate your report
|
180 |
+
quickly, especially if the issue is complex.
|
181 |
+
* Whether this vulnerability public or known to third parties. If it is, please
|
182 |
+
provide details.
|
183 |
+
|
184 |
+
If you believe that an existing (public) issue is security-related, please send
|
185 |
+
an email to `security@tensorflow.org`. The email should include the issue ID and
|
186 |
+
a short description of why it should be handled according to this security
|
187 |
+
policy.
|
188 |
+
|
189 |
+
Once an issue is reported, TensorFlow uses the following disclosure process:
|
190 |
+
|
191 |
+
* When a report is received, we confirm the issue and determine its severity.
|
192 |
+
* If we know of specific third-party services or software based on TensorFlow
|
193 |
+
that require mitigation before publication, those projects will be notified.
|
194 |
+
* An advisory is prepared (but not published) which details the problem and
|
195 |
+
steps for mitigation.
|
196 |
+
* The vulnerability is fixed and potential workarounds are identified.
|
197 |
+
* Wherever possible, the fix is also prepared for the branches corresponding to
|
198 |
+
all releases of TensorFlow at most one year old. We will attempt to commit
|
199 |
+
these fixes as soon as possible, and as close together as possible.
|
200 |
+
* Patch releases are published for all fixed released versions, a
|
201 |
+
notification is sent to discuss@tensorflow.org, and the advisory is published.
|
202 |
+
|
203 |
+
Note that we mostly do patch releases for security reasons and each version of
|
204 |
+
TensorFlow is supported for only 1 year after the release.
|
205 |
+
|
206 |
+
Past security advisories are listed below. We credit reporters for identifying
|
207 |
+
security issues, although we keep your name confidential if you request it.
|
208 |
+
|
209 |
+
#### Encryption key for `security@tensorflow.org`
|
210 |
+
|
211 |
+
If your disclosure is extremely sensitive, you may choose to encrypt your
|
212 |
+
report using the key below. Please only use this for critical security
|
213 |
+
reports.
|
214 |
+
|
215 |
+
```
|
216 |
+
-----BEGIN PGP PUBLIC KEY BLOCK-----
|
217 |
+
|
218 |
+
mQENBFpqdzwBCADTeAHLNEe9Vm77AxhmGP+CdjlY84O6DouOCDSq00zFYdIU/7aI
|
219 |
+
LjYwhEmDEvLnRCYeFGdIHVtW9YrVktqYE9HXVQC7nULU6U6cvkQbwHCdrjaDaylP
|
220 |
+
aJUXkNrrxibhx9YYdy465CfusAaZ0aM+T9DpcZg98SmsSml/HAiiY4mbg/yNVdPs
|
221 |
+
SEp/Ui4zdIBNNs6at2gGZrd4qWhdM0MqGJlehqdeUKRICE/mdedXwsWLM8AfEA0e
|
222 |
+
OeTVhZ+EtYCypiF4fVl/NsqJ/zhBJpCx/1FBI1Uf/lu2TE4eOS1FgmIqb2j4T+jY
|
223 |
+
e+4C8kGB405PAC0n50YpOrOs6k7fiQDjYmbNABEBAAG0LVRlbnNvckZsb3cgU2Vj
|
224 |
+
dXJpdHkgPHNlY3VyaXR5QHRlbnNvcmZsb3cub3JnPokBTgQTAQgAOBYhBEkvXzHm
|
225 |
+
gOJBnwP4Wxnef3wVoM2yBQJaanc8AhsDBQsJCAcCBhUKCQgLAgQWAgMBAh4BAheA
|
226 |
+
AAoJEBnef3wVoM2yNlkIAICqetv33MD9W6mPAXH3eon+KJoeHQHYOuwWfYkUF6CC
|
227 |
+
o+X2dlPqBSqMG3bFuTrrcwjr9w1V8HkNuzzOJvCm1CJVKaxMzPuXhBq5+DeT67+a
|
228 |
+
T/wK1L2R1bF0gs7Pp40W3np8iAFEh8sgqtxXvLGJLGDZ1Lnfdprg3HciqaVAiTum
|
229 |
+
HBFwszszZZ1wAnKJs5KVteFN7GSSng3qBcj0E0ql2nPGEqCVh+6RG/TU5C8gEsEf
|
230 |
+
3DX768M4okmFDKTzLNBm+l08kkBFt+P43rNK8dyC4PXk7yJa93SmS/dlK6DZ16Yw
|
231 |
+
2FS1StiZSVqygTW59rM5XNwdhKVXy2mf/RtNSr84gSi5AQ0EWmp3PAEIALInfBLR
|
232 |
+
N6fAUGPFj+K3za3PeD0fWDijlC9f4Ety/icwWPkOBdYVBn0atzI21thPRbfuUxfe
|
233 |
+
zr76xNNrtRRlbDSAChA1J5T86EflowcQor8dNC6fS+oHFCGeUjfEAm16P6mGTo0p
|
234 |
+
osdG2XnnTHOOEFbEUeWOwR/zT0QRaGGknoy2pc4doWcJptqJIdTl1K8xyBieik/b
|
235 |
+
nSoClqQdZJa4XA3H9G+F4NmoZGEguC5GGb2P9NHYAJ3MLHBHywZip8g9oojIwda+
|
236 |
+
OCLL4UPEZ89cl0EyhXM0nIAmGn3Chdjfu3ebF0SeuToGN8E1goUs3qSE77ZdzIsR
|
237 |
+
BzZSDFrgmZH+uP0AEQEAAYkBNgQYAQgAIBYhBEkvXzHmgOJBnwP4Wxnef3wVoM2y
|
238 |
+
BQJaanc8AhsMAAoJEBnef3wVoM2yX4wIALcYZbQhSEzCsTl56UHofze6C3QuFQIH
|
239 |
+
J4MIKrkTfwiHlCujv7GASGU2Vtis5YEyOoMidUVLlwnebE388MmaJYRm0fhYq6lP
|
240 |
+
A3vnOCcczy1tbo846bRdv012zdUA+wY+mOITdOoUjAhYulUR0kiA2UdLSfYzbWwy
|
241 |
+
7Obq96Jb/cPRxk8jKUu2rqC/KDrkFDtAtjdIHh6nbbQhFuaRuWntISZgpIJxd8Bt
|
242 |
+
Gwi0imUVd9m9wZGuTbDGi6YTNk0GPpX5OMF5hjtM/objzTihSw9UN+65Y/oSQM81
|
243 |
+
v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
|
244 |
+
=CDME
|
245 |
+
-----END PGP PUBLIC KEY BLOCK-----
|
246 |
+
```
|
247 |
+
|
248 |
+
### Known Vulnerabilities
|
249 |
+
|
250 |
+
At this time there are no known vulnerability with TensorFlow-models. For a list of known vulnerabilities and security advisories for TensorFlow,
|
251 |
+
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).
|
Tensorflow/models/This PC - Shortcut.lnk
ADDED
Binary file (420 Bytes). View file
|
|
Tensorflow/models/community/README.md
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
![Logo](https://storage.googleapis.com/tf_model_garden/tf_model_garden_logo.png)
|
2 |
+
|
3 |
+
# TensorFlow Community Models
|
4 |
+
|
5 |
+
This repository provides a curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2.
|
6 |
+
|
7 |
+
**Note**: Contributing companies or individuals are responsible for maintaining their repositories.
|
8 |
+
|
9 |
+
## Computer Vision
|
10 |
+
|
11 |
+
### Image Recognition
|
12 |
+
|
13 |
+
| Model | Paper | Features | Maintainer |
|
14 |
+
|-------|-------|----------|------------|
|
15 |
+
| [DenseNet 169](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/densenet169) | [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993) | • FP32 Inference | [Intel](https://github.com/IntelAI) |
|
16 |
+
| [Inception V3](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv3) | [Rethinking the Inception Architecture<br/>for Computer Vision](https://arxiv.org/pdf/1512.00567.pdf) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
|
17 |
+
| [Inception V4](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv4) | [Inception-v4, Inception-ResNet and the Impact<br/>of Residual Connections on Learning](https://arxiv.org/pdf/1602.07261) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
|
18 |
+
| [MobileNet V1](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/mobilenet_v1) | [MobileNets: Efficient Convolutional Neural Networks<br/>for Mobile Vision Applications](https://arxiv.org/pdf/1704.04861) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
|
19 |
+
| [ResNet 101](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet101) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
|
20 |
+
| [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
|
21 |
+
| [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
|
22 |
+
| EfficientNet [v1](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Classification/ConvNets/efficientnet_v1) [v2](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Classification/ConvNets/efficientnet_v2) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/pdf/1905.11946.pdf) | • Automatic mixed precision<br/>• Horovod Multi-GPU training (NCCL)<br/>• Multi-node training on a Pyxis/Enroot Slurm cluster<br/>• XLA | [NVIDIA](https://github.com/NVIDIA) |
|
23 |
+
|
24 |
+
### Object Detection
|
25 |
+
|
26 |
+
| Model | Paper | Features | Maintainer |
|
27 |
+
|-------|-------|----------|------------|
|
28 |
+
| [R-FCN](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/rfcn) | [R-FCN: Object Detection<br/>via Region-based Fully Convolutional Networks](https://arxiv.org/pdf/1605.06409) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
|
29 |
+
| [SSD-MobileNet](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-mobilenet) | [MobileNets: Efficient Convolutional Neural Networks<br/>for Mobile Vision Applications](https://arxiv.org/pdf/1704.04861) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
|
30 |
+
| [SSD-ResNet34](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-resnet34) | [SSD: Single Shot MultiBox Detector](https://arxiv.org/pdf/1512.02325) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
|
31 |
+
|
32 |
+
### Segmentation
|
33 |
+
|
34 |
+
| Model | Paper | Features | Maintainer |
|
35 |
+
|-------|-------|----------|------------|
|
36 |
+
| [Mask R-CNN](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/MaskRCNN) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) |
|
37 |
+
| [U-Net Medical Image Segmentation](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/UNet_Medical) | [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) |
|
38 |
+
|
39 |
+
## Natural Language Processing
|
40 |
+
|
41 |
+
| Model | Paper | Features | Maintainer |
|
42 |
+
|-------|-------|----------|------------|
|
43 |
+
| [BERT](https://github.com/IntelAI/models/tree/master/benchmarks/language_modeling/tensorflow/bert_large) | [BERT: Pre-training of Deep Bidirectional Transformers<br/>for Language Understanding](https://arxiv.org/pdf/1810.04805) | • FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
|
44 |
+
| [BERT](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/LanguageModeling/BERT) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805) | • Horovod Multi-GPU<br/>• Multi-node with Horovod and Pyxis/Enroot Slurm cluster<br/>• XLA<br/>• Automatic mixed precision<br/>• LAMB | [NVIDIA](https://github.com/NVIDIA) |
|
45 |
+
| [ELECTRA](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/LanguageModeling/ELECTRA) | [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://openreview.net/forum?id=r1xMH1BtvB) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• Multi-node training on a Pyxis/Enroot Slurm cluster | [NVIDIA](https://github.com/NVIDIA) |
|
46 |
+
| [GNMT](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/mlperf_gnmt) | [Google’s Neural Machine Translation System:<br/>Bridging the Gap between Human and Machine Translation](https://arxiv.org/pdf/1609.08144) | • FP32 Inference | [Intel](https://github.com/IntelAI) |
|
47 |
+
| [Transformer-LT (Official)](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/transformer_lt_official) | [Attention Is All You Need](https://arxiv.org/pdf/1706.03762) | • FP32 Inference | [Intel](https://github.com/IntelAI) |
|
48 |
+
| [Transformer-LT (MLPerf)](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/transformer_mlperf) | [Attention Is All You Need](https://arxiv.org/pdf/1706.03762) | • FP32 Training | [Intel](https://github.com/IntelAI) |
|
49 |
+
|
50 |
+
## Recommendation Systems
|
51 |
+
|
52 |
+
| Model | Paper | Features | Maintainer |
|
53 |
+
|-------|-------|----------|------------|
|
54 |
+
| [Wide & Deep](https://github.com/IntelAI/models/tree/master/benchmarks/recommendation/tensorflow/wide_deep_large_ds) | [Wide & Deep Learning for Recommender Systems](https://arxiv.org/pdf/1606.07792) | • FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
|
55 |
+
| [Wide & Deep](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Recommendation/WideAndDeep) | [Wide & Deep Learning for Recommender Systems](https://arxiv.org/pdf/1606.07792) | • Automatic mixed precision<br/>• Multi-GPU training support with Horovod<br/>• XLA | [NVIDIA](https://github.com/NVIDIA) |
|
56 |
+
| [DLRM](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Recommendation/DLRM) | [Deep Learning Recommendation Model for Personalization and Recommendation Systems](https://arxiv.org/pdf/1906.00091.pdf) | • Automatic Mixed Precision<br/>• Hybrid-parallel multiGPU training using Horovod all2all<br/>• Multinode training for Pyxis/Enroot Slurm clusters<br/>• XLA<br/>• Criteo dataset preprocessing with Spark on GPU | [NVIDIA](https://github.com/NVIDIA) |
|
57 |
+
|
58 |
+
## Contributions
|
59 |
+
|
60 |
+
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
|
Tensorflow/models/docs/README.md
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Public docs for TensorFlow Models
|
2 |
+
|
3 |
+
This directory contains the top-level public documentation for
|
4 |
+
[TensorFlow Models](https://github.com/tensorflow/models).
|
5 |
+
|
6 |
+
This directory is mirrored to https://tensorflow.org/tfmodels, and is mainly
|
7 |
+
concerned with documenting the tools provided in the `tensorflow_models` pip
|
8 |
+
package (including `orbit`).
|
9 |
+
|
10 |
+
Api-reference pages are
|
11 |
+
[available on the site](https://www.tensorflow.org/api_docs/more).
|
12 |
+
|
13 |
+
The
|
14 |
+
[Official Models](https://github.com/tensorflow/models/blob/master/official/projects)
|
15 |
+
and [Research Models](https://github.com/tensorflow/models/blob/master/research)
|
16 |
+
directories are not described in detail here, refer to the individual project
|
17 |
+
directories for more information.
|
Tensorflow/models/docs/index.md
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Garden overview
|
2 |
+
|
3 |
+
The TensorFlow Model Garden provides implementations of many state-of-the-art
|
4 |
+
machine learning (ML) models for vision and natural language processing (NLP),
|
5 |
+
as well as workflow tools to let you quickly configure and run those models on
|
6 |
+
standard datasets. Whether you are looking to benchmark performance for a
|
7 |
+
well-known model, verify the results of recently released research, or extend
|
8 |
+
existing models, the Model Garden can help you drive your ML research and
|
9 |
+
applications forward.
|
10 |
+
|
11 |
+
The Model Garden includes the following resources for machine learning
|
12 |
+
developers:
|
13 |
+
|
14 |
+
- [**Official models**](#official) for vision and NLP, maintained by Google
|
15 |
+
engineers
|
16 |
+
- [**Research models**](#research) published as part of ML research papers
|
17 |
+
- [**Training experiment framework**](#training_framework) for fast,
|
18 |
+
declarative training configuration of official models
|
19 |
+
- [**Specialized ML operations**](#ops) for vision and natural language
|
20 |
+
processing (NLP)
|
21 |
+
- [**Model training loop**](#orbit) management with Orbit
|
22 |
+
|
23 |
+
These resources are built to be used with the TensorFlow Core framework and
|
24 |
+
integrate with your existing TensorFlow development projects. Model
|
25 |
+
Garden resources are also provided under an [open
|
26 |
+
source](https://github.com/tensorflow/models/blob/master/LICENSE) license, so
|
27 |
+
you can freely extend and distribute the models and tools.
|
28 |
+
|
29 |
+
Practical ML models are computationally intensive to train and run, and may
|
30 |
+
require accelerators such as Graphical Processing Units (GPUs) and Tensor
|
31 |
+
Processing Units (TPUs). Most of the models in Model Garden were trained on
|
32 |
+
large datasets using TPUs. However, you can also train and run these models on
|
33 |
+
GPU and CPU processors.
|
34 |
+
|
35 |
+
## Model Garden models
|
36 |
+
|
37 |
+
The machine learning models in the Model Garden include full code so you can
|
38 |
+
test, train, or re-train them for research and experimentation. The Model Garden
|
39 |
+
includes two primary categories of models: *official models* and *research
|
40 |
+
models*.
|
41 |
+
|
42 |
+
### Official models {:#official}
|
43 |
+
|
44 |
+
The [Official Models](https://github.com/tensorflow/models/tree/master/official)
|
45 |
+
repository is a collection of state-of-the-art models, with a focus on
|
46 |
+
vision and natural language processing (NLP).
|
47 |
+
These models are implemented using current TensorFlow 2.x high-level
|
48 |
+
APIs. Model libraries in this repository are optimized for fast performance and
|
49 |
+
actively maintained by Google engineers. The official models include additional
|
50 |
+
metadata you can use to quickly configure experiments using the Model Garden
|
51 |
+
[training experiment framework](#training_framework).
|
52 |
+
|
53 |
+
### Research models {:#research}
|
54 |
+
|
55 |
+
The [Research Models](https://github.com/tensorflow/models/tree/master/research)
|
56 |
+
repository is a collection of models published as code resources for research
|
57 |
+
papers. These models are implemented using both TensorFlow 1.x and 2.x. Model
|
58 |
+
libraries in the research folder are supported by the code owners and the
|
59 |
+
research community.
|
60 |
+
|
61 |
+
## Training experiment framework {:#training_framework}
|
62 |
+
|
63 |
+
The Model Garden training experiment framework lets you quickly assemble and run
|
64 |
+
training experiments using its official models and standard datasets. The
|
65 |
+
training framework uses additional metadata included with the Model Garden's
|
66 |
+
official models to allow you to configure models quickly using a declarative
|
67 |
+
programming model. You can define a training experiment using Python commands in
|
68 |
+
the
|
69 |
+
[TensorFlow Model library](https://www.tensorflow.org/api_docs/python/tfm/core)
|
70 |
+
or configure training using a YAML configuration file, like this
|
71 |
+
[example](https://github.com/tensorflow/models/blob/master/official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml).
|
72 |
+
|
73 |
+
The training framework uses
|
74 |
+
[`tfm.core.base_trainer.ExperimentConfig`](https://www.tensorflow.org/api_docs/python/tfm/core/base_trainer/ExperimentConfig)
|
75 |
+
as the configuration object, which contains the following top-level
|
76 |
+
configuration objects:
|
77 |
+
|
78 |
+
- [`runtime`](https://www.tensorflow.org/api_docs/python/tfm/core/base_task/RuntimeConfig):
|
79 |
+
Defines the processing hardware, distribution strategy, and other
|
80 |
+
performance optimizations
|
81 |
+
- [`task`](https://www.tensorflow.org/api_docs/python/tfm/core/config_definitions/TaskConfig):
|
82 |
+
Defines the model, training data, losses, and initialization
|
83 |
+
- [`trainer`](https://www.tensorflow.org/api_docs/python/tfm/core/base_trainer/TrainerConfig):
|
84 |
+
Defines the optimizer, training loops, evaluation loops, summaries, and
|
85 |
+
checkpoints
|
86 |
+
|
87 |
+
For a complete example using the Model Garden training experiment framework, see
|
88 |
+
the [Image classification with Model Garden](vision/image_classification.ipynb)
|
89 |
+
tutorial. For information on the training experiment framework, check out the
|
90 |
+
[TensorFlow Models API documentation](https://tensorflow.org/api_docs/python/tfm/core).
|
91 |
+
If you are looking for a solution to manage training loops for your model
|
92 |
+
training experiments, check out [Orbit](#orbit).
|
93 |
+
|
94 |
+
## Specialized ML operations {:#ops}
|
95 |
+
|
96 |
+
The Model Garden contains many vision and NLP operations specifically designed
|
97 |
+
to execute state-of-the-art models that run efficiently on GPUs and TPUs. Review
|
98 |
+
the TensorFlow Models Vision library API docs for a list of specialized
|
99 |
+
[vision operations](https://www.tensorflow.org/api_docs/python/tfm/vision).
|
100 |
+
Review the TensorFlow Models NLP Library API docs for a list of
|
101 |
+
[NLP operations](https://www.tensorflow.org/api_docs/python/tfm/nlp). These
|
102 |
+
libraries also include additional utility functions used for vision and NLP data
|
103 |
+
processing, training, and model execution.
|
104 |
+
|
105 |
+
## Training loops with Orbit {:#orbit}
|
106 |
+
|
107 |
+
There are two default options for training TensorFlow models:
|
108 |
+
|
109 |
+
* Use the high-level Keras
|
110 |
+
[Model.fit](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit)
|
111 |
+
function. If your model and training procedure fit the assumptions of Keras'
|
112 |
+
`Model.fit` (incremental gradient descent on batches of data) method this can
|
113 |
+
be very convenient.
|
114 |
+
* Write a custom training loop
|
115 |
+
[with keras](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch),
|
116 |
+
or [without](https://www.tensorflow.org/guide/core/logistic_regression_core).
|
117 |
+
You can write a custom training loop with low-level TensorFlow methods such as
|
118 |
+
`tf.GradientTape` or `tf.function`. However, this approach requires a lot of
|
119 |
+
boilerplate code, and doesn't do anything to simplify distributed training.
|
120 |
+
|
121 |
+
Orbit tries to provide a third option in between these two extremes.
|
122 |
+
|
123 |
+
Orbit is a flexible, lightweight library designed to make it easier to
|
124 |
+
write custom training loops in TensorFlow 2.x, and works well with the Model
|
125 |
+
Garden [training experiment framework](#training_framework). Orbit handles
|
126 |
+
common model training tasks such as saving checkpoints, running model
|
127 |
+
evaluations, and setting up summary writing. It seamlessly integrates with
|
128 |
+
`tf.distribute` and supports running on different device types, including CPU,
|
129 |
+
GPU, and TPU hardware. The Orbit tool is also [open
|
130 |
+
source](https://github.com/tensorflow/models/blob/master/orbit/LICENSE), so you
|
131 |
+
can extend and adapt to your model training needs.
|
132 |
+
|
133 |
+
The Orbit guide is available [here](orbit/index.ipynb).
|
134 |
+
|
135 |
+
Note: You can customize how the Keras API executes training. Mainly you must
|
136 |
+
override the `Model.train_step` method or use `keras.callbacks` like
|
137 |
+
`callbacks.ModelCheckpoint` or `callbacks.TensorBoard`. For more information
|
138 |
+
about modifying the behavior of `train_step`, check out the
|
139 |
+
[Customize what happens in Model.fit](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit)
|
140 |
+
page.
|
Tensorflow/models/docs/nlp/_guide_toc.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
toc:
|
2 |
+
- heading: TensorFlow Models - NLP
|
3 |
+
style: divider
|
4 |
+
- title: "Overview"
|
5 |
+
path: /tfmodels/nlp
|
6 |
+
- title: "Customize a transformer encoder"
|
7 |
+
path: /tfmodels/nlp/customize_encoder
|
8 |
+
- title: "Load LM checkpoints"
|
9 |
+
path: /tfmodels/nlp/load_lm_ckpts
|
Tensorflow/models/docs/nlp/customize_encoder.ipynb
ADDED
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "Bp8t2AI8i7uP"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"##### Copyright 2022 The TensorFlow Authors."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"cellView": "form",
|
17 |
+
"id": "rxPj2Lsni9O4"
|
18 |
+
},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
22 |
+
"# you may not use this file except in compliance with the License.\n",
|
23 |
+
"# You may obtain a copy of the License at\n",
|
24 |
+
"#\n",
|
25 |
+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
26 |
+
"#\n",
|
27 |
+
"# Unless required by applicable law or agreed to in writing, software\n",
|
28 |
+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
29 |
+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
30 |
+
"# See the License for the specific language governing permissions and\n",
|
31 |
+
"# limitations under the License."
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"metadata": {
|
37 |
+
"id": "6xS-9i5DrRvO"
|
38 |
+
},
|
39 |
+
"source": [
|
40 |
+
"# Customizing a Transformer Encoder"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"metadata": {
|
46 |
+
"id": "Mwb9uw1cDXsa"
|
47 |
+
},
|
48 |
+
"source": [
|
49 |
+
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
50 |
+
" <td>\n",
|
51 |
+
" <a target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/customize_encoder\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
|
52 |
+
" </td>\n",
|
53 |
+
" <td>\n",
|
54 |
+
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
55 |
+
" </td>\n",
|
56 |
+
" <td>\n",
|
57 |
+
" <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
|
58 |
+
" </td>\n",
|
59 |
+
" <td>\n",
|
60 |
+
" <a href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
|
61 |
+
" </td>\n",
|
62 |
+
"</table>"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "markdown",
|
67 |
+
"metadata": {
|
68 |
+
"id": "iLrcV4IyrcGX"
|
69 |
+
},
|
70 |
+
"source": [
|
71 |
+
"## Learning objectives\n",
|
72 |
+
"\n",
|
73 |
+
"The [TensorFlow Models NLP library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling) is a collection of tools for building and training modern high performance natural language models.\n",
|
74 |
+
"\n",
|
75 |
+
"The `tfm.nlp.networks.EncoderScaffold` is the core of this library, and lots of new network architectures are proposed to improve the encoder. In this Colab notebook, we will learn how to customize the encoder to employ new network architectures."
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "markdown",
|
80 |
+
"metadata": {
|
81 |
+
"id": "YYxdyoWgsl8t"
|
82 |
+
},
|
83 |
+
"source": [
|
84 |
+
"## Install and import"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "markdown",
|
89 |
+
"metadata": {
|
90 |
+
"id": "fEJSFutUsn_h"
|
91 |
+
},
|
92 |
+
"source": [
|
93 |
+
"### Install the TensorFlow Model Garden pip package\n",
|
94 |
+
"\n",
|
95 |
+
"* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
|
96 |
+
"which is the nightly Model Garden package created daily automatically.\n",
|
97 |
+
"* `pip` will install all models and dependencies automatically."
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": null,
|
103 |
+
"metadata": {
|
104 |
+
"id": "mfHI5JyuJ1y9"
|
105 |
+
},
|
106 |
+
"outputs": [],
|
107 |
+
"source": [
|
108 |
+
"!pip install -q opencv-python"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": null,
|
114 |
+
"metadata": {
|
115 |
+
"id": "thsKZDjhswhR"
|
116 |
+
},
|
117 |
+
"outputs": [],
|
118 |
+
"source": [
|
119 |
+
"!pip install -q tf-models-official"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"cell_type": "markdown",
|
124 |
+
"metadata": {
|
125 |
+
"id": "hpf7JPCVsqtv"
|
126 |
+
},
|
127 |
+
"source": [
|
128 |
+
"### Import Tensorflow and other libraries"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": null,
|
134 |
+
"metadata": {
|
135 |
+
"id": "my4dp-RMssQe"
|
136 |
+
},
|
137 |
+
"outputs": [],
|
138 |
+
"source": [
|
139 |
+
"import numpy as np\n",
|
140 |
+
"import tensorflow as tf\n",
|
141 |
+
"\n",
|
142 |
+
"import tensorflow_models as tfm\n",
|
143 |
+
"nlp = tfm.nlp"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "markdown",
|
148 |
+
"metadata": {
|
149 |
+
"id": "vjDmVsFfs85n"
|
150 |
+
},
|
151 |
+
"source": [
|
152 |
+
"## Canonical BERT encoder\n",
|
153 |
+
"\n",
|
154 |
+
"Before learning how to customize the encoder, let's firstly create a canonical BERT enoder and use it to instantiate a `bert_classifier.BertClassifier` for classification task."
|
155 |
+
]
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"cell_type": "code",
|
159 |
+
"execution_count": null,
|
160 |
+
"metadata": {
|
161 |
+
"id": "Oav8sbgstWc-"
|
162 |
+
},
|
163 |
+
"outputs": [],
|
164 |
+
"source": [
|
165 |
+
"cfg = {\n",
|
166 |
+
" \"vocab_size\": 100,\n",
|
167 |
+
" \"hidden_size\": 32,\n",
|
168 |
+
" \"num_layers\": 3,\n",
|
169 |
+
" \"num_attention_heads\": 4,\n",
|
170 |
+
" \"intermediate_size\": 64,\n",
|
171 |
+
" \"activation\": tfm.utils.activations.gelu,\n",
|
172 |
+
" \"dropout_rate\": 0.1,\n",
|
173 |
+
" \"attention_dropout_rate\": 0.1,\n",
|
174 |
+
" \"max_sequence_length\": 16,\n",
|
175 |
+
" \"type_vocab_size\": 2,\n",
|
176 |
+
" \"initializer\": tf.keras.initializers.TruncatedNormal(stddev=0.02),\n",
|
177 |
+
"}\n",
|
178 |
+
"bert_encoder = nlp.networks.BertEncoder(**cfg)\n",
|
179 |
+
"\n",
|
180 |
+
"def build_classifier(bert_encoder):\n",
|
181 |
+
" return nlp.models.BertClassifier(bert_encoder, num_classes=2)\n",
|
182 |
+
"\n",
|
183 |
+
"canonical_classifier_model = build_classifier(bert_encoder)"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "markdown",
|
188 |
+
"metadata": {
|
189 |
+
"id": "Qe2UWI6_tsHo"
|
190 |
+
},
|
191 |
+
"source": [
|
192 |
+
"`canonical_classifier_model` can be trained using the training data. For details about how to train the model, please see the [Fine tuning bert](https://www.tensorflow.org/text/tutorials/fine_tune_bert) notebook. We skip the code that trains the model here.\n",
|
193 |
+
"\n",
|
194 |
+
"After training, we can apply the model to do prediction.\n"
|
195 |
+
]
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"cell_type": "code",
|
199 |
+
"execution_count": null,
|
200 |
+
"metadata": {
|
201 |
+
"id": "csED2d-Yt5h6"
|
202 |
+
},
|
203 |
+
"outputs": [],
|
204 |
+
"source": [
|
205 |
+
"def predict(model):\n",
|
206 |
+
" batch_size = 3\n",
|
207 |
+
" np.random.seed(0)\n",
|
208 |
+
" word_ids = np.random.randint(\n",
|
209 |
+
" cfg[\"vocab_size\"], size=(batch_size, cfg[\"max_sequence_length\"]))\n",
|
210 |
+
" mask = np.random.randint(2, size=(batch_size, cfg[\"max_sequence_length\"]))\n",
|
211 |
+
" type_ids = np.random.randint(\n",
|
212 |
+
" cfg[\"type_vocab_size\"], size=(batch_size, cfg[\"max_sequence_length\"]))\n",
|
213 |
+
" print(model([word_ids, mask, type_ids], training=False))\n",
|
214 |
+
"\n",
|
215 |
+
"predict(canonical_classifier_model)"
|
216 |
+
]
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"cell_type": "markdown",
|
220 |
+
"metadata": {
|
221 |
+
"id": "PzKStEK9t_Pb"
|
222 |
+
},
|
223 |
+
"source": [
|
224 |
+
"## Customize BERT encoder\n",
|
225 |
+
"\n",
|
226 |
+
"One BERT encoder consists of an embedding network and multiple transformer blocks, and each transformer block contains an attention layer and a feedforward layer."
|
227 |
+
]
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"cell_type": "markdown",
|
231 |
+
"metadata": {
|
232 |
+
"id": "rmwQfhj6fmKz"
|
233 |
+
},
|
234 |
+
"source": [
|
235 |
+
"We provide easy ways to customize each of those components via (1)\n",
|
236 |
+
"[EncoderScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py) and (2) [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py)."
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "markdown",
|
241 |
+
"metadata": {
|
242 |
+
"id": "xsMgEVHAui11"
|
243 |
+
},
|
244 |
+
"source": [
|
245 |
+
"### Use EncoderScaffold\n",
|
246 |
+
"\n",
|
247 |
+
"`networks.EncoderScaffold` allows users to provide a custom embedding subnetwork\n",
|
248 |
+
" (which will replace the standard embedding logic) and/or a custom hidden layer class (which will replace the `Transformer` instantiation in the encoder)."
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"cell_type": "markdown",
|
253 |
+
"metadata": {
|
254 |
+
"id": "-JBabpa2AOz8"
|
255 |
+
},
|
256 |
+
"source": [
|
257 |
+
"#### Without Customization\n",
|
258 |
+
"\n",
|
259 |
+
"Without any customization, `networks.EncoderScaffold` behaves the same the canonical `networks.BertEncoder`.\n",
|
260 |
+
"\n",
|
261 |
+
"As shown in the following example, `networks.EncoderScaffold` can load `networks.BertEncoder`'s weights and output the same values:"
|
262 |
+
]
|
263 |
+
},
|
264 |
+
{
|
265 |
+
"cell_type": "code",
|
266 |
+
"execution_count": null,
|
267 |
+
"metadata": {
|
268 |
+
"id": "ktNzKuVByZQf"
|
269 |
+
},
|
270 |
+
"outputs": [],
|
271 |
+
"source": [
|
272 |
+
"default_hidden_cfg = dict(\n",
|
273 |
+
" num_attention_heads=cfg[\"num_attention_heads\"],\n",
|
274 |
+
" intermediate_size=cfg[\"intermediate_size\"],\n",
|
275 |
+
" intermediate_activation=cfg[\"activation\"],\n",
|
276 |
+
" dropout_rate=cfg[\"dropout_rate\"],\n",
|
277 |
+
" attention_dropout_rate=cfg[\"attention_dropout_rate\"],\n",
|
278 |
+
" kernel_initializer=cfg[\"initializer\"],\n",
|
279 |
+
")\n",
|
280 |
+
"default_embedding_cfg = dict(\n",
|
281 |
+
" vocab_size=cfg[\"vocab_size\"],\n",
|
282 |
+
" type_vocab_size=cfg[\"type_vocab_size\"],\n",
|
283 |
+
" hidden_size=cfg[\"hidden_size\"],\n",
|
284 |
+
" initializer=cfg[\"initializer\"],\n",
|
285 |
+
" dropout_rate=cfg[\"dropout_rate\"],\n",
|
286 |
+
" max_seq_length=cfg[\"max_sequence_length\"]\n",
|
287 |
+
")\n",
|
288 |
+
"default_kwargs = dict(\n",
|
289 |
+
" hidden_cfg=default_hidden_cfg,\n",
|
290 |
+
" embedding_cfg=default_embedding_cfg,\n",
|
291 |
+
" num_hidden_instances=cfg[\"num_layers\"],\n",
|
292 |
+
" pooled_output_dim=cfg[\"hidden_size\"],\n",
|
293 |
+
" return_all_layer_outputs=True,\n",
|
294 |
+
" pooler_layer_initializer=cfg[\"initializer\"],\n",
|
295 |
+
")\n",
|
296 |
+
"\n",
|
297 |
+
"encoder_scaffold = nlp.networks.EncoderScaffold(**default_kwargs)\n",
|
298 |
+
"classifier_model_from_encoder_scaffold = build_classifier(encoder_scaffold)\n",
|
299 |
+
"classifier_model_from_encoder_scaffold.set_weights(\n",
|
300 |
+
" canonical_classifier_model.get_weights())\n",
|
301 |
+
"predict(classifier_model_from_encoder_scaffold)"
|
302 |
+
]
|
303 |
+
},
|
304 |
+
{
|
305 |
+
"cell_type": "markdown",
|
306 |
+
"metadata": {
|
307 |
+
"id": "sMaUmLyIuwcs"
|
308 |
+
},
|
309 |
+
"source": [
|
310 |
+
"#### Customize Embedding\n",
|
311 |
+
"\n",
|
312 |
+
"Next, we show how to use a customized embedding network.\n",
|
313 |
+
"\n",
|
314 |
+
"We first build an embedding network that would replace the default network. This one will have 2 inputs (`mask` and `word_ids`) instead of 3, and won't use positional embeddings."
|
315 |
+
]
|
316 |
+
},
|
317 |
+
{
|
318 |
+
"cell_type": "code",
|
319 |
+
"execution_count": null,
|
320 |
+
"metadata": {
|
321 |
+
"id": "LTinnaG6vcsw"
|
322 |
+
},
|
323 |
+
"outputs": [],
|
324 |
+
"source": [
|
325 |
+
"word_ids = tf.keras.layers.Input(\n",
|
326 |
+
" shape=(cfg['max_sequence_length'],), dtype=tf.int32, name=\"input_word_ids\")\n",
|
327 |
+
"mask = tf.keras.layers.Input(\n",
|
328 |
+
" shape=(cfg['max_sequence_length'],), dtype=tf.int32, name=\"input_mask\")\n",
|
329 |
+
"embedding_layer = nlp.layers.OnDeviceEmbedding(\n",
|
330 |
+
" vocab_size=cfg['vocab_size'],\n",
|
331 |
+
" embedding_width=cfg['hidden_size'],\n",
|
332 |
+
" initializer=cfg[\"initializer\"],\n",
|
333 |
+
" name=\"word_embeddings\")\n",
|
334 |
+
"word_embeddings = embedding_layer(word_ids)\n",
|
335 |
+
"attention_mask = nlp.layers.SelfAttentionMask()([word_embeddings, mask])\n",
|
336 |
+
"new_embedding_network = tf.keras.Model([word_ids, mask],\n",
|
337 |
+
" [word_embeddings, attention_mask])"
|
338 |
+
]
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "markdown",
|
342 |
+
"metadata": {
|
343 |
+
"id": "HN7_yu-6O3qI"
|
344 |
+
},
|
345 |
+
"source": [
|
346 |
+
"Inspecting `new_embedding_network`, we can see it takes two inputs:\n",
|
347 |
+
"`input_word_ids` and `input_mask`."
|
348 |
+
]
|
349 |
+
},
|
350 |
+
{
|
351 |
+
"cell_type": "code",
|
352 |
+
"execution_count": null,
|
353 |
+
"metadata": {
|
354 |
+
"id": "fO9zKFE4OpHp"
|
355 |
+
},
|
356 |
+
"outputs": [],
|
357 |
+
"source": [
|
358 |
+
"tf.keras.utils.plot_model(new_embedding_network, show_shapes=True, dpi=48)"
|
359 |
+
]
|
360 |
+
},
|
361 |
+
{
|
362 |
+
"cell_type": "markdown",
|
363 |
+
"metadata": {
|
364 |
+
"id": "9cOaGQHLv12W"
|
365 |
+
},
|
366 |
+
"source": [
|
367 |
+
"We can then build a new encoder using the above `new_embedding_network`."
|
368 |
+
]
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"cell_type": "code",
|
372 |
+
"execution_count": null,
|
373 |
+
"metadata": {
|
374 |
+
"id": "mtFDMNf2vIl9"
|
375 |
+
},
|
376 |
+
"outputs": [],
|
377 |
+
"source": [
|
378 |
+
"kwargs = dict(default_kwargs)\n",
|
379 |
+
"\n",
|
380 |
+
"# Use new embedding network.\n",
|
381 |
+
"kwargs['embedding_cls'] = new_embedding_network\n",
|
382 |
+
"kwargs['embedding_data'] = embedding_layer.embeddings\n",
|
383 |
+
"\n",
|
384 |
+
"encoder_with_customized_embedding = nlp.networks.EncoderScaffold(**kwargs)\n",
|
385 |
+
"classifier_model = build_classifier(encoder_with_customized_embedding)\n",
|
386 |
+
"# ... Train the model ...\n",
|
387 |
+
"print(classifier_model.inputs)\n",
|
388 |
+
"\n",
|
389 |
+
"# Assert that there are only two inputs.\n",
|
390 |
+
"assert len(classifier_model.inputs) == 2"
|
391 |
+
]
|
392 |
+
},
|
393 |
+
{
|
394 |
+
"cell_type": "markdown",
|
395 |
+
"metadata": {
|
396 |
+
"id": "Z73ZQDtmwg9K"
|
397 |
+
},
|
398 |
+
"source": [
|
399 |
+
"#### Customized Transformer\n",
|
400 |
+
"\n",
|
401 |
+
"Users can also override the `hidden_cls` argument in `networks.EncoderScaffold`'s constructor employ a customized Transformer layer.\n",
|
402 |
+
"\n",
|
403 |
+
"See [the source of `nlp.layers.ReZeroTransformer`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/rezero_transformer.py) for how to implement a customized Transformer layer.\n",
|
404 |
+
"\n",
|
405 |
+
"The following is an example of using `nlp.layers.ReZeroTransformer`:\n"
|
406 |
+
]
|
407 |
+
},
|
408 |
+
{
|
409 |
+
"cell_type": "code",
|
410 |
+
"execution_count": null,
|
411 |
+
"metadata": {
|
412 |
+
"id": "uAIarLZgw6pA"
|
413 |
+
},
|
414 |
+
"outputs": [],
|
415 |
+
"source": [
|
416 |
+
"kwargs = dict(default_kwargs)\n",
|
417 |
+
"\n",
|
418 |
+
"# Use ReZeroTransformer.\n",
|
419 |
+
"kwargs['hidden_cls'] = nlp.layers.ReZeroTransformer\n",
|
420 |
+
"\n",
|
421 |
+
"encoder_with_rezero_transformer = nlp.networks.EncoderScaffold(**kwargs)\n",
|
422 |
+
"classifier_model = build_classifier(encoder_with_rezero_transformer)\n",
|
423 |
+
"# ... Train the model ...\n",
|
424 |
+
"predict(classifier_model)\n",
|
425 |
+
"\n",
|
426 |
+
"# Assert that the variable `rezero_alpha` from ReZeroTransformer exists.\n",
|
427 |
+
"assert 'rezero_alpha' in ''.join([x.name for x in classifier_model.trainable_weights])"
|
428 |
+
]
|
429 |
+
},
|
430 |
+
{
|
431 |
+
"cell_type": "markdown",
|
432 |
+
"metadata": {
|
433 |
+
"id": "6PMHFdvnxvR0"
|
434 |
+
},
|
435 |
+
"source": [
|
436 |
+
"### Use `nlp.layers.TransformerScaffold`\n",
|
437 |
+
"\n",
|
438 |
+
"The above method of customizing the model requires rewriting the whole `nlp.layers.Transformer` layer, while sometimes you may only want to customize either attention layer or feedforward block. In this case, `nlp.layers.TransformerScaffold` can be used.\n"
|
439 |
+
]
|
440 |
+
},
|
441 |
+
{
|
442 |
+
"cell_type": "markdown",
|
443 |
+
"metadata": {
|
444 |
+
"id": "D6FejlgwyAy_"
|
445 |
+
},
|
446 |
+
"source": [
|
447 |
+
"#### Customize Attention Layer\n",
|
448 |
+
"\n",
|
449 |
+
"User can also override the `attention_cls` argument in `layers.TransformerScaffold`'s constructor to employ a customized Attention layer.\n",
|
450 |
+
"\n",
|
451 |
+
"See [the source of `nlp.layers.TalkingHeadsAttention`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py) for how to implement a customized `Attention` layer.\n",
|
452 |
+
"\n",
|
453 |
+
"Following is an example of using `nlp.layers.TalkingHeadsAttention`:"
|
454 |
+
]
|
455 |
+
},
|
456 |
+
{
|
457 |
+
"cell_type": "code",
|
458 |
+
"execution_count": null,
|
459 |
+
"metadata": {
|
460 |
+
"id": "nFrSMrZuyNeQ"
|
461 |
+
},
|
462 |
+
"outputs": [],
|
463 |
+
"source": [
|
464 |
+
"# Use TalkingHeadsAttention\n",
|
465 |
+
"hidden_cfg = dict(default_hidden_cfg)\n",
|
466 |
+
"hidden_cfg['attention_cls'] = nlp.layers.TalkingHeadsAttention\n",
|
467 |
+
"\n",
|
468 |
+
"kwargs = dict(default_kwargs)\n",
|
469 |
+
"kwargs['hidden_cls'] = nlp.layers.TransformerScaffold\n",
|
470 |
+
"kwargs['hidden_cfg'] = hidden_cfg\n",
|
471 |
+
"\n",
|
472 |
+
"encoder = nlp.networks.EncoderScaffold(**kwargs)\n",
|
473 |
+
"classifier_model = build_classifier(encoder)\n",
|
474 |
+
"# ... Train the model ...\n",
|
475 |
+
"predict(classifier_model)\n",
|
476 |
+
"\n",
|
477 |
+
"# Assert that the variable `pre_softmax_weight` from TalkingHeadsAttention exists.\n",
|
478 |
+
"assert 'pre_softmax_weight' in ''.join([x.name for x in classifier_model.trainable_weights])"
|
479 |
+
]
|
480 |
+
},
|
481 |
+
{
|
482 |
+
"cell_type": "code",
|
483 |
+
"execution_count": null,
|
484 |
+
"metadata": {
|
485 |
+
"id": "tKkZ8spzYmpc"
|
486 |
+
},
|
487 |
+
"outputs": [],
|
488 |
+
"source": [
|
489 |
+
"tf.keras.utils.plot_model(encoder_with_rezero_transformer, show_shapes=True, dpi=48)"
|
490 |
+
]
|
491 |
+
},
|
492 |
+
{
|
493 |
+
"cell_type": "markdown",
|
494 |
+
"metadata": {
|
495 |
+
"id": "kuEJcTyByVvI"
|
496 |
+
},
|
497 |
+
"source": [
|
498 |
+
"#### Customize Feedforward Layer\n",
|
499 |
+
"\n",
|
500 |
+
"Similiarly, one could also customize the feedforward layer.\n",
|
501 |
+
"\n",
|
502 |
+
"See [the source of `nlp.layers.GatedFeedforward`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/gated_feedforward.py) for how to implement a customized feedforward layer.\n",
|
503 |
+
"\n",
|
504 |
+
"Following is an example of using `nlp.layers.GatedFeedforward`:"
|
505 |
+
]
|
506 |
+
},
|
507 |
+
{
|
508 |
+
"cell_type": "code",
|
509 |
+
"execution_count": null,
|
510 |
+
"metadata": {
|
511 |
+
"id": "XAbKy_l4y_-i"
|
512 |
+
},
|
513 |
+
"outputs": [],
|
514 |
+
"source": [
|
515 |
+
"# Use GatedFeedforward\n",
|
516 |
+
"hidden_cfg = dict(default_hidden_cfg)\n",
|
517 |
+
"hidden_cfg['feedforward_cls'] = nlp.layers.GatedFeedforward\n",
|
518 |
+
"\n",
|
519 |
+
"kwargs = dict(default_kwargs)\n",
|
520 |
+
"kwargs['hidden_cls'] = nlp.layers.TransformerScaffold\n",
|
521 |
+
"kwargs['hidden_cfg'] = hidden_cfg\n",
|
522 |
+
"\n",
|
523 |
+
"encoder_with_gated_feedforward = nlp.networks.EncoderScaffold(**kwargs)\n",
|
524 |
+
"classifier_model = build_classifier(encoder_with_gated_feedforward)\n",
|
525 |
+
"# ... Train the model ...\n",
|
526 |
+
"predict(classifier_model)\n",
|
527 |
+
"\n",
|
528 |
+
"# Assert that the variable `gate` from GatedFeedforward exists.\n",
|
529 |
+
"assert 'gate' in ''.join([x.name for x in classifier_model.trainable_weights])"
|
530 |
+
]
|
531 |
+
},
|
532 |
+
{
|
533 |
+
"cell_type": "markdown",
|
534 |
+
"metadata": {
|
535 |
+
"id": "a_8NWUhkzeAq"
|
536 |
+
},
|
537 |
+
"source": [
|
538 |
+
"### Build a new Encoder\n",
|
539 |
+
"\n",
|
540 |
+
"Finally, you could also build a new encoder using building blocks in the modeling library.\n",
|
541 |
+
"\n",
|
542 |
+
"See [the source for `nlp.networks.AlbertEncoder`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/albert_encoder.py) as an example of how to do this. \n",
|
543 |
+
"\n",
|
544 |
+
"Here is an example using `nlp.networks.AlbertEncoder`:\n"
|
545 |
+
]
|
546 |
+
},
|
547 |
+
{
|
548 |
+
"cell_type": "code",
|
549 |
+
"execution_count": null,
|
550 |
+
"metadata": {
|
551 |
+
"id": "xsiA3RzUzmUM"
|
552 |
+
},
|
553 |
+
"outputs": [],
|
554 |
+
"source": [
|
555 |
+
"albert_encoder = nlp.networks.AlbertEncoder(**cfg)\n",
|
556 |
+
"classifier_model = build_classifier(albert_encoder)\n",
|
557 |
+
"# ... Train the model ...\n",
|
558 |
+
"predict(classifier_model)"
|
559 |
+
]
|
560 |
+
},
|
561 |
+
{
|
562 |
+
"cell_type": "markdown",
|
563 |
+
"metadata": {
|
564 |
+
"id": "MeidDfhlHKSO"
|
565 |
+
},
|
566 |
+
"source": [
|
567 |
+
"Inspecting the `albert_encoder`, we see it stacks the same `Transformer` layer multiple times (note the loop-back on the \"Transformer\" block below.."
|
568 |
+
]
|
569 |
+
},
|
570 |
+
{
|
571 |
+
"cell_type": "code",
|
572 |
+
"execution_count": null,
|
573 |
+
"metadata": {
|
574 |
+
"id": "Uv_juT22HERW"
|
575 |
+
},
|
576 |
+
"outputs": [],
|
577 |
+
"source": [
|
578 |
+
"tf.keras.utils.plot_model(albert_encoder, show_shapes=True, dpi=48)"
|
579 |
+
]
|
580 |
+
}
|
581 |
+
],
|
582 |
+
"metadata": {
|
583 |
+
"colab": {
|
584 |
+
"collapsed_sections": [],
|
585 |
+
"name": "customize_encoder.ipynb",
|
586 |
+
"provenance": [],
|
587 |
+
"toc_visible": true
|
588 |
+
},
|
589 |
+
"kernelspec": {
|
590 |
+
"display_name": "Python 3",
|
591 |
+
"name": "python3"
|
592 |
+
}
|
593 |
+
},
|
594 |
+
"nbformat": 4,
|
595 |
+
"nbformat_minor": 0
|
596 |
+
}
|
Tensorflow/models/docs/nlp/decoding_api.ipynb
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "vXLA5InzXydn"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"##### Copyright 2021 The TensorFlow Authors."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"cellView": "form",
|
17 |
+
"id": "RuRlpLL-X0R_"
|
18 |
+
},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
22 |
+
"# you may not use this file except in compliance with the License.\n",
|
23 |
+
"# You may obtain a copy of the License at\n",
|
24 |
+
"#\n",
|
25 |
+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
26 |
+
"#\n",
|
27 |
+
"# Unless required by applicable law or agreed to in writing, software\n",
|
28 |
+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
29 |
+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
30 |
+
"# See the License for the specific language governing permissions and\n",
|
31 |
+
"# limitations under the License."
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"metadata": {
|
37 |
+
"id": "2X-XaMSVcLua"
|
38 |
+
},
|
39 |
+
"source": [
|
40 |
+
"# Decoding API"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"metadata": {
|
46 |
+
"id": "hYEwGTeCXnnX"
|
47 |
+
},
|
48 |
+
"source": [
|
49 |
+
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
|
50 |
+
" \u003ctd\u003e\n",
|
51 |
+
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/decoding_api\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
|
52 |
+
" \u003c/td\u003e\n",
|
53 |
+
" \u003ctd\u003e\n",
|
54 |
+
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/decoding_api.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
|
55 |
+
" \u003c/td\u003e\n",
|
56 |
+
" \u003ctd\u003e\n",
|
57 |
+
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/decoding_api.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
|
58 |
+
" \u003c/td\u003e\n",
|
59 |
+
" \u003ctd\u003e\n",
|
60 |
+
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/decoding_api.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
|
61 |
+
" \u003c/td\u003e\n",
|
62 |
+
"\u003c/table\u003e"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "markdown",
|
67 |
+
"metadata": {
|
68 |
+
"id": "fsACVQpVSifi"
|
69 |
+
},
|
70 |
+
"source": [
|
71 |
+
"### Install the TensorFlow Model Garden pip package\n",
|
72 |
+
"\n",
|
73 |
+
"* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
|
74 |
+
"which is the nightly Model Garden package created daily automatically.\n",
|
75 |
+
"* pip will install all models and dependencies automatically."
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": null,
|
81 |
+
"metadata": {
|
82 |
+
"id": "G4BhAu01HZcM"
|
83 |
+
},
|
84 |
+
"outputs": [],
|
85 |
+
"source": [
|
86 |
+
"!pip uninstall -y opencv-python"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "code",
|
91 |
+
"execution_count": null,
|
92 |
+
"metadata": {
|
93 |
+
"id": "2j-xhrsVQOQT"
|
94 |
+
},
|
95 |
+
"outputs": [],
|
96 |
+
"source": [
|
97 |
+
"!pip install tf-models-official"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": null,
|
103 |
+
"metadata": {
|
104 |
+
"id": "BjP7zwxmskpY"
|
105 |
+
},
|
106 |
+
"outputs": [],
|
107 |
+
"source": [
|
108 |
+
"import os\n",
|
109 |
+
"\n",
|
110 |
+
"import numpy as np\n",
|
111 |
+
"import matplotlib.pyplot as plt\n",
|
112 |
+
"\n",
|
113 |
+
"import tensorflow as tf\n",
|
114 |
+
"\n",
|
115 |
+
"from tensorflow_models import nlp"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": null,
|
121 |
+
"metadata": {
|
122 |
+
"id": "T92ccAzlnGqh"
|
123 |
+
},
|
124 |
+
"outputs": [],
|
125 |
+
"source": [
|
126 |
+
"def length_norm(length, dtype):\n",
|
127 |
+
" \"\"\"Return length normalization factor.\"\"\"\n",
|
128 |
+
" return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "markdown",
|
133 |
+
"metadata": {
|
134 |
+
"id": "0AWgyo-IQ5sP"
|
135 |
+
},
|
136 |
+
"source": [
|
137 |
+
"## Overview\n",
|
138 |
+
"\n",
|
139 |
+
"This API provides an interface to experiment with different decoding strategies used for auto-regressive models.\n",
|
140 |
+
"\n",
|
141 |
+
"1. The following sampling strategies are provided in sampling_module.py, which inherits from the base Decoding class:\n",
|
142 |
+
" * [top_p](https://arxiv.org/abs/1904.09751) : [github](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/sampling_module.py#L65) \n",
|
143 |
+
"\n",
|
144 |
+
" This implementation chooses the most probable logits with cumulative probabilities up to top_p.\n",
|
145 |
+
"\n",
|
146 |
+
" * [top_k](https://arxiv.org/pdf/1805.04833.pdf) : [github](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/sampling_module.py#L48)\n",
|
147 |
+
"\n",
|
148 |
+
" At each timestep, this implementation samples from top-k logits based on their probability distribution\n",
|
149 |
+
"\n",
|
150 |
+
" * Greedy : [github](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/sampling_module.py#L26)\n",
|
151 |
+
"\n",
|
152 |
+
" This implementation returns the top logits based on probabilities.\n",
|
153 |
+
"\n",
|
154 |
+
"2. Beam search is provided in beam_search.py. [github](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/beam_search.py)\n",
|
155 |
+
"\n",
|
156 |
+
" This implementation reduces the risk of missing hidden high probability logits by keeping the most likely num_beams of logits at each time step and eventually choosing the logits that has the overall highest probability."
|
157 |
+
]
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"cell_type": "markdown",
|
161 |
+
"metadata": {
|
162 |
+
"id": "MfOj7oaBRQnS"
|
163 |
+
},
|
164 |
+
"source": [
|
165 |
+
"## Initialize Sampling Module in TF-NLP.\n",
|
166 |
+
"\n",
|
167 |
+
"\n",
|
168 |
+
"\u003e **symbols_to_logits_fn** : This is a closure implemented by the users of the API. The input to this closure will be \n",
|
169 |
+
"```\n",
|
170 |
+
"Args:\n",
|
171 |
+
" 1] ids [batch_size, .. (index + 1 or 1 if padded_decode is True)],\n",
|
172 |
+
" 2] index [scalar] : current decoded step,\n",
|
173 |
+
" 3] cache [nested dictionary of tensors].\n",
|
174 |
+
"Returns:\n",
|
175 |
+
" 1] tensor for next-step logits [batch_size, vocab]\n",
|
176 |
+
" 2] the updated_cache [nested dictionary of tensors].\n",
|
177 |
+
"```\n",
|
178 |
+
"This closure calls the model to predict the logits for the 'index+1' step. The cache is used for faster decoding.\n",
|
179 |
+
"Here is a [reference](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/beam_search_test.py#L88) implementation for the above closure.\n",
|
180 |
+
"\n",
|
181 |
+
"\n",
|
182 |
+
"\u003e **length_normalization_fn** : Closure for returning length normalization parameter.\n",
|
183 |
+
"```\n",
|
184 |
+
"Args: \n",
|
185 |
+
" 1] length : scalar for decoded step index.\n",
|
186 |
+
" 2] dtype : data-type of output tensor\n",
|
187 |
+
"Returns:\n",
|
188 |
+
" 1] value of length normalization factor.\n",
|
189 |
+
"Example :\n",
|
190 |
+
" def _length_norm(length, dtype):\n",
|
191 |
+
" return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)\n",
|
192 |
+
"```\n",
|
193 |
+
"\n",
|
194 |
+
"\u003e **vocab_size** : Output vocabulary size.\n",
|
195 |
+
"\n",
|
196 |
+
"\u003e **max_decode_length** : Scalar for total number of decoding steps.\n",
|
197 |
+
"\n",
|
198 |
+
"\u003e **eos_id** : Decoding will stop if all output decoded ids in the batch have this ID.\n",
|
199 |
+
"\n",
|
200 |
+
"\u003e **padded_decode** : Set this to True if running on TPU. Tensors are padded to max_decoding_length if this is True.\n",
|
201 |
+
"\n",
|
202 |
+
"\u003e **top_k** : top_k is enabled if this value is \u003e 1.\n",
|
203 |
+
"\n",
|
204 |
+
"\u003e **top_p** : top_p is enabled if this value is \u003e 0 and \u003c 1.0\n",
|
205 |
+
"\n",
|
206 |
+
"\u003e **sampling_temperature** : This is used to re-estimate the softmax output. Temperature skews the distribution towards high-probability tokens and lowers the mass in the tail distribution. Value has to be positive. Low temperature is equivalent to greedy and makes the distribution sharper, while high temperature makes it flatter.\n",
|
207 |
+
"\n",
|
208 |
+
"\u003e **enable_greedy** : By default, this is true and greedy decoding is enabled.\n"
|
209 |
+
]
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"cell_type": "markdown",
|
213 |
+
"metadata": {
|
214 |
+
"id": "lV1RRp6ihnGX"
|
215 |
+
},
|
216 |
+
"source": [
|
217 |
+
"## Initialize the Model Hyper-parameters"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"cell_type": "code",
|
222 |
+
"execution_count": null,
|
223 |
+
"metadata": {
|
224 |
+
"id": "eTsGp2gaKLdE"
|
225 |
+
},
|
226 |
+
"outputs": [],
|
227 |
+
"source": [
|
228 |
+
"params = {\n",
|
229 |
+
" 'num_heads': 2,\n",
|
230 |
+
" 'num_layers': 2,\n",
|
231 |
+
" 'batch_size': 2,\n",
|
232 |
+
" 'n_dims': 256,\n",
|
233 |
+
" 'max_decode_length': 4}"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "markdown",
|
238 |
+
"metadata": {
|
239 |
+
"id": "CYXkoplAij01"
|
240 |
+
},
|
241 |
+
"source": [
|
242 |
+
"## Initialize cache. "
|
243 |
+
]
|
244 |
+
},
|
245 |
+
{
|
246 |
+
"cell_type": "markdown",
|
247 |
+
"metadata": {
|
248 |
+
"id": "UGvmd0_dRFYI"
|
249 |
+
},
|
250 |
+
"source": [
|
251 |
+
"In auto-regressive architectures like Transformer based [Encoder-Decoder](https://arxiv.org/abs/1706.03762) models, \n",
|
252 |
+
"Cache is used for fast sequential decoding.\n",
|
253 |
+
"It is a nested dictionary storing pre-computed hidden-states (key and values in the self-attention blocks and the cross-attention blocks) for every layer."
|
254 |
+
]
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"cell_type": "code",
|
258 |
+
"execution_count": null,
|
259 |
+
"metadata": {
|
260 |
+
"id": "D6kfZOOKgkm1"
|
261 |
+
},
|
262 |
+
"outputs": [],
|
263 |
+
"source": [
|
264 |
+
"cache = {\n",
|
265 |
+
" 'layer_%d' % layer: {\n",
|
266 |
+
" 'k': tf.zeros(\n",
|
267 |
+
" shape=[params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims'] // params['num_heads']],\n",
|
268 |
+
" dtype=tf.float32),\n",
|
269 |
+
" 'v': tf.zeros(\n",
|
270 |
+
" shape=[params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims'] // params['num_heads']],\n",
|
271 |
+
" dtype=tf.float32)\n",
|
272 |
+
" } for layer in range(params['num_layers'])\n",
|
273 |
+
" }\n",
|
274 |
+
"print(\"cache value shape for layer 1 :\", cache['layer_1']['k'].shape)"
|
275 |
+
]
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"cell_type": "markdown",
|
279 |
+
"metadata": {
|
280 |
+
"id": "syl7I5nURPgW"
|
281 |
+
},
|
282 |
+
"source": [
|
283 |
+
"### Create model_fn\n",
|
284 |
+
" In practice, this will be replaced by an actual model implementation such as [here](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer.py#L236)\n",
|
285 |
+
"```\n",
|
286 |
+
"Args:\n",
|
287 |
+
"i : Step that is being decoded.\n",
|
288 |
+
"Returns:\n",
|
289 |
+
" logit probabilities of size [batch_size, 1, vocab_size]\n",
|
290 |
+
"```\n"
|
291 |
+
]
|
292 |
+
},
|
293 |
+
{
|
294 |
+
"cell_type": "code",
|
295 |
+
"execution_count": null,
|
296 |
+
"metadata": {
|
297 |
+
"id": "AhzSkRisRdB6"
|
298 |
+
},
|
299 |
+
"outputs": [],
|
300 |
+
"source": [
|
301 |
+
"probabilities = tf.constant([[[0.3, 0.4, 0.3], [0.3, 0.3, 0.4],\n",
|
302 |
+
" [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],\n",
|
303 |
+
" [[0.2, 0.5, 0.3], [0.2, 0.7, 0.1],\n",
|
304 |
+
" [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]])\n",
|
305 |
+
"def model_fn(i):\n",
|
306 |
+
" return probabilities[:, i, :]"
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": null,
|
312 |
+
"metadata": {
|
313 |
+
"id": "FAJ4CpbfVdjr"
|
314 |
+
},
|
315 |
+
"outputs": [],
|
316 |
+
"source": [
|
317 |
+
"def _symbols_to_logits_fn():\n",
|
318 |
+
" \"\"\"Calculates logits of the next tokens.\"\"\"\n",
|
319 |
+
" def symbols_to_logits_fn(ids, i, temp_cache):\n",
|
320 |
+
" del ids\n",
|
321 |
+
" logits = tf.cast(tf.math.log(model_fn(i)), tf.float32)\n",
|
322 |
+
" return logits, temp_cache\n",
|
323 |
+
" return symbols_to_logits_fn"
|
324 |
+
]
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "markdown",
|
328 |
+
"metadata": {
|
329 |
+
"id": "R_tV3jyWVL47"
|
330 |
+
},
|
331 |
+
"source": [
|
332 |
+
"## Greedy \n",
|
333 |
+
"Greedy decoding selects the token id with the highest probability as its next id: $id_t = argmax_{w}P(id | id_{1:t-1})$ at each timestep $t$. The following sketch shows greedy decoding. "
|
334 |
+
]
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"cell_type": "code",
|
338 |
+
"execution_count": null,
|
339 |
+
"metadata": {
|
340 |
+
"id": "aGt9idSkVQEJ"
|
341 |
+
},
|
342 |
+
"outputs": [],
|
343 |
+
"source": [
|
344 |
+
"greedy_obj = sampling_module.SamplingModule(\n",
|
345 |
+
" length_normalization_fn=None,\n",
|
346 |
+
" dtype=tf.float32,\n",
|
347 |
+
" symbols_to_logits_fn=_symbols_to_logits_fn(),\n",
|
348 |
+
" vocab_size=3,\n",
|
349 |
+
" max_decode_length=params['max_decode_length'],\n",
|
350 |
+
" eos_id=10,\n",
|
351 |
+
" padded_decode=False)\n",
|
352 |
+
"ids, _ = greedy_obj.generate(\n",
|
353 |
+
" initial_ids=tf.constant([9, 1]), initial_cache=cache)\n",
|
354 |
+
"print(\"Greedy Decoded Ids:\", ids)"
|
355 |
+
]
|
356 |
+
},
|
357 |
+
{
|
358 |
+
"cell_type": "markdown",
|
359 |
+
"metadata": {
|
360 |
+
"id": "s4pTTsQXVz5O"
|
361 |
+
},
|
362 |
+
"source": [
|
363 |
+
"## top_k sampling\n",
|
364 |
+
"In *Top-K* sampling, the *K* most likely next token ids are filtered and the probability mass is redistributed among only those *K* ids. "
|
365 |
+
]
|
366 |
+
},
|
367 |
+
{
|
368 |
+
"cell_type": "code",
|
369 |
+
"execution_count": null,
|
370 |
+
"metadata": {
|
371 |
+
"id": "pCLWIn6GV5_G"
|
372 |
+
},
|
373 |
+
"outputs": [],
|
374 |
+
"source": [
|
375 |
+
"top_k_obj = sampling_module.SamplingModule(\n",
|
376 |
+
" length_normalization_fn=length_norm,\n",
|
377 |
+
" dtype=tf.float32,\n",
|
378 |
+
" symbols_to_logits_fn=_symbols_to_logits_fn(),\n",
|
379 |
+
" vocab_size=3,\n",
|
380 |
+
" max_decode_length=params['max_decode_length'],\n",
|
381 |
+
" eos_id=10,\n",
|
382 |
+
" sample_temperature=tf.constant(1.0),\n",
|
383 |
+
" top_k=tf.constant(3),\n",
|
384 |
+
" padded_decode=False,\n",
|
385 |
+
" enable_greedy=False)\n",
|
386 |
+
"ids, _ = top_k_obj.generate(\n",
|
387 |
+
" initial_ids=tf.constant([9, 1]), initial_cache=cache)\n",
|
388 |
+
"print(\"top-k sampled Ids:\", ids)"
|
389 |
+
]
|
390 |
+
},
|
391 |
+
{
|
392 |
+
"cell_type": "markdown",
|
393 |
+
"metadata": {
|
394 |
+
"id": "Jp3G-eE_WI4Y"
|
395 |
+
},
|
396 |
+
"source": [
|
397 |
+
"## top_p sampling\n",
|
398 |
+
"Instead of sampling only from the most likely *K* token ids, in *Top-p* sampling chooses from the smallest possible set of ids whose cumulative probability exceeds the probability *p*."
|
399 |
+
]
|
400 |
+
},
|
401 |
+
{
|
402 |
+
"cell_type": "code",
|
403 |
+
"execution_count": null,
|
404 |
+
"metadata": {
|
405 |
+
"id": "rEGdIWcuWILO"
|
406 |
+
},
|
407 |
+
"outputs": [],
|
408 |
+
"source": [
|
409 |
+
"top_p_obj = sampling_module.SamplingModule(\n",
|
410 |
+
" length_normalization_fn=length_norm,\n",
|
411 |
+
" dtype=tf.float32,\n",
|
412 |
+
" symbols_to_logits_fn=_symbols_to_logits_fn(),\n",
|
413 |
+
" vocab_size=3,\n",
|
414 |
+
" max_decode_length=params['max_decode_length'],\n",
|
415 |
+
" eos_id=10,\n",
|
416 |
+
" sample_temperature=tf.constant(1.0),\n",
|
417 |
+
" top_p=tf.constant(0.9),\n",
|
418 |
+
" padded_decode=False,\n",
|
419 |
+
" enable_greedy=False)\n",
|
420 |
+
"ids, _ = top_p_obj.generate(\n",
|
421 |
+
" initial_ids=tf.constant([9, 1]), initial_cache=cache)\n",
|
422 |
+
"print(\"top-p sampled Ids:\", ids)"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"cell_type": "markdown",
|
427 |
+
"metadata": {
|
428 |
+
"id": "2hcuyJ2VWjDz"
|
429 |
+
},
|
430 |
+
"source": [
|
431 |
+
"## Beam search decoding\n",
|
432 |
+
"Beam search reduces the risk of missing hidden high probability token ids by keeping the most likely num_beams of hypotheses at each time step and eventually choosing the hypothesis that has the overall highest probability. "
|
433 |
+
]
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"cell_type": "code",
|
437 |
+
"execution_count": null,
|
438 |
+
"metadata": {
|
439 |
+
"id": "cJ3WzvSrWmSA"
|
440 |
+
},
|
441 |
+
"outputs": [],
|
442 |
+
"source": [
|
443 |
+
"beam_size = 2\n",
|
444 |
+
"params['batch_size'] = 1\n",
|
445 |
+
"beam_cache = {\n",
|
446 |
+
" 'layer_%d' % layer: {\n",
|
447 |
+
" 'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32),\n",
|
448 |
+
" 'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32)\n",
|
449 |
+
" } for layer in range(params['num_layers'])\n",
|
450 |
+
" }\n",
|
451 |
+
"print(\"cache key shape for layer 1 :\", beam_cache['layer_1']['k'].shape)\n",
|
452 |
+
"ids, _ = beam_search.sequence_beam_search(\n",
|
453 |
+
" symbols_to_logits_fn=_symbols_to_logits_fn(),\n",
|
454 |
+
" initial_ids=tf.constant([9], tf.int32),\n",
|
455 |
+
" initial_cache=beam_cache,\n",
|
456 |
+
" vocab_size=3,\n",
|
457 |
+
" beam_size=beam_size,\n",
|
458 |
+
" alpha=0.6,\n",
|
459 |
+
" max_decode_length=params['max_decode_length'],\n",
|
460 |
+
" eos_id=10,\n",
|
461 |
+
" padded_decode=False,\n",
|
462 |
+
" dtype=tf.float32)\n",
|
463 |
+
"print(\"Beam search ids:\", ids)"
|
464 |
+
]
|
465 |
+
}
|
466 |
+
],
|
467 |
+
"metadata": {
|
468 |
+
"accelerator": "GPU",
|
469 |
+
"colab": {
|
470 |
+
"collapsed_sections": [],
|
471 |
+
"name": "decoding_api_in_tf_nlp.ipynb",
|
472 |
+
"provenance": [],
|
473 |
+
"toc_visible": true
|
474 |
+
},
|
475 |
+
"kernelspec": {
|
476 |
+
"display_name": "Python 3",
|
477 |
+
"name": "python3"
|
478 |
+
}
|
479 |
+
},
|
480 |
+
"nbformat": 4,
|
481 |
+
"nbformat_minor": 0
|
482 |
+
}
|
Tensorflow/models/docs/nlp/fine_tune_bert.ipynb
ADDED
@@ -0,0 +1,1550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "vXLA5InzXydn"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"##### Copyright 2019 The TensorFlow Authors."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"cellView": "form",
|
17 |
+
"id": "RuRlpLL-X0R_"
|
18 |
+
},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
22 |
+
"# you may not use this file except in compliance with the License.\n",
|
23 |
+
"# You may obtain a copy of the License at\n",
|
24 |
+
"#\n",
|
25 |
+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
26 |
+
"#\n",
|
27 |
+
"# Unless required by applicable law or agreed to in writing, software\n",
|
28 |
+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
29 |
+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
30 |
+
"# See the License for the specific language governing permissions and\n",
|
31 |
+
"# limitations under the License."
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"metadata": {
|
37 |
+
"id": "1mLJmVotXs64"
|
38 |
+
},
|
39 |
+
"source": [
|
40 |
+
"# Fine-tuning a BERT model"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"metadata": {
|
46 |
+
"id": "hYEwGTeCXnnX"
|
47 |
+
},
|
48 |
+
"source": [
|
49 |
+
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
|
50 |
+
" \u003ctd\u003e\n",
|
51 |
+
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/fine_tune_bert\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
|
52 |
+
" \u003c/td\u003e\n",
|
53 |
+
" \u003ctd\u003e\n",
|
54 |
+
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
|
55 |
+
" \u003c/td\u003e\n",
|
56 |
+
" \u003ctd\u003e\n",
|
57 |
+
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
|
58 |
+
" \u003c/td\u003e\n",
|
59 |
+
" \u003ctd\u003e\n",
|
60 |
+
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
|
61 |
+
" \u003c/td\u003e\n",
|
62 |
+
" \u003ctd\u003e\n",
|
63 |
+
" \u003ca href=\"https://tfhub.dev/google/collections/bert\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub model\u003c/a\u003e\n",
|
64 |
+
" \u003c/td\u003e\n",
|
65 |
+
"\u003c/table\u003e"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "markdown",
|
70 |
+
"metadata": {
|
71 |
+
"id": "YN2ACivEPxgD"
|
72 |
+
},
|
73 |
+
"source": [
|
74 |
+
"This tutorial demonstrates how to fine-tune a [Bidirectional Encoder Representations from Transformers (BERT)](https://arxiv.org/abs/1810.04805) (Devlin et al., 2018) model using [TensorFlow Model Garden](https://github.com/tensorflow/models).\n",
|
75 |
+
"\n",
|
76 |
+
"You can also find the pre-trained BERT model used in this tutorial on [TensorFlow Hub (TF Hub)](https://tensorflow.org/hub). For concrete examples of how to use the models from TF Hub, refer to the [Solve Glue tasks using BERT](https://www.tensorflow.org/text/tutorials/bert_glue) tutorial. If you're just trying to fine-tune a model, the TF Hub tutorial is a good starting point.\n",
|
77 |
+
"\n",
|
78 |
+
"On the other hand, if you're interested in deeper customization, follow this tutorial. It shows how to do a lot of things manually, so you can learn how you can customize the workflow from data preprocessing to training, exporting and saving the model."
|
79 |
+
]
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "markdown",
|
83 |
+
"metadata": {
|
84 |
+
"id": "s2d9S2CSSO1z"
|
85 |
+
},
|
86 |
+
"source": [
|
87 |
+
"## Setup"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "markdown",
|
92 |
+
"metadata": {
|
93 |
+
"id": "69de3375e32a"
|
94 |
+
},
|
95 |
+
"source": [
|
96 |
+
"### Install pip packages"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "markdown",
|
101 |
+
"metadata": {
|
102 |
+
"id": "fsACVQpVSifi"
|
103 |
+
},
|
104 |
+
"source": [
|
105 |
+
"Start by installing the TensorFlow Text and Model Garden pip packages.\n",
|
106 |
+
"\n",
|
107 |
+
"* `tf-models-official` is the TensorFlow Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` GitHub repo. To include the latest changes, you may install `tf-models-nightly`, which is the nightly Model Garden package created daily automatically.\n",
|
108 |
+
"* pip will install all models and dependencies automatically."
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": null,
|
114 |
+
"metadata": {
|
115 |
+
"id": "sE6XUxLOf1s-"
|
116 |
+
},
|
117 |
+
"outputs": [],
|
118 |
+
"source": [
|
119 |
+
"!pip install -q opencv-python"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"cell_type": "code",
|
124 |
+
"execution_count": null,
|
125 |
+
"metadata": {
|
126 |
+
"id": "yic2y7_o-BCC"
|
127 |
+
},
|
128 |
+
"outputs": [],
|
129 |
+
"source": [
|
130 |
+
"!pip install -q -U \"tensorflow-text==2.11.*\""
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": null,
|
136 |
+
"metadata": {
|
137 |
+
"id": "NvNr2svBM-p3"
|
138 |
+
},
|
139 |
+
"outputs": [],
|
140 |
+
"source": [
|
141 |
+
"!pip install -q tf-models-official"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "markdown",
|
146 |
+
"metadata": {
|
147 |
+
"id": "U-7qPCjWUAyy"
|
148 |
+
},
|
149 |
+
"source": [
|
150 |
+
"### Import libraries"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": null,
|
156 |
+
"metadata": {
|
157 |
+
"id": "lXsXev5MNr20"
|
158 |
+
},
|
159 |
+
"outputs": [],
|
160 |
+
"source": [
|
161 |
+
"import os\n",
|
162 |
+
"\n",
|
163 |
+
"import numpy as np\n",
|
164 |
+
"import matplotlib.pyplot as plt\n",
|
165 |
+
"\n",
|
166 |
+
"import tensorflow as tf\n",
|
167 |
+
"import tensorflow_models as tfm\n",
|
168 |
+
"import tensorflow_hub as hub\n",
|
169 |
+
"import tensorflow_datasets as tfds\n",
|
170 |
+
"tfds.disable_progress_bar()"
|
171 |
+
]
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"cell_type": "markdown",
|
175 |
+
"metadata": {
|
176 |
+
"id": "mbanlzTvJBsz"
|
177 |
+
},
|
178 |
+
"source": [
|
179 |
+
"### Resources"
|
180 |
+
]
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"cell_type": "markdown",
|
184 |
+
"metadata": {
|
185 |
+
"id": "PpW0x8TpR8DT"
|
186 |
+
},
|
187 |
+
"source": [
|
188 |
+
"The following directory contains the BERT model's configuration, vocabulary, and a pre-trained checkpoint used in this tutorial:"
|
189 |
+
]
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"cell_type": "code",
|
193 |
+
"execution_count": null,
|
194 |
+
"metadata": {
|
195 |
+
"id": "vzRHOLciR8eq"
|
196 |
+
},
|
197 |
+
"outputs": [],
|
198 |
+
"source": [
|
199 |
+
"gs_folder_bert = \"gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12\"\n",
|
200 |
+
"tf.io.gfile.listdir(gs_folder_bert)"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "markdown",
|
205 |
+
"metadata": {
|
206 |
+
"id": "Qv6abtRvH4xO"
|
207 |
+
},
|
208 |
+
"source": [
|
209 |
+
"## Load and preprocess the dataset\n",
|
210 |
+
"\n",
|
211 |
+
"This example uses the GLUE (General Language Understanding Evaluation) MRPC (Microsoft Research Paraphrase Corpus) [dataset from TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc).\n",
|
212 |
+
"\n",
|
213 |
+
"This dataset is not set up such that it can be directly fed into the BERT model. The following section handles the necessary preprocessing."
|
214 |
+
]
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"cell_type": "markdown",
|
218 |
+
"metadata": {
|
219 |
+
"id": "28DvUhC1YUiB"
|
220 |
+
},
|
221 |
+
"source": [
|
222 |
+
"### Get the dataset from TensorFlow Datasets\n",
|
223 |
+
"\n",
|
224 |
+
"The GLUE MRPC (Dolan and Brockett, 2005) dataset is a corpus of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent. It has the following attributes:\n",
|
225 |
+
"\n",
|
226 |
+
"* Number of labels: 2\n",
|
227 |
+
"* Size of training dataset: 3668\n",
|
228 |
+
"* Size of evaluation dataset: 408\n",
|
229 |
+
"* Maximum sequence length of training and evaluation dataset: 128\n",
|
230 |
+
"\n",
|
231 |
+
"Begin by loading the MRPC dataset from TFDS:"
|
232 |
+
]
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"cell_type": "code",
|
236 |
+
"execution_count": null,
|
237 |
+
"metadata": {
|
238 |
+
"id": "Ijikx5OsH9AT"
|
239 |
+
},
|
240 |
+
"outputs": [],
|
241 |
+
"source": [
|
242 |
+
"batch_size=32\n",
|
243 |
+
"glue, info = tfds.load('glue/mrpc',\n",
|
244 |
+
" with_info=True,\n",
|
245 |
+
" batch_size=32)"
|
246 |
+
]
|
247 |
+
},
|
248 |
+
{
|
249 |
+
"cell_type": "code",
|
250 |
+
"execution_count": null,
|
251 |
+
"metadata": {
|
252 |
+
"id": "QcMTJU4N7VX-"
|
253 |
+
},
|
254 |
+
"outputs": [],
|
255 |
+
"source": [
|
256 |
+
"glue"
|
257 |
+
]
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"cell_type": "markdown",
|
261 |
+
"metadata": {
|
262 |
+
"id": "ZgBg2r2nYT-K"
|
263 |
+
},
|
264 |
+
"source": [
|
265 |
+
"The `info` object describes the dataset and its features:"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "code",
|
270 |
+
"execution_count": null,
|
271 |
+
"metadata": {
|
272 |
+
"id": "IQrHxv7W7jH5"
|
273 |
+
},
|
274 |
+
"outputs": [],
|
275 |
+
"source": [
|
276 |
+
"info.features"
|
277 |
+
]
|
278 |
+
},
|
279 |
+
{
|
280 |
+
"cell_type": "markdown",
|
281 |
+
"metadata": {
|
282 |
+
"id": "vhsVWYNxazz5"
|
283 |
+
},
|
284 |
+
"source": [
|
285 |
+
"The two classes are:"
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"execution_count": null,
|
291 |
+
"metadata": {
|
292 |
+
"id": "n0gfc_VTayfQ"
|
293 |
+
},
|
294 |
+
"outputs": [],
|
295 |
+
"source": [
|
296 |
+
"info.features['label'].names"
|
297 |
+
]
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"cell_type": "markdown",
|
301 |
+
"metadata": {
|
302 |
+
"id": "38zJcap6xkbC"
|
303 |
+
},
|
304 |
+
"source": [
|
305 |
+
"Here is one example from the training set:"
|
306 |
+
]
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"cell_type": "code",
|
310 |
+
"execution_count": null,
|
311 |
+
"metadata": {
|
312 |
+
"id": "xON_i6SkwApW"
|
313 |
+
},
|
314 |
+
"outputs": [],
|
315 |
+
"source": [
|
316 |
+
"example_batch = next(iter(glue['train']))\n",
|
317 |
+
"\n",
|
318 |
+
"for key, value in example_batch.items():\n",
|
319 |
+
" print(f\"{key:9s}: {value[0].numpy()}\")"
|
320 |
+
]
|
321 |
+
},
|
322 |
+
{
|
323 |
+
"cell_type": "markdown",
|
324 |
+
"metadata": {
|
325 |
+
"id": "R9vEWgKA4SxV"
|
326 |
+
},
|
327 |
+
"source": [
|
328 |
+
"### Preprocess the data\n",
|
329 |
+
"\n",
|
330 |
+
"The keys `\"sentence1\"` and `\"sentence2\"` in the GLUE MRPC dataset contain two input sentences for each example.\n",
|
331 |
+
"\n",
|
332 |
+
"Because the BERT model from the Model Garden doesn't take raw text as input, two things need to happen first:\n",
|
333 |
+
"\n",
|
334 |
+
"1. The text needs to be _tokenized_ (split into word pieces) and converted to _indices_.\n",
|
335 |
+
"2. Then, the _indices_ need to be packed into the format that the model expects."
|
336 |
+
]
|
337 |
+
},
|
338 |
+
{
|
339 |
+
"cell_type": "markdown",
|
340 |
+
"metadata": {
|
341 |
+
"id": "9fbTyfJpNr7x"
|
342 |
+
},
|
343 |
+
"source": [
|
344 |
+
"#### The BERT tokenizer"
|
345 |
+
]
|
346 |
+
},
|
347 |
+
{
|
348 |
+
"cell_type": "markdown",
|
349 |
+
"metadata": {
|
350 |
+
"id": "wqeN54S61ZKQ"
|
351 |
+
},
|
352 |
+
"source": [
|
353 |
+
"To fine tune a pre-trained language model from the Model Garden, such as BERT, you need to make sure that you're using exactly the same tokenization, vocabulary, and index mapping as used during training.\n",
|
354 |
+
"\n",
|
355 |
+
"The following code rebuilds the tokenizer that was used by the base model using the Model Garden's `tfm.nlp.layers.FastWordpieceBertTokenizer` layer:"
|
356 |
+
]
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"cell_type": "code",
|
360 |
+
"execution_count": null,
|
361 |
+
"metadata": {
|
362 |
+
"id": "-DK4q5wEBmlB"
|
363 |
+
},
|
364 |
+
"outputs": [],
|
365 |
+
"source": [
|
366 |
+
"tokenizer = tfm.nlp.layers.FastWordpieceBertTokenizer(\n",
|
367 |
+
" vocab_file=os.path.join(gs_folder_bert, \"vocab.txt\"),\n",
|
368 |
+
" lower_case=True)"
|
369 |
+
]
|
370 |
+
},
|
371 |
+
{
|
372 |
+
"cell_type": "markdown",
|
373 |
+
"metadata": {
|
374 |
+
"id": "zYHDSquU2lDU"
|
375 |
+
},
|
376 |
+
"source": [
|
377 |
+
"Let's tokenize a test sentence:"
|
378 |
+
]
|
379 |
+
},
|
380 |
+
{
|
381 |
+
"cell_type": "code",
|
382 |
+
"execution_count": null,
|
383 |
+
"metadata": {
|
384 |
+
"id": "L_OfOYPg853R"
|
385 |
+
},
|
386 |
+
"outputs": [],
|
387 |
+
"source": [
|
388 |
+
"tokens = tokenizer(tf.constant([\"Hello TensorFlow!\"]))\n",
|
389 |
+
"tokens"
|
390 |
+
]
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"cell_type": "markdown",
|
394 |
+
"metadata": {
|
395 |
+
"id": "MfjaaMYy5Gt8"
|
396 |
+
},
|
397 |
+
"source": [
|
398 |
+
"Learn more about the tokenization process in the [Subword tokenization](https://www.tensorflow.org/text/guide/subwords_tokenizer) and [Tokenizing with TensorFlow Text](https://www.tensorflow.org/text/guide/tokenizers) guides."
|
399 |
+
]
|
400 |
+
},
|
401 |
+
{
|
402 |
+
"cell_type": "markdown",
|
403 |
+
"metadata": {
|
404 |
+
"id": "wd1b09OO5GJl"
|
405 |
+
},
|
406 |
+
"source": [
|
407 |
+
"#### Pack the inputs"
|
408 |
+
]
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"cell_type": "markdown",
|
412 |
+
"metadata": {
|
413 |
+
"id": "62UTWLQd9-LB"
|
414 |
+
},
|
415 |
+
"source": [
|
416 |
+
"TensorFlow Model Garden's BERT model doesn't just take the tokenized strings as input. It also expects these to be packed into a particular format. `tfm.nlp.layers.BertPackInputs` layer can handle the conversion from _a list of tokenized sentences_ to the input format expected by the Model Garden's BERT model.\n",
|
417 |
+
"\n",
|
418 |
+
"`tfm.nlp.layers.BertPackInputs` packs the two input sentences (per example in the MRCP dataset) concatenated together. This input is expected to start with a `[CLS]` \"This is a classification problem\" token, and each sentence should end with a `[SEP]` \"Separator\" token.\n",
|
419 |
+
"\n",
|
420 |
+
"Therefore, the `tfm.nlp.layers.BertPackInputs` layer's constructor takes the `tokenizer`'s special tokens as an argument. It also needs to know the indices of the tokenizer's special tokens."
|
421 |
+
]
|
422 |
+
},
|
423 |
+
{
|
424 |
+
"cell_type": "code",
|
425 |
+
"execution_count": null,
|
426 |
+
"metadata": {
|
427 |
+
"id": "5iroDlrFDRcF"
|
428 |
+
},
|
429 |
+
"outputs": [],
|
430 |
+
"source": [
|
431 |
+
"special = tokenizer.get_special_tokens_dict()\n",
|
432 |
+
"special"
|
433 |
+
]
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"cell_type": "code",
|
437 |
+
"execution_count": null,
|
438 |
+
"metadata": {
|
439 |
+
"id": "b71HarkuG92H"
|
440 |
+
},
|
441 |
+
"outputs": [],
|
442 |
+
"source": [
|
443 |
+
"max_seq_length = 128\n",
|
444 |
+
"\n",
|
445 |
+
"packer = tfm.nlp.layers.BertPackInputs(\n",
|
446 |
+
" seq_length=max_seq_length,\n",
|
447 |
+
" special_tokens_dict = tokenizer.get_special_tokens_dict())"
|
448 |
+
]
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"cell_type": "markdown",
|
452 |
+
"metadata": {
|
453 |
+
"id": "CZlSZbYd6liN"
|
454 |
+
},
|
455 |
+
"source": [
|
456 |
+
"The `packer` takes a list of tokenized sentences as input. For example:"
|
457 |
+
]
|
458 |
+
},
|
459 |
+
{
|
460 |
+
"cell_type": "code",
|
461 |
+
"execution_count": null,
|
462 |
+
"metadata": {
|
463 |
+
"id": "27dU_VkJHc9S"
|
464 |
+
},
|
465 |
+
"outputs": [],
|
466 |
+
"source": [
|
467 |
+
"sentences1 = [\"hello tensorflow\"]\n",
|
468 |
+
"tok1 = tokenizer(sentences1)\n",
|
469 |
+
"tok1"
|
470 |
+
]
|
471 |
+
},
|
472 |
+
{
|
473 |
+
"cell_type": "code",
|
474 |
+
"execution_count": null,
|
475 |
+
"metadata": {
|
476 |
+
"id": "LURHmNOSHnWN"
|
477 |
+
},
|
478 |
+
"outputs": [],
|
479 |
+
"source": [
|
480 |
+
"sentences2 = [\"goodbye tensorflow\"]\n",
|
481 |
+
"tok2 = tokenizer(sentences2)\n",
|
482 |
+
"tok2"
|
483 |
+
]
|
484 |
+
},
|
485 |
+
{
|
486 |
+
"cell_type": "markdown",
|
487 |
+
"metadata": {
|
488 |
+
"id": "r8bvB8gI8BqP"
|
489 |
+
},
|
490 |
+
"source": [
|
491 |
+
"Then, it returns a dictionary containing three outputs:\n",
|
492 |
+
"\n",
|
493 |
+
"- `input_word_ids`: The tokenized sentences packed together.\n",
|
494 |
+
"- `input_mask`: The mask indicating which locations are valid in the other outputs.\n",
|
495 |
+
"- `input_type_ids`: Indicating which sentence each token belongs to."
|
496 |
+
]
|
497 |
+
},
|
498 |
+
{
|
499 |
+
"cell_type": "code",
|
500 |
+
"execution_count": null,
|
501 |
+
"metadata": {
|
502 |
+
"id": "YsIDTOMJHrUQ"
|
503 |
+
},
|
504 |
+
"outputs": [],
|
505 |
+
"source": [
|
506 |
+
"packed = packer([tok1, tok2])\n",
|
507 |
+
"\n",
|
508 |
+
"for key, tensor in packed.items():\n",
|
509 |
+
" print(f\"{key:15s}: {tensor[:, :12]}\")"
|
510 |
+
]
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"cell_type": "markdown",
|
514 |
+
"metadata": {
|
515 |
+
"id": "red4tRcq74Qc"
|
516 |
+
},
|
517 |
+
"source": [
|
518 |
+
"#### Put it all together\n",
|
519 |
+
"\n",
|
520 |
+
"Combine these two parts into a `keras.layers.Layer` that can be attached to your model:"
|
521 |
+
]
|
522 |
+
},
|
523 |
+
{
|
524 |
+
"cell_type": "code",
|
525 |
+
"execution_count": null,
|
526 |
+
"metadata": {
|
527 |
+
"id": "9Qtz-tv-6nz6"
|
528 |
+
},
|
529 |
+
"outputs": [],
|
530 |
+
"source": [
|
531 |
+
"class BertInputProcessor(tf.keras.layers.Layer):\n",
|
532 |
+
" def __init__(self, tokenizer, packer):\n",
|
533 |
+
" super().__init__()\n",
|
534 |
+
" self.tokenizer = tokenizer\n",
|
535 |
+
" self.packer = packer\n",
|
536 |
+
"\n",
|
537 |
+
" def call(self, inputs):\n",
|
538 |
+
" tok1 = self.tokenizer(inputs['sentence1'])\n",
|
539 |
+
" tok2 = self.tokenizer(inputs['sentence2'])\n",
|
540 |
+
"\n",
|
541 |
+
" packed = self.packer([tok1, tok2])\n",
|
542 |
+
"\n",
|
543 |
+
" if 'label' in inputs:\n",
|
544 |
+
" return packed, inputs['label']\n",
|
545 |
+
" else:\n",
|
546 |
+
" return packed"
|
547 |
+
]
|
548 |
+
},
|
549 |
+
{
|
550 |
+
"cell_type": "markdown",
|
551 |
+
"metadata": {
|
552 |
+
"id": "rdy9wp499btU"
|
553 |
+
},
|
554 |
+
"source": [
|
555 |
+
"But for now just apply it to the dataset using `Dataset.map`, since the dataset you loaded from TFDS is a `tf.data.Dataset` object:"
|
556 |
+
]
|
557 |
+
},
|
558 |
+
{
|
559 |
+
"cell_type": "code",
|
560 |
+
"execution_count": null,
|
561 |
+
"metadata": {
|
562 |
+
"id": "qmyh76AL7VAs"
|
563 |
+
},
|
564 |
+
"outputs": [],
|
565 |
+
"source": [
|
566 |
+
"bert_inputs_processor = BertInputProcessor(tokenizer, packer)"
|
567 |
+
]
|
568 |
+
},
|
569 |
+
{
|
570 |
+
"cell_type": "code",
|
571 |
+
"execution_count": null,
|
572 |
+
"metadata": {
|
573 |
+
"id": "B8SSCtDe9MCk"
|
574 |
+
},
|
575 |
+
"outputs": [],
|
576 |
+
"source": [
|
577 |
+
"glue_train = glue['train'].map(bert_inputs_processor).prefetch(1)"
|
578 |
+
]
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"cell_type": "markdown",
|
582 |
+
"metadata": {
|
583 |
+
"id": "KXpiDosO9rkY"
|
584 |
+
},
|
585 |
+
"source": [
|
586 |
+
"Here is an example batch from the processed dataset:"
|
587 |
+
]
|
588 |
+
},
|
589 |
+
{
|
590 |
+
"cell_type": "code",
|
591 |
+
"execution_count": null,
|
592 |
+
"metadata": {
|
593 |
+
"id": "ffNvDE6t9rP-"
|
594 |
+
},
|
595 |
+
"outputs": [],
|
596 |
+
"source": [
|
597 |
+
"example_inputs, example_labels = next(iter(glue_train))"
|
598 |
+
]
|
599 |
+
},
|
600 |
+
{
|
601 |
+
"cell_type": "code",
|
602 |
+
"execution_count": null,
|
603 |
+
"metadata": {
|
604 |
+
"id": "5sxtTuUi-bXt"
|
605 |
+
},
|
606 |
+
"outputs": [],
|
607 |
+
"source": [
|
608 |
+
"example_inputs"
|
609 |
+
]
|
610 |
+
},
|
611 |
+
{
|
612 |
+
"cell_type": "code",
|
613 |
+
"execution_count": null,
|
614 |
+
"metadata": {
|
615 |
+
"id": "wP4z_-9a-dFk"
|
616 |
+
},
|
617 |
+
"outputs": [],
|
618 |
+
"source": [
|
619 |
+
"example_labels"
|
620 |
+
]
|
621 |
+
},
|
622 |
+
{
|
623 |
+
"cell_type": "code",
|
624 |
+
"execution_count": null,
|
625 |
+
"metadata": {
|
626 |
+
"id": "jyjTdGpFhO_1"
|
627 |
+
},
|
628 |
+
"outputs": [],
|
629 |
+
"source": [
|
630 |
+
"for key, value in example_inputs.items():\n",
|
631 |
+
" print(f'{key:15s} shape: {value.shape}')\n",
|
632 |
+
"\n",
|
633 |
+
"print(f'{\"labels\":15s} shape: {example_labels.shape}')"
|
634 |
+
]
|
635 |
+
},
|
636 |
+
{
|
637 |
+
"cell_type": "markdown",
|
638 |
+
"metadata": {
|
639 |
+
"id": "mkGHN_FK-50U"
|
640 |
+
},
|
641 |
+
"source": [
|
642 |
+
"The `input_word_ids` contain the token IDs:"
|
643 |
+
]
|
644 |
+
},
|
645 |
+
{
|
646 |
+
"cell_type": "code",
|
647 |
+
"execution_count": null,
|
648 |
+
"metadata": {
|
649 |
+
"id": "eGL1_ktWLcgF"
|
650 |
+
},
|
651 |
+
"outputs": [],
|
652 |
+
"source": [
|
653 |
+
"plt.pcolormesh(example_inputs['input_word_ids'])"
|
654 |
+
]
|
655 |
+
},
|
656 |
+
{
|
657 |
+
"cell_type": "markdown",
|
658 |
+
"metadata": {
|
659 |
+
"id": "ulNZ4U96-8JZ"
|
660 |
+
},
|
661 |
+
"source": [
|
662 |
+
"The mask allows the model to cleanly differentiate between the content and the padding. The mask has the same shape as the `input_word_ids`, and contains a `1` anywhere the `input_word_ids` is not padding."
|
663 |
+
]
|
664 |
+
},
|
665 |
+
{
|
666 |
+
"cell_type": "code",
|
667 |
+
"execution_count": null,
|
668 |
+
"metadata": {
|
669 |
+
"id": "zB7mW7DGK3rW"
|
670 |
+
},
|
671 |
+
"outputs": [],
|
672 |
+
"source": [
|
673 |
+
"plt.pcolormesh(example_inputs['input_mask'])"
|
674 |
+
]
|
675 |
+
},
|
676 |
+
{
|
677 |
+
"cell_type": "markdown",
|
678 |
+
"metadata": {
|
679 |
+
"id": "rxLenwAvCkBf"
|
680 |
+
},
|
681 |
+
"source": [
|
682 |
+
"The \"input type\" also has the same shape, but inside the non-padded region, contains a `0` or a `1` indicating which sentence the token is a part of."
|
683 |
+
]
|
684 |
+
},
|
685 |
+
{
|
686 |
+
"cell_type": "code",
|
687 |
+
"execution_count": null,
|
688 |
+
"metadata": {
|
689 |
+
"id": "2CetH_5C9P2m"
|
690 |
+
},
|
691 |
+
"outputs": [],
|
692 |
+
"source": [
|
693 |
+
"plt.pcolormesh(example_inputs['input_type_ids'])"
|
694 |
+
]
|
695 |
+
},
|
696 |
+
{
|
697 |
+
"cell_type": "markdown",
|
698 |
+
"metadata": {
|
699 |
+
"id": "pxHHeyei_sb9"
|
700 |
+
},
|
701 |
+
"source": [
|
702 |
+
"Apply the same preprocessing to the validation and test subsets of the GLUE MRPC dataset:"
|
703 |
+
]
|
704 |
+
},
|
705 |
+
{
|
706 |
+
"cell_type": "code",
|
707 |
+
"execution_count": null,
|
708 |
+
"metadata": {
|
709 |
+
"id": "yuLKxf6zHxw-"
|
710 |
+
},
|
711 |
+
"outputs": [],
|
712 |
+
"source": [
|
713 |
+
"glue_validation = glue['validation'].map(bert_inputs_processor).prefetch(1)\n",
|
714 |
+
"glue_test = glue['test'].map(bert_inputs_processor).prefetch(1)"
|
715 |
+
]
|
716 |
+
},
|
717 |
+
{
|
718 |
+
"cell_type": "markdown",
|
719 |
+
"metadata": {
|
720 |
+
"id": "FSwymsbkbLDA"
|
721 |
+
},
|
722 |
+
"source": [
|
723 |
+
"## Build, train and export the model"
|
724 |
+
]
|
725 |
+
},
|
726 |
+
{
|
727 |
+
"cell_type": "markdown",
|
728 |
+
"metadata": {
|
729 |
+
"id": "bxxO3pJCEM9p"
|
730 |
+
},
|
731 |
+
"source": [
|
732 |
+
"Now that you have formatted the data as expected, you can start working on building and training the model."
|
733 |
+
]
|
734 |
+
},
|
735 |
+
{
|
736 |
+
"cell_type": "markdown",
|
737 |
+
"metadata": {
|
738 |
+
"id": "Efrj3Cn1kLAp"
|
739 |
+
},
|
740 |
+
"source": [
|
741 |
+
"### Build the model\n"
|
742 |
+
]
|
743 |
+
},
|
744 |
+
{
|
745 |
+
"cell_type": "markdown",
|
746 |
+
"metadata": {
|
747 |
+
"id": "xxpOY5r2Ayq6"
|
748 |
+
},
|
749 |
+
"source": [
|
750 |
+
"The first step is to download the configuration file—`config_dict`—for the pre-trained BERT model:\n"
|
751 |
+
]
|
752 |
+
},
|
753 |
+
{
|
754 |
+
"cell_type": "code",
|
755 |
+
"execution_count": null,
|
756 |
+
"metadata": {
|
757 |
+
"id": "v7ap0BONSJuz"
|
758 |
+
},
|
759 |
+
"outputs": [],
|
760 |
+
"source": [
|
761 |
+
"import json\n",
|
762 |
+
"\n",
|
763 |
+
"bert_config_file = os.path.join(gs_folder_bert, \"bert_config.json\")\n",
|
764 |
+
"config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())\n",
|
765 |
+
"config_dict"
|
766 |
+
]
|
767 |
+
},
|
768 |
+
{
|
769 |
+
"cell_type": "code",
|
770 |
+
"execution_count": null,
|
771 |
+
"metadata": {
|
772 |
+
"id": "pKaEaKJSX85J"
|
773 |
+
},
|
774 |
+
"outputs": [],
|
775 |
+
"source": [
|
776 |
+
"encoder_config = tfm.nlp.encoders.EncoderConfig({\n",
|
777 |
+
" 'type':'bert',\n",
|
778 |
+
" 'bert': config_dict\n",
|
779 |
+
"})"
|
780 |
+
]
|
781 |
+
},
|
782 |
+
{
|
783 |
+
"cell_type": "code",
|
784 |
+
"execution_count": null,
|
785 |
+
"metadata": {
|
786 |
+
"id": "LbgzWukNSqOS"
|
787 |
+
},
|
788 |
+
"outputs": [],
|
789 |
+
"source": [
|
790 |
+
"bert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)\n",
|
791 |
+
"bert_encoder"
|
792 |
+
]
|
793 |
+
},
|
794 |
+
{
|
795 |
+
"cell_type": "markdown",
|
796 |
+
"metadata": {
|
797 |
+
"id": "96ldxDSwkVkj"
|
798 |
+
},
|
799 |
+
"source": [
|
800 |
+
"The configuration file defines the core BERT model from the Model Garden, which is a Keras model that predicts the outputs of `num_classes` from the inputs with maximum sequence length `max_seq_length`."
|
801 |
+
]
|
802 |
+
},
|
803 |
+
{
|
804 |
+
"cell_type": "code",
|
805 |
+
"execution_count": null,
|
806 |
+
"metadata": {
|
807 |
+
"id": "cH682__U0FBv"
|
808 |
+
},
|
809 |
+
"outputs": [],
|
810 |
+
"source": [
|
811 |
+
"bert_classifier = tfm.nlp.models.BertClassifier(network=bert_encoder, num_classes=2)"
|
812 |
+
]
|
813 |
+
},
|
814 |
+
{
|
815 |
+
"cell_type": "markdown",
|
816 |
+
"metadata": {
|
817 |
+
"id": "sFmVG4SKZAw8"
|
818 |
+
},
|
819 |
+
"source": [
|
820 |
+
"Run it on a test batch of data 10 examples from the training set. The output is the logits for the two classes:"
|
821 |
+
]
|
822 |
+
},
|
823 |
+
{
|
824 |
+
"cell_type": "code",
|
825 |
+
"execution_count": null,
|
826 |
+
"metadata": {
|
827 |
+
"id": "VTjgPbp4ZDKo"
|
828 |
+
},
|
829 |
+
"outputs": [],
|
830 |
+
"source": [
|
831 |
+
"bert_classifier(\n",
|
832 |
+
" example_inputs, training=True).numpy()[:10]"
|
833 |
+
]
|
834 |
+
},
|
835 |
+
{
|
836 |
+
"cell_type": "markdown",
|
837 |
+
"metadata": {
|
838 |
+
"id": "Q0NTdwZsQK8n"
|
839 |
+
},
|
840 |
+
"source": [
|
841 |
+
"The `TransformerEncoder` in the center of the classifier above **is** the `bert_encoder`.\n",
|
842 |
+
"\n",
|
843 |
+
"If you inspect the encoder, notice the stack of `Transformer` layers connected to those same three inputs:"
|
844 |
+
]
|
845 |
+
},
|
846 |
+
{
|
847 |
+
"cell_type": "code",
|
848 |
+
"execution_count": null,
|
849 |
+
"metadata": {
|
850 |
+
"id": "8L__-erBwLIQ"
|
851 |
+
},
|
852 |
+
"outputs": [],
|
853 |
+
"source": [
|
854 |
+
"tf.keras.utils.plot_model(bert_encoder, show_shapes=True, dpi=48)"
|
855 |
+
]
|
856 |
+
},
|
857 |
+
{
|
858 |
+
"cell_type": "markdown",
|
859 |
+
"metadata": {
|
860 |
+
"id": "mKAvkQc3heSy"
|
861 |
+
},
|
862 |
+
"source": [
|
863 |
+
"### Restore the encoder weights\n",
|
864 |
+
"\n",
|
865 |
+
"When built, the encoder is randomly initialized. Restore the encoder's weights from the checkpoint:"
|
866 |
+
]
|
867 |
+
},
|
868 |
+
{
|
869 |
+
"cell_type": "code",
|
870 |
+
"execution_count": null,
|
871 |
+
"metadata": {
|
872 |
+
"id": "97Ll2Gichd_Y"
|
873 |
+
},
|
874 |
+
"outputs": [],
|
875 |
+
"source": [
|
876 |
+
"checkpoint = tf.train.Checkpoint(encoder=bert_encoder)\n",
|
877 |
+
"checkpoint.read(\n",
|
878 |
+
" os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
|
879 |
+
]
|
880 |
+
},
|
881 |
+
{
|
882 |
+
"cell_type": "markdown",
|
883 |
+
"metadata": {
|
884 |
+
"id": "2oHOql35k3Dd"
|
885 |
+
},
|
886 |
+
"source": [
|
887 |
+
"Note: The pre-trained `TransformerEncoder` is also available on [TensorFlow Hub](https://tensorflow.org/hub). Go to the [TF Hub appendix](#hub_bert) for details."
|
888 |
+
]
|
889 |
+
},
|
890 |
+
{
|
891 |
+
"cell_type": "markdown",
|
892 |
+
"metadata": {
|
893 |
+
"id": "115caFLMk-_l"
|
894 |
+
},
|
895 |
+
"source": [
|
896 |
+
"### Set up the optimizer\n",
|
897 |
+
"\n",
|
898 |
+
"BERT typically uses the Adam optimizer with weight decay—[AdamW](https://arxiv.org/abs/1711.05101) (`tf.keras.optimizers.experimental.AdamW`).\n",
|
899 |
+
"It also employs a learning rate schedule that first warms up from 0 and then decays to 0:"
|
900 |
+
]
|
901 |
+
},
|
902 |
+
{
|
903 |
+
"cell_type": "code",
|
904 |
+
"execution_count": null,
|
905 |
+
"metadata": {
|
906 |
+
"id": "c0jBycPDtkxR"
|
907 |
+
},
|
908 |
+
"outputs": [],
|
909 |
+
"source": [
|
910 |
+
"# Set up epochs and steps\n",
|
911 |
+
"epochs = 5\n",
|
912 |
+
"batch_size = 32\n",
|
913 |
+
"eval_batch_size = 32\n",
|
914 |
+
"\n",
|
915 |
+
"train_data_size = info.splits['train'].num_examples\n",
|
916 |
+
"steps_per_epoch = int(train_data_size / batch_size)\n",
|
917 |
+
"num_train_steps = steps_per_epoch * epochs\n",
|
918 |
+
"warmup_steps = int(0.1 * num_train_steps)\n",
|
919 |
+
"initial_learning_rate=2e-5"
|
920 |
+
]
|
921 |
+
},
|
922 |
+
{
|
923 |
+
"cell_type": "markdown",
|
924 |
+
"metadata": {
|
925 |
+
"id": "GFankgHK0Rvh"
|
926 |
+
},
|
927 |
+
"source": [
|
928 |
+
"Linear decay from `initial_learning_rate` to zero over `num_train_steps`."
|
929 |
+
]
|
930 |
+
},
|
931 |
+
{
|
932 |
+
"cell_type": "code",
|
933 |
+
"execution_count": null,
|
934 |
+
"metadata": {
|
935 |
+
"id": "qWSyT8P2j4mV"
|
936 |
+
},
|
937 |
+
"outputs": [],
|
938 |
+
"source": [
|
939 |
+
"linear_decay = tf.keras.optimizers.schedules.PolynomialDecay(\n",
|
940 |
+
" initial_learning_rate=initial_learning_rate,\n",
|
941 |
+
" end_learning_rate=0,\n",
|
942 |
+
" decay_steps=num_train_steps)"
|
943 |
+
]
|
944 |
+
},
|
945 |
+
{
|
946 |
+
"cell_type": "markdown",
|
947 |
+
"metadata": {
|
948 |
+
"id": "anZPZPAP0Y3n"
|
949 |
+
},
|
950 |
+
"source": [
|
951 |
+
"Warmup to that value over `warmup_steps`:"
|
952 |
+
]
|
953 |
+
},
|
954 |
+
{
|
955 |
+
"cell_type": "code",
|
956 |
+
"execution_count": null,
|
957 |
+
"metadata": {
|
958 |
+
"id": "z_AsVCiRkoN1"
|
959 |
+
},
|
960 |
+
"outputs": [],
|
961 |
+
"source": [
|
962 |
+
"warmup_schedule = tfm.optimization.lr_schedule.LinearWarmup(\n",
|
963 |
+
" warmup_learning_rate = 0,\n",
|
964 |
+
" after_warmup_lr_sched = linear_decay,\n",
|
965 |
+
" warmup_steps = warmup_steps\n",
|
966 |
+
")"
|
967 |
+
]
|
968 |
+
},
|
969 |
+
{
|
970 |
+
"cell_type": "markdown",
|
971 |
+
"metadata": {
|
972 |
+
"id": "arfbaK6t0kH_"
|
973 |
+
},
|
974 |
+
"source": [
|
975 |
+
"The overall schedule looks like this:"
|
976 |
+
]
|
977 |
+
},
|
978 |
+
{
|
979 |
+
"cell_type": "code",
|
980 |
+
"execution_count": null,
|
981 |
+
"metadata": {
|
982 |
+
"id": "rYZGunhqbGUZ"
|
983 |
+
},
|
984 |
+
"outputs": [],
|
985 |
+
"source": [
|
986 |
+
"x = tf.linspace(0, num_train_steps, 1001)\n",
|
987 |
+
"y = [warmup_schedule(xi) for xi in x]\n",
|
988 |
+
"plt.plot(x,y)\n",
|
989 |
+
"plt.xlabel('Train step')\n",
|
990 |
+
"plt.ylabel('Learning rate')"
|
991 |
+
]
|
992 |
+
},
|
993 |
+
{
|
994 |
+
"cell_type": "markdown",
|
995 |
+
"metadata": {
|
996 |
+
"id": "bjsmG_fm0opn"
|
997 |
+
},
|
998 |
+
"source": [
|
999 |
+
"Use `tf.keras.optimizers.experimental.AdamW` to instantiate the optimizer with that schedule:"
|
1000 |
+
]
|
1001 |
+
},
|
1002 |
+
{
|
1003 |
+
"cell_type": "code",
|
1004 |
+
"execution_count": null,
|
1005 |
+
"metadata": {
|
1006 |
+
"id": "R8pTNuKIw1dA"
|
1007 |
+
},
|
1008 |
+
"outputs": [],
|
1009 |
+
"source": [
|
1010 |
+
"optimizer = tf.keras.optimizers.experimental.Adam(\n",
|
1011 |
+
" learning_rate = warmup_schedule)"
|
1012 |
+
]
|
1013 |
+
},
|
1014 |
+
{
|
1015 |
+
"cell_type": "markdown",
|
1016 |
+
"metadata": {
|
1017 |
+
"id": "78FEUOOEkoP0"
|
1018 |
+
},
|
1019 |
+
"source": [
|
1020 |
+
"### Train the model"
|
1021 |
+
]
|
1022 |
+
},
|
1023 |
+
{
|
1024 |
+
"cell_type": "markdown",
|
1025 |
+
"metadata": {
|
1026 |
+
"id": "OTNcA0O0nSq9"
|
1027 |
+
},
|
1028 |
+
"source": [
|
1029 |
+
"Set the metric as accuracy and the loss as sparse categorical cross-entropy. Then, compile and train the BERT classifier:"
|
1030 |
+
]
|
1031 |
+
},
|
1032 |
+
{
|
1033 |
+
"cell_type": "code",
|
1034 |
+
"execution_count": null,
|
1035 |
+
"metadata": {
|
1036 |
+
"id": "d5FeL0b6j7ky"
|
1037 |
+
},
|
1038 |
+
"outputs": [],
|
1039 |
+
"source": [
|
1040 |
+
"metrics = [tf.keras.metrics.SparseCategoricalAccuracy('accuracy', dtype=tf.float32)]\n",
|
1041 |
+
"loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
|
1042 |
+
"\n",
|
1043 |
+
"bert_classifier.compile(\n",
|
1044 |
+
" optimizer=optimizer,\n",
|
1045 |
+
" loss=loss,\n",
|
1046 |
+
" metrics=metrics)"
|
1047 |
+
]
|
1048 |
+
},
|
1049 |
+
{
|
1050 |
+
"cell_type": "code",
|
1051 |
+
"execution_count": null,
|
1052 |
+
"metadata": {
|
1053 |
+
"id": "CsrylctIj_Xy"
|
1054 |
+
},
|
1055 |
+
"outputs": [],
|
1056 |
+
"source": [
|
1057 |
+
"bert_classifier.evaluate(glue_validation)"
|
1058 |
+
]
|
1059 |
+
},
|
1060 |
+
{
|
1061 |
+
"cell_type": "code",
|
1062 |
+
"execution_count": null,
|
1063 |
+
"metadata": {
|
1064 |
+
"id": "hgPPc2oNmcVZ"
|
1065 |
+
},
|
1066 |
+
"outputs": [],
|
1067 |
+
"source": [
|
1068 |
+
"bert_classifier.fit(\n",
|
1069 |
+
" glue_train,\n",
|
1070 |
+
" validation_data=(glue_validation),\n",
|
1071 |
+
" batch_size=32,\n",
|
1072 |
+
" epochs=epochs)"
|
1073 |
+
]
|
1074 |
+
},
|
1075 |
+
{
|
1076 |
+
"cell_type": "markdown",
|
1077 |
+
"metadata": {
|
1078 |
+
"id": "IFtKFWbNKb0u"
|
1079 |
+
},
|
1080 |
+
"source": [
|
1081 |
+
"Now run the fine-tuned model on a custom example to see that it works.\n",
|
1082 |
+
"\n",
|
1083 |
+
"Start by encoding some sentence pairs:"
|
1084 |
+
]
|
1085 |
+
},
|
1086 |
+
{
|
1087 |
+
"cell_type": "code",
|
1088 |
+
"execution_count": null,
|
1089 |
+
"metadata": {
|
1090 |
+
"id": "S1sdW6lLWaEi"
|
1091 |
+
},
|
1092 |
+
"outputs": [],
|
1093 |
+
"source": [
|
1094 |
+
"my_examples = {\n",
|
1095 |
+
" 'sentence1':[\n",
|
1096 |
+
" 'The rain in Spain falls mainly on the plain.',\n",
|
1097 |
+
" 'Look I fine tuned BERT.'],\n",
|
1098 |
+
" 'sentence2':[\n",
|
1099 |
+
" 'It mostly rains on the flat lands of Spain.',\n",
|
1100 |
+
" 'Is it working? This does not match.']\n",
|
1101 |
+
" }"
|
1102 |
+
]
|
1103 |
+
},
|
1104 |
+
{
|
1105 |
+
"cell_type": "markdown",
|
1106 |
+
"metadata": {
|
1107 |
+
"id": "7ynJibkBRTJF"
|
1108 |
+
},
|
1109 |
+
"source": [
|
1110 |
+
"The model should report class `1` \"match\" for the first example and class `0` \"no-match\" for the second:"
|
1111 |
+
]
|
1112 |
+
},
|
1113 |
+
{
|
1114 |
+
"cell_type": "code",
|
1115 |
+
"execution_count": null,
|
1116 |
+
"metadata": {
|
1117 |
+
"id": "umo0ttrgRYIM"
|
1118 |
+
},
|
1119 |
+
"outputs": [],
|
1120 |
+
"source": [
|
1121 |
+
"ex_packed = bert_inputs_processor(my_examples)\n",
|
1122 |
+
"my_logits = bert_classifier(ex_packed, training=False)\n",
|
1123 |
+
"\n",
|
1124 |
+
"result_cls_ids = tf.argmax(my_logits)\n",
|
1125 |
+
"result_cls_ids"
|
1126 |
+
]
|
1127 |
+
},
|
1128 |
+
{
|
1129 |
+
"cell_type": "code",
|
1130 |
+
"execution_count": null,
|
1131 |
+
"metadata": {
|
1132 |
+
"id": "HNdmOEHKT7e8"
|
1133 |
+
},
|
1134 |
+
"outputs": [],
|
1135 |
+
"source": [
|
1136 |
+
"tf.gather(tf.constant(info.features['label'].names), result_cls_ids)"
|
1137 |
+
]
|
1138 |
+
},
|
1139 |
+
{
|
1140 |
+
"cell_type": "markdown",
|
1141 |
+
"metadata": {
|
1142 |
+
"id": "fVo_AnT0l26j"
|
1143 |
+
},
|
1144 |
+
"source": [
|
1145 |
+
"### Export the model\n",
|
1146 |
+
"\n",
|
1147 |
+
"Often the goal of training a model is to _use_ it for something outside of the Python process that created it. You can do this by exporting the model using `tf.saved_model`. (Learn more in the [Using the SavedModel format](https://www.tensorflow.org/guide/saved_model) guide and the [Save and load a model using a distribution strategy](https://www.tensorflow.org/tutorials/distribute/save_and_load) tutorial.)\n",
|
1148 |
+
"\n",
|
1149 |
+
"First, build a wrapper class to export the model. This wrapper does two things:\n",
|
1150 |
+
"\n",
|
1151 |
+
"- First, it packages `bert_inputs_processor` and `bert_classifier` together into a single `tf.Module`, so you can export all the functionalities.\n",
|
1152 |
+
"- Second, it defines a `tf.function` that implements the end-to-end execution of the model.\n",
|
1153 |
+
"\n",
|
1154 |
+
"Setting the `input_signature` argument of `tf.function` lets you define a fixed signature for the `tf.function`. This can be less surprising than the default automatic retracing behavior."
|
1155 |
+
]
|
1156 |
+
},
|
1157 |
+
{
|
1158 |
+
"cell_type": "code",
|
1159 |
+
"execution_count": null,
|
1160 |
+
"metadata": {
|
1161 |
+
"id": "78h83mlt9wpY"
|
1162 |
+
},
|
1163 |
+
"outputs": [],
|
1164 |
+
"source": [
|
1165 |
+
"class ExportModel(tf.Module):\n",
|
1166 |
+
" def __init__(self, input_processor, classifier):\n",
|
1167 |
+
" self.input_processor = input_processor\n",
|
1168 |
+
" self.classifier = classifier\n",
|
1169 |
+
"\n",
|
1170 |
+
" @tf.function(input_signature=[{\n",
|
1171 |
+
" 'sentence1': tf.TensorSpec(shape=[None], dtype=tf.string),\n",
|
1172 |
+
" 'sentence2': tf.TensorSpec(shape=[None], dtype=tf.string)}])\n",
|
1173 |
+
" def __call__(self, inputs):\n",
|
1174 |
+
" packed = self.input_processor(inputs)\n",
|
1175 |
+
" logits = self.classifier(packed, training=False)\n",
|
1176 |
+
" result_cls_ids = tf.argmax(logits)\n",
|
1177 |
+
" return {\n",
|
1178 |
+
" 'logits': logits,\n",
|
1179 |
+
" 'class_id': result_cls_ids,\n",
|
1180 |
+
" 'class': tf.gather(\n",
|
1181 |
+
" tf.constant(info.features['label'].names),\n",
|
1182 |
+
" result_cls_ids)\n",
|
1183 |
+
" }"
|
1184 |
+
]
|
1185 |
+
},
|
1186 |
+
{
|
1187 |
+
"cell_type": "markdown",
|
1188 |
+
"metadata": {
|
1189 |
+
"id": "qnxysGUfIgFQ"
|
1190 |
+
},
|
1191 |
+
"source": [
|
1192 |
+
"Create an instance of this exported model and save it:"
|
1193 |
+
]
|
1194 |
+
},
|
1195 |
+
{
|
1196 |
+
"cell_type": "code",
|
1197 |
+
"execution_count": null,
|
1198 |
+
"metadata": {
|
1199 |
+
"id": "TmHW9DEFUZ0X"
|
1200 |
+
},
|
1201 |
+
"outputs": [],
|
1202 |
+
"source": [
|
1203 |
+
"export_model = ExportModel(bert_inputs_processor, bert_classifier)"
|
1204 |
+
]
|
1205 |
+
},
|
1206 |
+
{
|
1207 |
+
"cell_type": "code",
|
1208 |
+
"execution_count": null,
|
1209 |
+
"metadata": {
|
1210 |
+
"id": "Nl5x6nElZqkP"
|
1211 |
+
},
|
1212 |
+
"outputs": [],
|
1213 |
+
"source": [
|
1214 |
+
"import tempfile\n",
|
1215 |
+
"export_dir=tempfile.mkdtemp(suffix='_saved_model')\n",
|
1216 |
+
"tf.saved_model.save(export_model, export_dir=export_dir,\n",
|
1217 |
+
" signatures={'serving_default': export_model.__call__})"
|
1218 |
+
]
|
1219 |
+
},
|
1220 |
+
{
|
1221 |
+
"cell_type": "markdown",
|
1222 |
+
"metadata": {
|
1223 |
+
"id": "Pd8B5dy-ImDJ"
|
1224 |
+
},
|
1225 |
+
"source": [
|
1226 |
+
"Reload the model and compare the results to the original:"
|
1227 |
+
]
|
1228 |
+
},
|
1229 |
+
{
|
1230 |
+
"cell_type": "code",
|
1231 |
+
"execution_count": null,
|
1232 |
+
"metadata": {
|
1233 |
+
"id": "9cAhHySVXHD5"
|
1234 |
+
},
|
1235 |
+
"outputs": [],
|
1236 |
+
"source": [
|
1237 |
+
"original_logits = export_model(my_examples)['logits']"
|
1238 |
+
]
|
1239 |
+
},
|
1240 |
+
{
|
1241 |
+
"cell_type": "code",
|
1242 |
+
"execution_count": null,
|
1243 |
+
"metadata": {
|
1244 |
+
"id": "H9cAcYwfW2fy"
|
1245 |
+
},
|
1246 |
+
"outputs": [],
|
1247 |
+
"source": [
|
1248 |
+
"reloaded = tf.saved_model.load(export_dir)\n",
|
1249 |
+
"reloaded_logits = reloaded(my_examples)['logits']"
|
1250 |
+
]
|
1251 |
+
},
|
1252 |
+
{
|
1253 |
+
"cell_type": "code",
|
1254 |
+
"execution_count": null,
|
1255 |
+
"metadata": {
|
1256 |
+
"id": "y_ACvKPsVUXC"
|
1257 |
+
},
|
1258 |
+
"outputs": [],
|
1259 |
+
"source": [
|
1260 |
+
"# The results are identical:\n",
|
1261 |
+
"print(original_logits.numpy())\n",
|
1262 |
+
"print()\n",
|
1263 |
+
"print(reloaded_logits.numpy())"
|
1264 |
+
]
|
1265 |
+
},
|
1266 |
+
{
|
1267 |
+
"cell_type": "code",
|
1268 |
+
"execution_count": null,
|
1269 |
+
"metadata": {
|
1270 |
+
"id": "lBlPP20dXPFR"
|
1271 |
+
},
|
1272 |
+
"outputs": [],
|
1273 |
+
"source": [
|
1274 |
+
"print(np.mean(abs(original_logits - reloaded_logits)))"
|
1275 |
+
]
|
1276 |
+
},
|
1277 |
+
{
|
1278 |
+
"cell_type": "markdown",
|
1279 |
+
"metadata": {
|
1280 |
+
"id": "CPsg7dZwfBM2"
|
1281 |
+
},
|
1282 |
+
"source": [
|
1283 |
+
"Congratulations! You've used `tensorflow_models` to build a BERT-classifier, train it, and export it for later use."
|
1284 |
+
]
|
1285 |
+
},
|
1286 |
+
{
|
1287 |
+
"cell_type": "markdown",
|
1288 |
+
"metadata": {
|
1289 |
+
"id": "eQceYqRFT_Eg"
|
1290 |
+
},
|
1291 |
+
"source": [
|
1292 |
+
"## Optional: BERT on TF Hub"
|
1293 |
+
]
|
1294 |
+
},
|
1295 |
+
{
|
1296 |
+
"cell_type": "markdown",
|
1297 |
+
"metadata": {
|
1298 |
+
"id": "QbklKt-w_CiI"
|
1299 |
+
},
|
1300 |
+
"source": [
|
1301 |
+
"\u003ca id=\"hub_bert\"\u003e\u003c/a\u003e\n",
|
1302 |
+
"\n",
|
1303 |
+
"\n",
|
1304 |
+
"You can get the BERT model off the shelf from [TF Hub](https://tfhub.dev/). There are [many versions available along with their input preprocessors](https://tfhub.dev/google/collections/bert/1).\n",
|
1305 |
+
"\n",
|
1306 |
+
"This example uses [a small version of BERT from TF Hub](https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2) that was pre-trained using the English Wikipedia and BooksCorpus datasets, similar to the [original implementation](https://arxiv.org/abs/1908.08962) (Turc et al., 2019).\n",
|
1307 |
+
"\n",
|
1308 |
+
"Start by importing TF Hub:"
|
1309 |
+
]
|
1310 |
+
},
|
1311 |
+
{
|
1312 |
+
"cell_type": "code",
|
1313 |
+
"execution_count": null,
|
1314 |
+
"metadata": {
|
1315 |
+
"id": "GDWrHm0BGpbX"
|
1316 |
+
},
|
1317 |
+
"outputs": [],
|
1318 |
+
"source": [
|
1319 |
+
"import tensorflow_hub as hub"
|
1320 |
+
]
|
1321 |
+
},
|
1322 |
+
{
|
1323 |
+
"cell_type": "markdown",
|
1324 |
+
"metadata": {
|
1325 |
+
"id": "f02f38f83ac4"
|
1326 |
+
},
|
1327 |
+
"source": [
|
1328 |
+
"Select the input preprocessor and the model from TF Hub and wrap them as `hub.KerasLayer` layers:"
|
1329 |
+
]
|
1330 |
+
},
|
1331 |
+
{
|
1332 |
+
"cell_type": "code",
|
1333 |
+
"execution_count": null,
|
1334 |
+
"metadata": {
|
1335 |
+
"id": "lo6479At4sP1"
|
1336 |
+
},
|
1337 |
+
"outputs": [],
|
1338 |
+
"source": [
|
1339 |
+
"# Always make sure you use the right preprocessor.\n",
|
1340 |
+
"hub_preprocessor = hub.KerasLayer(\n",
|
1341 |
+
" \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\")\n",
|
1342 |
+
"\n",
|
1343 |
+
"# This is a really small BERT.\n",
|
1344 |
+
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2\",\n",
|
1345 |
+
" trainable=True)\n",
|
1346 |
+
"\n",
|
1347 |
+
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
|
1348 |
+
]
|
1349 |
+
},
|
1350 |
+
{
|
1351 |
+
"cell_type": "markdown",
|
1352 |
+
"metadata": {
|
1353 |
+
"id": "iTzF574wivQv"
|
1354 |
+
},
|
1355 |
+
"source": [
|
1356 |
+
"Test run the preprocessor on a batch of data:"
|
1357 |
+
]
|
1358 |
+
},
|
1359 |
+
{
|
1360 |
+
"cell_type": "code",
|
1361 |
+
"execution_count": null,
|
1362 |
+
"metadata": {
|
1363 |
+
"id": "GOASSKR5R3-N"
|
1364 |
+
},
|
1365 |
+
"outputs": [],
|
1366 |
+
"source": [
|
1367 |
+
"hub_inputs = hub_preprocessor(['Hello TensorFlow!'])\n",
|
1368 |
+
"{key: value[0, :10].numpy() for key, value in hub_inputs.items()} "
|
1369 |
+
]
|
1370 |
+
},
|
1371 |
+
{
|
1372 |
+
"cell_type": "code",
|
1373 |
+
"execution_count": null,
|
1374 |
+
"metadata": {
|
1375 |
+
"id": "XEcYrCR45Uwo"
|
1376 |
+
},
|
1377 |
+
"outputs": [],
|
1378 |
+
"source": [
|
1379 |
+
"result = hub_encoder(\n",
|
1380 |
+
" inputs=hub_inputs,\n",
|
1381 |
+
" training=False,\n",
|
1382 |
+
")\n",
|
1383 |
+
"\n",
|
1384 |
+
"print(\"Pooled output shape:\", result['pooled_output'].shape)\n",
|
1385 |
+
"print(\"Sequence output shape:\", result['sequence_output'].shape)"
|
1386 |
+
]
|
1387 |
+
},
|
1388 |
+
{
|
1389 |
+
"cell_type": "markdown",
|
1390 |
+
"metadata": {
|
1391 |
+
"id": "cjojn8SmLSRI"
|
1392 |
+
},
|
1393 |
+
"source": [
|
1394 |
+
"At this point, it would be simple to add a classification head yourself.\n",
|
1395 |
+
"\n",
|
1396 |
+
"The Model Garden `tfm.nlp.models.BertClassifier` class can also build a classifier onto the TF Hub encoder:"
|
1397 |
+
]
|
1398 |
+
},
|
1399 |
+
{
|
1400 |
+
"cell_type": "code",
|
1401 |
+
"execution_count": null,
|
1402 |
+
"metadata": {
|
1403 |
+
"id": "9nTDaApyLR70"
|
1404 |
+
},
|
1405 |
+
"outputs": [],
|
1406 |
+
"source": [
|
1407 |
+
"hub_classifier = tfm.nlp.models.BertClassifier(\n",
|
1408 |
+
" bert_encoder,\n",
|
1409 |
+
" num_classes=2,\n",
|
1410 |
+
" dropout_rate=0.1,\n",
|
1411 |
+
" initializer=tf.keras.initializers.TruncatedNormal(\n",
|
1412 |
+
" stddev=0.02))"
|
1413 |
+
]
|
1414 |
+
},
|
1415 |
+
{
|
1416 |
+
"cell_type": "markdown",
|
1417 |
+
"metadata": {
|
1418 |
+
"id": "xMJX3wV0_v7I"
|
1419 |
+
},
|
1420 |
+
"source": [
|
1421 |
+
"The one downside to loading this model from TF Hub is that the structure of internal Keras layers is not restored. This makes it more difficult to inspect or modify the model.\n",
|
1422 |
+
"\n",
|
1423 |
+
"The BERT encoder model—`hub_classifier`—is now a single layer."
|
1424 |
+
]
|
1425 |
+
},
|
1426 |
+
{
|
1427 |
+
"cell_type": "markdown",
|
1428 |
+
"metadata": {
|
1429 |
+
"id": "u_IqwXjRV1vd"
|
1430 |
+
},
|
1431 |
+
"source": [
|
1432 |
+
"For concrete examples of this approach, refer to [Solve Glue tasks using the BERT](https://www.tensorflow.org/text/tutorials/bert_glue)."
|
1433 |
+
]
|
1434 |
+
},
|
1435 |
+
{
|
1436 |
+
"cell_type": "markdown",
|
1437 |
+
"metadata": {
|
1438 |
+
"id": "ji3tdLz101km"
|
1439 |
+
},
|
1440 |
+
"source": [
|
1441 |
+
"## Optional: Optimizer `config`s\n",
|
1442 |
+
"\n",
|
1443 |
+
"The `tensorflow_models` package defines serializable `config` classes that describe how to build the live objects. Earlier in this tutorial, you built the optimizer manually.\n",
|
1444 |
+
"\n",
|
1445 |
+
"The configuration below describes an (almost) identical optimizer built by the `optimizer_factory.OptimizerFactory`:"
|
1446 |
+
]
|
1447 |
+
},
|
1448 |
+
{
|
1449 |
+
"cell_type": "code",
|
1450 |
+
"execution_count": null,
|
1451 |
+
"metadata": {
|
1452 |
+
"id": "Fdb9C1ontnH_"
|
1453 |
+
},
|
1454 |
+
"outputs": [],
|
1455 |
+
"source": [
|
1456 |
+
"optimization_config = tfm.optimization.OptimizationConfig(\n",
|
1457 |
+
" optimizer=tfm.optimization.OptimizerConfig(\n",
|
1458 |
+
" type = \"adam\"),\n",
|
1459 |
+
" learning_rate = tfm.optimization.LrConfig(\n",
|
1460 |
+
" type='polynomial',\n",
|
1461 |
+
" polynomial=tfm.optimization.PolynomialLrConfig(\n",
|
1462 |
+
" initial_learning_rate=2e-5,\n",
|
1463 |
+
" end_learning_rate=0.0,\n",
|
1464 |
+
" decay_steps=num_train_steps)),\n",
|
1465 |
+
" warmup = tfm.optimization.WarmupConfig(\n",
|
1466 |
+
" type='linear',\n",
|
1467 |
+
" linear=tfm.optimization.LinearWarmupConfig(warmup_steps=warmup_steps)\n",
|
1468 |
+
" ))\n",
|
1469 |
+
"\n",
|
1470 |
+
"\n",
|
1471 |
+
"fac = tfm.optimization.optimizer_factory.OptimizerFactory(optimization_config)\n",
|
1472 |
+
"lr = fac.build_learning_rate()\n",
|
1473 |
+
"optimizer = fac.build_optimizer(lr=lr)"
|
1474 |
+
]
|
1475 |
+
},
|
1476 |
+
{
|
1477 |
+
"cell_type": "code",
|
1478 |
+
"execution_count": null,
|
1479 |
+
"metadata": {
|
1480 |
+
"id": "Rp7R1hBfv5HG"
|
1481 |
+
},
|
1482 |
+
"outputs": [],
|
1483 |
+
"source": [
|
1484 |
+
"x = tf.linspace(0, num_train_steps, 1001).numpy()\n",
|
1485 |
+
"y = [lr(xi) for xi in x]\n",
|
1486 |
+
"plt.plot(x,y)\n",
|
1487 |
+
"plt.xlabel('Train step')\n",
|
1488 |
+
"plt.ylabel('Learning rate')"
|
1489 |
+
]
|
1490 |
+
},
|
1491 |
+
{
|
1492 |
+
"cell_type": "markdown",
|
1493 |
+
"metadata": {
|
1494 |
+
"id": "ywn5miD_dnuh"
|
1495 |
+
},
|
1496 |
+
"source": [
|
1497 |
+
"The advantage of using `config` objects is that they don't contain any complicated TensorFlow objects, and can be easily serialized to JSON, and rebuilt. Here's the JSON for the above `tfm.optimization.OptimizationConfig`:"
|
1498 |
+
]
|
1499 |
+
},
|
1500 |
+
{
|
1501 |
+
"cell_type": "code",
|
1502 |
+
"execution_count": null,
|
1503 |
+
"metadata": {
|
1504 |
+
"id": "zo5RV5lud81Y"
|
1505 |
+
},
|
1506 |
+
"outputs": [],
|
1507 |
+
"source": [
|
1508 |
+
"optimization_config = optimization_config.as_dict()\n",
|
1509 |
+
"optimization_config"
|
1510 |
+
]
|
1511 |
+
},
|
1512 |
+
{
|
1513 |
+
"cell_type": "markdown",
|
1514 |
+
"metadata": {
|
1515 |
+
"id": "Z6qPXPEhekkd"
|
1516 |
+
},
|
1517 |
+
"source": [
|
1518 |
+
"The `tfm.optimization.optimizer_factory.OptimizerFactory` can just as easily build the optimizer from the JSON dictionary:"
|
1519 |
+
]
|
1520 |
+
},
|
1521 |
+
{
|
1522 |
+
"cell_type": "code",
|
1523 |
+
"execution_count": null,
|
1524 |
+
"metadata": {
|
1525 |
+
"id": "p-bYrvfMYsxp"
|
1526 |
+
},
|
1527 |
+
"outputs": [],
|
1528 |
+
"source": [
|
1529 |
+
"fac = tfm.optimization.optimizer_factory.OptimizerFactory(\n",
|
1530 |
+
" tfm.optimization.OptimizationConfig(optimization_config))\n",
|
1531 |
+
"lr = fac.build_learning_rate()\n",
|
1532 |
+
"optimizer = fac.build_optimizer(lr=lr)"
|
1533 |
+
]
|
1534 |
+
}
|
1535 |
+
],
|
1536 |
+
"metadata": {
|
1537 |
+
"accelerator": "GPU",
|
1538 |
+
"colab": {
|
1539 |
+
"name": "fine_tune_bert.ipynb",
|
1540 |
+
"private_outputs": true,
|
1541 |
+
"toc_visible": true
|
1542 |
+
},
|
1543 |
+
"kernelspec": {
|
1544 |
+
"display_name": "Python 3",
|
1545 |
+
"name": "python3"
|
1546 |
+
}
|
1547 |
+
},
|
1548 |
+
"nbformat": 4,
|
1549 |
+
"nbformat_minor": 0
|
1550 |
+
}
|
Tensorflow/models/docs/nlp/index.ipynb
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "80xnUmoI7fBX"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"##### Copyright 2020 The TensorFlow Authors."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"cellView": "form",
|
17 |
+
"id": "8nvTnfs6Q692"
|
18 |
+
},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
22 |
+
"# you may not use this file except in compliance with the License.\n",
|
23 |
+
"# You may obtain a copy of the License at\n",
|
24 |
+
"#\n",
|
25 |
+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
26 |
+
"#\n",
|
27 |
+
"# Unless required by applicable law or agreed to in writing, software\n",
|
28 |
+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
29 |
+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
30 |
+
"# See the License for the specific language governing permissions and\n",
|
31 |
+
"# limitations under the License."
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"metadata": {
|
37 |
+
"id": "WmfcMK5P5C1G"
|
38 |
+
},
|
39 |
+
"source": [
|
40 |
+
"# Introduction to the TensorFlow Models NLP library"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"metadata": {
|
46 |
+
"id": "cH-oJ8R6AHMK"
|
47 |
+
},
|
48 |
+
"source": [
|
49 |
+
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
|
50 |
+
" \u003ctd\u003e\n",
|
51 |
+
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
|
52 |
+
" \u003c/td\u003e\n",
|
53 |
+
" \u003ctd\u003e\n",
|
54 |
+
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
|
55 |
+
" \u003c/td\u003e\n",
|
56 |
+
" \u003ctd\u003e\n",
|
57 |
+
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
|
58 |
+
" \u003c/td\u003e\n",
|
59 |
+
" \u003ctd\u003e\n",
|
60 |
+
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
|
61 |
+
" \u003c/td\u003e\n",
|
62 |
+
"\u003c/table\u003e"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "markdown",
|
67 |
+
"metadata": {
|
68 |
+
"id": "0H_EFIhq4-MJ"
|
69 |
+
},
|
70 |
+
"source": [
|
71 |
+
"## Learning objectives\n",
|
72 |
+
"\n",
|
73 |
+
"In this Colab notebook, you will learn how to build transformer-based models for common NLP tasks including pretraining, span labelling and classification using the building blocks from [NLP modeling library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling)."
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "markdown",
|
78 |
+
"metadata": {
|
79 |
+
"id": "2N97-dps_nUk"
|
80 |
+
},
|
81 |
+
"source": [
|
82 |
+
"## Install and import"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "markdown",
|
87 |
+
"metadata": {
|
88 |
+
"id": "459ygAVl_rg0"
|
89 |
+
},
|
90 |
+
"source": [
|
91 |
+
"### Install the TensorFlow Model Garden pip package\n",
|
92 |
+
"\n",
|
93 |
+
"* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
|
94 |
+
"which is the nightly Model Garden package created daily automatically.\n",
|
95 |
+
"* `pip` will install all models and dependencies automatically."
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "code",
|
100 |
+
"execution_count": null,
|
101 |
+
"metadata": {
|
102 |
+
"id": "Y-qGkdh6_sZc"
|
103 |
+
},
|
104 |
+
"outputs": [],
|
105 |
+
"source": [
|
106 |
+
"!pip install tf-models-official"
|
107 |
+
]
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"cell_type": "markdown",
|
111 |
+
"metadata": {
|
112 |
+
"id": "e4huSSwyAG_5"
|
113 |
+
},
|
114 |
+
"source": [
|
115 |
+
"### Import Tensorflow and other libraries"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": null,
|
121 |
+
"metadata": {
|
122 |
+
"id": "jqYXqtjBAJd9"
|
123 |
+
},
|
124 |
+
"outputs": [],
|
125 |
+
"source": [
|
126 |
+
"import numpy as np\n",
|
127 |
+
"import tensorflow as tf\n",
|
128 |
+
"\n",
|
129 |
+
"from tensorflow_models import nlp"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "markdown",
|
134 |
+
"metadata": {
|
135 |
+
"id": "djBQWjvy-60Y"
|
136 |
+
},
|
137 |
+
"source": [
|
138 |
+
"## BERT pretraining model\n",
|
139 |
+
"\n",
|
140 |
+
"BERT ([Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)) introduced the method of pre-training language representations on a large text corpus and then using that model for downstream NLP tasks.\n",
|
141 |
+
"\n",
|
142 |
+
"In this section, we will learn how to build a model to pretrain BERT on the masked language modeling task and next sentence prediction task. For simplicity, we only show the minimum example and use dummy data."
|
143 |
+
]
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"cell_type": "markdown",
|
147 |
+
"metadata": {
|
148 |
+
"id": "MKuHVlsCHmiq"
|
149 |
+
},
|
150 |
+
"source": [
|
151 |
+
"### Build a `BertPretrainer` model wrapping `BertEncoder`\n",
|
152 |
+
"\n",
|
153 |
+
"The `nlp.networks.BertEncoder` class implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers (`nlp.layers.TransformerEncoderBlock`), but not the masked language model or classification task networks.\n",
|
154 |
+
"\n",
|
155 |
+
"The `nlp.models.BertPretrainer` class allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives."
|
156 |
+
]
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "code",
|
160 |
+
"execution_count": null,
|
161 |
+
"metadata": {
|
162 |
+
"id": "EXkcXz-9BwB3"
|
163 |
+
},
|
164 |
+
"outputs": [],
|
165 |
+
"source": [
|
166 |
+
"# Build a small transformer network.\n",
|
167 |
+
"vocab_size = 100\n",
|
168 |
+
"network = nlp.networks.BertEncoder(\n",
|
169 |
+
" vocab_size=vocab_size, \n",
|
170 |
+
" # The number of TransformerEncoderBlock layers\n",
|
171 |
+
" num_layers=3)"
|
172 |
+
]
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "markdown",
|
176 |
+
"metadata": {
|
177 |
+
"id": "0NH5irV5KTMS"
|
178 |
+
},
|
179 |
+
"source": [
|
180 |
+
"Inspecting the encoder, we see it contains few embedding layers, stacked `nlp.layers.TransformerEncoderBlock` layers and are connected to three input layers:\n",
|
181 |
+
"\n",
|
182 |
+
"`input_word_ids`, `input_type_ids` and `input_mask`.\n"
|
183 |
+
]
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"cell_type": "code",
|
187 |
+
"execution_count": null,
|
188 |
+
"metadata": {
|
189 |
+
"id": "lZNoZkBrIoff"
|
190 |
+
},
|
191 |
+
"outputs": [],
|
192 |
+
"source": [
|
193 |
+
"tf.keras.utils.plot_model(network, show_shapes=True, expand_nested=True, dpi=48)"
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"execution_count": null,
|
199 |
+
"metadata": {
|
200 |
+
"id": "o7eFOZXiIl-b"
|
201 |
+
},
|
202 |
+
"outputs": [],
|
203 |
+
"source": [
|
204 |
+
"# Create a BERT pretrainer with the created network.\n",
|
205 |
+
"num_token_predictions = 8\n",
|
206 |
+
"bert_pretrainer = nlp.models.BertPretrainer(\n",
|
207 |
+
" network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "markdown",
|
212 |
+
"metadata": {
|
213 |
+
"id": "d5h5HT7gNHx_"
|
214 |
+
},
|
215 |
+
"source": [
|
216 |
+
"Inspecting the `bert_pretrainer`, we see it wraps the `encoder` with additional `MaskedLM` and `nlp.layers.ClassificationHead` heads."
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"execution_count": null,
|
222 |
+
"metadata": {
|
223 |
+
"id": "2tcNfm03IBF7"
|
224 |
+
},
|
225 |
+
"outputs": [],
|
226 |
+
"source": [
|
227 |
+
"tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, expand_nested=True, dpi=48)"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"cell_type": "code",
|
232 |
+
"execution_count": null,
|
233 |
+
"metadata": {
|
234 |
+
"id": "F2oHrXGUIS0M"
|
235 |
+
},
|
236 |
+
"outputs": [],
|
237 |
+
"source": [
|
238 |
+
"# We can feed some dummy data to get masked language model and sentence output.\n",
|
239 |
+
"sequence_length = 16\n",
|
240 |
+
"batch_size = 2\n",
|
241 |
+
"\n",
|
242 |
+
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
|
243 |
+
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
|
244 |
+
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
|
245 |
+
"masked_lm_positions_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
|
246 |
+
"\n",
|
247 |
+
"outputs = bert_pretrainer(\n",
|
248 |
+
" [word_id_data, mask_data, type_id_data, masked_lm_positions_data])\n",
|
249 |
+
"lm_output = outputs[\"masked_lm\"]\n",
|
250 |
+
"sentence_output = outputs[\"classification\"]\n",
|
251 |
+
"print(f'lm_output: shape={lm_output.shape}, dtype={lm_output.dtype!r}')\n",
|
252 |
+
"print(f'sentence_output: shape={sentence_output.shape}, dtype={sentence_output.dtype!r}')"
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"cell_type": "markdown",
|
257 |
+
"metadata": {
|
258 |
+
"id": "bnx3UCHniCS5"
|
259 |
+
},
|
260 |
+
"source": [
|
261 |
+
"### Compute loss\n",
|
262 |
+
"Next, we can use `lm_output` and `sentence_output` to compute `loss`."
|
263 |
+
]
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"cell_type": "code",
|
267 |
+
"execution_count": null,
|
268 |
+
"metadata": {
|
269 |
+
"id": "k30H4Q86f52x"
|
270 |
+
},
|
271 |
+
"outputs": [],
|
272 |
+
"source": [
|
273 |
+
"masked_lm_ids_data = np.random.randint(vocab_size, size=(batch_size, num_token_predictions))\n",
|
274 |
+
"masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
|
275 |
+
"next_sentence_labels_data = np.random.randint(2, size=(batch_size))\n",
|
276 |
+
"\n",
|
277 |
+
"mlm_loss = nlp.losses.weighted_sparse_categorical_crossentropy_loss(\n",
|
278 |
+
" labels=masked_lm_ids_data,\n",
|
279 |
+
" predictions=lm_output,\n",
|
280 |
+
" weights=masked_lm_weights_data)\n",
|
281 |
+
"sentence_loss = nlp.losses.weighted_sparse_categorical_crossentropy_loss(\n",
|
282 |
+
" labels=next_sentence_labels_data,\n",
|
283 |
+
" predictions=sentence_output)\n",
|
284 |
+
"loss = mlm_loss + sentence_loss\n",
|
285 |
+
"\n",
|
286 |
+
"print(loss)"
|
287 |
+
]
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"cell_type": "markdown",
|
291 |
+
"metadata": {
|
292 |
+
"id": "wrmSs8GjHxVw"
|
293 |
+
},
|
294 |
+
"source": [
|
295 |
+
"With the loss, you can optimize the model.\n",
|
296 |
+
"After training, we can save the weights of TransformerEncoder for the downstream fine-tuning tasks. Please see [run_pretraining.py](https://github.com/tensorflow/models/blob/master/official/legacy/bert/run_pretraining.py) for the full example.\n"
|
297 |
+
]
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"cell_type": "markdown",
|
301 |
+
"metadata": {
|
302 |
+
"id": "k8cQVFvBCV4s"
|
303 |
+
},
|
304 |
+
"source": [
|
305 |
+
"## Span labeling model\n",
|
306 |
+
"\n",
|
307 |
+
"Span labeling is the task to assign labels to a span of the text, for example, label a span of text as the answer of a given question.\n",
|
308 |
+
"\n",
|
309 |
+
"In this section, we will learn how to build a span labeling model. Again, we use dummy data for simplicity."
|
310 |
+
]
|
311 |
+
},
|
312 |
+
{
|
313 |
+
"cell_type": "markdown",
|
314 |
+
"metadata": {
|
315 |
+
"id": "xrLLEWpfknUW"
|
316 |
+
},
|
317 |
+
"source": [
|
318 |
+
"### Build a BertSpanLabeler wrapping BertEncoder\n",
|
319 |
+
"\n",
|
320 |
+
"The `nlp.models.BertSpanLabeler` class implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n",
|
321 |
+
"\n",
|
322 |
+
"Note that `nlp.models.BertSpanLabeler` wraps a `nlp.networks.BertEncoder`, the weights of which can be restored from the above pretraining model.\n"
|
323 |
+
]
|
324 |
+
},
|
325 |
+
{
|
326 |
+
"cell_type": "code",
|
327 |
+
"execution_count": null,
|
328 |
+
"metadata": {
|
329 |
+
"id": "B941M4iUCejO"
|
330 |
+
},
|
331 |
+
"outputs": [],
|
332 |
+
"source": [
|
333 |
+
"network = nlp.networks.BertEncoder(\n",
|
334 |
+
" vocab_size=vocab_size, num_layers=2)\n",
|
335 |
+
"\n",
|
336 |
+
"# Create a BERT trainer with the created network.\n",
|
337 |
+
"bert_span_labeler = nlp.models.BertSpanLabeler(network)"
|
338 |
+
]
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "markdown",
|
342 |
+
"metadata": {
|
343 |
+
"id": "QpB9pgj4PpMg"
|
344 |
+
},
|
345 |
+
"source": [
|
346 |
+
"Inspecting the `bert_span_labeler`, we see it wraps the encoder with additional `SpanLabeling` that outputs `start_position` and `end_position`."
|
347 |
+
]
|
348 |
+
},
|
349 |
+
{
|
350 |
+
"cell_type": "code",
|
351 |
+
"execution_count": null,
|
352 |
+
"metadata": {
|
353 |
+
"id": "RbqRNJCLJu4H"
|
354 |
+
},
|
355 |
+
"outputs": [],
|
356 |
+
"source": [
|
357 |
+
"tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, expand_nested=True, dpi=48)"
|
358 |
+
]
|
359 |
+
},
|
360 |
+
{
|
361 |
+
"cell_type": "code",
|
362 |
+
"execution_count": null,
|
363 |
+
"metadata": {
|
364 |
+
"id": "fUf1vRxZJwio"
|
365 |
+
},
|
366 |
+
"outputs": [],
|
367 |
+
"source": [
|
368 |
+
"# Create a set of 2-dimensional data tensors to feed into the model.\n",
|
369 |
+
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
|
370 |
+
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
|
371 |
+
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
|
372 |
+
"\n",
|
373 |
+
"# Feed the data to the model.\n",
|
374 |
+
"start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n",
|
375 |
+
"\n",
|
376 |
+
"print(f'start_logits: shape={start_logits.shape}, dtype={start_logits.dtype!r}')\n",
|
377 |
+
"print(f'end_logits: shape={end_logits.shape}, dtype={end_logits.dtype!r}')"
|
378 |
+
]
|
379 |
+
},
|
380 |
+
{
|
381 |
+
"cell_type": "markdown",
|
382 |
+
"metadata": {
|
383 |
+
"id": "WqhgQaN1lt-G"
|
384 |
+
},
|
385 |
+
"source": [
|
386 |
+
"### Compute loss\n",
|
387 |
+
"With `start_logits` and `end_logits`, we can compute loss:"
|
388 |
+
]
|
389 |
+
},
|
390 |
+
{
|
391 |
+
"cell_type": "code",
|
392 |
+
"execution_count": null,
|
393 |
+
"metadata": {
|
394 |
+
"id": "waqs6azNl3Nn"
|
395 |
+
},
|
396 |
+
"outputs": [],
|
397 |
+
"source": [
|
398 |
+
"start_positions = np.random.randint(sequence_length, size=(batch_size))\n",
|
399 |
+
"end_positions = np.random.randint(sequence_length, size=(batch_size))\n",
|
400 |
+
"\n",
|
401 |
+
"start_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
|
402 |
+
" start_positions, start_logits, from_logits=True)\n",
|
403 |
+
"end_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
|
404 |
+
" end_positions, end_logits, from_logits=True)\n",
|
405 |
+
"\n",
|
406 |
+
"total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2\n",
|
407 |
+
"print(total_loss)"
|
408 |
+
]
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"cell_type": "markdown",
|
412 |
+
"metadata": {
|
413 |
+
"id": "Zdf03YtZmd_d"
|
414 |
+
},
|
415 |
+
"source": [
|
416 |
+
"With the `loss`, you can optimize the model. Please see [run_squad.py](https://github.com/tensorflow/models/blob/master/official/legacy/bert/run_squad.py) for the full example."
|
417 |
+
]
|
418 |
+
},
|
419 |
+
{
|
420 |
+
"cell_type": "markdown",
|
421 |
+
"metadata": {
|
422 |
+
"id": "0A1XnGSTChg9"
|
423 |
+
},
|
424 |
+
"source": [
|
425 |
+
"## Classification model\n",
|
426 |
+
"\n",
|
427 |
+
"In the last section, we show how to build a text classification model.\n"
|
428 |
+
]
|
429 |
+
},
|
430 |
+
{
|
431 |
+
"cell_type": "markdown",
|
432 |
+
"metadata": {
|
433 |
+
"id": "MSK8OpZgnQa9"
|
434 |
+
},
|
435 |
+
"source": [
|
436 |
+
"### Build a BertClassifier model wrapping BertEncoder\n",
|
437 |
+
"\n",
|
438 |
+
"`nlp.models.BertClassifier` implements a [CLS] token classification model containing a single classification head."
|
439 |
+
]
|
440 |
+
},
|
441 |
+
{
|
442 |
+
"cell_type": "code",
|
443 |
+
"execution_count": null,
|
444 |
+
"metadata": {
|
445 |
+
"id": "cXXCsffkCphk"
|
446 |
+
},
|
447 |
+
"outputs": [],
|
448 |
+
"source": [
|
449 |
+
"network = nlp.networks.BertEncoder(\n",
|
450 |
+
" vocab_size=vocab_size, num_layers=2)\n",
|
451 |
+
"\n",
|
452 |
+
"# Create a BERT trainer with the created network.\n",
|
453 |
+
"num_classes = 2\n",
|
454 |
+
"bert_classifier = nlp.models.BertClassifier(\n",
|
455 |
+
" network, num_classes=num_classes)"
|
456 |
+
]
|
457 |
+
},
|
458 |
+
{
|
459 |
+
"cell_type": "markdown",
|
460 |
+
"metadata": {
|
461 |
+
"id": "8tZKueKYP4bB"
|
462 |
+
},
|
463 |
+
"source": [
|
464 |
+
"Inspecting the `bert_classifier`, we see it wraps the `encoder` with additional `Classification` head."
|
465 |
+
]
|
466 |
+
},
|
467 |
+
{
|
468 |
+
"cell_type": "code",
|
469 |
+
"execution_count": null,
|
470 |
+
"metadata": {
|
471 |
+
"id": "snlutm9ZJgEZ"
|
472 |
+
},
|
473 |
+
"outputs": [],
|
474 |
+
"source": [
|
475 |
+
"tf.keras.utils.plot_model(bert_classifier, show_shapes=True, expand_nested=True, dpi=48)"
|
476 |
+
]
|
477 |
+
},
|
478 |
+
{
|
479 |
+
"cell_type": "code",
|
480 |
+
"execution_count": null,
|
481 |
+
"metadata": {
|
482 |
+
"id": "yyHPHsqBJkCz"
|
483 |
+
},
|
484 |
+
"outputs": [],
|
485 |
+
"source": [
|
486 |
+
"# Create a set of 2-dimensional data tensors to feed into the model.\n",
|
487 |
+
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
|
488 |
+
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
|
489 |
+
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
|
490 |
+
"\n",
|
491 |
+
"# Feed the data to the model.\n",
|
492 |
+
"logits = bert_classifier([word_id_data, mask_data, type_id_data])\n",
|
493 |
+
"print(f'logits: shape={logits.shape}, dtype={logits.dtype!r}')"
|
494 |
+
]
|
495 |
+
},
|
496 |
+
{
|
497 |
+
"cell_type": "markdown",
|
498 |
+
"metadata": {
|
499 |
+
"id": "w--a2mg4nzKm"
|
500 |
+
},
|
501 |
+
"source": [
|
502 |
+
"### Compute loss\n",
|
503 |
+
"\n",
|
504 |
+
"With `logits`, we can compute `loss`:"
|
505 |
+
]
|
506 |
+
},
|
507 |
+
{
|
508 |
+
"cell_type": "code",
|
509 |
+
"execution_count": null,
|
510 |
+
"metadata": {
|
511 |
+
"id": "9X0S1DoFn_5Q"
|
512 |
+
},
|
513 |
+
"outputs": [],
|
514 |
+
"source": [
|
515 |
+
"labels = np.random.randint(num_classes, size=(batch_size))\n",
|
516 |
+
"\n",
|
517 |
+
"loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
|
518 |
+
" labels, logits, from_logits=True)\n",
|
519 |
+
"print(loss)"
|
520 |
+
]
|
521 |
+
},
|
522 |
+
{
|
523 |
+
"cell_type": "markdown",
|
524 |
+
"metadata": {
|
525 |
+
"id": "mzBqOylZo3og"
|
526 |
+
},
|
527 |
+
"source": [
|
528 |
+
"With the `loss`, you can optimize the model. Please see the [Fine tune_bert](https://www.tensorflow.org/text/tutorials/fine_tune_bert) notebook or the [model training documentation](https://github.com/tensorflow/models/blob/master/official/nlp/docs/train.md) for the full example."
|
529 |
+
]
|
530 |
+
}
|
531 |
+
],
|
532 |
+
"metadata": {
|
533 |
+
"colab": {
|
534 |
+
"name": "nlp_modeling_library_intro.ipynb",
|
535 |
+
"provenance": [],
|
536 |
+
"toc_visible": true
|
537 |
+
},
|
538 |
+
"kernelspec": {
|
539 |
+
"display_name": "Python 3",
|
540 |
+
"name": "python3"
|
541 |
+
}
|
542 |
+
},
|
543 |
+
"nbformat": 4,
|
544 |
+
"nbformat_minor": 0
|
545 |
+
}
|
Tensorflow/models/docs/nlp/load_lm_ckpts.ipynb
ADDED
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "30155835fc9f"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"##### Copyright 2022 The TensorFlow Authors."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"cellView": "form",
|
17 |
+
"id": "906e07f6e562"
|
18 |
+
},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
22 |
+
"# you may not use this file except in compliance with the License.\n",
|
23 |
+
"# You may obtain a copy of the License at\n",
|
24 |
+
"#\n",
|
25 |
+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
26 |
+
"#\n",
|
27 |
+
"# Unless required by applicable law or agreed to in writing, software\n",
|
28 |
+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
29 |
+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
30 |
+
"# See the License for the specific language governing permissions and\n",
|
31 |
+
"# limitations under the License."
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"metadata": {
|
37 |
+
"id": "5hrbPTziJK15"
|
38 |
+
},
|
39 |
+
"source": [
|
40 |
+
"# Load LM Checkpoints using Model Garden"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"metadata": {
|
46 |
+
"id": "-PYqCW1II75I"
|
47 |
+
},
|
48 |
+
"source": [
|
49 |
+
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
|
50 |
+
" \u003ctd\u003e\n",
|
51 |
+
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/load_lm_ckpts\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
|
52 |
+
" \u003c/td\u003e\n",
|
53 |
+
" \u003ctd\u003e\n",
|
54 |
+
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/load_lm_ckpts.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
|
55 |
+
" \u003c/td\u003e\n",
|
56 |
+
" \u003ctd\u003e\n",
|
57 |
+
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/load_lm_ckpts.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
|
58 |
+
" \u003c/td\u003e\n",
|
59 |
+
" \u003ctd\u003e\n",
|
60 |
+
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/load_lm_ckpts.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
|
61 |
+
" \u003c/td\u003e\n",
|
62 |
+
"\u003c/table\u003e"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "markdown",
|
67 |
+
"metadata": {
|
68 |
+
"id": "yyyk1KMlJdWd"
|
69 |
+
},
|
70 |
+
"source": [
|
71 |
+
"This tutorial demonstrates how to load BERT, ALBERT and ELECTRA pretrained checkpoints and use them for downstream tasks.\n",
|
72 |
+
"\n",
|
73 |
+
"[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development."
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "markdown",
|
78 |
+
"metadata": {
|
79 |
+
"id": "uEG4RYHolQij"
|
80 |
+
},
|
81 |
+
"source": [
|
82 |
+
"## Install TF Model Garden package"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "code",
|
87 |
+
"execution_count": null,
|
88 |
+
"metadata": {
|
89 |
+
"id": "kPfC1NJZnJq1"
|
90 |
+
},
|
91 |
+
"outputs": [],
|
92 |
+
"source": [
|
93 |
+
"!pip install -U -q \"tf-models-official\""
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"cell_type": "markdown",
|
98 |
+
"metadata": {
|
99 |
+
"id": "Op9R3zy3lUk8"
|
100 |
+
},
|
101 |
+
"source": [
|
102 |
+
"## Import necessary libraries"
|
103 |
+
]
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"cell_type": "code",
|
107 |
+
"execution_count": null,
|
108 |
+
"metadata": {
|
109 |
+
"id": "6_y4Rfq23wK-"
|
110 |
+
},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"import os\n",
|
114 |
+
"import yaml\n",
|
115 |
+
"import json\n",
|
116 |
+
"\n",
|
117 |
+
"import tensorflow as tf"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": null,
|
123 |
+
"metadata": {
|
124 |
+
"id": "xjgv3gllzbYQ"
|
125 |
+
},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"import tensorflow_models as tfm\n",
|
129 |
+
"\n",
|
130 |
+
"from official.core import exp_factory"
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "markdown",
|
135 |
+
"metadata": {
|
136 |
+
"id": "J-t2mo6VQNfY"
|
137 |
+
},
|
138 |
+
"source": [
|
139 |
+
"## Load BERT model pretrained checkpoints"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "markdown",
|
144 |
+
"metadata": {
|
145 |
+
"id": "hdBsFnI20LDE"
|
146 |
+
},
|
147 |
+
"source": [
|
148 |
+
"### Select required BERT model"
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "code",
|
153 |
+
"execution_count": null,
|
154 |
+
"metadata": {
|
155 |
+
"id": "apn3VgxUlr5G"
|
156 |
+
},
|
157 |
+
"outputs": [],
|
158 |
+
"source": [
|
159 |
+
"# @title Download Checkpoint of the Selected Model { display-mode: \"form\", run: \"auto\" }\n",
|
160 |
+
"model_display_name = 'BERT-base cased English' # @param ['BERT-base uncased English','BERT-base cased English','BERT-large uncased English', 'BERT-large cased English', 'BERT-large, Uncased (Whole Word Masking)', 'BERT-large, Cased (Whole Word Masking)', 'BERT-base MultiLingual','BERT-base Chinese']\n",
|
161 |
+
"\n",
|
162 |
+
"if model_display_name == 'BERT-base uncased English':\n",
|
163 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-12_H-768_A-12.tar.gz\"\n",
|
164 |
+
" !tar -xvf \"uncased_L-12_H-768_A-12.tar.gz\"\n",
|
165 |
+
"elif model_display_name == 'BERT-base cased English':\n",
|
166 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-12_H-768_A-12.tar.gz\"\n",
|
167 |
+
" !tar -xvf \"cased_L-12_H-768_A-12.tar.gz\"\n",
|
168 |
+
"elif model_display_name == \"BERT-large uncased English\":\n",
|
169 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-24_H-1024_A-16.tar.gz\"\n",
|
170 |
+
" !tar -xvf \"uncased_L-24_H-1024_A-16.tar.gz\"\n",
|
171 |
+
"elif model_display_name == \"BERT-large cased English\":\n",
|
172 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-24_H-1024_A-16.tar.gz\"\n",
|
173 |
+
" !tar -xvf \"cased_L-24_H-1024_A-16.tar.gz\"\n",
|
174 |
+
"elif model_display_name == \"BERT-large, Uncased (Whole Word Masking)\":\n",
|
175 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/wwm_uncased_L-24_H-1024_A-16.tar.gz\"\n",
|
176 |
+
" !tar -xvf \"wwm_uncased_L-24_H-1024_A-16.tar.gz\"\n",
|
177 |
+
"elif model_display_name == \"BERT-large, Cased (Whole Word Masking)\":\n",
|
178 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/wwm_cased_L-24_H-1024_A-16.tar.gz\"\n",
|
179 |
+
" !tar -xvf \"wwm_cased_L-24_H-1024_A-16.tar.gz\"\n",
|
180 |
+
"elif model_display_name == \"BERT-base MultiLingual\":\n",
|
181 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/multi_cased_L-12_H-768_A-12.tar.gz\"\n",
|
182 |
+
" !tar -xvf \"multi_cased_L-12_H-768_A-12.tar.gz\"\n",
|
183 |
+
"elif model_display_name == \"BERT-base Chinese\":\n",
|
184 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/chinese_L-12_H-768_A-12.tar.gz\"\n",
|
185 |
+
" !tar -xvf \"chinese_L-12_H-768_A-12.tar.gz\""
|
186 |
+
]
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"cell_type": "code",
|
190 |
+
"execution_count": null,
|
191 |
+
"metadata": {
|
192 |
+
"id": "jzxyziRuaC95"
|
193 |
+
},
|
194 |
+
"outputs": [],
|
195 |
+
"source": [
|
196 |
+
"# Lookup table of the directory name corresponding to each model checkpoint\n",
|
197 |
+
"folder_bert_dict = {\n",
|
198 |
+
" 'BERT-base uncased English': 'uncased_L-12_H-768_A-12',\n",
|
199 |
+
" 'BERT-base cased English': 'cased_L-12_H-768_A-12',\n",
|
200 |
+
" 'BERT-large uncased English': 'uncased_L-24_H-1024_A-16',\n",
|
201 |
+
" 'BERT-large cased English': 'cased_L-24_H-1024_A-16',\n",
|
202 |
+
" 'BERT-large, Uncased (Whole Word Masking)': 'wwm_uncased_L-24_H-1024_A-16',\n",
|
203 |
+
" 'BERT-large, Cased (Whole Word Masking)': 'wwm_cased_L-24_H-1024_A-16',\n",
|
204 |
+
" 'BERT-base MultiLingual': 'multi_cased_L-12_H-768_A-1',\n",
|
205 |
+
" 'BERT-base Chinese': 'chinese_L-12_H-768_A-12'\n",
|
206 |
+
"}\n",
|
207 |
+
"\n",
|
208 |
+
"folder_bert = folder_bert_dict.get(model_display_name)\n",
|
209 |
+
"folder_bert"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "markdown",
|
214 |
+
"metadata": {
|
215 |
+
"id": "q1WrYswpZPlc"
|
216 |
+
},
|
217 |
+
"source": [
|
218 |
+
"### Construct BERT Model Using the New `params.yaml`\n",
|
219 |
+
"\n",
|
220 |
+
"params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here."
|
221 |
+
]
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"cell_type": "code",
|
225 |
+
"execution_count": null,
|
226 |
+
"metadata": {
|
227 |
+
"id": "quu1s8Hi2szo"
|
228 |
+
},
|
229 |
+
"outputs": [],
|
230 |
+
"source": [
|
231 |
+
"config_file = os.path.join(folder_bert, \"params.yaml\")\n",
|
232 |
+
"config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())\n",
|
233 |
+
"config_dict"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "code",
|
238 |
+
"execution_count": null,
|
239 |
+
"metadata": {
|
240 |
+
"id": "3t8o0iG9v8ac"
|
241 |
+
},
|
242 |
+
"outputs": [],
|
243 |
+
"source": [
|
244 |
+
"# Method 1: pass encoder config dict into EncoderConfig\n",
|
245 |
+
"encoder_config = tfm.nlp.encoders.EncoderConfig(config_dict[\"task\"][\"model\"][\"encoder\"])\n",
|
246 |
+
"encoder_config.get().as_dict()"
|
247 |
+
]
|
248 |
+
},
|
249 |
+
{
|
250 |
+
"cell_type": "code",
|
251 |
+
"execution_count": null,
|
252 |
+
"metadata": {
|
253 |
+
"id": "2I5PetB6wPvb"
|
254 |
+
},
|
255 |
+
"outputs": [],
|
256 |
+
"source": [
|
257 |
+
"# Method 2: use override_params_dict function to override default Encoder params\n",
|
258 |
+
"encoder_config = tfm.nlp.encoders.EncoderConfig()\n",
|
259 |
+
"tfm.hyperparams.override_params_dict(encoder_config, config_dict[\"task\"][\"model\"][\"encoder\"], is_strict=True)\n",
|
260 |
+
"encoder_config.get().as_dict()"
|
261 |
+
]
|
262 |
+
},
|
263 |
+
{
|
264 |
+
"cell_type": "markdown",
|
265 |
+
"metadata": {
|
266 |
+
"id": "5yHiG_9oS3Uw"
|
267 |
+
},
|
268 |
+
"source": [
|
269 |
+
"### Construct BERT Model Using the Old `bert_config.json`"
|
270 |
+
]
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "code",
|
274 |
+
"execution_count": null,
|
275 |
+
"metadata": {
|
276 |
+
"id": "WEyaqLcW3nne"
|
277 |
+
},
|
278 |
+
"outputs": [],
|
279 |
+
"source": [
|
280 |
+
"bert_config_file = os.path.join(folder_bert, \"bert_config.json\")\n",
|
281 |
+
"config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())\n",
|
282 |
+
"config_dict"
|
283 |
+
]
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"cell_type": "code",
|
287 |
+
"execution_count": null,
|
288 |
+
"metadata": {
|
289 |
+
"id": "xSIcaW9tdrl4"
|
290 |
+
},
|
291 |
+
"outputs": [],
|
292 |
+
"source": [
|
293 |
+
"encoder_config = tfm.nlp.encoders.EncoderConfig({\n",
|
294 |
+
" 'type':'bert',\n",
|
295 |
+
" 'bert': config_dict\n",
|
296 |
+
"})\n",
|
297 |
+
"\n",
|
298 |
+
"encoder_config.get().as_dict()"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "markdown",
|
303 |
+
"metadata": {
|
304 |
+
"id": "yZznAP--TDLe"
|
305 |
+
},
|
306 |
+
"source": [
|
307 |
+
"### Construct a classifier with `encoder_config`\n",
|
308 |
+
"\n",
|
309 |
+
"Here, we construct a new BERT Classifier with 2 classes and plot its model architecture. A BERT Classifier consists of a BERT encoder using the selected encoder config, a Dropout layer and a MLP classification head."
|
310 |
+
]
|
311 |
+
},
|
312 |
+
{
|
313 |
+
"cell_type": "code",
|
314 |
+
"execution_count": null,
|
315 |
+
"metadata": {
|
316 |
+
"id": "Ny962I8nqs4n"
|
317 |
+
},
|
318 |
+
"outputs": [],
|
319 |
+
"source": [
|
320 |
+
"bert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)\n",
|
321 |
+
"bert_classifier = tfm.nlp.models.BertClassifier(network=bert_encoder, num_classes=2)\n",
|
322 |
+
"\n",
|
323 |
+
"tf.keras.utils.plot_model(bert_classifier)"
|
324 |
+
]
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "markdown",
|
328 |
+
"metadata": {
|
329 |
+
"id": "IStKfxXkTJMu"
|
330 |
+
},
|
331 |
+
"source": [
|
332 |
+
"### Load Pretrained Weights into the BERT Classifier\n",
|
333 |
+
"\n",
|
334 |
+
"The provided pretrained checkpoint only contains weights for the BERT Encoder within the BERT Classifier. Weights for the Classification Head is still randomly initialized."
|
335 |
+
]
|
336 |
+
},
|
337 |
+
{
|
338 |
+
"cell_type": "code",
|
339 |
+
"execution_count": null,
|
340 |
+
"metadata": {
|
341 |
+
"id": "G9_XCBpOEo4y"
|
342 |
+
},
|
343 |
+
"outputs": [],
|
344 |
+
"source": [
|
345 |
+
"checkpoint = tf.train.Checkpoint(encoder=bert_encoder)\n",
|
346 |
+
"checkpoint.read(\n",
|
347 |
+
" os.path.join(folder_bert, 'bert_model.ckpt')).expect_partial().assert_existing_objects_matched()"
|
348 |
+
]
|
349 |
+
},
|
350 |
+
{
|
351 |
+
"cell_type": "markdown",
|
352 |
+
"metadata": {
|
353 |
+
"id": "E6Hu1FFgQWUU"
|
354 |
+
},
|
355 |
+
"source": [
|
356 |
+
"## Load ALBERT model pretrained checkpoints"
|
357 |
+
]
|
358 |
+
},
|
359 |
+
{
|
360 |
+
"cell_type": "code",
|
361 |
+
"execution_count": null,
|
362 |
+
"metadata": {
|
363 |
+
"id": "TWUtFeWxQn0V"
|
364 |
+
},
|
365 |
+
"outputs": [],
|
366 |
+
"source": [
|
367 |
+
"# @title Download Checkpoint of the Selected Model { display-mode: \"form\", run: \"auto\" }\n",
|
368 |
+
"albert_model_display_name = 'ALBERT-xxlarge English' # @param ['ALBERT-base English', 'ALBERT-large English', 'ALBERT-xlarge English', 'ALBERT-xxlarge English']\n",
|
369 |
+
"\n",
|
370 |
+
"if albert_model_display_name == 'ALBERT-base English':\n",
|
371 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_base.tar.gz\"\n",
|
372 |
+
" !tar -xvf \"albert_base.tar.gz\"\n",
|
373 |
+
"elif albert_model_display_name == 'ALBERT-large English':\n",
|
374 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_large.tar.gz\"\n",
|
375 |
+
" !tar -xvf \"albert_large.tar.gz\"\n",
|
376 |
+
"elif albert_model_display_name == \"ALBERT-xlarge English\":\n",
|
377 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_xlarge.tar.gz\"\n",
|
378 |
+
" !tar -xvf \"albert_xlarge.tar.gz\"\n",
|
379 |
+
"elif albert_model_display_name == \"ALBERT-xxlarge English\":\n",
|
380 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_xxlarge.tar.gz\"\n",
|
381 |
+
" !tar -xvf \"albert_xxlarge.tar.gz\""
|
382 |
+
]
|
383 |
+
},
|
384 |
+
{
|
385 |
+
"cell_type": "code",
|
386 |
+
"execution_count": null,
|
387 |
+
"metadata": {
|
388 |
+
"id": "5lZDWD7zUAAO"
|
389 |
+
},
|
390 |
+
"outputs": [],
|
391 |
+
"source": [
|
392 |
+
"# Lookup table of the directory name corresponding to each model checkpoint\n",
|
393 |
+
"folder_albert_dict = {\n",
|
394 |
+
" 'ALBERT-base English': 'albert_base',\n",
|
395 |
+
" 'ALBERT-large English': 'albert_large',\n",
|
396 |
+
" 'ALBERT-xlarge English': 'albert_xlarge',\n",
|
397 |
+
" 'ALBERT-xxlarge English': 'albert_xxlarge'\n",
|
398 |
+
"}\n",
|
399 |
+
"\n",
|
400 |
+
"folder_albert = folder_albert_dict.get(albert_model_display_name)\n",
|
401 |
+
"folder_albert"
|
402 |
+
]
|
403 |
+
},
|
404 |
+
{
|
405 |
+
"cell_type": "markdown",
|
406 |
+
"metadata": {
|
407 |
+
"id": "ftXwmObdU2fS"
|
408 |
+
},
|
409 |
+
"source": [
|
410 |
+
"### Construct ALBERT Model Using the New `params.yaml`\n",
|
411 |
+
"\n",
|
412 |
+
"params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here."
|
413 |
+
]
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"cell_type": "code",
|
417 |
+
"execution_count": null,
|
418 |
+
"metadata": {
|
419 |
+
"id": "VXn20q2oU1UJ"
|
420 |
+
},
|
421 |
+
"outputs": [],
|
422 |
+
"source": [
|
423 |
+
"config_file = os.path.join(folder_albert, \"params.yaml\")\n",
|
424 |
+
"config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())\n",
|
425 |
+
"config_dict"
|
426 |
+
]
|
427 |
+
},
|
428 |
+
{
|
429 |
+
"cell_type": "code",
|
430 |
+
"execution_count": null,
|
431 |
+
"metadata": {
|
432 |
+
"id": "Uo_TSMSvWOX_"
|
433 |
+
},
|
434 |
+
"outputs": [],
|
435 |
+
"source": [
|
436 |
+
"# Method 1: pass encoder config dict into EncoderConfig\n",
|
437 |
+
"encoder_config = tfm.nlp.encoders.EncoderConfig(config_dict[\"task\"][\"model\"][\"encoder\"])\n",
|
438 |
+
"encoder_config.get().as_dict()"
|
439 |
+
]
|
440 |
+
},
|
441 |
+
{
|
442 |
+
"cell_type": "code",
|
443 |
+
"execution_count": null,
|
444 |
+
"metadata": {
|
445 |
+
"id": "u7oJe93uWcy0"
|
446 |
+
},
|
447 |
+
"outputs": [],
|
448 |
+
"source": [
|
449 |
+
"# Method 2: use override_params_dict function to override default Encoder params\n",
|
450 |
+
"encoder_config = tfm.nlp.encoders.EncoderConfig()\n",
|
451 |
+
"tfm.hyperparams.override_params_dict(encoder_config, config_dict[\"task\"][\"model\"][\"encoder\"], is_strict=True)\n",
|
452 |
+
"encoder_config.get().as_dict()"
|
453 |
+
]
|
454 |
+
},
|
455 |
+
{
|
456 |
+
"cell_type": "markdown",
|
457 |
+
"metadata": {
|
458 |
+
"id": "abpQFw80Wx6c"
|
459 |
+
},
|
460 |
+
"source": [
|
461 |
+
"### Construct ALBERT Model Using the Old `albert_config.json`"
|
462 |
+
]
|
463 |
+
},
|
464 |
+
{
|
465 |
+
"cell_type": "code",
|
466 |
+
"execution_count": null,
|
467 |
+
"metadata": {
|
468 |
+
"id": "Xb99qms6WuPa"
|
469 |
+
},
|
470 |
+
"outputs": [],
|
471 |
+
"source": [
|
472 |
+
"albert_config_file = os.path.join(folder_albert, \"albert_config.json\")\n",
|
473 |
+
"config_dict = json.loads(tf.io.gfile.GFile(albert_config_file).read())\n",
|
474 |
+
"config_dict"
|
475 |
+
]
|
476 |
+
},
|
477 |
+
{
|
478 |
+
"cell_type": "code",
|
479 |
+
"execution_count": null,
|
480 |
+
"metadata": {
|
481 |
+
"id": "mCW0RJHcEtVV"
|
482 |
+
},
|
483 |
+
"outputs": [],
|
484 |
+
"source": [
|
485 |
+
"encoder_config = tfm.nlp.encoders.EncoderConfig({\n",
|
486 |
+
" 'type':'albert',\n",
|
487 |
+
" 'albert': config_dict\n",
|
488 |
+
"})\n",
|
489 |
+
"\n",
|
490 |
+
"encoder_config.get().as_dict()"
|
491 |
+
]
|
492 |
+
},
|
493 |
+
{
|
494 |
+
"cell_type": "markdown",
|
495 |
+
"metadata": {
|
496 |
+
"id": "EIAMaOxdZw5u"
|
497 |
+
},
|
498 |
+
"source": [
|
499 |
+
"### Construct a Classifier with `encoder_config`\n",
|
500 |
+
"\n",
|
501 |
+
"Here, we construct a new BERT Classifier with 2 classes and plot its model architecture. A BERT Classifier consists of a BERT encoder using the selected encoder config, a Dropout layer and a MLP classification head."
|
502 |
+
]
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"cell_type": "code",
|
506 |
+
"execution_count": null,
|
507 |
+
"metadata": {
|
508 |
+
"id": "xTkUisEEFEey"
|
509 |
+
},
|
510 |
+
"outputs": [],
|
511 |
+
"source": [
|
512 |
+
"albert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)\n",
|
513 |
+
"albert_classifier = tfm.nlp.models.BertClassifier(network=albert_encoder, num_classes=2)\n",
|
514 |
+
"\n",
|
515 |
+
"tf.keras.utils.plot_model(albert_classifier)"
|
516 |
+
]
|
517 |
+
},
|
518 |
+
{
|
519 |
+
"cell_type": "markdown",
|
520 |
+
"metadata": {
|
521 |
+
"id": "m6EG_7CaZ2rI"
|
522 |
+
},
|
523 |
+
"source": [
|
524 |
+
"### Load Pretrained Weights into the Classifier\n",
|
525 |
+
"\n",
|
526 |
+
"The provided pretrained checkpoint only contains weights for the ALBERT Encoder within the ALBERT Classifier. Weights for the Classification Head is still randomly initialized."
|
527 |
+
]
|
528 |
+
},
|
529 |
+
{
|
530 |
+
"cell_type": "code",
|
531 |
+
"execution_count": null,
|
532 |
+
"metadata": {
|
533 |
+
"id": "7dOG3agXZ9Dx"
|
534 |
+
},
|
535 |
+
"outputs": [],
|
536 |
+
"source": [
|
537 |
+
"checkpoint = tf.train.Checkpoint(encoder=albert_encoder)\n",
|
538 |
+
"checkpoint.read(\n",
|
539 |
+
" os.path.join(folder_albert, 'bert_model.ckpt')).expect_partial().assert_existing_objects_matched()"
|
540 |
+
]
|
541 |
+
},
|
542 |
+
{
|
543 |
+
"cell_type": "markdown",
|
544 |
+
"metadata": {
|
545 |
+
"id": "6xsbeS-EcCqu"
|
546 |
+
},
|
547 |
+
"source": [
|
548 |
+
"## Load ELECTRA model pretrained checkpoints"
|
549 |
+
]
|
550 |
+
},
|
551 |
+
{
|
552 |
+
"cell_type": "code",
|
553 |
+
"execution_count": null,
|
554 |
+
"metadata": {
|
555 |
+
"id": "VpwIrAR4cIBF"
|
556 |
+
},
|
557 |
+
"outputs": [],
|
558 |
+
"source": [
|
559 |
+
"# @title Download Checkpoint of the Selected Model { display-mode: \"form\", run: \"auto\" }\n",
|
560 |
+
"electra_model_display_name = 'ELECTRA-small English' # @param ['ELECTRA-small English', 'ELECTRA-base English']\n",
|
561 |
+
"\n",
|
562 |
+
"if electra_model_display_name == 'ELECTRA-small English':\n",
|
563 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/electra/small.tar.gz\"\n",
|
564 |
+
" !tar -xvf \"small.tar.gz\"\n",
|
565 |
+
"elif electra_model_display_name == 'ELECTRA-base English':\n",
|
566 |
+
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/electra/base.tar.gz\"\n",
|
567 |
+
" !tar -xvf \"base.tar.gz\""
|
568 |
+
]
|
569 |
+
},
|
570 |
+
{
|
571 |
+
"cell_type": "code",
|
572 |
+
"execution_count": null,
|
573 |
+
"metadata": {
|
574 |
+
"id": "fy4FmsNOhlNa"
|
575 |
+
},
|
576 |
+
"outputs": [],
|
577 |
+
"source": [
|
578 |
+
"# Lookup table of the directory name corresponding to each model checkpoint\n",
|
579 |
+
"folder_electra_dict = {\n",
|
580 |
+
" 'ELECTRA-small English': 'small',\n",
|
581 |
+
" 'ELECTRA-base English': 'base'\n",
|
582 |
+
"}\n",
|
583 |
+
"\n",
|
584 |
+
"folder_electra = folder_electra_dict.get(electra_model_display_name)\n",
|
585 |
+
"folder_electra"
|
586 |
+
]
|
587 |
+
},
|
588 |
+
{
|
589 |
+
"cell_type": "markdown",
|
590 |
+
"metadata": {
|
591 |
+
"id": "rgAcf-Fl3RTG"
|
592 |
+
},
|
593 |
+
"source": [
|
594 |
+
"### Construct BERT Model Using the `params.yaml`\n",
|
595 |
+
"\n",
|
596 |
+
"params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here."
|
597 |
+
]
|
598 |
+
},
|
599 |
+
{
|
600 |
+
"cell_type": "code",
|
601 |
+
"execution_count": null,
|
602 |
+
"metadata": {
|
603 |
+
"id": "ZNBg5xzqh0Gr"
|
604 |
+
},
|
605 |
+
"outputs": [],
|
606 |
+
"source": [
|
607 |
+
"config_file = os.path.join(folder_electra, \"params.yaml\")\n",
|
608 |
+
"config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())\n",
|
609 |
+
"config_dict"
|
610 |
+
]
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"cell_type": "code",
|
614 |
+
"execution_count": null,
|
615 |
+
"metadata": {
|
616 |
+
"id": "i-yX-KgJyduv"
|
617 |
+
},
|
618 |
+
"outputs": [],
|
619 |
+
"source": [
|
620 |
+
"disc_encoder_config = tfm.nlp.encoders.EncoderConfig(\n",
|
621 |
+
" config_dict['model']['discriminator_encoder']\n",
|
622 |
+
")\n",
|
623 |
+
"\n",
|
624 |
+
"disc_encoder_config.get().as_dict()"
|
625 |
+
]
|
626 |
+
},
|
627 |
+
{
|
628 |
+
"cell_type": "markdown",
|
629 |
+
"metadata": {
|
630 |
+
"id": "1AdrMkH73VYz"
|
631 |
+
},
|
632 |
+
"source": [
|
633 |
+
"### Construct a Classifier with `encoder_config`\n",
|
634 |
+
"\n",
|
635 |
+
"Here, we construct a Classifier with 2 classes and plot its model architecture. A Classifier consists of a ELECTRA discriminator encoder using the selected encoder config, a Dropout layer and a MLP classification head.\n",
|
636 |
+
"\n",
|
637 |
+
"**Note**: The generator is discarded and the discriminator is used for downstream tasks"
|
638 |
+
]
|
639 |
+
},
|
640 |
+
{
|
641 |
+
"cell_type": "code",
|
642 |
+
"execution_count": null,
|
643 |
+
"metadata": {
|
644 |
+
"id": "98Pt-SxszAvN"
|
645 |
+
},
|
646 |
+
"outputs": [],
|
647 |
+
"source": [
|
648 |
+
"disc_encoder = tfm.nlp.encoders.build_encoder(disc_encoder_config)\n",
|
649 |
+
"elctra_dic_classifier = tfm.nlp.models.BertClassifier(network=disc_encoder, num_classes=2)\n",
|
650 |
+
"tf.keras.utils.plot_model(elctra_dic_classifier)"
|
651 |
+
]
|
652 |
+
},
|
653 |
+
{
|
654 |
+
"cell_type": "markdown",
|
655 |
+
"metadata": {
|
656 |
+
"id": "aWQ2FKj64X5U"
|
657 |
+
},
|
658 |
+
"source": [
|
659 |
+
"### Load Pretrained Weights into the Classifier\n",
|
660 |
+
"\n",
|
661 |
+
"The provided pretrained checkpoint contains weights for the entire ELECTRA model. We are only loading its discriminator (conveninently named as `encoder`) wights within the Classifier. Weights for the Classification Head is still randomly initialized."
|
662 |
+
]
|
663 |
+
},
|
664 |
+
{
|
665 |
+
"cell_type": "code",
|
666 |
+
"execution_count": null,
|
667 |
+
"metadata": {
|
668 |
+
"id": "99pznFJszQfV"
|
669 |
+
},
|
670 |
+
"outputs": [],
|
671 |
+
"source": [
|
672 |
+
"checkpoint = tf.train.Checkpoint(encoder=disc_encoder)\n",
|
673 |
+
"checkpoint.read(\n",
|
674 |
+
" tf.train.latest_checkpoint(os.path.join(folder_electra))\n",
|
675 |
+
" ).expect_partial().assert_existing_objects_matched()"
|
676 |
+
]
|
677 |
+
}
|
678 |
+
],
|
679 |
+
"metadata": {
|
680 |
+
"colab": {
|
681 |
+
"name": "load_lm_ckpts.ipynb",
|
682 |
+
"provenance": [],
|
683 |
+
"toc_visible": true
|
684 |
+
},
|
685 |
+
"kernelspec": {
|
686 |
+
"display_name": "Python 3",
|
687 |
+
"name": "python3"
|
688 |
+
}
|
689 |
+
},
|
690 |
+
"nbformat": 4,
|
691 |
+
"nbformat_minor": 0
|
692 |
+
}
|
Tensorflow/models/docs/orbit/index.ipynb
ADDED
@@ -0,0 +1,898 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "Tce3stUlHN0L"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"##### Copyright 2020 The TensorFlow Authors."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"cellView": "form",
|
17 |
+
"id": "tuOe1ymfHZPu"
|
18 |
+
},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
22 |
+
"# you may not use this file except in compliance with the License.\n",
|
23 |
+
"# You may obtain a copy of the License at\n",
|
24 |
+
"#\n",
|
25 |
+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
26 |
+
"#\n",
|
27 |
+
"# Unless required by applicable law or agreed to in writing, software\n",
|
28 |
+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
29 |
+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
30 |
+
"# See the License for the specific language governing permissions and\n",
|
31 |
+
"# limitations under the License."
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"metadata": {
|
37 |
+
"id": "qFdPvlXBOdUN"
|
38 |
+
},
|
39 |
+
"source": [
|
40 |
+
"# Training with Orbit"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"metadata": {
|
46 |
+
"id": "MfBg1C5NB3X0"
|
47 |
+
},
|
48 |
+
"source": [
|
49 |
+
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
|
50 |
+
" \u003ctd\u003e\n",
|
51 |
+
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/orbit\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
|
52 |
+
" \u003c/td\u003e\n",
|
53 |
+
" \u003ctd\u003e\n",
|
54 |
+
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
|
55 |
+
" \u003c/td\u003e\n",
|
56 |
+
" \u003ctd\u003e\n",
|
57 |
+
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
|
58 |
+
" \u003c/td\u003e\n",
|
59 |
+
" \u003ctd\u003e\n",
|
60 |
+
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
|
61 |
+
" \u003c/td\u003e\n",
|
62 |
+
"\n",
|
63 |
+
"\u003c/table\u003e"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "markdown",
|
68 |
+
"metadata": {
|
69 |
+
"id": "456h0idS2Xcq"
|
70 |
+
},
|
71 |
+
"source": [
|
72 |
+
"This example will work through fine-tuning a BERT model using the [Orbit](https://www.tensorflow.org/api_docs/python/orbit) training library.\n",
|
73 |
+
"\n",
|
74 |
+
"Orbit is a flexible, lightweight library designed to make it easy to write [custom training loops](https://www.tensorflow.org/tutorials/distribute/custom_training) in TensorFlow. Orbit handles common model training tasks such as saving checkpoints, running model evaluations, and setting up summary writing, while giving users full control over implementing the inner training loop. It integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU).\n",
|
75 |
+
"\n",
|
76 |
+
"Most examples on [tensorflow.org](https://www.tensorflow.org/) use custom training loops or [model.fit()](https://www.tensorflow.org/api_docs/python/tf/keras/Model) from Keras. Orbit is a good alternative to `model.fit` if your model is complex and your training loop requires more flexibility, control, or customization. Also, using Orbit can simplify the code when there are many different model architectures that all use the same custom training loop.\n",
|
77 |
+
"\n",
|
78 |
+
"This tutorial focuses on setting up and using Orbit, rather than details about BERT, model construction, and data processing. For more in-depth tutorials on these topics, refer to the following tutorials:\n",
|
79 |
+
"\n",
|
80 |
+
"* [Fine tune BERT](https://www.tensorflow.org/text/tutorials/fine_tune_bert) - which goes into detail on these sub-topics.\n",
|
81 |
+
"* [Fine tune BERT for GLUE on TPU](https://www.tensorflow.org/text/tutorials/bert_glue) - which generalizes the code to run any BERT configuration on any [GLUE](https://www.tensorflow.org/datasets/catalog/glue) sub-task, and runs on TPU."
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "markdown",
|
86 |
+
"metadata": {
|
87 |
+
"id": "TJ4m3khW3p_W"
|
88 |
+
},
|
89 |
+
"source": [
|
90 |
+
"## Install the TensorFlow Models package\n",
|
91 |
+
"\n",
|
92 |
+
"Install and import the necessary packages, then configure all the objects necessary for training a model.\n",
|
93 |
+
"\n"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"cell_type": "code",
|
98 |
+
"execution_count": null,
|
99 |
+
"metadata": {
|
100 |
+
"id": "FZlj0U8Aq9Gt"
|
101 |
+
},
|
102 |
+
"outputs": [],
|
103 |
+
"source": [
|
104 |
+
"!pip install -q opencv-python\n",
|
105 |
+
"!pip install tensorflow>=2.9.0 tf-models-official"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "markdown",
|
110 |
+
"metadata": {
|
111 |
+
"id": "MEJkRrmapr16"
|
112 |
+
},
|
113 |
+
"source": [
|
114 |
+
"The `tf-models-official` package contains both the `orbit` and `tensorflow_models` modules."
|
115 |
+
]
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"cell_type": "code",
|
119 |
+
"execution_count": null,
|
120 |
+
"metadata": {
|
121 |
+
"id": "dUVPW84Zucuq"
|
122 |
+
},
|
123 |
+
"outputs": [],
|
124 |
+
"source": [
|
125 |
+
"import tensorflow_models as tfm\n",
|
126 |
+
"import orbit"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "markdown",
|
131 |
+
"metadata": {
|
132 |
+
"id": "18Icocf3lwYD"
|
133 |
+
},
|
134 |
+
"source": [
|
135 |
+
"## Setup for training\n",
|
136 |
+
"\n",
|
137 |
+
"This tutorial does not focus on configuring the environment, building the model and optimizer, and loading data. All these techniques are covered in more detail in the [Fine tune BERT](https://www.tensorflow.org/text/tutorials/fine_tune_bert) and [Fine tune BERT with GLUE](https://www.tensorflow.org/text/tutorials/bert_glue) tutorials.\n",
|
138 |
+
"\n",
|
139 |
+
"To view how the training is set up for this tutorial, expand the rest of this section.\n",
|
140 |
+
"\n",
|
141 |
+
" \u003c!-- \u003cdiv class=\"tfo-display-only-on-site\"\u003e\u003cdevsite-expandable\u003e\n",
|
142 |
+
" \u003cbutton type=\"button\" class=\"button-red button expand-control\"\u003eExpand Section\u003c/button\u003e --\u003e"
|
143 |
+
]
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"cell_type": "markdown",
|
147 |
+
"metadata": {
|
148 |
+
"id": "Ljy0z-i3okCS"
|
149 |
+
},
|
150 |
+
"source": [
|
151 |
+
"### Import the necessary packages\n",
|
152 |
+
"\n",
|
153 |
+
"Import the BERT model and dataset building library from [Tensorflow Model Garden](https://github.com/tensorflow/models)."
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "code",
|
158 |
+
"execution_count": null,
|
159 |
+
"metadata": {
|
160 |
+
"id": "gCBo6wxA2b5n"
|
161 |
+
},
|
162 |
+
"outputs": [],
|
163 |
+
"source": [
|
164 |
+
"import glob\n",
|
165 |
+
"import os\n",
|
166 |
+
"import pathlib\n",
|
167 |
+
"import tempfile\n",
|
168 |
+
"import time\n",
|
169 |
+
"\n",
|
170 |
+
"import numpy as np\n",
|
171 |
+
"\n",
|
172 |
+
"import tensorflow as tf"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": null,
|
178 |
+
"metadata": {
|
179 |
+
"id": "PG1kwhnvq3VC"
|
180 |
+
},
|
181 |
+
"outputs": [],
|
182 |
+
"source": [
|
183 |
+
"from official.nlp.data import sentence_prediction_dataloader\n",
|
184 |
+
"from official.nlp import optimization"
|
185 |
+
]
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"cell_type": "markdown",
|
189 |
+
"metadata": {
|
190 |
+
"id": "PsbhUV_p3wxN"
|
191 |
+
},
|
192 |
+
"source": [
|
193 |
+
"### Configure the distribution strategy\n",
|
194 |
+
"\n",
|
195 |
+
"While `tf.distribute` won't help the model's runtime if you're running on a single machine or GPU, it's necessary for TPUs. Setting up a distribution strategy allows you to use the same code regardless of the configuration."
|
196 |
+
]
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"cell_type": "code",
|
200 |
+
"execution_count": null,
|
201 |
+
"metadata": {
|
202 |
+
"id": "PG702dqstXIk"
|
203 |
+
},
|
204 |
+
"outputs": [],
|
205 |
+
"source": [
|
206 |
+
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
|
207 |
+
"\n",
|
208 |
+
"if 'GPU' in ''.join(logical_device_names):\n",
|
209 |
+
" strategy = tf.distribute.MirroredStrategy()\n",
|
210 |
+
"elif 'TPU' in ''.join(logical_device_names):\n",
|
211 |
+
" resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n",
|
212 |
+
" tf.config.experimental_connect_to_cluster(resolver)\n",
|
213 |
+
" tf.tpu.experimental.initialize_tpu_system(resolver)\n",
|
214 |
+
" strategy = tf.distribute.TPUStrategy(resolver)\n",
|
215 |
+
"else:\n",
|
216 |
+
" strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "markdown",
|
221 |
+
"metadata": {
|
222 |
+
"id": "eaQgM98deAMu"
|
223 |
+
},
|
224 |
+
"source": [
|
225 |
+
"For more information about the TPU setup, refer to the [TPU guide](https://www.tensorflow.org/guide/tpu)."
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "markdown",
|
230 |
+
"metadata": {
|
231 |
+
"id": "7aOxMLLV32Zm"
|
232 |
+
},
|
233 |
+
"source": [
|
234 |
+
"### Create a model and an optimizer"
|
235 |
+
]
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"cell_type": "code",
|
239 |
+
"execution_count": null,
|
240 |
+
"metadata": {
|
241 |
+
"id": "YRdWzOfK3_56"
|
242 |
+
},
|
243 |
+
"outputs": [],
|
244 |
+
"source": [
|
245 |
+
"max_seq_length = 128\n",
|
246 |
+
"learning_rate = 3e-5\n",
|
247 |
+
"num_train_epochs = 3\n",
|
248 |
+
"train_batch_size = 32\n",
|
249 |
+
"eval_batch_size = 64\n",
|
250 |
+
"\n",
|
251 |
+
"train_data_size = 3668\n",
|
252 |
+
"steps_per_epoch = int(train_data_size / train_batch_size)\n",
|
253 |
+
"\n",
|
254 |
+
"train_steps = steps_per_epoch * num_train_epochs\n",
|
255 |
+
"warmup_steps = int(train_steps * 0.1)\n",
|
256 |
+
"\n",
|
257 |
+
"print(\"train batch size: \", train_batch_size)\n",
|
258 |
+
"print(\"train epochs: \", num_train_epochs)\n",
|
259 |
+
"print(\"steps_per_epoch: \", steps_per_epoch)"
|
260 |
+
]
|
261 |
+
},
|
262 |
+
{
|
263 |
+
"cell_type": "code",
|
264 |
+
"execution_count": null,
|
265 |
+
"metadata": {
|
266 |
+
"id": "BVw3886Ysse6"
|
267 |
+
},
|
268 |
+
"outputs": [],
|
269 |
+
"source": [
|
270 |
+
"model_dir = pathlib.Path(tempfile.mkdtemp())\n",
|
271 |
+
"print(model_dir)"
|
272 |
+
]
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"cell_type": "markdown",
|
276 |
+
"metadata": {
|
277 |
+
"id": "mu9cV7ew-cVe"
|
278 |
+
},
|
279 |
+
"source": [
|
280 |
+
"\n",
|
281 |
+
"Create a BERT Classifier model and a simple optimizer. They must be created inside `strategy.scope` so that the variables can be distributed. "
|
282 |
+
]
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"cell_type": "code",
|
286 |
+
"execution_count": null,
|
287 |
+
"metadata": {
|
288 |
+
"id": "gmwtX0cp-mj5"
|
289 |
+
},
|
290 |
+
"outputs": [],
|
291 |
+
"source": [
|
292 |
+
"with strategy.scope():\n",
|
293 |
+
" encoder_network = tfm.nlp.encoders.build_encoder(\n",
|
294 |
+
" tfm.nlp.encoders.EncoderConfig(type=\"bert\"))\n",
|
295 |
+
" classifier_model = tfm.nlp.models.BertClassifier(\n",
|
296 |
+
" network=encoder_network, num_classes=2)\n",
|
297 |
+
"\n",
|
298 |
+
" optimizer = optimization.create_optimizer(\n",
|
299 |
+
" init_lr=3e-5,\n",
|
300 |
+
" num_train_steps=steps_per_epoch * num_train_epochs,\n",
|
301 |
+
" num_warmup_steps=warmup_steps,\n",
|
302 |
+
" end_lr=0.0,\n",
|
303 |
+
" optimizer_type='adamw')"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"cell_type": "code",
|
308 |
+
"execution_count": null,
|
309 |
+
"metadata": {
|
310 |
+
"id": "jwJSfewG5jVV"
|
311 |
+
},
|
312 |
+
"outputs": [],
|
313 |
+
"source": [
|
314 |
+
"tf.keras.utils.plot_model(classifier_model)"
|
315 |
+
]
|
316 |
+
},
|
317 |
+
{
|
318 |
+
"cell_type": "markdown",
|
319 |
+
"metadata": {
|
320 |
+
"id": "IQy5pYgAf8Ft"
|
321 |
+
},
|
322 |
+
"source": [
|
323 |
+
"### Initialize from a Checkpoint"
|
324 |
+
]
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "code",
|
328 |
+
"execution_count": null,
|
329 |
+
"metadata": {
|
330 |
+
"id": "6CE14GEybgRR"
|
331 |
+
},
|
332 |
+
"outputs": [],
|
333 |
+
"source": [
|
334 |
+
"bert_dir = 'gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12/'\n",
|
335 |
+
"tf.io.gfile.listdir(bert_dir)"
|
336 |
+
]
|
337 |
+
},
|
338 |
+
{
|
339 |
+
"cell_type": "code",
|
340 |
+
"execution_count": null,
|
341 |
+
"metadata": {
|
342 |
+
"id": "x7fwxz9xidKt"
|
343 |
+
},
|
344 |
+
"outputs": [],
|
345 |
+
"source": [
|
346 |
+
"bert_checkpoint = bert_dir + 'bert_model.ckpt'"
|
347 |
+
]
|
348 |
+
},
|
349 |
+
{
|
350 |
+
"cell_type": "code",
|
351 |
+
"execution_count": null,
|
352 |
+
"metadata": {
|
353 |
+
"id": "q7EfwVCRe7N_"
|
354 |
+
},
|
355 |
+
"outputs": [],
|
356 |
+
"source": [
|
357 |
+
"def init_from_ckpt_fn():\n",
|
358 |
+
" init_checkpoint = tf.train.Checkpoint(**classifier_model.checkpoint_items)\n",
|
359 |
+
" with strategy.scope():\n",
|
360 |
+
" (init_checkpoint\n",
|
361 |
+
" .read(bert_checkpoint)\n",
|
362 |
+
" .expect_partial()\n",
|
363 |
+
" .assert_existing_objects_matched())"
|
364 |
+
]
|
365 |
+
},
|
366 |
+
{
|
367 |
+
"cell_type": "code",
|
368 |
+
"execution_count": null,
|
369 |
+
"metadata": {
|
370 |
+
"id": "M0LUMlsde-2f"
|
371 |
+
},
|
372 |
+
"outputs": [],
|
373 |
+
"source": [
|
374 |
+
"with strategy.scope():\n",
|
375 |
+
" init_from_ckpt_fn()"
|
376 |
+
]
|
377 |
+
},
|
378 |
+
{
|
379 |
+
"cell_type": "markdown",
|
380 |
+
"metadata": {
|
381 |
+
"id": "gAuns4vN_IYV"
|
382 |
+
},
|
383 |
+
"source": [
|
384 |
+
"\n",
|
385 |
+
"To use Orbit, create a `tf.train.CheckpointManager` object."
|
386 |
+
]
|
387 |
+
},
|
388 |
+
{
|
389 |
+
"cell_type": "code",
|
390 |
+
"execution_count": null,
|
391 |
+
"metadata": {
|
392 |
+
"id": "i7NwM1Jq_MX7"
|
393 |
+
},
|
394 |
+
"outputs": [],
|
395 |
+
"source": [
|
396 |
+
"checkpoint = tf.train.Checkpoint(model=classifier_model, optimizer=optimizer)\n",
|
397 |
+
"checkpoint_manager = tf.train.CheckpointManager(\n",
|
398 |
+
" checkpoint,\n",
|
399 |
+
" directory=model_dir,\n",
|
400 |
+
" max_to_keep=5,\n",
|
401 |
+
" step_counter=optimizer.iterations,\n",
|
402 |
+
" checkpoint_interval=steps_per_epoch,\n",
|
403 |
+
" init_fn=init_from_ckpt_fn)"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"cell_type": "markdown",
|
408 |
+
"metadata": {
|
409 |
+
"id": "nzeiAFhcCOAo"
|
410 |
+
},
|
411 |
+
"source": [
|
412 |
+
"### Create distributed datasets\n",
|
413 |
+
"\n",
|
414 |
+
"As a shortcut for this tutorial, the [GLUE/MPRC dataset](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc) has been converted to a pair of [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) files containing serialized `tf.train.Example` protos.\n",
|
415 |
+
"\n",
|
416 |
+
"The data was converted using [this script](https://github.com/tensorflow/models/blob/r2.9.0/official/nlp/data/create_finetuning_data.py).\n",
|
417 |
+
"\n"
|
418 |
+
]
|
419 |
+
},
|
420 |
+
{
|
421 |
+
"cell_type": "code",
|
422 |
+
"execution_count": null,
|
423 |
+
"metadata": {
|
424 |
+
"id": "ZVfbiT1dCnDk"
|
425 |
+
},
|
426 |
+
"outputs": [],
|
427 |
+
"source": [
|
428 |
+
"train_data_path = \"gs://download.tensorflow.org/data/model_garden_colab/mrpc_train.tf_record\"\n",
|
429 |
+
"eval_data_path = \"gs://download.tensorflow.org/data/model_garden_colab/mrpc_eval.tf_record\"\n",
|
430 |
+
"\n",
|
431 |
+
"def _dataset_fn(input_file_pattern, \n",
|
432 |
+
" global_batch_size, \n",
|
433 |
+
" is_training, \n",
|
434 |
+
" input_context=None):\n",
|
435 |
+
" data_config = sentence_prediction_dataloader.SentencePredictionDataConfig(\n",
|
436 |
+
" input_path=input_file_pattern,\n",
|
437 |
+
" seq_length=max_seq_length,\n",
|
438 |
+
" global_batch_size=global_batch_size,\n",
|
439 |
+
" is_training=is_training)\n",
|
440 |
+
" return sentence_prediction_dataloader.SentencePredictionDataLoader(\n",
|
441 |
+
" data_config).load(input_context=input_context)\n",
|
442 |
+
"\n",
|
443 |
+
"train_dataset = orbit.utils.make_distributed_dataset(\n",
|
444 |
+
" strategy, _dataset_fn, input_file_pattern=train_data_path,\n",
|
445 |
+
" global_batch_size=train_batch_size, is_training=True)\n",
|
446 |
+
"eval_dataset = orbit.utils.make_distributed_dataset(\n",
|
447 |
+
" strategy, _dataset_fn, input_file_pattern=eval_data_path,\n",
|
448 |
+
" global_batch_size=eval_batch_size, is_training=False)\n",
|
449 |
+
"\n"
|
450 |
+
]
|
451 |
+
},
|
452 |
+
{
|
453 |
+
"cell_type": "markdown",
|
454 |
+
"metadata": {
|
455 |
+
"id": "dPgiDBQCjsXW"
|
456 |
+
},
|
457 |
+
"source": [
|
458 |
+
"### Create a loss function\n"
|
459 |
+
]
|
460 |
+
},
|
461 |
+
{
|
462 |
+
"cell_type": "code",
|
463 |
+
"execution_count": null,
|
464 |
+
"metadata": {
|
465 |
+
"id": "7MCUmmo2jvXl"
|
466 |
+
},
|
467 |
+
"outputs": [],
|
468 |
+
"source": [
|
469 |
+
"def loss_fn(labels, logits):\n",
|
470 |
+
" \"\"\"Classification loss.\"\"\"\n",
|
471 |
+
" labels = tf.squeeze(labels)\n",
|
472 |
+
" log_probs = tf.nn.log_softmax(logits, axis=-1)\n",
|
473 |
+
" one_hot_labels = tf.one_hot(\n",
|
474 |
+
" tf.cast(labels, dtype=tf.int32), depth=2, dtype=tf.float32)\n",
|
475 |
+
" per_example_loss = -tf.reduce_sum(\n",
|
476 |
+
" tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)\n",
|
477 |
+
" return tf.reduce_mean(per_example_loss)"
|
478 |
+
]
|
479 |
+
},
|
480 |
+
{
|
481 |
+
"cell_type": "markdown",
|
482 |
+
"metadata": {
|
483 |
+
"id": "ohlO-8FQkwsr"
|
484 |
+
},
|
485 |
+
"source": [
|
486 |
+
" \u003c/devsite-expandable\u003e\u003c/div\u003e"
|
487 |
+
]
|
488 |
+
},
|
489 |
+
{
|
490 |
+
"cell_type": "markdown",
|
491 |
+
"metadata": {
|
492 |
+
"id": "ymhbvPaEJ96T"
|
493 |
+
},
|
494 |
+
"source": [
|
495 |
+
"## Controllers, Trainers and Evaluators\n",
|
496 |
+
"\n",
|
497 |
+
"When using Orbit, the `orbit.Controller` class drives the training. The Controller handles the details of distribution strategies, step counting, TensorBoard summaries, and checkpointing.\n",
|
498 |
+
"\n",
|
499 |
+
"To implement the training and evaluation, pass a `trainer` and `evaluator`, which are subclass instances of `orbit.AbstractTrainer` and `orbit.AbstractEvaluator`. Keeping with Orbit's light-weight design, these two classes have a minimal interface.\n",
|
500 |
+
"\n",
|
501 |
+
"The Controller drives training and evaluation by calling `trainer.train(num_steps)` and `evaluator.evaluate(num_steps)`. These `train` and `evaluate` methods return a dictionary of results for logging.\n",
|
502 |
+
"\n"
|
503 |
+
]
|
504 |
+
},
|
505 |
+
{
|
506 |
+
"cell_type": "markdown",
|
507 |
+
"metadata": {
|
508 |
+
"id": "a6sU2vBeyXtu"
|
509 |
+
},
|
510 |
+
"source": [
|
511 |
+
"Training is broken into chunks of length `num_steps`. This is set by the Controller's [`steps_per_loop`](https://tensorflow.org/api_docs/python/orbit/Controller#args) argument. With the trainer and evaluator abstract base classes, the meaning of `num_steps` is entirely determined by the implementer.\n",
|
512 |
+
"\n",
|
513 |
+
"Some common examples include:\n",
|
514 |
+
"\n",
|
515 |
+
"* Having the chunks represent dataset-epoch boundaries, like the default keras setup. \n",
|
516 |
+
"* Using it to more efficiently dispatch a number of training steps to an accelerator with a single `tf.function` call (like the `steps_per_execution` argument to `Model.compile`). \n",
|
517 |
+
"* Subdividing into smaller chunks as needed.\n"
|
518 |
+
]
|
519 |
+
},
|
520 |
+
{
|
521 |
+
"cell_type": "markdown",
|
522 |
+
"metadata": {
|
523 |
+
"id": "p4mXGIRJsf1j"
|
524 |
+
},
|
525 |
+
"source": [
|
526 |
+
"### StandardTrainer and StandardEvaluator\n",
|
527 |
+
"\n",
|
528 |
+
"Orbit provides two additional classes, `orbit.StandardTrainer` and `orbit.StandardEvaluator`, to give more structure around the training and evaluation loops.\n",
|
529 |
+
"\n",
|
530 |
+
"With StandardTrainer, you only need to set `train_loop_begin`, `train_step`, and `train_loop_end`. The base class handles the loops, dataset logic, and `tf.function` (according to the options set by their `orbit.StandardTrainerOptions`). This is simpler than `orbit.AbstractTrainer`, which requires you to handle the entire loop. StandardEvaluator has a similar structure and simplification to StandardTrainer.\n",
|
531 |
+
"\n",
|
532 |
+
"This is effectively an implementation of the `steps_per_execution` approach used by Keras."
|
533 |
+
]
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"cell_type": "markdown",
|
537 |
+
"metadata": {
|
538 |
+
"id": "-hvZ8PvohmR5"
|
539 |
+
},
|
540 |
+
"source": [
|
541 |
+
"Contrast this with Keras, where training is divided both into epochs (a single pass over the dataset) and `steps_per_execution`(set within [`Model.compile`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile). In Keras, metric averages are typically accumulated over an epoch, and reported \u0026 reset between epochs. For efficiency, `steps_per_execution` only controls the number of training steps made per call.\n",
|
542 |
+
"\n",
|
543 |
+
"In this simple case, `steps_per_loop` (within `StandardTrainer`) will handle both the metric resets and the number of steps per call. \n"
|
544 |
+
]
|
545 |
+
},
|
546 |
+
{
|
547 |
+
"cell_type": "markdown",
|
548 |
+
"metadata": {
|
549 |
+
"id": "NoDFN1L-1jIu"
|
550 |
+
},
|
551 |
+
"source": [
|
552 |
+
"The minimal setup when using these base classes is to implement the methods as follows:\n",
|
553 |
+
"\n",
|
554 |
+
"1. `StandardTrainer.train_loop_begin` - Reset your training metrics.\n",
|
555 |
+
"2. `StandardTrainer.train_step` - Apply a single gradient update.\n",
|
556 |
+
"3. `StandardTrainer.train_loop_end` - Report your training metrics.\n",
|
557 |
+
"\n",
|
558 |
+
"and\n",
|
559 |
+
"\n",
|
560 |
+
"4. `StandardEvaluator.eval_begin` - Reset your evaluation metrics.\n",
|
561 |
+
"5. `StandardEvaluator.eval_step` - Run a single evaluation setep.\n",
|
562 |
+
"6. `StandardEvaluator.eval_reduce` - This is not necessary in this simple setup.\n",
|
563 |
+
"7. `StandardEvaluator.eval_end` - Report your evaluation metrics.\n",
|
564 |
+
"\n",
|
565 |
+
"Depending on the settings, the base class may wrap the `train_step` and `eval_step` code in `tf.function` or `tf.while_loop`, which has some limitations compared to standard python."
|
566 |
+
]
|
567 |
+
},
|
568 |
+
{
|
569 |
+
"cell_type": "markdown",
|
570 |
+
"metadata": {
|
571 |
+
"id": "3KPA0NDZt2JD"
|
572 |
+
},
|
573 |
+
"source": [
|
574 |
+
"### Define the trainer class"
|
575 |
+
]
|
576 |
+
},
|
577 |
+
{
|
578 |
+
"cell_type": "markdown",
|
579 |
+
"metadata": {
|
580 |
+
"id": "6LDPsvJwfuPR"
|
581 |
+
},
|
582 |
+
"source": [
|
583 |
+
"In this section you'll create a subclass of `orbit.StandardTrainer` for this task. \n",
|
584 |
+
"\n",
|
585 |
+
"Note: To better explain the `BertClassifierTrainer` class, this section defines each method as a stand-alone function and assembles them into a class at the end.\n",
|
586 |
+
"\n",
|
587 |
+
"The trainer needs access to the training data, model, optimizer, and distribution strategy. Pass these as arguments to the initializer.\n",
|
588 |
+
"\n",
|
589 |
+
"Define a single training metric, `training_loss`, using `tf.keras.metrics.Mean`. "
|
590 |
+
]
|
591 |
+
},
|
592 |
+
{
|
593 |
+
"cell_type": "code",
|
594 |
+
"execution_count": null,
|
595 |
+
"metadata": {
|
596 |
+
"id": "6DQYZN5ax-MG"
|
597 |
+
},
|
598 |
+
"outputs": [],
|
599 |
+
"source": [
|
600 |
+
"def trainer_init(self,\n",
|
601 |
+
" train_dataset,\n",
|
602 |
+
" model,\n",
|
603 |
+
" optimizer,\n",
|
604 |
+
" strategy):\n",
|
605 |
+
" self.strategy = strategy\n",
|
606 |
+
" with self.strategy.scope():\n",
|
607 |
+
" self.model = model\n",
|
608 |
+
" self.optimizer = optimizer\n",
|
609 |
+
" self.global_step = self.optimizer.iterations\n",
|
610 |
+
" \n",
|
611 |
+
"\n",
|
612 |
+
" self.train_loss = tf.keras.metrics.Mean(\n",
|
613 |
+
" 'training_loss', dtype=tf.float32)\n",
|
614 |
+
" orbit.StandardTrainer.__init__(self, train_dataset)\n"
|
615 |
+
]
|
616 |
+
},
|
617 |
+
{
|
618 |
+
"cell_type": "markdown",
|
619 |
+
"metadata": {
|
620 |
+
"id": "QOwHD7U5hVue"
|
621 |
+
},
|
622 |
+
"source": [
|
623 |
+
"Before starting a run of the training loop, the `train_loop_begin` method will reset the `train_loss` metric."
|
624 |
+
]
|
625 |
+
},
|
626 |
+
{
|
627 |
+
"cell_type": "code",
|
628 |
+
"execution_count": null,
|
629 |
+
"metadata": {
|
630 |
+
"id": "AkpcHqXShWL0"
|
631 |
+
},
|
632 |
+
"outputs": [],
|
633 |
+
"source": [
|
634 |
+
"def train_loop_begin(self):\n",
|
635 |
+
" self.train_loss.reset_states()"
|
636 |
+
]
|
637 |
+
},
|
638 |
+
{
|
639 |
+
"cell_type": "markdown",
|
640 |
+
"metadata": {
|
641 |
+
"id": "UjtFOFyxn2BB"
|
642 |
+
},
|
643 |
+
"source": [
|
644 |
+
"The `train_step` is a straight-forward loss-calculation and gradient update that is run by the distribution strategy. This is accomplished by defining the gradient step as a nested function (`step_fn`).\n",
|
645 |
+
"\n",
|
646 |
+
"The method receives `tf.distribute.DistributedIterator` to handle the [distributed input](https://www.tensorflow.org/tutorials/distribute/input). The method uses `Strategy.run` to execute `step_fn` and feeds it from the distributed iterator.\n",
|
647 |
+
"\n"
|
648 |
+
]
|
649 |
+
},
|
650 |
+
{
|
651 |
+
"cell_type": "code",
|
652 |
+
"execution_count": null,
|
653 |
+
"metadata": {
|
654 |
+
"id": "QuPwNnT5I-GP"
|
655 |
+
},
|
656 |
+
"outputs": [],
|
657 |
+
"source": [
|
658 |
+
"def train_step(self, iterator):\n",
|
659 |
+
"\n",
|
660 |
+
" def step_fn(inputs):\n",
|
661 |
+
" labels = inputs.pop(\"label_ids\")\n",
|
662 |
+
" with tf.GradientTape() as tape:\n",
|
663 |
+
" model_outputs = self.model(inputs, training=True)\n",
|
664 |
+
" # Raw loss is used for reporting in metrics/logs.\n",
|
665 |
+
" raw_loss = loss_fn(labels, model_outputs)\n",
|
666 |
+
" # Scales down the loss for gradients to be invariant from replicas.\n",
|
667 |
+
" loss = raw_loss / self.strategy.num_replicas_in_sync\n",
|
668 |
+
"\n",
|
669 |
+
" grads = tape.gradient(loss, self.model.trainable_variables)\n",
|
670 |
+
" optimizer.apply_gradients(zip(grads, self.model.trainable_variables))\n",
|
671 |
+
" # For reporting, the metric takes the mean of losses.\n",
|
672 |
+
" self.train_loss.update_state(raw_loss)\n",
|
673 |
+
"\n",
|
674 |
+
" self.strategy.run(step_fn, args=(next(iterator),))"
|
675 |
+
]
|
676 |
+
},
|
677 |
+
{
|
678 |
+
"cell_type": "markdown",
|
679 |
+
"metadata": {
|
680 |
+
"id": "VmQNwx5QpyDt"
|
681 |
+
},
|
682 |
+
"source": [
|
683 |
+
"The `orbit.StandardTrainer` handles the `@tf.function` and loops.\n",
|
684 |
+
"\n",
|
685 |
+
"After running through `num_steps` of training, `StandardTrainer` calls `train_loop_end`. The function returns the metric results:"
|
686 |
+
]
|
687 |
+
},
|
688 |
+
{
|
689 |
+
"cell_type": "code",
|
690 |
+
"execution_count": null,
|
691 |
+
"metadata": {
|
692 |
+
"id": "GqCyVk1zzGod"
|
693 |
+
},
|
694 |
+
"outputs": [],
|
695 |
+
"source": [
|
696 |
+
"def train_loop_end(self):\n",
|
697 |
+
" return {\n",
|
698 |
+
" self.train_loss.name: self.train_loss.result(),\n",
|
699 |
+
" }"
|
700 |
+
]
|
701 |
+
},
|
702 |
+
{
|
703 |
+
"cell_type": "markdown",
|
704 |
+
"metadata": {
|
705 |
+
"id": "xvmLONl80KUv"
|
706 |
+
},
|
707 |
+
"source": [
|
708 |
+
"Build a subclass of `orbit.StandardTrainer` with those methods."
|
709 |
+
]
|
710 |
+
},
|
711 |
+
{
|
712 |
+
"cell_type": "code",
|
713 |
+
"execution_count": null,
|
714 |
+
"metadata": {
|
715 |
+
"id": "oRoL7VE6xt1G"
|
716 |
+
},
|
717 |
+
"outputs": [],
|
718 |
+
"source": [
|
719 |
+
"class BertClassifierTrainer(orbit.StandardTrainer):\n",
|
720 |
+
" __init__ = trainer_init\n",
|
721 |
+
" train_loop_begin = train_loop_begin\n",
|
722 |
+
" train_step = train_step\n",
|
723 |
+
" train_loop_end = train_loop_end"
|
724 |
+
]
|
725 |
+
},
|
726 |
+
{
|
727 |
+
"cell_type": "markdown",
|
728 |
+
"metadata": {
|
729 |
+
"id": "yjG4QAWj1B00"
|
730 |
+
},
|
731 |
+
"source": [
|
732 |
+
"### Define the evaluator class\n",
|
733 |
+
"\n",
|
734 |
+
"Note: Like the previous section, this section defines each method as a stand-alone function and assembles them into a `BertClassifierEvaluator` class at the end.\n",
|
735 |
+
"\n",
|
736 |
+
"The evaluator is even simpler for this task. It needs access to the evaluation dataset, the model, and the strategy. After saving references to those objects, the constructor just needs to create the metrics."
|
737 |
+
]
|
738 |
+
},
|
739 |
+
{
|
740 |
+
"cell_type": "code",
|
741 |
+
"execution_count": null,
|
742 |
+
"metadata": {
|
743 |
+
"id": "cvX7seCY1CWj"
|
744 |
+
},
|
745 |
+
"outputs": [],
|
746 |
+
"source": [
|
747 |
+
"def evaluator_init(self,\n",
|
748 |
+
" eval_dataset,\n",
|
749 |
+
" model,\n",
|
750 |
+
" strategy):\n",
|
751 |
+
" self.strategy = strategy\n",
|
752 |
+
" with self.strategy.scope():\n",
|
753 |
+
" self.model = model\n",
|
754 |
+
" \n",
|
755 |
+
" self.eval_loss = tf.keras.metrics.Mean(\n",
|
756 |
+
" 'evaluation_loss', dtype=tf.float32)\n",
|
757 |
+
" self.eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
|
758 |
+
" name='accuracy', dtype=tf.float32)\n",
|
759 |
+
" orbit.StandardEvaluator.__init__(self, eval_dataset)"
|
760 |
+
]
|
761 |
+
},
|
762 |
+
{
|
763 |
+
"cell_type": "markdown",
|
764 |
+
"metadata": {
|
765 |
+
"id": "0r-z-XK7ybyX"
|
766 |
+
},
|
767 |
+
"source": [
|
768 |
+
"Similar to the trainer, the `eval_begin` and `eval_end` methods just need to reset the metrics before the loop and then report the results after the loop."
|
769 |
+
]
|
770 |
+
},
|
771 |
+
{
|
772 |
+
"cell_type": "code",
|
773 |
+
"execution_count": null,
|
774 |
+
"metadata": {
|
775 |
+
"id": "7VVb0Tg6yZjI"
|
776 |
+
},
|
777 |
+
"outputs": [],
|
778 |
+
"source": [
|
779 |
+
"def eval_begin(self):\n",
|
780 |
+
" self.eval_accuracy.reset_states()\n",
|
781 |
+
" self.eval_loss.reset_states()\n",
|
782 |
+
"\n",
|
783 |
+
"def eval_end(self):\n",
|
784 |
+
" return {\n",
|
785 |
+
" self.eval_accuracy.name: self.eval_accuracy.result(),\n",
|
786 |
+
" self.eval_loss.name: self.eval_loss.result(),\n",
|
787 |
+
" }"
|
788 |
+
]
|
789 |
+
},
|
790 |
+
{
|
791 |
+
"cell_type": "markdown",
|
792 |
+
"metadata": {
|
793 |
+
"id": "iDOZcQvttdmZ"
|
794 |
+
},
|
795 |
+
"source": [
|
796 |
+
"The `eval_step` method works like `train_step`. The inner `step_fn` defines the actual work of calculating the loss \u0026 accuracy and updating the metrics. The outer `eval_step` receives `tf.distribute.DistributedIterator` as input, and uses `Strategy.run` to launch the distributed execution to `step_fn`, feeding it from the distributed iterator."
|
797 |
+
]
|
798 |
+
},
|
799 |
+
{
|
800 |
+
"cell_type": "code",
|
801 |
+
"execution_count": null,
|
802 |
+
"metadata": {
|
803 |
+
"id": "JLJnYuuGJjvd"
|
804 |
+
},
|
805 |
+
"outputs": [],
|
806 |
+
"source": [
|
807 |
+
"def eval_step(self, iterator):\n",
|
808 |
+
"\n",
|
809 |
+
" def step_fn(inputs):\n",
|
810 |
+
" labels = inputs.pop(\"label_ids\")\n",
|
811 |
+
" model_outputs = self.model(inputs, training=True)\n",
|
812 |
+
" loss = loss_fn(labels, model_outputs)\n",
|
813 |
+
" self.eval_loss.update_state(loss)\n",
|
814 |
+
" self.eval_accuracy.update_state(labels, model_outputs)\n",
|
815 |
+
"\n",
|
816 |
+
" self.strategy.run(step_fn, args=(next(iterator),))"
|
817 |
+
]
|
818 |
+
},
|
819 |
+
{
|
820 |
+
"cell_type": "markdown",
|
821 |
+
"metadata": {
|
822 |
+
"id": "Gt3hh0V30QcP"
|
823 |
+
},
|
824 |
+
"source": [
|
825 |
+
"Build a subclass of `orbit.StandardEvaluator` with those methods."
|
826 |
+
]
|
827 |
+
},
|
828 |
+
{
|
829 |
+
"cell_type": "code",
|
830 |
+
"execution_count": null,
|
831 |
+
"metadata": {
|
832 |
+
"id": "3zqyLxfNyCgA"
|
833 |
+
},
|
834 |
+
"outputs": [],
|
835 |
+
"source": [
|
836 |
+
"class BertClassifierEvaluator(orbit.StandardEvaluator):\n",
|
837 |
+
" __init__ = evaluator_init\n",
|
838 |
+
" eval_begin = eval_begin\n",
|
839 |
+
" eval_end = eval_end\n",
|
840 |
+
" eval_step = eval_step"
|
841 |
+
]
|
842 |
+
},
|
843 |
+
{
|
844 |
+
"cell_type": "markdown",
|
845 |
+
"metadata": {
|
846 |
+
"id": "aK9gEja9qPOc"
|
847 |
+
},
|
848 |
+
"source": [
|
849 |
+
"### End-to-end training and evaluation\n",
|
850 |
+
"\n",
|
851 |
+
"To run the training and evaluation, simply create the trainer, evaluator, and `orbit.Controller` instances. Then call the `Controller.train_and_evaluate` method."
|
852 |
+
]
|
853 |
+
},
|
854 |
+
{
|
855 |
+
"cell_type": "code",
|
856 |
+
"execution_count": null,
|
857 |
+
"metadata": {
|
858 |
+
"id": "PqQetxyXqRA9"
|
859 |
+
},
|
860 |
+
"outputs": [],
|
861 |
+
"source": [
|
862 |
+
"trainer = BertClassifierTrainer(\n",
|
863 |
+
" train_dataset, classifier_model, optimizer, strategy)\n",
|
864 |
+
"\n",
|
865 |
+
"evaluator = BertClassifierEvaluator(\n",
|
866 |
+
" eval_dataset, classifier_model, strategy)\n",
|
867 |
+
"\n",
|
868 |
+
"controller = orbit.Controller(\n",
|
869 |
+
" trainer=trainer,\n",
|
870 |
+
" evaluator=evaluator,\n",
|
871 |
+
" global_step=trainer.global_step,\n",
|
872 |
+
" steps_per_loop=20,\n",
|
873 |
+
" checkpoint_manager=checkpoint_manager)\n",
|
874 |
+
"\n",
|
875 |
+
"result = controller.train_and_evaluate(\n",
|
876 |
+
" train_steps=steps_per_epoch * num_train_epochs,\n",
|
877 |
+
" eval_steps=-1,\n",
|
878 |
+
" eval_interval=steps_per_epoch)"
|
879 |
+
]
|
880 |
+
}
|
881 |
+
],
|
882 |
+
"metadata": {
|
883 |
+
"colab": {
|
884 |
+
"collapsed_sections": [
|
885 |
+
"Tce3stUlHN0L"
|
886 |
+
],
|
887 |
+
"name": "Orbit Tutorial.ipynb",
|
888 |
+
"provenance": [],
|
889 |
+
"toc_visible": true
|
890 |
+
},
|
891 |
+
"kernelspec": {
|
892 |
+
"display_name": "Python 3",
|
893 |
+
"name": "python3"
|
894 |
+
}
|
895 |
+
},
|
896 |
+
"nbformat": 4,
|
897 |
+
"nbformat_minor": 0
|
898 |
+
}
|
Tensorflow/models/docs/vision/_toc.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
toc:
|
2 |
+
- title: "Example: Image classification"
|
3 |
+
path: /tfmodels/vision/image_classification
|
4 |
+
- title: "Example: Object Detection"
|
5 |
+
path: /tfmodels/vision/object_detection
|
6 |
+
- title: "Example: Semantic Segmentation"
|
7 |
+
path: /tfmodels/vision/semantic_segmentation
|
8 |
+
- title: "Example: Instance Segmentation"
|
9 |
+
path: /tfmodels/vision/instance_segmentation
|
Tensorflow/models/docs/vision/image_classification.ipynb
ADDED
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "Tce3stUlHN0L"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"##### Copyright 2020 The TensorFlow Authors."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"cellView": "form",
|
17 |
+
"id": "tuOe1ymfHZPu"
|
18 |
+
},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
22 |
+
"# you may not use this file except in compliance with the License.\n",
|
23 |
+
"# You may obtain a copy of the License at\n",
|
24 |
+
"#\n",
|
25 |
+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
26 |
+
"#\n",
|
27 |
+
"# Unless required by applicable law or agreed to in writing, software\n",
|
28 |
+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
29 |
+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
30 |
+
"# See the License for the specific language governing permissions and\n",
|
31 |
+
"# limitations under the License."
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"metadata": {
|
37 |
+
"id": "qFdPvlXBOdUN"
|
38 |
+
},
|
39 |
+
"source": [
|
40 |
+
"# Image classification with Model Garden"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"metadata": {
|
46 |
+
"id": "MfBg1C5NB3X0"
|
47 |
+
},
|
48 |
+
"source": [
|
49 |
+
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
|
50 |
+
" \u003ctd\u003e\n",
|
51 |
+
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/image_classification\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
|
52 |
+
" \u003c/td\u003e\n",
|
53 |
+
" \u003ctd\u003e\n",
|
54 |
+
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
|
55 |
+
" \u003c/td\u003e\n",
|
56 |
+
" \u003ctd\u003e\n",
|
57 |
+
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
|
58 |
+
" \u003c/td\u003e\n",
|
59 |
+
" \u003ctd\u003e\n",
|
60 |
+
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
|
61 |
+
" \u003c/td\u003e\n",
|
62 |
+
"\u003c/table\u003e"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "markdown",
|
67 |
+
"metadata": {
|
68 |
+
"id": "Ta_nFXaVAqLD"
|
69 |
+
},
|
70 |
+
"source": [
|
71 |
+
"This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow [Model Garden](https://github.com/tensorflow/models) package (`tensorflow-models`) to classify images in the [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.\n",
|
72 |
+
"\n",
|
73 |
+
"Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
|
74 |
+
"\n",
|
75 |
+
"This tutorial uses a [ResNet](https://arxiv.org/pdf/1512.03385.pdf) model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.\n",
|
76 |
+
"\n",
|
77 |
+
"This tutorial demonstrates how to:\n",
|
78 |
+
"1. Use models from the TensorFlow Models package.\n",
|
79 |
+
"2. Fine-tune a pre-built ResNet for image classification.\n",
|
80 |
+
"3. Export the tuned ResNet model."
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "markdown",
|
85 |
+
"metadata": {
|
86 |
+
"id": "G2FlaQcEPOER"
|
87 |
+
},
|
88 |
+
"source": [
|
89 |
+
"## Setup\n",
|
90 |
+
"\n",
|
91 |
+
"Install and import the necessary modules."
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": null,
|
97 |
+
"metadata": {
|
98 |
+
"id": "XvWfdCrvrV5W"
|
99 |
+
},
|
100 |
+
"outputs": [],
|
101 |
+
"source": [
|
102 |
+
"!pip install -U -q \"tf-models-official\""
|
103 |
+
]
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"cell_type": "markdown",
|
107 |
+
"metadata": {
|
108 |
+
"id": "CKYMTPjOE400"
|
109 |
+
},
|
110 |
+
"source": [
|
111 |
+
"Import TensorFlow, TensorFlow Datasets, and a few helper libraries."
|
112 |
+
]
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "code",
|
116 |
+
"execution_count": null,
|
117 |
+
"metadata": {
|
118 |
+
"id": "Wlon1uoIowmZ"
|
119 |
+
},
|
120 |
+
"outputs": [],
|
121 |
+
"source": [
|
122 |
+
"import pprint\n",
|
123 |
+
"import tempfile\n",
|
124 |
+
"\n",
|
125 |
+
"from IPython import display\n",
|
126 |
+
"import matplotlib.pyplot as plt\n",
|
127 |
+
"\n",
|
128 |
+
"import tensorflow as tf\n",
|
129 |
+
"import tensorflow_datasets as tfds"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "markdown",
|
134 |
+
"metadata": {
|
135 |
+
"id": "AVTs0jDd1b24"
|
136 |
+
},
|
137 |
+
"source": [
|
138 |
+
"The `tensorflow_models` package contains the ResNet vision model, and the `official.vision.serving` model contains the function to save and export the tuned model."
|
139 |
+
]
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"cell_type": "code",
|
143 |
+
"execution_count": null,
|
144 |
+
"metadata": {
|
145 |
+
"id": "NHT1iiIiBzlC"
|
146 |
+
},
|
147 |
+
"outputs": [],
|
148 |
+
"source": [
|
149 |
+
"import tensorflow_models as tfm\n",
|
150 |
+
"\n",
|
151 |
+
"# These are not in the tfm public API for v2.9. They will be available in v2.10\n",
|
152 |
+
"from official.vision.serving import export_saved_model_lib\n",
|
153 |
+
"import official.core.train_lib"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "markdown",
|
158 |
+
"metadata": {
|
159 |
+
"id": "aKv3wdqkQ8FU"
|
160 |
+
},
|
161 |
+
"source": [
|
162 |
+
"## Configure the ResNet-18 model for the Cifar-10 dataset"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "markdown",
|
167 |
+
"metadata": {
|
168 |
+
"id": "5iN8mHEJjKYE"
|
169 |
+
},
|
170 |
+
"source": [
|
171 |
+
"The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.\n",
|
172 |
+
"\n",
|
173 |
+
"In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
|
174 |
+
"\n",
|
175 |
+
"Use the `resnet_imagenet` factory configuration, as defined by `tfm.vision.configs.image_classification.image_classification_imagenet`. The configuration is set up to train ResNet to converge on [ImageNet](https://www.image-net.org/)."
|
176 |
+
]
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"cell_type": "code",
|
180 |
+
"execution_count": null,
|
181 |
+
"metadata": {
|
182 |
+
"id": "1M77f88Dj2Td"
|
183 |
+
},
|
184 |
+
"outputs": [],
|
185 |
+
"source": [
|
186 |
+
"exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n",
|
187 |
+
"tfds_name = 'cifar10'\n",
|
188 |
+
"ds,ds_info = tfds.load(\n",
|
189 |
+
"tfds_name,\n",
|
190 |
+
"with_info=True)\n",
|
191 |
+
"ds_info"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"cell_type": "markdown",
|
196 |
+
"metadata": {
|
197 |
+
"id": "U6PVwXA-j3E7"
|
198 |
+
},
|
199 |
+
"source": [
|
200 |
+
"Adjust the model and dataset configurations so that it works with Cifar-10 (`cifar10`)."
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": null,
|
206 |
+
"metadata": {
|
207 |
+
"id": "YWI7faVStQaV"
|
208 |
+
},
|
209 |
+
"outputs": [],
|
210 |
+
"source": [
|
211 |
+
"# Configure model\n",
|
212 |
+
"exp_config.task.model.num_classes = 10\n",
|
213 |
+
"exp_config.task.model.input_size = list(ds_info.features[\"image\"].shape)\n",
|
214 |
+
"exp_config.task.model.backbone.resnet.model_id = 18\n",
|
215 |
+
"\n",
|
216 |
+
"# Configure training and testing data\n",
|
217 |
+
"batch_size = 128\n",
|
218 |
+
"\n",
|
219 |
+
"exp_config.task.train_data.input_path = ''\n",
|
220 |
+
"exp_config.task.train_data.tfds_name = tfds_name\n",
|
221 |
+
"exp_config.task.train_data.tfds_split = 'train'\n",
|
222 |
+
"exp_config.task.train_data.global_batch_size = batch_size\n",
|
223 |
+
"\n",
|
224 |
+
"exp_config.task.validation_data.input_path = ''\n",
|
225 |
+
"exp_config.task.validation_data.tfds_name = tfds_name\n",
|
226 |
+
"exp_config.task.validation_data.tfds_split = 'test'\n",
|
227 |
+
"exp_config.task.validation_data.global_batch_size = batch_size\n"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"cell_type": "markdown",
|
232 |
+
"metadata": {
|
233 |
+
"id": "DE3ggKzzTD56"
|
234 |
+
},
|
235 |
+
"source": [
|
236 |
+
"Adjust the trainer configuration."
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "code",
|
241 |
+
"execution_count": null,
|
242 |
+
"metadata": {
|
243 |
+
"id": "inE_-4UGkLud"
|
244 |
+
},
|
245 |
+
"outputs": [],
|
246 |
+
"source": [
|
247 |
+
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
|
248 |
+
"\n",
|
249 |
+
"if 'GPU' in ''.join(logical_device_names):\n",
|
250 |
+
" print('This may be broken in Colab.')\n",
|
251 |
+
" device = 'GPU'\n",
|
252 |
+
"elif 'TPU' in ''.join(logical_device_names):\n",
|
253 |
+
" print('This may be broken in Colab.')\n",
|
254 |
+
" device = 'TPU'\n",
|
255 |
+
"else:\n",
|
256 |
+
" print('Running on CPU is slow, so only train for a few steps.')\n",
|
257 |
+
" device = 'CPU'\n",
|
258 |
+
"\n",
|
259 |
+
"if device=='CPU':\n",
|
260 |
+
" train_steps = 20\n",
|
261 |
+
" exp_config.trainer.steps_per_loop = 5\n",
|
262 |
+
"else:\n",
|
263 |
+
" train_steps=5000\n",
|
264 |
+
" exp_config.trainer.steps_per_loop = 100\n",
|
265 |
+
"\n",
|
266 |
+
"exp_config.trainer.summary_interval = 100\n",
|
267 |
+
"exp_config.trainer.checkpoint_interval = train_steps\n",
|
268 |
+
"exp_config.trainer.validation_interval = 1000\n",
|
269 |
+
"exp_config.trainer.validation_steps = ds_info.splits['test'].num_examples // batch_size\n",
|
270 |
+
"exp_config.trainer.train_steps = train_steps\n",
|
271 |
+
"exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
|
272 |
+
"exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
|
273 |
+
"exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
|
274 |
+
"exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100"
|
275 |
+
]
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"cell_type": "markdown",
|
279 |
+
"metadata": {
|
280 |
+
"id": "5mTcDnBiTOYD"
|
281 |
+
},
|
282 |
+
"source": [
|
283 |
+
"Print the modified configuration."
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"execution_count": null,
|
289 |
+
"metadata": {
|
290 |
+
"id": "tuVfxSBCTK-y"
|
291 |
+
},
|
292 |
+
"outputs": [],
|
293 |
+
"source": [
|
294 |
+
"pprint.pprint(exp_config.as_dict())\n",
|
295 |
+
"\n",
|
296 |
+
"display.Javascript(\"google.colab.output.setIframeHeight('300px');\")"
|
297 |
+
]
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"cell_type": "markdown",
|
301 |
+
"metadata": {
|
302 |
+
"id": "w7_X0UHaRF2m"
|
303 |
+
},
|
304 |
+
"source": [
|
305 |
+
"Set up the distribution strategy."
|
306 |
+
]
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"cell_type": "code",
|
310 |
+
"execution_count": null,
|
311 |
+
"metadata": {
|
312 |
+
"id": "ykL14FIbTaSt"
|
313 |
+
},
|
314 |
+
"outputs": [],
|
315 |
+
"source": [
|
316 |
+
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
|
317 |
+
"\n",
|
318 |
+
"if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
|
319 |
+
" tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
|
320 |
+
"\n",
|
321 |
+
"if 'GPU' in ''.join(logical_device_names):\n",
|
322 |
+
" distribution_strategy = tf.distribute.MirroredStrategy()\n",
|
323 |
+
"elif 'TPU' in ''.join(logical_device_names):\n",
|
324 |
+
" tf.tpu.experimental.initialize_tpu_system()\n",
|
325 |
+
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
|
326 |
+
" distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
|
327 |
+
"else:\n",
|
328 |
+
" print('Warning: this will be really slow.')\n",
|
329 |
+
" distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])"
|
330 |
+
]
|
331 |
+
},
|
332 |
+
{
|
333 |
+
"cell_type": "markdown",
|
334 |
+
"metadata": {
|
335 |
+
"id": "W4k5YH5pTjaK"
|
336 |
+
},
|
337 |
+
"source": [
|
338 |
+
"Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
|
339 |
+
"\n",
|
340 |
+
"The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
|
341 |
+
]
|
342 |
+
},
|
343 |
+
{
|
344 |
+
"cell_type": "code",
|
345 |
+
"execution_count": null,
|
346 |
+
"metadata": {
|
347 |
+
"id": "6MgYSH0PtUaW"
|
348 |
+
},
|
349 |
+
"outputs": [],
|
350 |
+
"source": [
|
351 |
+
"with distribution_strategy.scope():\n",
|
352 |
+
" model_dir = tempfile.mkdtemp()\n",
|
353 |
+
" task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)\n",
|
354 |
+
"\n",
|
355 |
+
"# tf.keras.utils.plot_model(task.build_model(), show_shapes=True)"
|
356 |
+
]
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"cell_type": "code",
|
360 |
+
"execution_count": null,
|
361 |
+
"metadata": {
|
362 |
+
"id": "IFXEZYdzBKoX"
|
363 |
+
},
|
364 |
+
"outputs": [],
|
365 |
+
"source": [
|
366 |
+
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
|
367 |
+
" print()\n",
|
368 |
+
" print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
|
369 |
+
" print(f'labels.shape: {str(labels.shape):16} labels.dtype: {labels.dtype!r}')"
|
370 |
+
]
|
371 |
+
},
|
372 |
+
{
|
373 |
+
"cell_type": "markdown",
|
374 |
+
"metadata": {
|
375 |
+
"id": "yrwxnGDaRU0U"
|
376 |
+
},
|
377 |
+
"source": [
|
378 |
+
"## Visualize the training data"
|
379 |
+
]
|
380 |
+
},
|
381 |
+
{
|
382 |
+
"cell_type": "markdown",
|
383 |
+
"metadata": {
|
384 |
+
"id": "683c255c6c52"
|
385 |
+
},
|
386 |
+
"source": [
|
387 |
+
"The dataloader applies a z-score normalization using \n",
|
388 |
+
"`preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`, so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range."
|
389 |
+
]
|
390 |
+
},
|
391 |
+
{
|
392 |
+
"cell_type": "code",
|
393 |
+
"execution_count": null,
|
394 |
+
"metadata": {
|
395 |
+
"id": "PdmOz2EC0Nx2"
|
396 |
+
},
|
397 |
+
"outputs": [],
|
398 |
+
"source": [
|
399 |
+
"plt.hist(images.numpy().flatten());"
|
400 |
+
]
|
401 |
+
},
|
402 |
+
{
|
403 |
+
"cell_type": "markdown",
|
404 |
+
"metadata": {
|
405 |
+
"id": "7a8582ebde7b"
|
406 |
+
},
|
407 |
+
"source": [
|
408 |
+
"Use `ds_info` (which is an instance of `tfds.core.DatasetInfo`) to lookup the text descriptions of each class ID."
|
409 |
+
]
|
410 |
+
},
|
411 |
+
{
|
412 |
+
"cell_type": "code",
|
413 |
+
"execution_count": null,
|
414 |
+
"metadata": {
|
415 |
+
"id": "Wq4Wq_CuDG3Q"
|
416 |
+
},
|
417 |
+
"outputs": [],
|
418 |
+
"source": [
|
419 |
+
"label_info = ds_info.features['label']\n",
|
420 |
+
"label_info.int2str(1)"
|
421 |
+
]
|
422 |
+
},
|
423 |
+
{
|
424 |
+
"cell_type": "markdown",
|
425 |
+
"metadata": {
|
426 |
+
"id": "8c652a6fdbcf"
|
427 |
+
},
|
428 |
+
"source": [
|
429 |
+
"Visualize a batch of the data."
|
430 |
+
]
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"cell_type": "code",
|
434 |
+
"execution_count": null,
|
435 |
+
"metadata": {
|
436 |
+
"id": "ZKfTxytf1l0d"
|
437 |
+
},
|
438 |
+
"outputs": [],
|
439 |
+
"source": [
|
440 |
+
"def show_batch(images, labels, predictions=None):\n",
|
441 |
+
" plt.figure(figsize=(10, 10))\n",
|
442 |
+
" min = images.numpy().min()\n",
|
443 |
+
" max = images.numpy().max()\n",
|
444 |
+
" delta = max - min\n",
|
445 |
+
"\n",
|
446 |
+
" for i in range(12):\n",
|
447 |
+
" plt.subplot(6, 6, i + 1)\n",
|
448 |
+
" plt.imshow((images[i]-min) / delta)\n",
|
449 |
+
" if predictions is None:\n",
|
450 |
+
" plt.title(label_info.int2str(labels[i]))\n",
|
451 |
+
" else:\n",
|
452 |
+
" if labels[i] == predictions[i]:\n",
|
453 |
+
" color = 'g'\n",
|
454 |
+
" else:\n",
|
455 |
+
" color = 'r'\n",
|
456 |
+
" plt.title(label_info.int2str(predictions[i]), color=color)\n",
|
457 |
+
" plt.axis(\"off\")"
|
458 |
+
]
|
459 |
+
},
|
460 |
+
{
|
461 |
+
"cell_type": "code",
|
462 |
+
"execution_count": null,
|
463 |
+
"metadata": {
|
464 |
+
"id": "xkA5h_RBtYYU"
|
465 |
+
},
|
466 |
+
"outputs": [],
|
467 |
+
"source": [
|
468 |
+
"plt.figure(figsize=(10, 10))\n",
|
469 |
+
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
|
470 |
+
" show_batch(images, labels)"
|
471 |
+
]
|
472 |
+
},
|
473 |
+
{
|
474 |
+
"cell_type": "markdown",
|
475 |
+
"metadata": {
|
476 |
+
"id": "v_A9VnL2RbXP"
|
477 |
+
},
|
478 |
+
"source": [
|
479 |
+
"## Visualize the testing data"
|
480 |
+
]
|
481 |
+
},
|
482 |
+
{
|
483 |
+
"cell_type": "markdown",
|
484 |
+
"metadata": {
|
485 |
+
"id": "AXovuumW_I2z"
|
486 |
+
},
|
487 |
+
"source": [
|
488 |
+
"Visualize a batch of images from the validation dataset."
|
489 |
+
]
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"cell_type": "code",
|
493 |
+
"execution_count": null,
|
494 |
+
"metadata": {
|
495 |
+
"id": "Ma-_Eb-nte9A"
|
496 |
+
},
|
497 |
+
"outputs": [],
|
498 |
+
"source": [
|
499 |
+
"plt.figure(figsize=(10, 10));\n",
|
500 |
+
"for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):\n",
|
501 |
+
" show_batch(images, labels)"
|
502 |
+
]
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"cell_type": "markdown",
|
506 |
+
"metadata": {
|
507 |
+
"id": "ihKJt2FHRi2N"
|
508 |
+
},
|
509 |
+
"source": [
|
510 |
+
"## Train and evaluate"
|
511 |
+
]
|
512 |
+
},
|
513 |
+
{
|
514 |
+
"cell_type": "code",
|
515 |
+
"execution_count": null,
|
516 |
+
"metadata": {
|
517 |
+
"id": "0AFMNvYxtjXx"
|
518 |
+
},
|
519 |
+
"outputs": [],
|
520 |
+
"source": [
|
521 |
+
"model, eval_logs = tfm.core.train_lib.run_experiment(\n",
|
522 |
+
" distribution_strategy=distribution_strategy,\n",
|
523 |
+
" task=task,\n",
|
524 |
+
" mode='train_and_eval',\n",
|
525 |
+
" params=exp_config,\n",
|
526 |
+
" model_dir=model_dir,\n",
|
527 |
+
" run_post_eval=True)"
|
528 |
+
]
|
529 |
+
},
|
530 |
+
{
|
531 |
+
"cell_type": "code",
|
532 |
+
"execution_count": null,
|
533 |
+
"metadata": {
|
534 |
+
"id": "gCcHMQYhozmA"
|
535 |
+
},
|
536 |
+
"outputs": [],
|
537 |
+
"source": [
|
538 |
+
"# tf.keras.utils.plot_model(model, show_shapes=True)"
|
539 |
+
]
|
540 |
+
},
|
541 |
+
{
|
542 |
+
"cell_type": "markdown",
|
543 |
+
"metadata": {
|
544 |
+
"id": "L7nVfxlBA8Gb"
|
545 |
+
},
|
546 |
+
"source": [
|
547 |
+
"Print the `accuracy`, `top_5_accuracy`, and `validation_loss` evaluation metrics."
|
548 |
+
]
|
549 |
+
},
|
550 |
+
{
|
551 |
+
"cell_type": "code",
|
552 |
+
"execution_count": null,
|
553 |
+
"metadata": {
|
554 |
+
"id": "0124f938a1b9"
|
555 |
+
},
|
556 |
+
"outputs": [],
|
557 |
+
"source": [
|
558 |
+
"for key, value in eval_logs.items():\n",
|
559 |
+
" if isinstance(value, tf.Tensor):\n",
|
560 |
+
" value = value.numpy()\n",
|
561 |
+
" print(f'{key:20}: {value:.3f}')"
|
562 |
+
]
|
563 |
+
},
|
564 |
+
{
|
565 |
+
"cell_type": "markdown",
|
566 |
+
"metadata": {
|
567 |
+
"id": "TDys5bZ1zsml"
|
568 |
+
},
|
569 |
+
"source": [
|
570 |
+
"Run a batch of the processed training data through the model, and view the results"
|
571 |
+
]
|
572 |
+
},
|
573 |
+
{
|
574 |
+
"cell_type": "code",
|
575 |
+
"execution_count": null,
|
576 |
+
"metadata": {
|
577 |
+
"id": "GhI7zR-Uz1JT"
|
578 |
+
},
|
579 |
+
"outputs": [],
|
580 |
+
"source": [
|
581 |
+
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
|
582 |
+
" predictions = model.predict(images)\n",
|
583 |
+
" predictions = tf.argmax(predictions, axis=-1)\n",
|
584 |
+
"\n",
|
585 |
+
"show_batch(images, labels, tf.cast(predictions, tf.int32))\n",
|
586 |
+
"\n",
|
587 |
+
"if device=='CPU':\n",
|
588 |
+
" plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')"
|
589 |
+
]
|
590 |
+
},
|
591 |
+
{
|
592 |
+
"cell_type": "markdown",
|
593 |
+
"metadata": {
|
594 |
+
"id": "fkE9locGTBgt"
|
595 |
+
},
|
596 |
+
"source": [
|
597 |
+
"## Export a SavedModel"
|
598 |
+
]
|
599 |
+
},
|
600 |
+
{
|
601 |
+
"cell_type": "markdown",
|
602 |
+
"metadata": {
|
603 |
+
"id": "9669d08c91af"
|
604 |
+
},
|
605 |
+
"source": [
|
606 |
+
"The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results.\n"
|
607 |
+
]
|
608 |
+
},
|
609 |
+
{
|
610 |
+
"cell_type": "code",
|
611 |
+
"execution_count": null,
|
612 |
+
"metadata": {
|
613 |
+
"id": "AQCFa7BvtmDg"
|
614 |
+
},
|
615 |
+
"outputs": [],
|
616 |
+
"source": [
|
617 |
+
"# Saving and exporting the trained model\n",
|
618 |
+
"export_saved_model_lib.export_inference_graph(\n",
|
619 |
+
" input_type='image_tensor',\n",
|
620 |
+
" batch_size=1,\n",
|
621 |
+
" input_image_size=[32, 32],\n",
|
622 |
+
" params=exp_config,\n",
|
623 |
+
" checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
|
624 |
+
" export_dir='./export/')"
|
625 |
+
]
|
626 |
+
},
|
627 |
+
{
|
628 |
+
"cell_type": "markdown",
|
629 |
+
"metadata": {
|
630 |
+
"id": "vVr6DxNqTyLZ"
|
631 |
+
},
|
632 |
+
"source": [
|
633 |
+
"Test the exported model."
|
634 |
+
]
|
635 |
+
},
|
636 |
+
{
|
637 |
+
"cell_type": "code",
|
638 |
+
"execution_count": null,
|
639 |
+
"metadata": {
|
640 |
+
"id": "gP7nOvrftsB0"
|
641 |
+
},
|
642 |
+
"outputs": [],
|
643 |
+
"source": [
|
644 |
+
"# Importing SavedModel\n",
|
645 |
+
"imported = tf.saved_model.load('./export/')\n",
|
646 |
+
"model_fn = imported.signatures['serving_default']"
|
647 |
+
]
|
648 |
+
},
|
649 |
+
{
|
650 |
+
"cell_type": "markdown",
|
651 |
+
"metadata": {
|
652 |
+
"id": "GiOp2WVIUNUZ"
|
653 |
+
},
|
654 |
+
"source": [
|
655 |
+
"Visualize the predictions."
|
656 |
+
]
|
657 |
+
},
|
658 |
+
{
|
659 |
+
"cell_type": "code",
|
660 |
+
"execution_count": null,
|
661 |
+
"metadata": {
|
662 |
+
"id": "BTRMrZQAN4mk"
|
663 |
+
},
|
664 |
+
"outputs": [],
|
665 |
+
"source": [
|
666 |
+
"plt.figure(figsize=(10, 10))\n",
|
667 |
+
"for data in tfds.load('cifar10', split='test').batch(12).take(1):\n",
|
668 |
+
" predictions = []\n",
|
669 |
+
" for image in data['image']:\n",
|
670 |
+
" index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]\n",
|
671 |
+
" predictions.append(index)\n",
|
672 |
+
" show_batch(data['image'], data['label'], predictions)\n",
|
673 |
+
"\n",
|
674 |
+
" if device=='CPU':\n",
|
675 |
+
" plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')"
|
676 |
+
]
|
677 |
+
}
|
678 |
+
],
|
679 |
+
"metadata": {
|
680 |
+
"colab": {
|
681 |
+
"name": "classification_with_model_garden.ipynb",
|
682 |
+
"provenance": [],
|
683 |
+
"toc_visible": true
|
684 |
+
},
|
685 |
+
"kernelspec": {
|
686 |
+
"display_name": "Python 3",
|
687 |
+
"name": "python3"
|
688 |
+
}
|
689 |
+
},
|
690 |
+
"nbformat": 4,
|
691 |
+
"nbformat_minor": 0
|
692 |
+
}
|
Tensorflow/models/docs/vision/instance_segmentation.ipynb
ADDED
@@ -0,0 +1,1138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "eCes7jVU8r08"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"##### Copyright 2023 The TensorFlow Authors."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"id": "pc1j3ZVF8mmG"
|
17 |
+
},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
21 |
+
"# you may not use this file except in compliance with the License.\n",
|
22 |
+
"# You may obtain a copy of the License at\n",
|
23 |
+
"#\n",
|
24 |
+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
25 |
+
"#\n",
|
26 |
+
"# Unless required by applicable law or agreed to in writing, software\n",
|
27 |
+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
28 |
+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
29 |
+
"# See the License for the specific language governing permissions and\n",
|
30 |
+
"# limitations under the License."
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "markdown",
|
35 |
+
"metadata": {
|
36 |
+
"id": "SUUX9CnCYI9Y"
|
37 |
+
},
|
38 |
+
"source": [
|
39 |
+
"# Instance Segmentation with Model Garden\n",
|
40 |
+
"\n",
|
41 |
+
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
42 |
+
" <td>\n",
|
43 |
+
" <a target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/instance_segmentation\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
|
44 |
+
" </td>\n",
|
45 |
+
" <td>\n",
|
46 |
+
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/instance_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
47 |
+
" </td>\n",
|
48 |
+
" <td>\n",
|
49 |
+
" <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/instance_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
|
50 |
+
" </td>\n",
|
51 |
+
" <td>\n",
|
52 |
+
" <a href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/instance_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
|
53 |
+
" </td>\n",
|
54 |
+
"</table>"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "markdown",
|
59 |
+
"metadata": {
|
60 |
+
"id": "UjP7bQUdTeFr"
|
61 |
+
},
|
62 |
+
"source": [
|
63 |
+
"This tutorial fine-tunes a [Mask R-CNN](https://arxiv.org/abs/1703.06870) with [Mobilenet V2](https://arxiv.org/abs/1801.04381) as backbone model from the [TensorFlow Model Garden](https://pypi.org/project/tf-models-official/) package (tensorflow-models).\n",
|
64 |
+
"\n",
|
65 |
+
"\n",
|
66 |
+
"[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
|
67 |
+
"\n",
|
68 |
+
"This tutorial demonstrates how to:\n",
|
69 |
+
"\n",
|
70 |
+
"1. Use models from the TensorFlow Models package.\n",
|
71 |
+
"2. Train/Fine-tune a pre-built Mask R-CNN with mobilenet as backbone for Object Detection and Instance Segmentation\n",
|
72 |
+
"3. Export the trained/tuned Mask R-CNN model"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "markdown",
|
77 |
+
"metadata": {
|
78 |
+
"id": "RDp6Kk1Baoi4"
|
79 |
+
},
|
80 |
+
"source": [
|
81 |
+
"## Install Necessary Dependencies"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "code",
|
86 |
+
"execution_count": null,
|
87 |
+
"metadata": {
|
88 |
+
"id": "hcl98qUOxlL8"
|
89 |
+
},
|
90 |
+
"outputs": [],
|
91 |
+
"source": [
|
92 |
+
"!pip install -U -q \"tf-models-official\"\n",
|
93 |
+
"!pip install -U -q remotezip tqdm opencv-python einops"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"cell_type": "markdown",
|
98 |
+
"metadata": {
|
99 |
+
"id": "5-gCe_YTapey"
|
100 |
+
},
|
101 |
+
"source": [
|
102 |
+
"## Import required libraries"
|
103 |
+
]
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"cell_type": "code",
|
107 |
+
"execution_count": null,
|
108 |
+
"metadata": {
|
109 |
+
"id": "Qa9552Ukgf3d"
|
110 |
+
},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"import os\n",
|
114 |
+
"import io\n",
|
115 |
+
"import json\n",
|
116 |
+
"import tqdm\n",
|
117 |
+
"import shutil\n",
|
118 |
+
"import pprint\n",
|
119 |
+
"import pathlib\n",
|
120 |
+
"import tempfile\n",
|
121 |
+
"import requests\n",
|
122 |
+
"import collections\n",
|
123 |
+
"import matplotlib\n",
|
124 |
+
"import numpy as np\n",
|
125 |
+
"import tensorflow as tf\n",
|
126 |
+
"import matplotlib.pyplot as plt\n",
|
127 |
+
"\n",
|
128 |
+
"from PIL import Image\n",
|
129 |
+
"from six import BytesIO\n",
|
130 |
+
"from etils import epath\n",
|
131 |
+
"from IPython import display\n",
|
132 |
+
"from urllib.request import urlopen"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "code",
|
137 |
+
"execution_count": null,
|
138 |
+
"metadata": {
|
139 |
+
"id": "tSCMIDRDP2fV"
|
140 |
+
},
|
141 |
+
"outputs": [],
|
142 |
+
"source": [
|
143 |
+
"import orbit\n",
|
144 |
+
"import tensorflow as tf\n",
|
145 |
+
"import tensorflow_models as tfm\n",
|
146 |
+
"import tensorflow_datasets as tfds\n",
|
147 |
+
"\n",
|
148 |
+
"from official.core import exp_factory\n",
|
149 |
+
"from official.core import config_definitions as cfg\n",
|
150 |
+
"from official.vision.data import tfrecord_lib\n",
|
151 |
+
"from official.vision.serving import export_saved_model_lib\n",
|
152 |
+
"from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder\n",
|
153 |
+
"from official.vision.utils.object_detection import visualization_utils\n",
|
154 |
+
"from official.vision.ops.preprocess_ops import normalize_image, resize_and_crop_image\n",
|
155 |
+
"from official.vision.data.create_coco_tf_record import coco_annotations_to_lists\n",
|
156 |
+
"\n",
|
157 |
+
"pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
|
158 |
+
"print(tf.__version__) # Check the version of tensorflow used\n",
|
159 |
+
"\n",
|
160 |
+
"%matplotlib inline"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "markdown",
|
165 |
+
"metadata": {
|
166 |
+
"id": "GIrXW8sp2bKa"
|
167 |
+
},
|
168 |
+
"source": [
|
169 |
+
"## Download subset of lvis dataset\n",
|
170 |
+
"\n",
|
171 |
+
"[LVIS](https://www.tensorflow.org/datasets/catalog/lvis): A dataset for large vocabulary instance segmentation.\n",
|
172 |
+
"\n",
|
173 |
+
"Note: LVIS uses the COCO 2017 train, validation, and test image sets. \n",
|
174 |
+
"If you have already downloaded the COCO images, you only need to download \n",
|
175 |
+
"the LVIS annotations. LVIS val set contains images from COCO 2017 train in \n",
|
176 |
+
"addition to the COCO 2017 val split."
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": null,
|
182 |
+
"metadata": {
|
183 |
+
"cellView": "form",
|
184 |
+
"id": "F_A9_cS310jf"
|
185 |
+
},
|
186 |
+
"outputs": [],
|
187 |
+
"source": [
|
188 |
+
"# @title Download annotation files\n",
|
189 |
+
"\n",
|
190 |
+
"!wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip\n",
|
191 |
+
"!unzip -q lvis_v1_train.json.zip\n",
|
192 |
+
"!rm lvis_v1_train.json.zip\n",
|
193 |
+
"\n",
|
194 |
+
"!wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_val.json.zip\n",
|
195 |
+
"!unzip -q lvis_v1_val.json.zip\n",
|
196 |
+
"!rm lvis_v1_val.json.zip\n",
|
197 |
+
"\n",
|
198 |
+
"!wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_image_info_test_dev.json.zip\n",
|
199 |
+
"!unzip -q lvis_v1_image_info_test_dev.json.zip\n",
|
200 |
+
"!rm lvis_v1_image_info_test_dev.json.zip"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": null,
|
206 |
+
"metadata": {
|
207 |
+
"cellView": "form",
|
208 |
+
"id": "kB-C5Svj11S0"
|
209 |
+
},
|
210 |
+
"outputs": [],
|
211 |
+
"source": [
|
212 |
+
"# @title Lvis annotation parsing\n",
|
213 |
+
"\n",
|
214 |
+
"# Annotations with invalid bounding boxes. Will not be used.\n",
|
215 |
+
"_INVALID_ANNOTATIONS = [\n",
|
216 |
+
" # Train split.\n",
|
217 |
+
" 662101,\n",
|
218 |
+
" 81217,\n",
|
219 |
+
" 462924,\n",
|
220 |
+
" 227817,\n",
|
221 |
+
" 29381,\n",
|
222 |
+
" 601484,\n",
|
223 |
+
" 412185,\n",
|
224 |
+
" 504667,\n",
|
225 |
+
" 572573,\n",
|
226 |
+
" 91937,\n",
|
227 |
+
" 239022,\n",
|
228 |
+
" 181534,\n",
|
229 |
+
" 101685,\n",
|
230 |
+
" # Validation split.\n",
|
231 |
+
" 36668,\n",
|
232 |
+
" 57541,\n",
|
233 |
+
" 33126,\n",
|
234 |
+
" 10932,\n",
|
235 |
+
"]\n",
|
236 |
+
"\n",
|
237 |
+
"def get_category_map(annotation_path, num_classes):\n",
|
238 |
+
" with epath.Path(annotation_path).open() as f:\n",
|
239 |
+
" data = json.load(f)\n",
|
240 |
+
"\n",
|
241 |
+
" category_map = {id+1: {'id': cat_dict['id'],\n",
|
242 |
+
" 'name': cat_dict['name']}\n",
|
243 |
+
" for id, cat_dict in enumerate(data['categories'][:num_classes])}\n",
|
244 |
+
" return category_map\n",
|
245 |
+
"\n",
|
246 |
+
"class LvisAnnotation:\n",
|
247 |
+
" \"\"\"LVIS annotation helper class.\n",
|
248 |
+
" The format of the annations is explained on\n",
|
249 |
+
" https://www.lvisdataset.org/dataset.\n",
|
250 |
+
" \"\"\"\n",
|
251 |
+
"\n",
|
252 |
+
" def __init__(self, annotation_path):\n",
|
253 |
+
" with epath.Path(annotation_path).open() as f:\n",
|
254 |
+
" data = json.load(f)\n",
|
255 |
+
" self._data = data\n",
|
256 |
+
"\n",
|
257 |
+
" img_id2annotations = collections.defaultdict(list)\n",
|
258 |
+
" for a in self._data.get('annotations', []):\n",
|
259 |
+
" if a['category_id'] in category_ids:\n",
|
260 |
+
" img_id2annotations[a['image_id']].append(a)\n",
|
261 |
+
" self._img_id2annotations = {\n",
|
262 |
+
" k: list(sorted(v, key=lambda a: a['id']))\n",
|
263 |
+
" for k, v in img_id2annotations.items()\n",
|
264 |
+
" }\n",
|
265 |
+
"\n",
|
266 |
+
" @property\n",
|
267 |
+
" def categories(self):\n",
|
268 |
+
" \"\"\"Return the category dicts, as sorted in the file.\"\"\"\n",
|
269 |
+
" return self._data['categories']\n",
|
270 |
+
"\n",
|
271 |
+
" @property\n",
|
272 |
+
" def images(self):\n",
|
273 |
+
" \"\"\"Return the image dicts, as sorted in the file.\"\"\"\n",
|
274 |
+
" sub_images = []\n",
|
275 |
+
" for image_info in self._data['images']:\n",
|
276 |
+
" if image_info['id'] in self._img_id2annotations:\n",
|
277 |
+
" sub_images.append(image_info)\n",
|
278 |
+
" return sub_images\n",
|
279 |
+
"\n",
|
280 |
+
" def get_annotations(self, img_id):\n",
|
281 |
+
" \"\"\"Return all annotations associated with the image id string.\"\"\"\n",
|
282 |
+
" # Some images don't have any annotations. Return empty list instead.\n",
|
283 |
+
" return self._img_id2annotations.get(img_id, [])\n",
|
284 |
+
"\n",
|
285 |
+
"def _generate_tf_records(prefix, images_zip, annotation_file, num_shards=5):\n",
|
286 |
+
" \"\"\"Generate TFRecords.\"\"\"\n",
|
287 |
+
"\n",
|
288 |
+
" lvis_annotation = LvisAnnotation(annotation_file)\n",
|
289 |
+
"\n",
|
290 |
+
" def _process_example(prefix, image_info, id_to_name_map):\n",
|
291 |
+
" # Search image dirs.\n",
|
292 |
+
" filename = pathlib.Path(image_info['coco_url']).name\n",
|
293 |
+
" image = tf.io.read_file(os.path.join(IMGS_DIR, filename))\n",
|
294 |
+
" instances = lvis_annotation.get_annotations(img_id=image_info['id'])\n",
|
295 |
+
" instances = [x for x in instances if x['id'] not in _INVALID_ANNOTATIONS]\n",
|
296 |
+
" # print([x['category_id'] for x in instances])\n",
|
297 |
+
" is_crowd = {'iscrowd': 0}\n",
|
298 |
+
" instances = [dict(x, **is_crowd) for x in instances]\n",
|
299 |
+
" neg_category_ids = image_info.get('neg_category_ids', [])\n",
|
300 |
+
" not_exhaustive_category_ids = image_info.get(\n",
|
301 |
+
" 'not_exhaustive_category_ids', []\n",
|
302 |
+
" )\n",
|
303 |
+
" data, _ = coco_annotations_to_lists(instances,\n",
|
304 |
+
" id_to_name_map,\n",
|
305 |
+
" image_info['height'],\n",
|
306 |
+
" image_info['width'],\n",
|
307 |
+
" include_masks=True)\n",
|
308 |
+
" # data['category_id'] = [id-1 for id in data['category_id']]\n",
|
309 |
+
" keys_to_features = {\n",
|
310 |
+
" 'image/encoded':\n",
|
311 |
+
" tfrecord_lib.convert_to_feature(image.numpy()),\n",
|
312 |
+
" 'image/filename':\n",
|
313 |
+
" tfrecord_lib.convert_to_feature(filename.encode('utf8')),\n",
|
314 |
+
" 'image/format':\n",
|
315 |
+
" tfrecord_lib.convert_to_feature('jpg'.encode('utf8')),\n",
|
316 |
+
" 'image/height':\n",
|
317 |
+
" tfrecord_lib.convert_to_feature(image_info['height']),\n",
|
318 |
+
" 'image/width':\n",
|
319 |
+
" tfrecord_lib.convert_to_feature(image_info['width']),\n",
|
320 |
+
" 'image/source_id':\n",
|
321 |
+
" tfrecord_lib.convert_to_feature(str(image_info['id']).encode('utf8')),\n",
|
322 |
+
" 'image/object/bbox/xmin':\n",
|
323 |
+
" tfrecord_lib.convert_to_feature(data['xmin']),\n",
|
324 |
+
" 'image/object/bbox/xmax':\n",
|
325 |
+
" tfrecord_lib.convert_to_feature(data['xmax']),\n",
|
326 |
+
" 'image/object/bbox/ymin':\n",
|
327 |
+
" tfrecord_lib.convert_to_feature(data['ymin']),\n",
|
328 |
+
" 'image/object/bbox/ymax':\n",
|
329 |
+
" tfrecord_lib.convert_to_feature(data['ymax']),\n",
|
330 |
+
" 'image/object/class/text':\n",
|
331 |
+
" tfrecord_lib.convert_to_feature(data['category_names']),\n",
|
332 |
+
" 'image/object/class/label':\n",
|
333 |
+
" tfrecord_lib.convert_to_feature(data['category_id']),\n",
|
334 |
+
" 'image/object/is_crowd':\n",
|
335 |
+
" tfrecord_lib.convert_to_feature(data['is_crowd']),\n",
|
336 |
+
" 'image/object/area':\n",
|
337 |
+
" tfrecord_lib.convert_to_feature(data['area'], 'float_list'),\n",
|
338 |
+
" 'image/object/mask':\n",
|
339 |
+
" tfrecord_lib.convert_to_feature(data['encoded_mask_png'])\n",
|
340 |
+
" }\n",
|
341 |
+
" # print(keys_to_features['image/object/class/label'])\n",
|
342 |
+
" example = tf.train.Example(\n",
|
343 |
+
" features=tf.train.Features(feature=keys_to_features))\n",
|
344 |
+
" return example\n",
|
345 |
+
"\n",
|
346 |
+
"\n",
|
347 |
+
"\n",
|
348 |
+
" # file_names = [f\"{prefix}/{pathlib.Path(image_info['coco_url']).name}\"\n",
|
349 |
+
" # for image_info in lvis_annotation.images]\n",
|
350 |
+
" # _extract_images(images_zip, file_names)\n",
|
351 |
+
" writers = [\n",
|
352 |
+
" tf.io.TFRecordWriter(\n",
|
353 |
+
" tf_records_dir + prefix +'-%05d-of-%05d.tfrecord' % (i, num_shards))\n",
|
354 |
+
" for i in range(num_shards)\n",
|
355 |
+
" ]\n",
|
356 |
+
" id_to_name_map = {cat_dict['id']: cat_dict['name']\n",
|
357 |
+
" for cat_dict in lvis_annotation.categories[:NUM_CLASSES]}\n",
|
358 |
+
" # print(id_to_name_map)\n",
|
359 |
+
" for idx, image_info in enumerate(tqdm.tqdm(lvis_annotation.images)):\n",
|
360 |
+
" img_data = requests.get(image_info['coco_url'], stream=True).content\n",
|
361 |
+
" img_name = image_info['coco_url'].split('/')[-1]\n",
|
362 |
+
" with open(os.path.join(IMGS_DIR, img_name), 'wb') as handler:\n",
|
363 |
+
" handler.write(img_data)\n",
|
364 |
+
" tf_example = _process_example(prefix, image_info, id_to_name_map)\n",
|
365 |
+
" writers[idx % num_shards].write(tf_example.SerializeToString())\n",
|
366 |
+
"\n",
|
367 |
+
" del lvis_annotation"
|
368 |
+
]
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"cell_type": "code",
|
372 |
+
"execution_count": null,
|
373 |
+
"metadata": {
|
374 |
+
"id": "5u2dwjIT2HZu"
|
375 |
+
},
|
376 |
+
"outputs": [],
|
377 |
+
"source": [
|
378 |
+
"_URLS = {\n",
|
379 |
+
" 'train_images': 'http://images.cocodataset.org/zips/train2017.zip',\n",
|
380 |
+
" 'validation_images': 'http://images.cocodataset.org/zips/val2017.zip',\n",
|
381 |
+
" 'test_images': 'http://images.cocodataset.org/zips/test2017.zip',\n",
|
382 |
+
"}\n",
|
383 |
+
"\n",
|
384 |
+
"train_prefix = 'train'\n",
|
385 |
+
"valid_prefix = 'val'\n",
|
386 |
+
"\n",
|
387 |
+
"train_annotation_path = './lvis_v1_train.json'\n",
|
388 |
+
"valid_annotation_path = './lvis_v1_val.json'\n",
|
389 |
+
"\n",
|
390 |
+
"IMGS_DIR = './lvis_sub_dataset/'\n",
|
391 |
+
"tf_records_dir = './lvis_tfrecords/'\n",
|
392 |
+
"\n",
|
393 |
+
"\n",
|
394 |
+
"if not os.path.exists(IMGS_DIR):\n",
|
395 |
+
" os.mkdir(IMGS_DIR)\n",
|
396 |
+
"\n",
|
397 |
+
"if not os.path.exists(tf_records_dir):\n",
|
398 |
+
" os.mkdir(tf_records_dir)\n",
|
399 |
+
"\n",
|
400 |
+
"\n",
|
401 |
+
"\n",
|
402 |
+
"NUM_CLASSES = 3\n",
|
403 |
+
"category_index = get_category_map(valid_annotation_path, NUM_CLASSES)\n",
|
404 |
+
"category_ids = list(category_index.keys())"
|
405 |
+
]
|
406 |
+
},
|
407 |
+
{
|
408 |
+
"cell_type": "code",
|
409 |
+
"execution_count": null,
|
410 |
+
"metadata": {
|
411 |
+
"id": "KBgl5fG42LpD"
|
412 |
+
},
|
413 |
+
"outputs": [],
|
414 |
+
"source": [
|
415 |
+
"# Below helper function are taken from github tensorflow dataset lvis\n",
|
416 |
+
"# https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/datasets/lvis/lvis_dataset_builder.py\n",
|
417 |
+
"_generate_tf_records(train_prefix,\n",
|
418 |
+
" _URLS['train_images'],\n",
|
419 |
+
" train_annotation_path)"
|
420 |
+
]
|
421 |
+
},
|
422 |
+
{
|
423 |
+
"cell_type": "code",
|
424 |
+
"execution_count": null,
|
425 |
+
"metadata": {
|
426 |
+
"id": "89O59u_H2NIJ"
|
427 |
+
},
|
428 |
+
"outputs": [],
|
429 |
+
"source": [
|
430 |
+
"_generate_tf_records(valid_prefix,\n",
|
431 |
+
" _URLS['validation_images'],\n",
|
432 |
+
" valid_annotation_path)"
|
433 |
+
]
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"cell_type": "markdown",
|
437 |
+
"metadata": {
|
438 |
+
"id": "EREyevfIY4rz"
|
439 |
+
},
|
440 |
+
"source": [
|
441 |
+
"## Configure the MaskRCNN Resnet FPN COCO model for custom dataset"
|
442 |
+
]
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"cell_type": "code",
|
446 |
+
"execution_count": null,
|
447 |
+
"metadata": {
|
448 |
+
"id": "5yGLLvXlPInP"
|
449 |
+
},
|
450 |
+
"outputs": [],
|
451 |
+
"source": [
|
452 |
+
"train_data_input_path = './lvis_tfrecords/train*'\n",
|
453 |
+
"valid_data_input_path = './lvis_tfrecords/val*'\n",
|
454 |
+
"test_data_input_path = './lvis_tfrecords/test*'\n",
|
455 |
+
"model_dir = './trained_model/'\n",
|
456 |
+
"export_dir ='./exported_model/'"
|
457 |
+
]
|
458 |
+
},
|
459 |
+
{
|
460 |
+
"cell_type": "code",
|
461 |
+
"execution_count": null,
|
462 |
+
"metadata": {
|
463 |
+
"id": "ms3wRQKAIORe"
|
464 |
+
},
|
465 |
+
"outputs": [],
|
466 |
+
"source": [
|
467 |
+
"if not os.path.exists(model_dir):\n",
|
468 |
+
" os.mkdir(model_dir)"
|
469 |
+
]
|
470 |
+
},
|
471 |
+
{
|
472 |
+
"cell_type": "markdown",
|
473 |
+
"metadata": {
|
474 |
+
"id": "EXA5NmvDblYP"
|
475 |
+
},
|
476 |
+
"source": [
|
477 |
+
"In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
|
478 |
+
"\n",
|
479 |
+
"\n",
|
480 |
+
"Use the `retinanet_mobilenet_coco` experiment configuration, as defined by `tfm.vision.configs.maskrcnn.maskrcnn_mobilenet_coco`.\n",
|
481 |
+
"\n",
|
482 |
+
"Please find all the registered experiements [here](https://www.tensorflow.org/api_docs/python/tfm/core/exp_factory/get_exp_config)\n",
|
483 |
+
"\n",
|
484 |
+
"The configuration defines an experiment to train a Mask R-CNN model with mobilenet as backbone and FPN as decoder. Default Congiguration is trained on [COCO](https://cocodataset.org/) train2017 and evaluated on [COCO](https://cocodataset.org/) val2017.\n",
|
485 |
+
"\n",
|
486 |
+
"There are also other alternative experiments available such as\n",
|
487 |
+
"`maskrcnn_resnetfpn_coco`,\n",
|
488 |
+
"`maskrcnn_spinenet_coco` and more. One can switch to them by changing the experiment name argument to the `get_exp_config` function."
|
489 |
+
]
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"cell_type": "code",
|
493 |
+
"execution_count": null,
|
494 |
+
"metadata": {
|
495 |
+
"id": "Zi2F1qGgPWOH"
|
496 |
+
},
|
497 |
+
"outputs": [],
|
498 |
+
"source": [
|
499 |
+
"exp_config = exp_factory.get_exp_config('maskrcnn_mobilenet_coco')"
|
500 |
+
]
|
501 |
+
},
|
502 |
+
{
|
503 |
+
"cell_type": "code",
|
504 |
+
"execution_count": null,
|
505 |
+
"metadata": {
|
506 |
+
"id": "zo-EaCdmn5j-"
|
507 |
+
},
|
508 |
+
"outputs": [],
|
509 |
+
"source": [
|
510 |
+
"model_ckpt_path = './model_ckpt/'\n",
|
511 |
+
"if not os.path.exists(model_ckpt_path):\n",
|
512 |
+
" os.mkdir(model_ckpt_path)\n",
|
513 |
+
"\n",
|
514 |
+
"!gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.data-00000-of-00001 './model_ckpt/'\n",
|
515 |
+
"!gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.index './model_ckpt/'"
|
516 |
+
]
|
517 |
+
},
|
518 |
+
{
|
519 |
+
"cell_type": "markdown",
|
520 |
+
"metadata": {
|
521 |
+
"id": "ymnwJYaFgHs2"
|
522 |
+
},
|
523 |
+
"source": [
|
524 |
+
"### Adjust the model and dataset configurations so that it works with custom dataset."
|
525 |
+
]
|
526 |
+
},
|
527 |
+
{
|
528 |
+
"cell_type": "code",
|
529 |
+
"execution_count": null,
|
530 |
+
"metadata": {
|
531 |
+
"id": "zyn9ieZyUbEJ"
|
532 |
+
},
|
533 |
+
"outputs": [],
|
534 |
+
"source": [
|
535 |
+
"BATCH_SIZE = 8\n",
|
536 |
+
"HEIGHT, WIDTH = 256, 256\n",
|
537 |
+
"IMG_SHAPE = [HEIGHT, WIDTH, 3]\n",
|
538 |
+
"\n",
|
539 |
+
"\n",
|
540 |
+
"# Backbone Config\n",
|
541 |
+
"exp_config.task.annotation_file = None\n",
|
542 |
+
"exp_config.task.freeze_backbone = True\n",
|
543 |
+
"exp_config.task.init_checkpoint = \"./model_ckpt/ckpt-180648\"\n",
|
544 |
+
"exp_config.task.init_checkpoint_modules = \"backbone\"\n",
|
545 |
+
"\n",
|
546 |
+
"# Model Config\n",
|
547 |
+
"exp_config.task.model.num_classes = NUM_CLASSES + 1\n",
|
548 |
+
"exp_config.task.model.input_size = IMG_SHAPE\n",
|
549 |
+
"\n",
|
550 |
+
"# Training Data Config\n",
|
551 |
+
"exp_config.task.train_data.input_path = train_data_input_path\n",
|
552 |
+
"exp_config.task.train_data.dtype = 'float32'\n",
|
553 |
+
"exp_config.task.train_data.global_batch_size = BATCH_SIZE\n",
|
554 |
+
"exp_config.task.train_data.shuffle_buffer_size = 64\n",
|
555 |
+
"exp_config.task.train_data.parser.aug_scale_max = 1.0\n",
|
556 |
+
"exp_config.task.train_data.parser.aug_scale_min = 1.0\n",
|
557 |
+
"\n",
|
558 |
+
"# Validation Data Config\n",
|
559 |
+
"exp_config.task.validation_data.input_path = valid_data_input_path\n",
|
560 |
+
"exp_config.task.validation_data.dtype = 'float32'\n",
|
561 |
+
"exp_config.task.validation_data.global_batch_size = BATCH_SIZE"
|
562 |
+
]
|
563 |
+
},
|
564 |
+
{
|
565 |
+
"cell_type": "markdown",
|
566 |
+
"metadata": {
|
567 |
+
"id": "0409ReANgKzF"
|
568 |
+
},
|
569 |
+
"source": [
|
570 |
+
"### Adjust the trainer configuration."
|
571 |
+
]
|
572 |
+
},
|
573 |
+
{
|
574 |
+
"cell_type": "code",
|
575 |
+
"execution_count": null,
|
576 |
+
"metadata": {
|
577 |
+
"id": "ne8t5AHRUd9g"
|
578 |
+
},
|
579 |
+
"outputs": [],
|
580 |
+
"source": [
|
581 |
+
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
|
582 |
+
"\n",
|
583 |
+
"if 'GPU' in ''.join(logical_device_names):\n",
|
584 |
+
" print('This may be broken in Colab.')\n",
|
585 |
+
" device = 'GPU'\n",
|
586 |
+
"elif 'TPU' in ''.join(logical_device_names):\n",
|
587 |
+
" print('This may be broken in Colab.')\n",
|
588 |
+
" device = 'TPU'\n",
|
589 |
+
"else:\n",
|
590 |
+
" print('Running on CPU is slow, so only train for a few steps.')\n",
|
591 |
+
" device = 'CPU'\n",
|
592 |
+
"\n",
|
593 |
+
"\n",
|
594 |
+
"train_steps = 2000\n",
|
595 |
+
"exp_config.trainer.steps_per_loop = 200 # steps_per_loop = num_of_training_examples // train_batch_size\n",
|
596 |
+
"\n",
|
597 |
+
"exp_config.trainer.summary_interval = 200\n",
|
598 |
+
"exp_config.trainer.checkpoint_interval = 200\n",
|
599 |
+
"exp_config.trainer.validation_interval = 200\n",
|
600 |
+
"exp_config.trainer.validation_steps = 200 # validation_steps = num_of_validation_examples // eval_batch_size\n",
|
601 |
+
"exp_config.trainer.train_steps = train_steps\n",
|
602 |
+
"exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 200\n",
|
603 |
+
"exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
|
604 |
+
"exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
|
605 |
+
"exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.07\n",
|
606 |
+
"exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05"
|
607 |
+
]
|
608 |
+
},
|
609 |
+
{
|
610 |
+
"cell_type": "markdown",
|
611 |
+
"metadata": {
|
612 |
+
"id": "k3I4X-bWgNm0"
|
613 |
+
},
|
614 |
+
"source": [
|
615 |
+
"### Print the modified configuration."
|
616 |
+
]
|
617 |
+
},
|
618 |
+
{
|
619 |
+
"cell_type": "code",
|
620 |
+
"execution_count": null,
|
621 |
+
"metadata": {
|
622 |
+
"id": "IsmxXNlyWBAK"
|
623 |
+
},
|
624 |
+
"outputs": [],
|
625 |
+
"source": [
|
626 |
+
"pp.pprint(exp_config.as_dict())\n",
|
627 |
+
"display.Javascript(\"google.colab.output.setIframeHeight('500px');\")"
|
628 |
+
]
|
629 |
+
},
|
630 |
+
{
|
631 |
+
"cell_type": "markdown",
|
632 |
+
"metadata": {
|
633 |
+
"id": "jxarWEHDgQSk"
|
634 |
+
},
|
635 |
+
"source": [
|
636 |
+
"### Set up the distribution strategy."
|
637 |
+
]
|
638 |
+
},
|
639 |
+
{
|
640 |
+
"cell_type": "code",
|
641 |
+
"execution_count": null,
|
642 |
+
"metadata": {
|
643 |
+
"id": "4JxhiGNwQRv2"
|
644 |
+
},
|
645 |
+
"outputs": [],
|
646 |
+
"source": [
|
647 |
+
"# Setting up the Strategy\n",
|
648 |
+
"if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
|
649 |
+
" tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
|
650 |
+
"\n",
|
651 |
+
"if 'GPU' in ''.join(logical_device_names):\n",
|
652 |
+
" distribution_strategy = tf.distribute.MirroredStrategy()\n",
|
653 |
+
"elif 'TPU' in ''.join(logical_device_names):\n",
|
654 |
+
" tf.tpu.experimental.initialize_tpu_system()\n",
|
655 |
+
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
|
656 |
+
" distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
|
657 |
+
"else:\n",
|
658 |
+
" print('Warning: this will be really slow.')\n",
|
659 |
+
" distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n",
|
660 |
+
"\n",
|
661 |
+
"print(\"Done\")"
|
662 |
+
]
|
663 |
+
},
|
664 |
+
{
|
665 |
+
"cell_type": "markdown",
|
666 |
+
"metadata": {
|
667 |
+
"id": "QqZU9f1ugS_A"
|
668 |
+
},
|
669 |
+
"source": [
|
670 |
+
"## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
|
671 |
+
"\n",
|
672 |
+
"The `Task` object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
|
673 |
+
]
|
674 |
+
},
|
675 |
+
{
|
676 |
+
"cell_type": "code",
|
677 |
+
"execution_count": null,
|
678 |
+
"metadata": {
|
679 |
+
"id": "N5R-7KzORB1n"
|
680 |
+
},
|
681 |
+
"outputs": [],
|
682 |
+
"source": [
|
683 |
+
"with distribution_strategy.scope():\n",
|
684 |
+
" task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)"
|
685 |
+
]
|
686 |
+
},
|
687 |
+
{
|
688 |
+
"cell_type": "markdown",
|
689 |
+
"metadata": {
|
690 |
+
"id": "Fmpz2R_cglIv"
|
691 |
+
},
|
692 |
+
"source": [
|
693 |
+
"## Visualize a batch of the data."
|
694 |
+
]
|
695 |
+
},
|
696 |
+
{
|
697 |
+
"cell_type": "code",
|
698 |
+
"execution_count": null,
|
699 |
+
"metadata": {
|
700 |
+
"id": "O82f_7A8gfnY"
|
701 |
+
},
|
702 |
+
"outputs": [],
|
703 |
+
"source": [
|
704 |
+
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
|
705 |
+
" print()\n",
|
706 |
+
" print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
|
707 |
+
" print(f'labels.keys: {labels.keys()}')"
|
708 |
+
]
|
709 |
+
},
|
710 |
+
{
|
711 |
+
"cell_type": "markdown",
|
712 |
+
"metadata": {
|
713 |
+
"id": "dLcSHWjqgl66"
|
714 |
+
},
|
715 |
+
"source": [
|
716 |
+
"### Create Category Index Dictionary to map the labels to coressponding label names"
|
717 |
+
]
|
718 |
+
},
|
719 |
+
{
|
720 |
+
"cell_type": "code",
|
721 |
+
"execution_count": null,
|
722 |
+
"metadata": {
|
723 |
+
"id": "ajF85r_6R9d9"
|
724 |
+
},
|
725 |
+
"outputs": [],
|
726 |
+
"source": [
|
727 |
+
"tf_ex_decoder = TfExampleDecoder(include_mask=True)"
|
728 |
+
]
|
729 |
+
},
|
730 |
+
{
|
731 |
+
"cell_type": "markdown",
|
732 |
+
"metadata": {
|
733 |
+
"id": "gRdveeYVgr7B"
|
734 |
+
},
|
735 |
+
"source": [
|
736 |
+
"### Helper Function for Visualizing the results from TFRecords\n",
|
737 |
+
"Use `visualize_boxes_and_labels_on_image_array` from `visualization_utils` to draw boudning boxes on the image."
|
738 |
+
]
|
739 |
+
},
|
740 |
+
{
|
741 |
+
"cell_type": "code",
|
742 |
+
"execution_count": null,
|
743 |
+
"metadata": {
|
744 |
+
"id": "uWEuOs8QStrz"
|
745 |
+
},
|
746 |
+
"outputs": [],
|
747 |
+
"source": [
|
748 |
+
"def show_batch(raw_records):\n",
|
749 |
+
" plt.figure(figsize=(20, 20))\n",
|
750 |
+
" use_normalized_coordinates=True\n",
|
751 |
+
" min_score_thresh = 0.30\n",
|
752 |
+
" for i, serialized_example in enumerate(raw_records):\n",
|
753 |
+
" plt.subplot(1, 3, i + 1)\n",
|
754 |
+
" decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
|
755 |
+
" image = decoded_tensors['image'].numpy().astype('uint8')\n",
|
756 |
+
" scores = np.ones(shape=(len(decoded_tensors['groundtruth_boxes'])))\n",
|
757 |
+
" # print(decoded_tensors['groundtruth_instance_masks'].numpy().shape)\n",
|
758 |
+
" # print(decoded_tensors.keys())\n",
|
759 |
+
" visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
|
760 |
+
" image,\n",
|
761 |
+
" decoded_tensors['groundtruth_boxes'].numpy(),\n",
|
762 |
+
" decoded_tensors['groundtruth_classes'].numpy().astype('int'),\n",
|
763 |
+
" scores,\n",
|
764 |
+
" category_index=category_index,\n",
|
765 |
+
" use_normalized_coordinates=use_normalized_coordinates,\n",
|
766 |
+
" min_score_thresh=min_score_thresh,\n",
|
767 |
+
" instance_masks=decoded_tensors['groundtruth_instance_masks'].numpy().astype('uint8'),\n",
|
768 |
+
" line_thickness=4)\n",
|
769 |
+
"\n",
|
770 |
+
" plt.imshow(image)\n",
|
771 |
+
" plt.axis(\"off\")\n",
|
772 |
+
" plt.title(f\"Image-{i+1}\")\n",
|
773 |
+
" plt.show()"
|
774 |
+
]
|
775 |
+
},
|
776 |
+
{
|
777 |
+
"cell_type": "markdown",
|
778 |
+
"metadata": {
|
779 |
+
"id": "FergQ2P5gv_j"
|
780 |
+
},
|
781 |
+
"source": [
|
782 |
+
"### Visualization of Train Data\n",
|
783 |
+
"\n",
|
784 |
+
"The bounding box detection has three components\n",
|
785 |
+
" 1. Class label of the object detected.\n",
|
786 |
+
" 2. Percentage of match between predicted and ground truth bounding boxes.\n",
|
787 |
+
" 3. Instance Segmentation Mask\n",
|
788 |
+
"\n",
|
789 |
+
"**Note**: The reason of everything is 100% is because we are visualising the groundtruth"
|
790 |
+
]
|
791 |
+
},
|
792 |
+
{
|
793 |
+
"cell_type": "code",
|
794 |
+
"execution_count": null,
|
795 |
+
"metadata": {
|
796 |
+
"id": "lN0zdBwxU5Z5"
|
797 |
+
},
|
798 |
+
"outputs": [],
|
799 |
+
"source": [
|
800 |
+
"buffer_size = 100\n",
|
801 |
+
"num_of_examples = 3\n",
|
802 |
+
"\n",
|
803 |
+
"train_tfrecords = tf.io.gfile.glob(exp_config.task.train_data.input_path)\n",
|
804 |
+
"raw_records = tf.data.TFRecordDataset(train_tfrecords).shuffle(buffer_size=buffer_size).take(num_of_examples)\n",
|
805 |
+
"show_batch(raw_records)"
|
806 |
+
]
|
807 |
+
},
|
808 |
+
{
|
809 |
+
"cell_type": "markdown",
|
810 |
+
"metadata": {
|
811 |
+
"id": "nn7IZSs5hQLg"
|
812 |
+
},
|
813 |
+
"source": [
|
814 |
+
"## Train and evaluate\n",
|
815 |
+
"\n",
|
816 |
+
"We follow the COCO challenge tradition to evaluate the accuracy of object detection based on mAP(mean Average Precision). Please check [here](https://cocodataset.org/#detection-eval) for detail explanation of how evaluation metrics for detection task is done.\n",
|
817 |
+
"\n",
|
818 |
+
"**IoU**: is defined as the area of the intersection divided by the area of the union of a predicted bounding box and ground truth bounding box."
|
819 |
+
]
|
820 |
+
},
|
821 |
+
{
|
822 |
+
"cell_type": "code",
|
823 |
+
"execution_count": null,
|
824 |
+
"metadata": {
|
825 |
+
"id": "UTuIs4kFZGv_"
|
826 |
+
},
|
827 |
+
"outputs": [],
|
828 |
+
"source": [
|
829 |
+
"model, eval_logs = tfm.core.train_lib.run_experiment(\n",
|
830 |
+
" distribution_strategy=distribution_strategy,\n",
|
831 |
+
" task=task,\n",
|
832 |
+
" mode='train_and_eval',\n",
|
833 |
+
" params=exp_config,\n",
|
834 |
+
" model_dir=model_dir,\n",
|
835 |
+
" run_post_eval=True)"
|
836 |
+
]
|
837 |
+
},
|
838 |
+
{
|
839 |
+
"cell_type": "markdown",
|
840 |
+
"metadata": {
|
841 |
+
"id": "rfpH4QHkh1gI"
|
842 |
+
},
|
843 |
+
"source": [
|
844 |
+
"## Load logs in tensorboard"
|
845 |
+
]
|
846 |
+
},
|
847 |
+
{
|
848 |
+
"cell_type": "code",
|
849 |
+
"execution_count": null,
|
850 |
+
"metadata": {
|
851 |
+
"id": "wcdOvg6eNP6R"
|
852 |
+
},
|
853 |
+
"outputs": [],
|
854 |
+
"source": [
|
855 |
+
"%load_ext tensorboard\n",
|
856 |
+
"%tensorboard --logdir \"./trained_model\""
|
857 |
+
]
|
858 |
+
},
|
859 |
+
{
|
860 |
+
"cell_type": "markdown",
|
861 |
+
"metadata": {
|
862 |
+
"id": "hAo9lozJh2cV"
|
863 |
+
},
|
864 |
+
"source": [
|
865 |
+
"## Saving and exporting the trained model\n",
|
866 |
+
"\n",
|
867 |
+
"The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results."
|
868 |
+
]
|
869 |
+
},
|
870 |
+
{
|
871 |
+
"cell_type": "code",
|
872 |
+
"execution_count": null,
|
873 |
+
"metadata": {
|
874 |
+
"id": "iZG1vPbTQqFh"
|
875 |
+
},
|
876 |
+
"outputs": [],
|
877 |
+
"source": [
|
878 |
+
"export_saved_model_lib.export_inference_graph(\n",
|
879 |
+
" input_type='image_tensor',\n",
|
880 |
+
" batch_size=1,\n",
|
881 |
+
" input_image_size=[HEIGHT, WIDTH],\n",
|
882 |
+
" params=exp_config,\n",
|
883 |
+
" checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
|
884 |
+
" export_dir=export_dir)"
|
885 |
+
]
|
886 |
+
},
|
887 |
+
{
|
888 |
+
"cell_type": "markdown",
|
889 |
+
"metadata": {
|
890 |
+
"id": "OHIfMeVXh7vJ"
|
891 |
+
},
|
892 |
+
"source": [
|
893 |
+
"## Inference from Trained Model"
|
894 |
+
]
|
895 |
+
},
|
896 |
+
{
|
897 |
+
"cell_type": "code",
|
898 |
+
"execution_count": null,
|
899 |
+
"metadata": {
|
900 |
+
"id": "uaXyzMvXROTd"
|
901 |
+
},
|
902 |
+
"outputs": [],
|
903 |
+
"source": [
|
904 |
+
"def load_image_into_numpy_array(path):\n",
|
905 |
+
" \"\"\"Load an image from file into a numpy array.\n",
|
906 |
+
"\n",
|
907 |
+
" Puts image into numpy array to feed into tensorflow graph.\n",
|
908 |
+
" Note that by convention we put it into a numpy array with shape\n",
|
909 |
+
" (height, width, channels), where channels=3 for RGB.\n",
|
910 |
+
"\n",
|
911 |
+
" Args:\n",
|
912 |
+
" path: the file path to the image\n",
|
913 |
+
"\n",
|
914 |
+
" Returns:\n",
|
915 |
+
" uint8 numpy array with shape (img_height, img_width, 3)\n",
|
916 |
+
" \"\"\"\n",
|
917 |
+
" image = None\n",
|
918 |
+
" if(path.startswith('http')):\n",
|
919 |
+
" response = urlopen(path)\n",
|
920 |
+
" image_data = response.read()\n",
|
921 |
+
" image_data = BytesIO(image_data)\n",
|
922 |
+
" image = Image.open(image_data)\n",
|
923 |
+
" else:\n",
|
924 |
+
" image_data = tf.io.gfile.GFile(path, 'rb').read()\n",
|
925 |
+
" image = Image.open(BytesIO(image_data))\n",
|
926 |
+
"\n",
|
927 |
+
" (im_width, im_height) = image.size\n",
|
928 |
+
" return np.array(image.getdata()).reshape(\n",
|
929 |
+
" (1, im_height, im_width, 3)).astype(np.uint8)\n",
|
930 |
+
"\n",
|
931 |
+
"\n",
|
932 |
+
"\n",
|
933 |
+
"def build_inputs_for_object_detection(image, input_image_size):\n",
|
934 |
+
" \"\"\"Builds Object Detection model inputs for serving.\"\"\"\n",
|
935 |
+
" image, _ = resize_and_crop_image(\n",
|
936 |
+
" image,\n",
|
937 |
+
" input_image_size,\n",
|
938 |
+
" padded_size=input_image_size,\n",
|
939 |
+
" aug_scale_min=1.0,\n",
|
940 |
+
" aug_scale_max=1.0)\n",
|
941 |
+
" return image"
|
942 |
+
]
|
943 |
+
},
|
944 |
+
{
|
945 |
+
"cell_type": "markdown",
|
946 |
+
"metadata": {
|
947 |
+
"id": "ZDI9zv_4h-7-"
|
948 |
+
},
|
949 |
+
"source": [
|
950 |
+
"## Visualize test data"
|
951 |
+
]
|
952 |
+
},
|
953 |
+
{
|
954 |
+
"cell_type": "code",
|
955 |
+
"execution_count": null,
|
956 |
+
"metadata": {
|
957 |
+
"id": "rdyIri-1RThk"
|
958 |
+
},
|
959 |
+
"outputs": [],
|
960 |
+
"source": [
|
961 |
+
"num_of_examples = 3\n",
|
962 |
+
"\n",
|
963 |
+
"test_tfrecords = tf.io.gfile.glob('./lvis_tfrecords/val*')\n",
|
964 |
+
"test_ds = tf.data.TFRecordDataset(test_tfrecords).take(num_of_examples)\n",
|
965 |
+
"show_batch(test_ds)"
|
966 |
+
]
|
967 |
+
},
|
968 |
+
{
|
969 |
+
"cell_type": "markdown",
|
970 |
+
"metadata": {
|
971 |
+
"id": "KkMZm4DtiAHO"
|
972 |
+
},
|
973 |
+
"source": [
|
974 |
+
"## Importing SavedModel"
|
975 |
+
]
|
976 |
+
},
|
977 |
+
{
|
978 |
+
"cell_type": "code",
|
979 |
+
"execution_count": null,
|
980 |
+
"metadata": {
|
981 |
+
"id": "rDozz4NXRZ7p"
|
982 |
+
},
|
983 |
+
"outputs": [],
|
984 |
+
"source": [
|
985 |
+
"imported = tf.saved_model.load(export_dir)\n",
|
986 |
+
"model_fn = imported.signatures['serving_default']"
|
987 |
+
]
|
988 |
+
},
|
989 |
+
{
|
990 |
+
"cell_type": "markdown",
|
991 |
+
"metadata": {
|
992 |
+
"id": "DUxk4-AjLAcO"
|
993 |
+
},
|
994 |
+
"source": [
|
995 |
+
"## Visualize predictions"
|
996 |
+
]
|
997 |
+
},
|
998 |
+
{
|
999 |
+
"cell_type": "code",
|
1000 |
+
"execution_count": null,
|
1001 |
+
"metadata": {
|
1002 |
+
"id": "Gez57T5ShYnM"
|
1003 |
+
},
|
1004 |
+
"outputs": [],
|
1005 |
+
"source": [
|
1006 |
+
"def reframe_image_corners_relative_to_boxes(boxes):\n",
|
1007 |
+
" \"\"\"Reframe the image corners ([0, 0, 1, 1]) to be relative to boxes.\n",
|
1008 |
+
" The local coordinate frame of each box is assumed to be relative to\n",
|
1009 |
+
" its own for corners.\n",
|
1010 |
+
" Args:\n",
|
1011 |
+
" boxes: A float tensor of [num_boxes, 4] of (ymin, xmin, ymax, xmax)\n",
|
1012 |
+
" coordinates in relative coordinate space of each bounding box.\n",
|
1013 |
+
" Returns:\n",
|
1014 |
+
" reframed_boxes: Reframes boxes with same shape as input.\n",
|
1015 |
+
" \"\"\"\n",
|
1016 |
+
" ymin, xmin, ymax, xmax = (boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3])\n",
|
1017 |
+
"\n",
|
1018 |
+
" height = tf.maximum(ymax - ymin, 1e-4)\n",
|
1019 |
+
" width = tf.maximum(xmax - xmin, 1e-4)\n",
|
1020 |
+
"\n",
|
1021 |
+
" ymin_out = (0 - ymin) / height\n",
|
1022 |
+
" xmin_out = (0 - xmin) / width\n",
|
1023 |
+
" ymax_out = (1 - ymin) / height\n",
|
1024 |
+
" xmax_out = (1 - xmin) / width\n",
|
1025 |
+
" return tf.stack([ymin_out, xmin_out, ymax_out, xmax_out], axis=1)\n",
|
1026 |
+
"\n",
|
1027 |
+
"def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,\n",
|
1028 |
+
" image_width, resize_method='bilinear'):\n",
|
1029 |
+
" \"\"\"Transforms the box masks back to full image masks.\n",
|
1030 |
+
" Embeds masks in bounding boxes of larger masks whose shapes correspond to\n",
|
1031 |
+
" image shape.\n",
|
1032 |
+
" Args:\n",
|
1033 |
+
" box_masks: A tensor of size [num_masks, mask_height, mask_width].\n",
|
1034 |
+
" boxes: A tf.float32 tensor of size [num_masks, 4] containing the box\n",
|
1035 |
+
" corners. Row i contains [ymin, xmin, ymax, xmax] of the box\n",
|
1036 |
+
" corresponding to mask i. Note that the box corners are in\n",
|
1037 |
+
" normalized coordinates.\n",
|
1038 |
+
" image_height: Image height. The output mask will have the same height as\n",
|
1039 |
+
" the image height.\n",
|
1040 |
+
" image_width: Image width. The output mask will have the same width as the\n",
|
1041 |
+
" image width.\n",
|
1042 |
+
" resize_method: The resize method, either 'bilinear' or 'nearest'. Note that\n",
|
1043 |
+
" 'bilinear' is only respected if box_masks is a float.\n",
|
1044 |
+
" Returns:\n",
|
1045 |
+
" A tensor of size [num_masks, image_height, image_width] with the same dtype\n",
|
1046 |
+
" as `box_masks`.\n",
|
1047 |
+
" \"\"\"\n",
|
1048 |
+
" resize_method = 'nearest' if box_masks.dtype == tf.uint8 else resize_method\n",
|
1049 |
+
" # TODO(rathodv): Make this a public function.\n",
|
1050 |
+
" def reframe_box_masks_to_image_masks_default():\n",
|
1051 |
+
" \"\"\"The default function when there are more than 0 box masks.\"\"\"\n",
|
1052 |
+
"\n",
|
1053 |
+
" num_boxes = tf.shape(box_masks)[0]\n",
|
1054 |
+
" box_masks_expanded = tf.expand_dims(box_masks, axis=3)\n",
|
1055 |
+
"\n",
|
1056 |
+
" resized_crops = tf.image.crop_and_resize(\n",
|
1057 |
+
" image=box_masks_expanded,\n",
|
1058 |
+
" boxes=reframe_image_corners_relative_to_boxes(boxes),\n",
|
1059 |
+
" box_indices=tf.range(num_boxes),\n",
|
1060 |
+
" crop_size=[image_height, image_width],\n",
|
1061 |
+
" method=resize_method,\n",
|
1062 |
+
" extrapolation_value=0)\n",
|
1063 |
+
" return tf.cast(resized_crops, box_masks.dtype)\n",
|
1064 |
+
"\n",
|
1065 |
+
" image_masks = tf.cond(\n",
|
1066 |
+
" tf.shape(box_masks)[0] > 0,\n",
|
1067 |
+
" reframe_box_masks_to_image_masks_default,\n",
|
1068 |
+
" lambda: tf.zeros([0, image_height, image_width, 1], box_masks.dtype))\n",
|
1069 |
+
" return tf.squeeze(image_masks, axis=3)"
|
1070 |
+
]
|
1071 |
+
},
|
1072 |
+
{
|
1073 |
+
"cell_type": "code",
|
1074 |
+
"execution_count": null,
|
1075 |
+
"metadata": {
|
1076 |
+
"id": "6EIRAlXcSQaA"
|
1077 |
+
},
|
1078 |
+
"outputs": [],
|
1079 |
+
"source": [
|
1080 |
+
"input_image_size = (HEIGHT, WIDTH)\n",
|
1081 |
+
"plt.figure(figsize=(20, 20))\n",
|
1082 |
+
"min_score_thresh = 0.40 # Change minimum score for threshold to see all bounding boxes confidences\n",
|
1083 |
+
"\n",
|
1084 |
+
"for i, serialized_example in enumerate(test_ds):\n",
|
1085 |
+
" plt.subplot(1, 3, i+1)\n",
|
1086 |
+
" decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
|
1087 |
+
" image = build_inputs_for_object_detection(decoded_tensors['image'], input_image_size)\n",
|
1088 |
+
" image = tf.expand_dims(image, axis=0)\n",
|
1089 |
+
" image = tf.cast(image, dtype = tf.uint8)\n",
|
1090 |
+
" image_np = image[0].numpy()\n",
|
1091 |
+
" result = model_fn(image)\n",
|
1092 |
+
" # Visualize detection and masks\n",
|
1093 |
+
" if 'detection_masks' in result:\n",
|
1094 |
+
" # we need to convert np.arrays to tensors\n",
|
1095 |
+
" detection_masks = tf.convert_to_tensor(result['detection_masks'][0])\n",
|
1096 |
+
" detection_boxes = tf.convert_to_tensor(result['detection_boxes'][0])\n",
|
1097 |
+
" detection_masks_reframed = reframe_box_masks_to_image_masks(\n",
|
1098 |
+
" detection_masks, detection_boxes/256.0,\n",
|
1099 |
+
" image_np.shape[0], image_np.shape[1])\n",
|
1100 |
+
" detection_masks_reframed = tf.cast(\n",
|
1101 |
+
" detection_masks_reframed > min_score_thresh,\n",
|
1102 |
+
" np.uint8)\n",
|
1103 |
+
"\n",
|
1104 |
+
" result['detection_masks_reframed'] = detection_masks_reframed.numpy()\n",
|
1105 |
+
" visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
|
1106 |
+
" image_np,\n",
|
1107 |
+
" result['detection_boxes'][0].numpy(),\n",
|
1108 |
+
" (result['detection_classes'][0] + 0).numpy().astype(int),\n",
|
1109 |
+
" result['detection_scores'][0].numpy(),\n",
|
1110 |
+
" category_index=category_index,\n",
|
1111 |
+
" use_normalized_coordinates=False,\n",
|
1112 |
+
" max_boxes_to_draw=200,\n",
|
1113 |
+
" min_score_thresh=min_score_thresh,\n",
|
1114 |
+
" instance_masks=result.get('detection_masks_reframed', None),\n",
|
1115 |
+
" line_thickness=4)\n",
|
1116 |
+
"\n",
|
1117 |
+
" plt.imshow(image_np)\n",
|
1118 |
+
" plt.axis(\"off\")\n",
|
1119 |
+
"\n",
|
1120 |
+
"plt.show()"
|
1121 |
+
]
|
1122 |
+
}
|
1123 |
+
],
|
1124 |
+
"metadata": {
|
1125 |
+
"accelerator": "GPU",
|
1126 |
+
"colab": {
|
1127 |
+
"name": "instance_segmentation.ipynb",
|
1128 |
+
"provenance": [],
|
1129 |
+
"toc_visible": true
|
1130 |
+
},
|
1131 |
+
"kernelspec": {
|
1132 |
+
"display_name": "Python 3",
|
1133 |
+
"name": "python3"
|
1134 |
+
}
|
1135 |
+
},
|
1136 |
+
"nbformat": 4,
|
1137 |
+
"nbformat_minor": 0
|
1138 |
+
}
|
Tensorflow/models/docs/vision/object_detection.ipynb
ADDED
@@ -0,0 +1,902 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "Cayt5nCXb3WG"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"##### Copyright 2022 The TensorFlow Authors."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"id": "DYL3CXHRb9-f"
|
17 |
+
},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
21 |
+
"# you may not use this file except in compliance with the License.\n",
|
22 |
+
"# You may obtain a copy of the License at\n",
|
23 |
+
"#\n",
|
24 |
+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
25 |
+
"#\n",
|
26 |
+
"# Unless required by applicable law or agreed to in writing, software\n",
|
27 |
+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
28 |
+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
29 |
+
"# See the License for the specific language governing permissions and\n",
|
30 |
+
"# limitations under the License."
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "markdown",
|
35 |
+
"metadata": {
|
36 |
+
"id": "VYDmsvURYZjz"
|
37 |
+
},
|
38 |
+
"source": [
|
39 |
+
"# Object detection with Model Garden\n",
|
40 |
+
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
41 |
+
" <td>\n",
|
42 |
+
" <a target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/object_detection\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
|
43 |
+
" </td>\n",
|
44 |
+
" <td>\n",
|
45 |
+
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/object_detection.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
46 |
+
" </td>\n",
|
47 |
+
" <td>\n",
|
48 |
+
" <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/object_detection.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
|
49 |
+
" </td>\n",
|
50 |
+
" <td>\n",
|
51 |
+
" <a href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/object_detection.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
|
52 |
+
" </td>\n",
|
53 |
+
"</table>"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "markdown",
|
58 |
+
"metadata": {
|
59 |
+
"id": "69aQq_PXcUvL"
|
60 |
+
},
|
61 |
+
"source": [
|
62 |
+
"This tutorial fine-tunes a [RetinaNet](https://arxiv.org/abs/1708.02002) with ResNet-50 as backbone model from the [TensorFlow Model Garden](https://pypi.org/project/tf-models-official/) package (tensorflow-models) to detect three different Blood Cells in [BCCD](https://public.roboflow.com/object-detection/bccd) dataset. The RetinaNet is pretrained on [COCO](https://cocodataset.org/) train2017 and evaluated on [COCO](https://cocodataset.org/) val2017\n",
|
63 |
+
"\n",
|
64 |
+
"[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
|
65 |
+
"\n",
|
66 |
+
"This tutorial demonstrates how to:\n",
|
67 |
+
"\n",
|
68 |
+
"1. Use models from the Tensorflow Model Garden(TFM) package.\n",
|
69 |
+
"2. Fine-tune a pre-trained RetinanNet with ResNet-50 as backbone for object detection.\n",
|
70 |
+
"3. Export the tuned RetinaNet model"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "markdown",
|
75 |
+
"metadata": {
|
76 |
+
"id": "IeSHlZyUZl6f"
|
77 |
+
},
|
78 |
+
"source": [
|
79 |
+
"## Install necessary dependencies"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": null,
|
85 |
+
"metadata": {
|
86 |
+
"id": "Pip0LHj3ZqgL"
|
87 |
+
},
|
88 |
+
"outputs": [],
|
89 |
+
"source": [
|
90 |
+
"!pip install -U -q \"tf-models-official\""
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "markdown",
|
95 |
+
"metadata": {
|
96 |
+
"id": "H3kS7Y0sZsIj"
|
97 |
+
},
|
98 |
+
"source": [
|
99 |
+
"## Import required libraries"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"cell_type": "code",
|
104 |
+
"execution_count": null,
|
105 |
+
"metadata": {
|
106 |
+
"id": "hFdVelJ2YbQz"
|
107 |
+
},
|
108 |
+
"outputs": [],
|
109 |
+
"source": [
|
110 |
+
"import os\n",
|
111 |
+
"import io\n",
|
112 |
+
"import pprint\n",
|
113 |
+
"import tempfile\n",
|
114 |
+
"import matplotlib\n",
|
115 |
+
"import numpy as np\n",
|
116 |
+
"import tensorflow as tf\n",
|
117 |
+
"import matplotlib.pyplot as plt\n",
|
118 |
+
"\n",
|
119 |
+
"from PIL import Image\n",
|
120 |
+
"from six import BytesIO\n",
|
121 |
+
"from IPython import display\n",
|
122 |
+
"from urllib.request import urlopen"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "markdown",
|
127 |
+
"metadata": {
|
128 |
+
"id": "TF77J-iMZn_u"
|
129 |
+
},
|
130 |
+
"source": [
|
131 |
+
"## Import required libraries from tensorflow models"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": null,
|
137 |
+
"metadata": {
|
138 |
+
"id": "iT27_SOTY1Dz"
|
139 |
+
},
|
140 |
+
"outputs": [],
|
141 |
+
"source": [
|
142 |
+
"import orbit\n",
|
143 |
+
"import tensorflow_models as tfm\n",
|
144 |
+
"\n",
|
145 |
+
"from official.core import exp_factory\n",
|
146 |
+
"from official.core import config_definitions as cfg\n",
|
147 |
+
"from official.vision.serving import export_saved_model_lib\n",
|
148 |
+
"from official.vision.ops.preprocess_ops import normalize_image\n",
|
149 |
+
"from official.vision.ops.preprocess_ops import resize_and_crop_image\n",
|
150 |
+
"from official.vision.utils.object_detection import visualization_utils\n",
|
151 |
+
"from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder\n",
|
152 |
+
"\n",
|
153 |
+
"pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
|
154 |
+
"print(tf.__version__) # Check the version of tensorflow used\n",
|
155 |
+
"\n",
|
156 |
+
"%matplotlib inline"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"cell_type": "markdown",
|
161 |
+
"metadata": {
|
162 |
+
"id": "WGbMG8cpZyKa"
|
163 |
+
},
|
164 |
+
"source": [
|
165 |
+
"## Custom dataset preparation for object detection\n",
|
166 |
+
"\n",
|
167 |
+
"Models in official repository(of model-garden) requires data in a TFRecords format.\n",
|
168 |
+
"\n",
|
169 |
+
"\n",
|
170 |
+
"Please check [this resource](https://www.tensorflow.org/tutorials/load_data/tfrecord) to learn more about TFRecords data format.\n"
|
171 |
+
]
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"cell_type": "markdown",
|
175 |
+
"metadata": {
|
176 |
+
"id": "Uq5hcbJ8Z4th"
|
177 |
+
},
|
178 |
+
"source": [
|
179 |
+
"### Upload your custom data in drive or local disk of the notebook and unzip the data"
|
180 |
+
]
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"cell_type": "code",
|
184 |
+
"execution_count": null,
|
185 |
+
"metadata": {
|
186 |
+
"id": "rDixpoqoY3Za"
|
187 |
+
},
|
188 |
+
"outputs": [],
|
189 |
+
"source": [
|
190 |
+
"!curl -L 'https://public.roboflow.com/ds/ZpYLqHeT0W?key=ZXfZLRnhsc' > './BCCD.v1-bccd.coco.zip'\n",
|
191 |
+
"!unzip -q -o './BCCD.v1-bccd.coco.zip' -d './BCC.v1-bccd.coco/'\n",
|
192 |
+
"!rm './BCCD.v1-bccd.coco.zip'"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"cell_type": "markdown",
|
197 |
+
"metadata": {
|
198 |
+
"id": "GI1h9UChZ8cC"
|
199 |
+
},
|
200 |
+
"source": [
|
201 |
+
"### CLI command to convert data(train data)."
|
202 |
+
]
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"cell_type": "code",
|
206 |
+
"execution_count": null,
|
207 |
+
"metadata": {
|
208 |
+
"id": "x_8cmB82Y65O"
|
209 |
+
},
|
210 |
+
"outputs": [],
|
211 |
+
"source": [
|
212 |
+
"%%bash\n",
|
213 |
+
"\n",
|
214 |
+
"TRAIN_DATA_DIR='./BCC.v1-bccd.coco/train'\n",
|
215 |
+
"TRAIN_ANNOTATION_FILE_DIR='./BCC.v1-bccd.coco/train/_annotations.coco.json'\n",
|
216 |
+
"OUTPUT_TFRECORD_TRAIN='./bccd_coco_tfrecords/train'\n",
|
217 |
+
"\n",
|
218 |
+
"# Need to provide\n",
|
219 |
+
" # 1. image_dir: where images are present\n",
|
220 |
+
" # 2. object_annotations_file: where annotations are listed in json format\n",
|
221 |
+
" # 3. output_file_prefix: where to write output convered TFRecords files\n",
|
222 |
+
"python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
|
223 |
+
" --image_dir=${TRAIN_DATA_DIR} \\\n",
|
224 |
+
" --object_annotations_file=${TRAIN_ANNOTATION_FILE_DIR} \\\n",
|
225 |
+
" --output_file_prefix=$OUTPUT_TFRECORD_TRAIN \\\n",
|
226 |
+
" --num_shards=1"
|
227 |
+
]
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"cell_type": "markdown",
|
231 |
+
"metadata": {
|
232 |
+
"id": "VuwZpwUoaAKU"
|
233 |
+
},
|
234 |
+
"source": [
|
235 |
+
"### CLI command to convert data(validation data)."
|
236 |
+
]
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "code",
|
240 |
+
"execution_count": null,
|
241 |
+
"metadata": {
|
242 |
+
"id": "q8mQ8prGY8kh"
|
243 |
+
},
|
244 |
+
"outputs": [],
|
245 |
+
"source": [
|
246 |
+
"%%bash\n",
|
247 |
+
"\n",
|
248 |
+
"VALID_DATA_DIR='./BCC.v1-bccd.coco/valid'\n",
|
249 |
+
"VALID_ANNOTATION_FILE_DIR='./BCC.v1-bccd.coco/valid/_annotations.coco.json'\n",
|
250 |
+
"OUTPUT_TFRECORD_VALID='./bccd_coco_tfrecords/valid'\n",
|
251 |
+
"\n",
|
252 |
+
"python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
|
253 |
+
" --image_dir=$VALID_DATA_DIR \\\n",
|
254 |
+
" --object_annotations_file=$VALID_ANNOTATION_FILE_DIR \\\n",
|
255 |
+
" --output_file_prefix=$OUTPUT_TFRECORD_VALID \\\n",
|
256 |
+
" --num_shards=1"
|
257 |
+
]
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"cell_type": "markdown",
|
261 |
+
"metadata": {
|
262 |
+
"id": "BYGxNNAXaCW6"
|
263 |
+
},
|
264 |
+
"source": [
|
265 |
+
"### CLI command to convert data(test data)."
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "code",
|
270 |
+
"execution_count": null,
|
271 |
+
"metadata": {
|
272 |
+
"id": "-K8qlfstY-Ua"
|
273 |
+
},
|
274 |
+
"outputs": [],
|
275 |
+
"source": [
|
276 |
+
"%%bash\n",
|
277 |
+
"\n",
|
278 |
+
"TEST_DATA_DIR='./BCC.v1-bccd.coco/test'\n",
|
279 |
+
"TEST_ANNOTATION_FILE_DIR='./BCC.v1-bccd.coco/test/_annotations.coco.json'\n",
|
280 |
+
"OUTPUT_TFRECORD_TEST='./bccd_coco_tfrecords/test'\n",
|
281 |
+
"\n",
|
282 |
+
"python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
|
283 |
+
" --image_dir=$TEST_DATA_DIR \\\n",
|
284 |
+
" --object_annotations_file=$TEST_ANNOTATION_FILE_DIR \\\n",
|
285 |
+
" --output_file_prefix=$OUTPUT_TFRECORD_TEST \\\n",
|
286 |
+
" --num_shards=1"
|
287 |
+
]
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"cell_type": "markdown",
|
291 |
+
"metadata": {
|
292 |
+
"id": "cW7hQEJTaEtj"
|
293 |
+
},
|
294 |
+
"source": [
|
295 |
+
"## Configure the Retinanet Resnet FPN COCO model for custom dataset.\n",
|
296 |
+
"\n",
|
297 |
+
"Dataset used for fine tuning the checkpoint is Blood Cells Detection (BCCD)."
|
298 |
+
]
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"cell_type": "code",
|
302 |
+
"execution_count": null,
|
303 |
+
"metadata": {
|
304 |
+
"id": "PMGEl7iXZAAF"
|
305 |
+
},
|
306 |
+
"outputs": [],
|
307 |
+
"source": [
|
308 |
+
"train_data_input_path = './bccd_coco_tfrecords/train-00000-of-00001.tfrecord'\n",
|
309 |
+
"valid_data_input_path = './bccd_coco_tfrecords/valid-00000-of-00001.tfrecord'\n",
|
310 |
+
"test_data_input_path = './bccd_coco_tfrecords/test-00000-of-00001.tfrecord'\n",
|
311 |
+
"model_dir = './trained_model/'\n",
|
312 |
+
"export_dir ='./exported_model/'"
|
313 |
+
]
|
314 |
+
},
|
315 |
+
{
|
316 |
+
"cell_type": "markdown",
|
317 |
+
"metadata": {
|
318 |
+
"id": "2DJpKvdeaHF3"
|
319 |
+
},
|
320 |
+
"source": [
|
321 |
+
"In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
|
322 |
+
"\n",
|
323 |
+
"\n",
|
324 |
+
"Use the `retinanet_resnetfpn_coco` experiment configuration, as defined by `tfm.vision.configs.retinanet.retinanet_resnetfpn_coco`.\n",
|
325 |
+
"\n",
|
326 |
+
"The configuration defines an experiment to train a RetinanNet with Resnet-50 as backbone, FPN as decoder. Default Configuration is trained on [COCO](https://cocodataset.org/) train2017 and evaluated on [COCO](https://cocodataset.org/) val2017.\n",
|
327 |
+
"\n",
|
328 |
+
"There are also other alternative experiments available such as\n",
|
329 |
+
"`retinanet_resnetfpn_coco`, `retinanet_spinenet_coco`, `fasterrcnn_resnetfpn_coco` and more. One can switch to them by changing the experiment name argument to the `get_exp_config` function.\n",
|
330 |
+
"\n",
|
331 |
+
"We are going to fine tune the Resnet-50 backbone checkpoint which is already present in the default configuration."
|
332 |
+
]
|
333 |
+
},
|
334 |
+
{
|
335 |
+
"cell_type": "code",
|
336 |
+
"execution_count": null,
|
337 |
+
"metadata": {
|
338 |
+
"id": "Ie1ObPH9ZBpa"
|
339 |
+
},
|
340 |
+
"outputs": [],
|
341 |
+
"source": [
|
342 |
+
"exp_config = exp_factory.get_exp_config('retinanet_resnetfpn_coco')"
|
343 |
+
]
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"cell_type": "markdown",
|
347 |
+
"metadata": {
|
348 |
+
"id": "LFhjFkw-alba"
|
349 |
+
},
|
350 |
+
"source": [
|
351 |
+
"### Adjust the model and dataset configurations so that it works with custom dataset(in this case `BCCD`)."
|
352 |
+
]
|
353 |
+
},
|
354 |
+
{
|
355 |
+
"cell_type": "code",
|
356 |
+
"execution_count": null,
|
357 |
+
"metadata": {
|
358 |
+
"id": "ej7j6dvIZDQA"
|
359 |
+
},
|
360 |
+
"outputs": [],
|
361 |
+
"source": [
|
362 |
+
"batch_size = 8\n",
|
363 |
+
"num_classes = 3\n",
|
364 |
+
"\n",
|
365 |
+
"HEIGHT, WIDTH = 256, 256\n",
|
366 |
+
"IMG_SIZE = [HEIGHT, WIDTH, 3]\n",
|
367 |
+
"\n",
|
368 |
+
"# Backbone config.\n",
|
369 |
+
"exp_config.task.freeze_backbone = False\n",
|
370 |
+
"exp_config.task.annotation_file = ''\n",
|
371 |
+
"\n",
|
372 |
+
"# Model config.\n",
|
373 |
+
"exp_config.task.model.input_size = IMG_SIZE\n",
|
374 |
+
"exp_config.task.model.num_classes = num_classes + 1\n",
|
375 |
+
"exp_config.task.model.detection_generator.tflite_post_processing.max_classes_per_detection = exp_config.task.model.num_classes\n",
|
376 |
+
"\n",
|
377 |
+
"# Training data config.\n",
|
378 |
+
"exp_config.task.train_data.input_path = train_data_input_path\n",
|
379 |
+
"exp_config.task.train_data.dtype = 'float32'\n",
|
380 |
+
"exp_config.task.train_data.global_batch_size = batch_size\n",
|
381 |
+
"exp_config.task.train_data.parser.aug_scale_max = 1.0\n",
|
382 |
+
"exp_config.task.train_data.parser.aug_scale_min = 1.0\n",
|
383 |
+
"\n",
|
384 |
+
"# Validation data config.\n",
|
385 |
+
"exp_config.task.validation_data.input_path = valid_data_input_path\n",
|
386 |
+
"exp_config.task.validation_data.dtype = 'float32'\n",
|
387 |
+
"exp_config.task.validation_data.global_batch_size = batch_size"
|
388 |
+
]
|
389 |
+
},
|
390 |
+
{
|
391 |
+
"cell_type": "markdown",
|
392 |
+
"metadata": {
|
393 |
+
"id": "ROVc1rayaqI1"
|
394 |
+
},
|
395 |
+
"source": [
|
396 |
+
"### Adjust the trainer configuration."
|
397 |
+
]
|
398 |
+
},
|
399 |
+
{
|
400 |
+
"cell_type": "code",
|
401 |
+
"execution_count": null,
|
402 |
+
"metadata": {
|
403 |
+
"id": "BZsCVBafZFIE"
|
404 |
+
},
|
405 |
+
"outputs": [],
|
406 |
+
"source": [
|
407 |
+
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
|
408 |
+
"\n",
|
409 |
+
"if 'GPU' in ''.join(logical_device_names):\n",
|
410 |
+
" print('This may be broken in Colab.')\n",
|
411 |
+
" device = 'GPU'\n",
|
412 |
+
"elif 'TPU' in ''.join(logical_device_names):\n",
|
413 |
+
" print('This may be broken in Colab.')\n",
|
414 |
+
" device = 'TPU'\n",
|
415 |
+
"else:\n",
|
416 |
+
" print('Running on CPU is slow, so only train for a few steps.')\n",
|
417 |
+
" device = 'CPU'\n",
|
418 |
+
"\n",
|
419 |
+
"\n",
|
420 |
+
"train_steps = 1000\n",
|
421 |
+
"exp_config.trainer.steps_per_loop = 100 # steps_per_loop = num_of_training_examples // train_batch_size\n",
|
422 |
+
"\n",
|
423 |
+
"exp_config.trainer.summary_interval = 100\n",
|
424 |
+
"exp_config.trainer.checkpoint_interval = 100\n",
|
425 |
+
"exp_config.trainer.validation_interval = 100\n",
|
426 |
+
"exp_config.trainer.validation_steps = 100 # validation_steps = num_of_validation_examples // eval_batch_size\n",
|
427 |
+
"exp_config.trainer.train_steps = train_steps\n",
|
428 |
+
"exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100\n",
|
429 |
+
"exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
|
430 |
+
"exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
|
431 |
+
"exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
|
432 |
+
"exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05"
|
433 |
+
]
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"cell_type": "markdown",
|
437 |
+
"metadata": {
|
438 |
+
"id": "XS6cJfs2atgI"
|
439 |
+
},
|
440 |
+
"source": [
|
441 |
+
"### Print the modified configuration."
|
442 |
+
]
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"cell_type": "code",
|
446 |
+
"execution_count": null,
|
447 |
+
"metadata": {
|
448 |
+
"id": "IvfJlqI7ZIcD"
|
449 |
+
},
|
450 |
+
"outputs": [],
|
451 |
+
"source": [
|
452 |
+
"pp.pprint(exp_config.as_dict())\n",
|
453 |
+
"display.Javascript('google.colab.output.setIframeHeight(\"500px\");')"
|
454 |
+
]
|
455 |
+
},
|
456 |
+
{
|
457 |
+
"cell_type": "markdown",
|
458 |
+
"metadata": {
|
459 |
+
"id": "6o5mbpRBawbs"
|
460 |
+
},
|
461 |
+
"source": [
|
462 |
+
"### Set up the distribution strategy."
|
463 |
+
]
|
464 |
+
},
|
465 |
+
{
|
466 |
+
"cell_type": "code",
|
467 |
+
"execution_count": null,
|
468 |
+
"metadata": {
|
469 |
+
"id": "2NvY8QHOZKGr"
|
470 |
+
},
|
471 |
+
"outputs": [],
|
472 |
+
"source": [
|
473 |
+
"if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
|
474 |
+
" tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
|
475 |
+
"\n",
|
476 |
+
"if 'GPU' in ''.join(logical_device_names):\n",
|
477 |
+
" distribution_strategy = tf.distribute.MirroredStrategy()\n",
|
478 |
+
"elif 'TPU' in ''.join(logical_device_names):\n",
|
479 |
+
" tf.tpu.experimental.initialize_tpu_system()\n",
|
480 |
+
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
|
481 |
+
" distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
|
482 |
+
"else:\n",
|
483 |
+
" print('Warning: this will be really slow.')\n",
|
484 |
+
" distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n",
|
485 |
+
"\n",
|
486 |
+
"print('Done')"
|
487 |
+
]
|
488 |
+
},
|
489 |
+
{
|
490 |
+
"cell_type": "markdown",
|
491 |
+
"metadata": {
|
492 |
+
"id": "4wPtJgoOa33v"
|
493 |
+
},
|
494 |
+
"source": [
|
495 |
+
"## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
|
496 |
+
"\n",
|
497 |
+
"The `Task` object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
|
498 |
+
]
|
499 |
+
},
|
500 |
+
{
|
501 |
+
"cell_type": "code",
|
502 |
+
"execution_count": null,
|
503 |
+
"metadata": {
|
504 |
+
"id": "Ns9LAsiXZLuX"
|
505 |
+
},
|
506 |
+
"outputs": [],
|
507 |
+
"source": [
|
508 |
+
"with distribution_strategy.scope():\n",
|
509 |
+
" task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)"
|
510 |
+
]
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"cell_type": "markdown",
|
514 |
+
"metadata": {
|
515 |
+
"id": "vTKbQxDkbArE"
|
516 |
+
},
|
517 |
+
"source": [
|
518 |
+
"## Visualize a batch of the data."
|
519 |
+
]
|
520 |
+
},
|
521 |
+
{
|
522 |
+
"cell_type": "code",
|
523 |
+
"execution_count": null,
|
524 |
+
"metadata": {
|
525 |
+
"id": "3RIlbhj0ZNvt"
|
526 |
+
},
|
527 |
+
"outputs": [],
|
528 |
+
"source": [
|
529 |
+
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
|
530 |
+
" print()\n",
|
531 |
+
" print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
|
532 |
+
" print(f'labels.keys: {labels.keys()}')"
|
533 |
+
]
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"cell_type": "markdown",
|
537 |
+
"metadata": {
|
538 |
+
"id": "m-QW7DoKbD8z"
|
539 |
+
},
|
540 |
+
"source": [
|
541 |
+
"### Create category index dictionary to map the labels to coressponding label names."
|
542 |
+
]
|
543 |
+
},
|
544 |
+
{
|
545 |
+
"cell_type": "code",
|
546 |
+
"execution_count": null,
|
547 |
+
"metadata": {
|
548 |
+
"id": "MN0sSthbZR-s"
|
549 |
+
},
|
550 |
+
"outputs": [],
|
551 |
+
"source": [
|
552 |
+
"category_index={\n",
|
553 |
+
" 1: {\n",
|
554 |
+
" 'id': 1,\n",
|
555 |
+
" 'name': 'Platelets'\n",
|
556 |
+
" },\n",
|
557 |
+
" 2: {\n",
|
558 |
+
" 'id': 2,\n",
|
559 |
+
" 'name': 'RBC'\n",
|
560 |
+
" },\n",
|
561 |
+
" 3: {\n",
|
562 |
+
" 'id': 3,\n",
|
563 |
+
" 'name': 'WBC'\n",
|
564 |
+
" }\n",
|
565 |
+
"}\n",
|
566 |
+
"tf_ex_decoder = TfExampleDecoder()"
|
567 |
+
]
|
568 |
+
},
|
569 |
+
{
|
570 |
+
"cell_type": "markdown",
|
571 |
+
"metadata": {
|
572 |
+
"id": "AcbmD1pRbGcS"
|
573 |
+
},
|
574 |
+
"source": [
|
575 |
+
"### Helper function for visualizing the results from TFRecords.\n",
|
576 |
+
"Use `visualize_boxes_and_labels_on_image_array` from `visualization_utils` to draw boudning boxes on the image."
|
577 |
+
]
|
578 |
+
},
|
579 |
+
{
|
580 |
+
"cell_type": "code",
|
581 |
+
"execution_count": null,
|
582 |
+
"metadata": {
|
583 |
+
"id": "wWBeomMMZThI"
|
584 |
+
},
|
585 |
+
"outputs": [],
|
586 |
+
"source": [
|
587 |
+
"def show_batch(raw_records, num_of_examples):\n",
|
588 |
+
" plt.figure(figsize=(20, 20))\n",
|
589 |
+
" use_normalized_coordinates=True\n",
|
590 |
+
" min_score_thresh = 0.30\n",
|
591 |
+
" for i, serialized_example in enumerate(raw_records):\n",
|
592 |
+
" plt.subplot(1, 3, i + 1)\n",
|
593 |
+
" decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
|
594 |
+
" image = decoded_tensors['image'].numpy().astype('uint8')\n",
|
595 |
+
" scores = np.ones(shape=(len(decoded_tensors['groundtruth_boxes'])))\n",
|
596 |
+
" visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
|
597 |
+
" image,\n",
|
598 |
+
" decoded_tensors['groundtruth_boxes'].numpy(),\n",
|
599 |
+
" decoded_tensors['groundtruth_classes'].numpy().astype('int'),\n",
|
600 |
+
" scores,\n",
|
601 |
+
" category_index=category_index,\n",
|
602 |
+
" use_normalized_coordinates=use_normalized_coordinates,\n",
|
603 |
+
" max_boxes_to_draw=200,\n",
|
604 |
+
" min_score_thresh=min_score_thresh,\n",
|
605 |
+
" agnostic_mode=False,\n",
|
606 |
+
" instance_masks=None,\n",
|
607 |
+
" line_thickness=4)\n",
|
608 |
+
"\n",
|
609 |
+
" plt.imshow(image)\n",
|
610 |
+
" plt.axis('off')\n",
|
611 |
+
" plt.title(f'Image-{i+1}')\n",
|
612 |
+
" plt.show()"
|
613 |
+
]
|
614 |
+
},
|
615 |
+
{
|
616 |
+
"cell_type": "markdown",
|
617 |
+
"metadata": {
|
618 |
+
"id": "R3EgriELbJly"
|
619 |
+
},
|
620 |
+
"source": [
|
621 |
+
"### Visualization of train data\n",
|
622 |
+
"\n",
|
623 |
+
"The bounding box detection has two components\n",
|
624 |
+
" 1. Class label of the object detected (e.g.RBC)\n",
|
625 |
+
" 2. Percentage of match between predicted and ground truth bounding boxes.\n",
|
626 |
+
"\n",
|
627 |
+
"**Note**: The reason of everything is 100% is because we are visualising the groundtruth."
|
628 |
+
]
|
629 |
+
},
|
630 |
+
{
|
631 |
+
"cell_type": "code",
|
632 |
+
"execution_count": null,
|
633 |
+
"metadata": {
|
634 |
+
"id": "hdrsciGIZVNO"
|
635 |
+
},
|
636 |
+
"outputs": [],
|
637 |
+
"source": [
|
638 |
+
"buffer_size = 20\n",
|
639 |
+
"num_of_examples = 3\n",
|
640 |
+
"\n",
|
641 |
+
"raw_records = tf.data.TFRecordDataset(\n",
|
642 |
+
" exp_config.task.train_data.input_path).shuffle(\n",
|
643 |
+
" buffer_size=buffer_size).take(num_of_examples)\n",
|
644 |
+
"show_batch(raw_records, num_of_examples)"
|
645 |
+
]
|
646 |
+
},
|
647 |
+
{
|
648 |
+
"cell_type": "markdown",
|
649 |
+
"metadata": {
|
650 |
+
"id": "IrWkJPyEbMKg"
|
651 |
+
},
|
652 |
+
"source": [
|
653 |
+
"## Train and evaluate.\n",
|
654 |
+
"\n",
|
655 |
+
"We follow the COCO challenge tradition to evaluate the accuracy of object detection based on mAP(mean Average Precision). Please check [here](https://cocodataset.org/#detection-eval) for detail explanation of how evaluation metrics for detection task is done.\n",
|
656 |
+
"\n",
|
657 |
+
"**IoU**: is defined as the area of the intersection divided by the area of the union of a predicted bounding box and ground truth bounding box."
|
658 |
+
]
|
659 |
+
},
|
660 |
+
{
|
661 |
+
"cell_type": "code",
|
662 |
+
"execution_count": 18,
|
663 |
+
"metadata": {
|
664 |
+
"id": "SCjHHXvfZXX1"
|
665 |
+
},
|
666 |
+
"outputs": [],
|
667 |
+
"source": [
|
668 |
+
"model, eval_logs = tfm.core.train_lib.run_experiment(\n",
|
669 |
+
" distribution_strategy=distribution_strategy,\n",
|
670 |
+
" task=task,\n",
|
671 |
+
" mode='train_and_eval',\n",
|
672 |
+
" params=exp_config,\n",
|
673 |
+
" model_dir=model_dir,\n",
|
674 |
+
" run_post_eval=True)"
|
675 |
+
]
|
676 |
+
},
|
677 |
+
{
|
678 |
+
"cell_type": "markdown",
|
679 |
+
"metadata": {
|
680 |
+
"id": "2Gd6uHLjbPKW"
|
681 |
+
},
|
682 |
+
"source": [
|
683 |
+
"## Load logs in tensorboard."
|
684 |
+
]
|
685 |
+
},
|
686 |
+
{
|
687 |
+
"cell_type": "code",
|
688 |
+
"execution_count": null,
|
689 |
+
"metadata": {
|
690 |
+
"id": "Q6iDRUVqZY86"
|
691 |
+
},
|
692 |
+
"outputs": [],
|
693 |
+
"source": [
|
694 |
+
"%load_ext tensorboard\n",
|
695 |
+
"%tensorboard --logdir './trained_model/'"
|
696 |
+
]
|
697 |
+
},
|
698 |
+
{
|
699 |
+
"cell_type": "markdown",
|
700 |
+
"metadata": {
|
701 |
+
"id": "AoL2MIJobReU"
|
702 |
+
},
|
703 |
+
"source": [
|
704 |
+
"## Saving and exporting the trained model.\n",
|
705 |
+
"\n",
|
706 |
+
"The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results."
|
707 |
+
]
|
708 |
+
},
|
709 |
+
{
|
710 |
+
"cell_type": "code",
|
711 |
+
"execution_count": null,
|
712 |
+
"metadata": {
|
713 |
+
"id": "CmOBYXdXZah4"
|
714 |
+
},
|
715 |
+
"outputs": [],
|
716 |
+
"source": [
|
717 |
+
"export_saved_model_lib.export_inference_graph(\n",
|
718 |
+
" input_type='image_tensor',\n",
|
719 |
+
" batch_size=1,\n",
|
720 |
+
" input_image_size=[HEIGHT, WIDTH],\n",
|
721 |
+
" params=exp_config,\n",
|
722 |
+
" checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
|
723 |
+
" export_dir=export_dir)"
|
724 |
+
]
|
725 |
+
},
|
726 |
+
{
|
727 |
+
"cell_type": "markdown",
|
728 |
+
"metadata": {
|
729 |
+
"id": "_JhXopm8bU1g"
|
730 |
+
},
|
731 |
+
"source": [
|
732 |
+
"## Inference from trained model"
|
733 |
+
]
|
734 |
+
},
|
735 |
+
{
|
736 |
+
"cell_type": "code",
|
737 |
+
"execution_count": null,
|
738 |
+
"metadata": {
|
739 |
+
"id": "EbD4j1uCZcIV"
|
740 |
+
},
|
741 |
+
"outputs": [],
|
742 |
+
"source": [
|
743 |
+
"def load_image_into_numpy_array(path):\n",
|
744 |
+
" \"\"\"Load an image from file into a numpy array.\n",
|
745 |
+
"\n",
|
746 |
+
" Puts image into numpy array to feed into tensorflow graph.\n",
|
747 |
+
" Note that by convention we put it into a numpy array with shape\n",
|
748 |
+
" (height, width, channels), where channels=3 for RGB.\n",
|
749 |
+
"\n",
|
750 |
+
" Args:\n",
|
751 |
+
" path: the file path to the image\n",
|
752 |
+
"\n",
|
753 |
+
" Returns:\n",
|
754 |
+
" uint8 numpy array with shape (img_height, img_width, 3)\n",
|
755 |
+
" \"\"\"\n",
|
756 |
+
" image = None\n",
|
757 |
+
" if(path.startswith('http')):\n",
|
758 |
+
" response = urlopen(path)\n",
|
759 |
+
" image_data = response.read()\n",
|
760 |
+
" image_data = BytesIO(image_data)\n",
|
761 |
+
" image = Image.open(image_data)\n",
|
762 |
+
" else:\n",
|
763 |
+
" image_data = tf.io.gfile.GFile(path, 'rb').read()\n",
|
764 |
+
" image = Image.open(BytesIO(image_data))\n",
|
765 |
+
"\n",
|
766 |
+
" (im_width, im_height) = image.size\n",
|
767 |
+
" return np.array(image.getdata()).reshape(\n",
|
768 |
+
" (1, im_height, im_width, 3)).astype(np.uint8)\n",
|
769 |
+
"\n",
|
770 |
+
"\n",
|
771 |
+
"\n",
|
772 |
+
"def build_inputs_for_object_detection(image, input_image_size):\n",
|
773 |
+
" \"\"\"Builds Object Detection model inputs for serving.\"\"\"\n",
|
774 |
+
" image, _ = resize_and_crop_image(\n",
|
775 |
+
" image,\n",
|
776 |
+
" input_image_size,\n",
|
777 |
+
" padded_size=input_image_size,\n",
|
778 |
+
" aug_scale_min=1.0,\n",
|
779 |
+
" aug_scale_max=1.0)\n",
|
780 |
+
" return image"
|
781 |
+
]
|
782 |
+
},
|
783 |
+
{
|
784 |
+
"cell_type": "markdown",
|
785 |
+
"metadata": {
|
786 |
+
"id": "o8bguhK_batq"
|
787 |
+
},
|
788 |
+
"source": [
|
789 |
+
"### Visualize test data."
|
790 |
+
]
|
791 |
+
},
|
792 |
+
{
|
793 |
+
"cell_type": "code",
|
794 |
+
"execution_count": null,
|
795 |
+
"metadata": {
|
796 |
+
"id": "sOsDhYmyZd_m"
|
797 |
+
},
|
798 |
+
"outputs": [],
|
799 |
+
"source": [
|
800 |
+
"num_of_examples = 3\n",
|
801 |
+
"\n",
|
802 |
+
"test_ds = tf.data.TFRecordDataset(\n",
|
803 |
+
" './bccd_coco_tfrecords/test-00000-of-00001.tfrecord').take(\n",
|
804 |
+
" num_of_examples)\n",
|
805 |
+
"show_batch(test_ds, num_of_examples)"
|
806 |
+
]
|
807 |
+
},
|
808 |
+
{
|
809 |
+
"cell_type": "markdown",
|
810 |
+
"metadata": {
|
811 |
+
"id": "kcYnb1Zfbba9"
|
812 |
+
},
|
813 |
+
"source": [
|
814 |
+
"### Importing SavedModel."
|
815 |
+
]
|
816 |
+
},
|
817 |
+
{
|
818 |
+
"cell_type": "code",
|
819 |
+
"execution_count": null,
|
820 |
+
"metadata": {
|
821 |
+
"id": "nQ6waz9rZfhy"
|
822 |
+
},
|
823 |
+
"outputs": [],
|
824 |
+
"source": [
|
825 |
+
"imported = tf.saved_model.load(export_dir)\n",
|
826 |
+
"model_fn = imported.signatures['serving_default']"
|
827 |
+
]
|
828 |
+
},
|
829 |
+
{
|
830 |
+
"cell_type": "markdown",
|
831 |
+
"metadata": {
|
832 |
+
"id": "CtB4gfZ3bfiC"
|
833 |
+
},
|
834 |
+
"source": [
|
835 |
+
"### Visualize predictions."
|
836 |
+
]
|
837 |
+
},
|
838 |
+
{
|
839 |
+
"cell_type": "code",
|
840 |
+
"execution_count": null,
|
841 |
+
"metadata": {
|
842 |
+
"id": "UTSfNZ6yZhEV"
|
843 |
+
},
|
844 |
+
"outputs": [],
|
845 |
+
"source": [
|
846 |
+
"input_image_size = (HEIGHT, WIDTH)\n",
|
847 |
+
"plt.figure(figsize=(20, 20))\n",
|
848 |
+
"min_score_thresh = 0.30 # Change minimum score for threshold to see all bounding boxes confidences.\n",
|
849 |
+
"\n",
|
850 |
+
"for i, serialized_example in enumerate(test_ds):\n",
|
851 |
+
" plt.subplot(1, 3, i+1)\n",
|
852 |
+
" decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
|
853 |
+
" image = build_inputs_for_object_detection(decoded_tensors['image'], input_image_size)\n",
|
854 |
+
" image = tf.expand_dims(image, axis=0)\n",
|
855 |
+
" image = tf.cast(image, dtype = tf.uint8)\n",
|
856 |
+
" image_np = image[0].numpy()\n",
|
857 |
+
" result = model_fn(image)\n",
|
858 |
+
" visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
|
859 |
+
" image_np,\n",
|
860 |
+
" result['detection_boxes'][0].numpy(),\n",
|
861 |
+
" result['detection_classes'][0].numpy().astype(int),\n",
|
862 |
+
" result['detection_scores'][0].numpy(),\n",
|
863 |
+
" category_index=category_index,\n",
|
864 |
+
" use_normalized_coordinates=False,\n",
|
865 |
+
" max_boxes_to_draw=200,\n",
|
866 |
+
" min_score_thresh=min_score_thresh,\n",
|
867 |
+
" agnostic_mode=False,\n",
|
868 |
+
" instance_masks=None,\n",
|
869 |
+
" line_thickness=4)\n",
|
870 |
+
" plt.imshow(image_np)\n",
|
871 |
+
" plt.axis('off')\n",
|
872 |
+
"\n",
|
873 |
+
"plt.show()"
|
874 |
+
]
|
875 |
+
}
|
876 |
+
],
|
877 |
+
"metadata": {
|
878 |
+
"colab": {
|
879 |
+
"name": "object_detection.ipynb",
|
880 |
+
"provenance": [],
|
881 |
+
"toc_visible": true
|
882 |
+
},
|
883 |
+
"kernelspec": {
|
884 |
+
"display_name": "Python 3",
|
885 |
+
"name": "python3"
|
886 |
+
},
|
887 |
+
"language_info": {
|
888 |
+
"codemirror_mode": {
|
889 |
+
"name": "ipython",
|
890 |
+
"version": 3
|
891 |
+
},
|
892 |
+
"file_extension": ".py",
|
893 |
+
"mimetype": "text/x-python",
|
894 |
+
"name": "python",
|
895 |
+
"nbconvert_exporter": "python",
|
896 |
+
"pygments_lexer": "ipython3",
|
897 |
+
"version": "3.10.8"
|
898 |
+
}
|
899 |
+
},
|
900 |
+
"nbformat": 4,
|
901 |
+
"nbformat_minor": 0
|
902 |
+
}
|
Tensorflow/models/docs/vision/semantic_segmentation.ipynb
ADDED
@@ -0,0 +1,790 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "uY4QMaQw9Yvi"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"##### Copyright 2022 The TensorFlow Authors."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"cellView": "form",
|
17 |
+
"id": "NM0OBLSN9heW"
|
18 |
+
},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
22 |
+
"# you may not use this file except in compliance with the License.\n",
|
23 |
+
"# You may obtain a copy of the License at\n",
|
24 |
+
"#\n",
|
25 |
+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
26 |
+
"#\n",
|
27 |
+
"# Unless required by applicable law or agreed to in writing, software\n",
|
28 |
+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
29 |
+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
30 |
+
"# See the License for the specific language governing permissions and\n",
|
31 |
+
"# limitations under the License."
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"metadata": {
|
37 |
+
"id": "sg-GchQwFr_r"
|
38 |
+
},
|
39 |
+
"source": [
|
40 |
+
"# Semantic Segmentation with Model Garden\n",
|
41 |
+
"\n",
|
42 |
+
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
43 |
+
" <td>\n",
|
44 |
+
" <a target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/semantic_segmentation\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
|
45 |
+
" </td>\n",
|
46 |
+
" <td>\n",
|
47 |
+
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/semantic_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
48 |
+
" </td>\n",
|
49 |
+
" <td>\n",
|
50 |
+
" <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/semantic_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
|
51 |
+
" </td>\n",
|
52 |
+
" <td>\n",
|
53 |
+
" <a href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/semantic_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
|
54 |
+
" </td>\n",
|
55 |
+
"</table>"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "markdown",
|
60 |
+
"metadata": {
|
61 |
+
"id": "c6J4IoNfN9jp"
|
62 |
+
},
|
63 |
+
"source": [
|
64 |
+
"This tutorial trains a [DeepLabV3](https://arxiv.org/pdf/1706.05587.pdf) with [Mobilenet V2](https://arxiv.org/abs/1801.04381) as backbone model from the [TensorFlow Model Garden](https://pypi.org/project/tf-models-official/) package (tensorflow-models).\n",
|
65 |
+
"\n",
|
66 |
+
"\n",
|
67 |
+
"[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
|
68 |
+
"\n",
|
69 |
+
"**Dataset**: [Oxford-IIIT Pets](https://www.tensorflow.org/datasets/catalog/oxford_iiit_pet)\n",
|
70 |
+
"\n",
|
71 |
+
"* The Oxford-IIIT pet dataset is a 37 category pet image dataset with roughly 200 images for each class. The images have large variations in scale, pose and lighting. All images have an associated ground truth annotation of breed.\n",
|
72 |
+
"\n",
|
73 |
+
"\n",
|
74 |
+
"**This tutorial demonstrates how to:**\n",
|
75 |
+
"\n",
|
76 |
+
"1. Use models from the TensorFlow Models package.\n",
|
77 |
+
"2. Train/Fine-tune a pre-built DeepLabV3 with mobilenet as backbone for Semantic Segmentation.\n",
|
78 |
+
"3. Export the trained/tuned DeepLabV3 model"
|
79 |
+
]
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "markdown",
|
83 |
+
"metadata": {
|
84 |
+
"id": "AlxYhP0XFnDn"
|
85 |
+
},
|
86 |
+
"source": [
|
87 |
+
"## Install necessary dependencies"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": null,
|
93 |
+
"metadata": {
|
94 |
+
"id": "pXWAySwgaWpN"
|
95 |
+
},
|
96 |
+
"outputs": [],
|
97 |
+
"source": [
|
98 |
+
"!pip install -U -q \"tf-models-official\""
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "markdown",
|
103 |
+
"metadata": {
|
104 |
+
"id": "uExUsXlgaPD6"
|
105 |
+
},
|
106 |
+
"source": [
|
107 |
+
"## Import required libraries"
|
108 |
+
]
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"cell_type": "code",
|
112 |
+
"execution_count": null,
|
113 |
+
"metadata": {
|
114 |
+
"id": "mOmKZ3Vky5t9"
|
115 |
+
},
|
116 |
+
"outputs": [],
|
117 |
+
"source": [
|
118 |
+
"import os\n",
|
119 |
+
"import pprint\n",
|
120 |
+
"import numpy as np\n",
|
121 |
+
"import matplotlib.pyplot as plt\n",
|
122 |
+
"\n",
|
123 |
+
"from IPython import display"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "code",
|
128 |
+
"execution_count": null,
|
129 |
+
"metadata": {
|
130 |
+
"id": "nF8IHrXua_0b"
|
131 |
+
},
|
132 |
+
"outputs": [],
|
133 |
+
"source": [
|
134 |
+
"import tensorflow as tf\n",
|
135 |
+
"import tensorflow_datasets as tfds\n",
|
136 |
+
"\n",
|
137 |
+
"\n",
|
138 |
+
"import orbit\n",
|
139 |
+
"import tensorflow_models as tfm\n",
|
140 |
+
"from official.vision.data import tfrecord_lib\n",
|
141 |
+
"from official.vision.utils import summary_manager\n",
|
142 |
+
"from official.vision.serving import export_saved_model_lib\n",
|
143 |
+
"from official.vision.utils.object_detection import visualization_utils\n",
|
144 |
+
"\n",
|
145 |
+
"pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
|
146 |
+
"print(tf.__version__) # Check the version of tensorflow used\n",
|
147 |
+
"\n",
|
148 |
+
"%matplotlib inline"
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "markdown",
|
153 |
+
"metadata": {
|
154 |
+
"id": "gMs4l2dpaTd3"
|
155 |
+
},
|
156 |
+
"source": [
|
157 |
+
"## Custom dataset preparation for semantic segmentation\n",
|
158 |
+
"Models in Official repository (of model-garden) require models in a TFRecords dataformat.\n",
|
159 |
+
"\n",
|
160 |
+
"Please check [this resource](https://www.tensorflow.org/tutorials/load_data/tfrecord) to learn more about TFRecords data format.\n",
|
161 |
+
"\n",
|
162 |
+
"[Oxford_IIIT_pet:3](https://www.tensorflow.org/datasets/catalog/oxford_iiit_pet) dataset is taken from [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview)"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": null,
|
168 |
+
"metadata": {
|
169 |
+
"id": "JpWK1Z-N3fHh"
|
170 |
+
},
|
171 |
+
"outputs": [],
|
172 |
+
"source": [
|
173 |
+
"(train_ds, val_ds, test_ds), info = tfds.load(\n",
|
174 |
+
" 'oxford_iiit_pet:3.*.*',\n",
|
175 |
+
" split=['train+test[:50%]', 'test[50%:80%]', 'test[80%:100%]'],\n",
|
176 |
+
" with_info=True)\n",
|
177 |
+
"info"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "markdown",
|
182 |
+
"metadata": {
|
183 |
+
"id": "Sq6s11E1bMJB"
|
184 |
+
},
|
185 |
+
"source": [
|
186 |
+
"### Helper function to encode dataset as tfrecords"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"cell_type": "code",
|
191 |
+
"execution_count": null,
|
192 |
+
"metadata": {
|
193 |
+
"id": "NlEf_C-DjDHG"
|
194 |
+
},
|
195 |
+
"outputs": [],
|
196 |
+
"source": [
|
197 |
+
"def process_record(record):\n",
|
198 |
+
" keys_to_features = {\n",
|
199 |
+
" 'image/encoded': tfrecord_lib.convert_to_feature(\n",
|
200 |
+
" tf.io.encode_jpeg(record['image']).numpy()),\n",
|
201 |
+
" 'image/height': tfrecord_lib.convert_to_feature(record['image'].shape[0]),\n",
|
202 |
+
" 'image/width': tfrecord_lib.convert_to_feature(record['image'].shape[1]),\n",
|
203 |
+
" 'image/segmentation/class/encoded':tfrecord_lib.convert_to_feature(\n",
|
204 |
+
" tf.io.encode_png(record['segmentation_mask'] - 1).numpy())\n",
|
205 |
+
" }\n",
|
206 |
+
" example = tf.train.Example(\n",
|
207 |
+
" features=tf.train.Features(feature=keys_to_features))\n",
|
208 |
+
" return example"
|
209 |
+
]
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"cell_type": "markdown",
|
213 |
+
"metadata": {
|
214 |
+
"id": "FoapGlIebP9r"
|
215 |
+
},
|
216 |
+
"source": [
|
217 |
+
"### Write TFRecords to a folder"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"cell_type": "code",
|
222 |
+
"execution_count": null,
|
223 |
+
"metadata": {
|
224 |
+
"id": "dDbMn5q551LQ"
|
225 |
+
},
|
226 |
+
"outputs": [],
|
227 |
+
"source": [
|
228 |
+
"output_dir = './oxford_iiit_pet_tfrecords/'\n",
|
229 |
+
"LOG_EVERY = 100\n",
|
230 |
+
"if not os.path.exists(output_dir):\n",
|
231 |
+
" os.mkdir(output_dir)\n",
|
232 |
+
"\n",
|
233 |
+
"def write_tfrecords(dataset, output_path, num_shards=1):\n",
|
234 |
+
" writers = [\n",
|
235 |
+
" tf.io.TFRecordWriter(\n",
|
236 |
+
" output_path + '-%05d-of-%05d.tfrecord' % (i, num_shards))\n",
|
237 |
+
" for i in range(num_shards)\n",
|
238 |
+
" ]\n",
|
239 |
+
" for idx, record in enumerate(dataset):\n",
|
240 |
+
" if idx % LOG_EVERY == 0:\n",
|
241 |
+
" print('On image %d', idx)\n",
|
242 |
+
" tf_example = process_record(record)\n",
|
243 |
+
" writers[idx % num_shards].write(tf_example.SerializeToString())"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "markdown",
|
248 |
+
"metadata": {
|
249 |
+
"id": "QHDD-D7rbZj7"
|
250 |
+
},
|
251 |
+
"source": [
|
252 |
+
"### Write training data as TFRecords"
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"cell_type": "code",
|
257 |
+
"execution_count": null,
|
258 |
+
"metadata": {
|
259 |
+
"id": "qxJnVUfT0qBJ"
|
260 |
+
},
|
261 |
+
"outputs": [],
|
262 |
+
"source": [
|
263 |
+
"output_train_tfrecs = output_dir + 'train'\n",
|
264 |
+
"write_tfrecords(train_ds, output_train_tfrecs, num_shards=10)"
|
265 |
+
]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "markdown",
|
269 |
+
"metadata": {
|
270 |
+
"id": "ap55RwVFbhtu"
|
271 |
+
},
|
272 |
+
"source": [
|
273 |
+
"### Write validation data as TFRecords"
|
274 |
+
]
|
275 |
+
},
|
276 |
+
{
|
277 |
+
"cell_type": "code",
|
278 |
+
"execution_count": null,
|
279 |
+
"metadata": {
|
280 |
+
"id": "Fgq-VxF79ucR"
|
281 |
+
},
|
282 |
+
"outputs": [],
|
283 |
+
"source": [
|
284 |
+
"output_val_tfrecs = output_dir + 'val'\n",
|
285 |
+
"write_tfrecords(val_ds, output_val_tfrecs, num_shards=5)"
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "markdown",
|
290 |
+
"metadata": {
|
291 |
+
"id": "0AZoIEzRbxZu"
|
292 |
+
},
|
293 |
+
"source": [
|
294 |
+
"### Write test data as TFRecords"
|
295 |
+
]
|
296 |
+
},
|
297 |
+
{
|
298 |
+
"cell_type": "code",
|
299 |
+
"execution_count": null,
|
300 |
+
"metadata": {
|
301 |
+
"id": "QmwFmbP69t0U"
|
302 |
+
},
|
303 |
+
"outputs": [],
|
304 |
+
"source": [
|
305 |
+
"output_test_tfrecs = output_dir + 'test'\n",
|
306 |
+
"write_tfrecords(test_ds, output_test_tfrecs, num_shards=5)"
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "markdown",
|
311 |
+
"metadata": {
|
312 |
+
"id": "uEFzV-6ZfBZW"
|
313 |
+
},
|
314 |
+
"source": [
|
315 |
+
"## Configure the DeepLabV3 Mobilenet model for custom dataset"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
{
|
319 |
+
"cell_type": "code",
|
320 |
+
"execution_count": null,
|
321 |
+
"metadata": {
|
322 |
+
"id": "_LPEIvLsqSaG"
|
323 |
+
},
|
324 |
+
"outputs": [],
|
325 |
+
"source": [
|
326 |
+
"train_data_tfrecords = './oxford_iiit_pet_tfrecords/train*'\n",
|
327 |
+
"val_data_tfrecords = './oxford_iiit_pet_tfrecords/val*'\n",
|
328 |
+
"test_data_tfrecords = './oxford_iiit_pet_tfrecords/test*'\n",
|
329 |
+
"trained_model = './trained_model/'\n",
|
330 |
+
"export_dir = './exported_model/'"
|
331 |
+
]
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"cell_type": "markdown",
|
335 |
+
"metadata": {
|
336 |
+
"id": "1ZlSiSRyb1Q6"
|
337 |
+
},
|
338 |
+
"source": [
|
339 |
+
"In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
|
340 |
+
"\n",
|
341 |
+
"\n",
|
342 |
+
"Use the `mnv2_deeplabv3_pascal` experiment configuration, as defined by `tfm.vision.configs.semantic_segmentation.mnv2_deeplabv3_pascal`.\n",
|
343 |
+
"\n",
|
344 |
+
"Please find all the registered experiements [here](https://www.tensorflow.org/api_docs/python/tfm/core/exp_factory/get_exp_config)\n",
|
345 |
+
"\n",
|
346 |
+
"The configuration defines an experiment to train a [DeepLabV3](https://arxiv.org/pdf/1706.05587.pdf) model with MobilenetV2 as backbone and [ASPP](https://arxiv.org/pdf/1606.00915v2.pdf) as decoder.\n",
|
347 |
+
"\n",
|
348 |
+
"There are also other alternative experiments available such as\n",
|
349 |
+
"\n",
|
350 |
+
"* `seg_deeplabv3_pascal`\n",
|
351 |
+
"* `seg_deeplabv3plus_pascal`\n",
|
352 |
+
"* `seg_resnetfpn_pascal`\n",
|
353 |
+
"* `mnv2_deeplabv3plus_cityscapes`\n",
|
354 |
+
"\n",
|
355 |
+
"and more. One can switch to them by changing the experiment name argument to the `get_exp_config` function.\n"
|
356 |
+
]
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"cell_type": "code",
|
360 |
+
"execution_count": null,
|
361 |
+
"metadata": {
|
362 |
+
"id": "bj5UZ6BkfJCX"
|
363 |
+
},
|
364 |
+
"outputs": [],
|
365 |
+
"source": [
|
366 |
+
"exp_config = tfm.core.exp_factory.get_exp_config('mnv2_deeplabv3_pascal')"
|
367 |
+
]
|
368 |
+
},
|
369 |
+
{
|
370 |
+
"cell_type": "code",
|
371 |
+
"execution_count": null,
|
372 |
+
"metadata": {
|
373 |
+
"id": "B8jyG-jGIdFs"
|
374 |
+
},
|
375 |
+
"outputs": [],
|
376 |
+
"source": [
|
377 |
+
"model_ckpt_path = './model_ckpt/'\n",
|
378 |
+
"if not os.path.exists(model_ckpt_path):\n",
|
379 |
+
" os.mkdir(model_ckpt_path)\n",
|
380 |
+
"\n",
|
381 |
+
"!gsutil cp gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63.data-00000-of-00001 './model_ckpt/'\n",
|
382 |
+
"!gsutil cp gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63.index './model_ckpt/'"
|
383 |
+
]
|
384 |
+
},
|
385 |
+
{
|
386 |
+
"cell_type": "markdown",
|
387 |
+
"metadata": {
|
388 |
+
"id": "QBYvVFZXhSGQ"
|
389 |
+
},
|
390 |
+
"source": [
|
391 |
+
"### Adjust the model and dataset configurations so that it works with custom dataset."
|
392 |
+
]
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"cell_type": "code",
|
396 |
+
"execution_count": null,
|
397 |
+
"metadata": {
|
398 |
+
"id": "o_Z_vWW9-5Sy"
|
399 |
+
},
|
400 |
+
"outputs": [],
|
401 |
+
"source": [
|
402 |
+
"num_classes = 3\n",
|
403 |
+
"WIDTH, HEIGHT = 128, 128\n",
|
404 |
+
"input_size = [HEIGHT, WIDTH, 3]\n",
|
405 |
+
"BATCH_SIZE = 16\n",
|
406 |
+
"\n",
|
407 |
+
"# Backbone Config\n",
|
408 |
+
"exp_config.task.init_checkpoint = model_ckpt_path + 'best_ckpt-63'\n",
|
409 |
+
"exp_config.task.freeze_backbone = True\n",
|
410 |
+
"\n",
|
411 |
+
"# Model Config\n",
|
412 |
+
"exp_config.task.model.num_classes = num_classes\n",
|
413 |
+
"exp_config.task.model.input_size = input_size\n",
|
414 |
+
"\n",
|
415 |
+
"# Training Data Config\n",
|
416 |
+
"exp_config.task.train_data.aug_scale_min = 1.0\n",
|
417 |
+
"exp_config.task.train_data.aug_scale_max = 1.0\n",
|
418 |
+
"exp_config.task.train_data.input_path = train_data_tfrecords\n",
|
419 |
+
"exp_config.task.train_data.global_batch_size = BATCH_SIZE\n",
|
420 |
+
"exp_config.task.train_data.dtype = 'float32'\n",
|
421 |
+
"exp_config.task.train_data.output_size = [HEIGHT, WIDTH]\n",
|
422 |
+
"exp_config.task.train_data.preserve_aspect_ratio = False\n",
|
423 |
+
"exp_config.task.train_data.seed = 21 # Reproducable Training Data\n",
|
424 |
+
"\n",
|
425 |
+
"# Validation Data Config\n",
|
426 |
+
"exp_config.task.validation_data.input_path = val_data_tfrecords\n",
|
427 |
+
"exp_config.task.validation_data.global_batch_size = BATCH_SIZE\n",
|
428 |
+
"exp_config.task.validation_data.dtype = 'float32'\n",
|
429 |
+
"exp_config.task.validation_data.output_size = [HEIGHT, WIDTH]\n",
|
430 |
+
"exp_config.task.validation_data.preserve_aspect_ratio = False\n",
|
431 |
+
"exp_config.task.validation_data.groundtruth_padded_size = [HEIGHT, WIDTH]\n",
|
432 |
+
"exp_config.task.validation_data.seed = 21 # Reproducable Validation Data\n",
|
433 |
+
"exp_config.task.validation_data.resize_eval_groundtruth = True # To enable validation loss"
|
434 |
+
]
|
435 |
+
},
|
436 |
+
{
|
437 |
+
"cell_type": "markdown",
|
438 |
+
"metadata": {
|
439 |
+
"id": "0HDg5eKniMGJ"
|
440 |
+
},
|
441 |
+
"source": [
|
442 |
+
"### Adjust the trainer configuration."
|
443 |
+
]
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"cell_type": "code",
|
447 |
+
"execution_count": null,
|
448 |
+
"metadata": {
|
449 |
+
"id": "WASJZ3gUH8ni"
|
450 |
+
},
|
451 |
+
"outputs": [],
|
452 |
+
"source": [
|
453 |
+
"logical_device_names = [logical_device.name\n",
|
454 |
+
" for logical_device in tf.config.list_logical_devices()]\n",
|
455 |
+
"\n",
|
456 |
+
"if 'GPU' in ''.join(logical_device_names):\n",
|
457 |
+
" print('This may be broken in Colab.')\n",
|
458 |
+
" device = 'GPU'\n",
|
459 |
+
"elif 'TPU' in ''.join(logical_device_names):\n",
|
460 |
+
" print('This may be broken in Colab.')\n",
|
461 |
+
" device = 'TPU'\n",
|
462 |
+
"else:\n",
|
463 |
+
" print('Running on CPU is slow, so only train for a few steps.')\n",
|
464 |
+
" device = 'CPU'\n",
|
465 |
+
"\n",
|
466 |
+
"\n",
|
467 |
+
"train_steps = 2000\n",
|
468 |
+
"exp_config.trainer.steps_per_loop = int(train_ds.__len__().numpy() // BATCH_SIZE)\n",
|
469 |
+
"\n",
|
470 |
+
"exp_config.trainer.summary_interval = exp_config.trainer.steps_per_loop # steps_per_loop = num_of_validation_examples // eval_batch_size\n",
|
471 |
+
"exp_config.trainer.checkpoint_interval = exp_config.trainer.steps_per_loop\n",
|
472 |
+
"exp_config.trainer.validation_interval = exp_config.trainer.steps_per_loop\n",
|
473 |
+
"exp_config.trainer.validation_steps = int(train_ds.__len__().numpy() // BATCH_SIZE) # validation_steps = num_of_validation_examples // eval_batch_size\n",
|
474 |
+
"exp_config.trainer.train_steps = train_steps\n",
|
475 |
+
"exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = exp_config.trainer.steps_per_loop\n",
|
476 |
+
"exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
|
477 |
+
"exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
|
478 |
+
"exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
|
479 |
+
"exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05"
|
480 |
+
]
|
481 |
+
},
|
482 |
+
{
|
483 |
+
"cell_type": "markdown",
|
484 |
+
"metadata": {
|
485 |
+
"id": "R66w5MwkiO8Z"
|
486 |
+
},
|
487 |
+
"source": [
|
488 |
+
"### Print the modified configuration."
|
489 |
+
]
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"cell_type": "code",
|
493 |
+
"execution_count": null,
|
494 |
+
"metadata": {
|
495 |
+
"id": "ckpjzrqfhoSn"
|
496 |
+
},
|
497 |
+
"outputs": [],
|
498 |
+
"source": [
|
499 |
+
"pp.pprint(exp_config.as_dict())\n",
|
500 |
+
"display.Javascript('google.colab.output.setIframeHeight(\"500px\");')"
|
501 |
+
]
|
502 |
+
},
|
503 |
+
{
|
504 |
+
"cell_type": "markdown",
|
505 |
+
"metadata": {
|
506 |
+
"id": "FYwzdGKAiSOV"
|
507 |
+
},
|
508 |
+
"source": [
|
509 |
+
"### Set up the distribution strategy."
|
510 |
+
]
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"cell_type": "code",
|
514 |
+
"execution_count": null,
|
515 |
+
"metadata": {
|
516 |
+
"id": "iwiOuYRRqdBi"
|
517 |
+
},
|
518 |
+
"outputs": [],
|
519 |
+
"source": [
|
520 |
+
"# Setting up the Strategy\n",
|
521 |
+
"if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
|
522 |
+
" tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
|
523 |
+
"\n",
|
524 |
+
"if 'GPU' in ''.join(logical_device_names):\n",
|
525 |
+
" distribution_strategy = tf.distribute.MirroredStrategy()\n",
|
526 |
+
"elif 'TPU' in ''.join(logical_device_names):\n",
|
527 |
+
" tf.tpu.experimental.initialize_tpu_system()\n",
|
528 |
+
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
|
529 |
+
" distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
|
530 |
+
"else:\n",
|
531 |
+
" print('Warning: this will be really slow.')\n",
|
532 |
+
" distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n",
|
533 |
+
"\n",
|
534 |
+
"print(\"Done\")"
|
535 |
+
]
|
536 |
+
},
|
537 |
+
{
|
538 |
+
"cell_type": "markdown",
|
539 |
+
"metadata": {
|
540 |
+
"id": "ZLtk1GIIiVR2"
|
541 |
+
},
|
542 |
+
"source": [
|
543 |
+
"## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
|
544 |
+
"\n",
|
545 |
+
"The `Task` object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
|
546 |
+
]
|
547 |
+
},
|
548 |
+
{
|
549 |
+
"cell_type": "code",
|
550 |
+
"execution_count": null,
|
551 |
+
"metadata": {
|
552 |
+
"id": "ASTB5D2UISSr"
|
553 |
+
},
|
554 |
+
"outputs": [],
|
555 |
+
"source": [
|
556 |
+
"model_dir = './trained_model/'\n",
|
557 |
+
"\n",
|
558 |
+
"with distribution_strategy.scope():\n",
|
559 |
+
" task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)"
|
560 |
+
]
|
561 |
+
},
|
562 |
+
{
|
563 |
+
"cell_type": "markdown",
|
564 |
+
"metadata": {
|
565 |
+
"id": "YIQ26TW-ihzA"
|
566 |
+
},
|
567 |
+
"source": [
|
568 |
+
"## Visualize a batch of the data."
|
569 |
+
]
|
570 |
+
},
|
571 |
+
{
|
572 |
+
"cell_type": "code",
|
573 |
+
"execution_count": null,
|
574 |
+
"metadata": {
|
575 |
+
"id": "412WyIUAIdCr"
|
576 |
+
},
|
577 |
+
"outputs": [],
|
578 |
+
"source": [
|
579 |
+
"for images, masks in task.build_inputs(exp_config.task.train_data).take(1):\n",
|
580 |
+
" print()\n",
|
581 |
+
" print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
|
582 |
+
" print(f'masks.shape: {str(masks[\"masks\"].shape):16} images.dtype: {masks[\"masks\"].dtype!r}')"
|
583 |
+
]
|
584 |
+
},
|
585 |
+
{
|
586 |
+
"cell_type": "markdown",
|
587 |
+
"metadata": {
|
588 |
+
"id": "3GgluDVJixMd"
|
589 |
+
},
|
590 |
+
"source": [
|
591 |
+
"### Helper function for visualizing the results from TFRecords"
|
592 |
+
]
|
593 |
+
},
|
594 |
+
{
|
595 |
+
"cell_type": "code",
|
596 |
+
"execution_count": null,
|
597 |
+
"metadata": {
|
598 |
+
"id": "1kueMMfERvLx"
|
599 |
+
},
|
600 |
+
"outputs": [],
|
601 |
+
"source": [
|
602 |
+
"def plot_masks(display_list):\n",
|
603 |
+
" plt.figure(figsize=(15, 15))\n",
|
604 |
+
"\n",
|
605 |
+
" title = ['Input Image', 'True Mask', 'Predicted Mask']\n",
|
606 |
+
"\n",
|
607 |
+
" for i in range(len(display_list)):\n",
|
608 |
+
" plt.subplot(1, len(display_list), i+1)\n",
|
609 |
+
" plt.title(title[i])\n",
|
610 |
+
" plt.imshow(tf.keras.utils.array_to_img(display_list[i]))\n",
|
611 |
+
"\n",
|
612 |
+
"\n",
|
613 |
+
" plt.axis('off')\n",
|
614 |
+
" plt.show()"
|
615 |
+
]
|
616 |
+
},
|
617 |
+
{
|
618 |
+
"cell_type": "markdown",
|
619 |
+
"metadata": {
|
620 |
+
"id": "ZCtt09G7i3dq"
|
621 |
+
},
|
622 |
+
"source": [
|
623 |
+
"### Visualization of training data\n",
|
624 |
+
"\n",
|
625 |
+
"Image Title represents what is depicted from the image.\n",
|
626 |
+
"\n",
|
627 |
+
"Same helper function can be used while visualizing predicted mask"
|
628 |
+
]
|
629 |
+
},
|
630 |
+
{
|
631 |
+
"cell_type": "code",
|
632 |
+
"execution_count": null,
|
633 |
+
"metadata": {
|
634 |
+
"id": "YwUPf9V2B6SR"
|
635 |
+
},
|
636 |
+
"outputs": [],
|
637 |
+
"source": [
|
638 |
+
"num_examples = 3\n",
|
639 |
+
"\n",
|
640 |
+
"for images, masks in task.build_inputs(exp_config.task.train_data).take(num_examples):\n",
|
641 |
+
" plot_masks([images[0], masks['masks'][0]])"
|
642 |
+
]
|
643 |
+
},
|
644 |
+
{
|
645 |
+
"cell_type": "markdown",
|
646 |
+
"metadata": {
|
647 |
+
"id": "MeJ5w8KfjMmP"
|
648 |
+
},
|
649 |
+
"source": [
|
650 |
+
"## Train and evaluate\n",
|
651 |
+
"**IoU**: is defined as the area of the intersection divided by the area of the union of a predicted mask and ground truth mask."
|
652 |
+
]
|
653 |
+
},
|
654 |
+
{
|
655 |
+
"cell_type": "code",
|
656 |
+
"execution_count": null,
|
657 |
+
"metadata": {
|
658 |
+
"id": "ru3aHTCySHoH"
|
659 |
+
},
|
660 |
+
"outputs": [],
|
661 |
+
"source": [
|
662 |
+
"\n",
|
663 |
+
"model, eval_logs = tfm.core.train_lib.run_experiment(\n",
|
664 |
+
" distribution_strategy=distribution_strategy,\n",
|
665 |
+
" task=task,\n",
|
666 |
+
" mode='train_and_eval',\n",
|
667 |
+
" params=exp_config,\n",
|
668 |
+
" model_dir=model_dir,\n",
|
669 |
+
" eval_summary_manager=summary_manager.maybe_build_eval_summary_manager(\n",
|
670 |
+
" params=exp_config, model_dir=model_dir),\n",
|
671 |
+
" run_post_eval=True)"
|
672 |
+
]
|
673 |
+
},
|
674 |
+
{
|
675 |
+
"cell_type": "markdown",
|
676 |
+
"metadata": {
|
677 |
+
"id": "vt3WmtxhjfGe"
|
678 |
+
},
|
679 |
+
"source": [
|
680 |
+
"## Load logs in tensorboard"
|
681 |
+
]
|
682 |
+
},
|
683 |
+
{
|
684 |
+
"cell_type": "code",
|
685 |
+
"execution_count": null,
|
686 |
+
"metadata": {
|
687 |
+
"id": "A9rct_7BoJFb"
|
688 |
+
},
|
689 |
+
"outputs": [],
|
690 |
+
"source": [
|
691 |
+
"%load_ext tensorboard\n",
|
692 |
+
"%tensorboard --logdir './trained_model'"
|
693 |
+
]
|
694 |
+
},
|
695 |
+
{
|
696 |
+
"cell_type": "markdown",
|
697 |
+
"metadata": {
|
698 |
+
"id": "v6XaGoUuji7P"
|
699 |
+
},
|
700 |
+
"source": [
|
701 |
+
"## Saving and exporting the trained model\n",
|
702 |
+
"\n",
|
703 |
+
"The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results."
|
704 |
+
]
|
705 |
+
},
|
706 |
+
{
|
707 |
+
"cell_type": "code",
|
708 |
+
"execution_count": null,
|
709 |
+
"metadata": {
|
710 |
+
"id": "GVsnyqzdnxHd"
|
711 |
+
},
|
712 |
+
"outputs": [],
|
713 |
+
"source": [
|
714 |
+
"export_saved_model_lib.export_inference_graph(\n",
|
715 |
+
" input_type='image_tensor',\n",
|
716 |
+
" batch_size=1,\n",
|
717 |
+
" input_image_size=[HEIGHT, WIDTH],\n",
|
718 |
+
" params=exp_config,\n",
|
719 |
+
" checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
|
720 |
+
" export_dir=export_dir)"
|
721 |
+
]
|
722 |
+
},
|
723 |
+
{
|
724 |
+
"cell_type": "markdown",
|
725 |
+
"metadata": {
|
726 |
+
"id": "nM1S-tjIjvAr"
|
727 |
+
},
|
728 |
+
"source": [
|
729 |
+
"## Importing SavedModel"
|
730 |
+
]
|
731 |
+
},
|
732 |
+
{
|
733 |
+
"cell_type": "code",
|
734 |
+
"execution_count": null,
|
735 |
+
"metadata": {
|
736 |
+
"id": "Nxi9pEwluUcT"
|
737 |
+
},
|
738 |
+
"outputs": [],
|
739 |
+
"source": [
|
740 |
+
"imported = tf.saved_model.load(export_dir)\n",
|
741 |
+
"model_fn = imported.signatures['serving_default']"
|
742 |
+
]
|
743 |
+
},
|
744 |
+
{
|
745 |
+
"cell_type": "markdown",
|
746 |
+
"metadata": {
|
747 |
+
"id": "LbBfl6AUj_My"
|
748 |
+
},
|
749 |
+
"source": [
|
750 |
+
"## Visualize predictions"
|
751 |
+
]
|
752 |
+
},
|
753 |
+
{
|
754 |
+
"cell_type": "code",
|
755 |
+
"execution_count": null,
|
756 |
+
"metadata": {
|
757 |
+
"id": "qifGt_ohpFhn"
|
758 |
+
},
|
759 |
+
"outputs": [],
|
760 |
+
"source": [
|
761 |
+
"def create_mask(pred_mask):\n",
|
762 |
+
" pred_mask = tf.math.argmax(pred_mask, axis=-1)\n",
|
763 |
+
" pred_mask = pred_mask[..., tf.newaxis]\n",
|
764 |
+
" return pred_mask[0]\n",
|
765 |
+
"\n",
|
766 |
+
"\n",
|
767 |
+
"for record in test_ds.take(15):\n",
|
768 |
+
" image = tf.image.resize(record['image'], size=[HEIGHT, WIDTH])\n",
|
769 |
+
" image = tf.cast(image, dtype=tf.uint8)\n",
|
770 |
+
" mask = tf.image.resize(record['segmentation_mask'], size=[HEIGHT, WIDTH])\n",
|
771 |
+
" predicted_mask = model_fn(tf.expand_dims(record['image'], axis=0))\n",
|
772 |
+
" plot_masks([image, mask, create_mask(predicted_mask['logits'])])"
|
773 |
+
]
|
774 |
+
}
|
775 |
+
],
|
776 |
+
"metadata": {
|
777 |
+
"accelerator": "GPU",
|
778 |
+
"colab": {
|
779 |
+
"name": "semantic_segmentation.ipynb",
|
780 |
+
"provenance": [],
|
781 |
+
"toc_visible": true
|
782 |
+
},
|
783 |
+
"kernelspec": {
|
784 |
+
"display_name": "Python 3",
|
785 |
+
"name": "python3"
|
786 |
+
}
|
787 |
+
},
|
788 |
+
"nbformat": 4,
|
789 |
+
"nbformat_minor": 0
|
790 |
+
}
|
Tensorflow/models/official/README-TPU.md
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Offically Supported TensorFlow 2.1+ Models on Cloud TPU
|
2 |
+
|
3 |
+
## Natural Language Processing
|
4 |
+
|
5 |
+
* [bert](https://arxiv.org/abs/1810.04805): A powerful pre-trained language representation model:
|
6 |
+
BERT, which stands for Bidirectional Encoder Representations from
|
7 |
+
Transformers.
|
8 |
+
[BERT FineTuning with Cloud TPU](https://cloud.google.com/ai-platform/training/docs/algorithms/bert-start) provides step by step instructions on Cloud TPU training. You can look [Bert MNLI Tensorboard.dev metrics](https://tensorboard.dev/experiment/LijZ1IrERxKALQfr76gndA) for MNLI fine tuning task.
|
9 |
+
* [transformer](nlp/transformer): A transformer model to translate the WMT
|
10 |
+
English to German dataset.
|
11 |
+
[Training transformer on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/transformer-2.x) for step by step instructions on Cloud TPU training.
|
12 |
+
|
13 |
+
## Computer Vision
|
14 |
+
|
15 |
+
* [efficientnet](https://github.com/tensorflow/models/blob/master/official/vision/modeling/backbones/efficientnet.py): A family of convolutional
|
16 |
+
neural networks that scale by balancing network depth, width, and
|
17 |
+
resolution and can be used to classify ImageNet's dataset of 1000 classes.
|
18 |
+
See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/KnaWjrq5TXGfv0NW5m7rpg/#scalars).
|
19 |
+
* [mnist](https://www.tensorflow.org/datasets/catalog/mnist): A basic model to classify digits
|
20 |
+
from the MNIST dataset. See [Running MNIST on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/mnist-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/mIah5lppTASvrHqWrdr6NA).
|
21 |
+
* [mask-rcnn](https://www.tensorflow.org/api_docs/python/tfm/vision/configs/maskrcnn/MaskRCNN): An object detection and instance segmentation model. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/LH7k0fMsRwqUAcE09o9kPA).
|
22 |
+
* [resnet]((https://www.tensorflow.org/api_docs/python/tfm/vision/configs/image_classification/image_classification_imagenet)): A deep residual network that can
|
23 |
+
be used to classify ImageNet's dataset of 1000 classes.
|
24 |
+
See [Training ResNet on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/resnet-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/CxlDK8YMRrSpYEGtBRpOhg).
|
25 |
+
* [retinanet](https://www.tensorflow.org/api_docs/python/tfm/vision/retinanet): A fast and powerful object detector. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/b8NRnWU3TqG6Rw0UxueU6Q).
|
26 |
+
* [shapemask](https://cloud.google.com/tpu/docs/tutorials/shapemask-2.x): An object detection and instance segmentation model using shape priors. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/ZbXgVoc6Rf6mBRlPj0JpLA).
|
27 |
+
|
28 |
+
## Recommendation
|
29 |
+
* [dlrm](recommendation/ranking): [Deep Learning Recommendation Model for
|
30 |
+
Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091).
|
31 |
+
* [dcn v2](recommendation/ranking): [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535).
|
32 |
+
* [ncf](recommendation): Neural Collaborative Filtering. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/0k3gKjZlR1ewkVTRyLB6IQ).
|
Tensorflow/models/official/README.md
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<img src="https://storage.googleapis.com/tf_model_garden/tf_model_garden_logo.png">
|
3 |
+
</div>
|
4 |
+
|
5 |
+
# TensorFlow Official Models
|
6 |
+
|
7 |
+
The TensorFlow official models are a collection of models
|
8 |
+
that use TensorFlow’s high-level APIs.
|
9 |
+
They are intended to be well-maintained, tested, and kept up to date
|
10 |
+
with the latest TensorFlow API.
|
11 |
+
|
12 |
+
They should also be reasonably optimized for fast performance while still
|
13 |
+
being easy to read.
|
14 |
+
These models are used as end-to-end tests, ensuring that the models run
|
15 |
+
with the same or improved speed and performance with each new TensorFlow build.
|
16 |
+
|
17 |
+
The API documentation of the latest stable release is published to
|
18 |
+
[tensorflow.org](https://www.tensorflow.org/api_docs/python/tfm).
|
19 |
+
|
20 |
+
## More models to come!
|
21 |
+
|
22 |
+
The team is actively developing new models.
|
23 |
+
In the near future, we will add:
|
24 |
+
|
25 |
+
* State-of-the-art language understanding models.
|
26 |
+
* State-of-the-art image classification models.
|
27 |
+
* State-of-the-art object detection and instance segmentation models.
|
28 |
+
* State-of-the-art video classification models.
|
29 |
+
|
30 |
+
## Table of Contents
|
31 |
+
|
32 |
+
- [Models and Implementations](#models-and-implementations)
|
33 |
+
* [Computer Vision](#computer-vision)
|
34 |
+
+ [Image Classification](#image-classification)
|
35 |
+
+ [Object Detection and Segmentation](#object-detection-and-segmentation)
|
36 |
+
+ [Video Classification](#video-classification)
|
37 |
+
* [Natural Language Processing](#natural-language-processing)
|
38 |
+
* [Recommendation](#recommendation)
|
39 |
+
- [How to get started with the official models](#how-to-get-started-with-the-official-models)
|
40 |
+
- [Contributions](#contributions)
|
41 |
+
|
42 |
+
## Models and Implementations
|
43 |
+
|
44 |
+
### [Computer Vision](vision/README.md)
|
45 |
+
|
46 |
+
#### Image Classification
|
47 |
+
|
48 |
+
| Model | Reference (Paper) |
|
49 |
+
|-------|-------------------|
|
50 |
+
| [ResNet](vision/MODEL_GARDEN.md) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
|
51 |
+
| [ResNet-RS](vision/MODEL_GARDEN.md) | [Revisiting ResNets: Improved Training and Scaling Strategies](https://arxiv.org/abs/2103.07579) |
|
52 |
+
| [EfficientNet](vision/MODEL_GARDEN.md) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
|
53 |
+
| [Vision Transformer](vision/MODEL_GARDEN.md) | [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) |
|
54 |
+
|
55 |
+
#### Object Detection and Segmentation
|
56 |
+
|
57 |
+
| Model | Reference (Paper) |
|
58 |
+
|-------|-------------------|
|
59 |
+
| [RetinaNet](vision/MODEL_GARDEN.md) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
|
60 |
+
| [Mask R-CNN](vision/MODEL_GARDEN.md) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
|
61 |
+
| [YOLO](projects/yolo/README.md) | [YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors](https://arxiv.org/abs/2207.02696) |
|
62 |
+
| [SpineNet](vision/MODEL_GARDEN.md) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
|
63 |
+
| [Cascade RCNN-RS and RetinaNet-RS](vision/MODEL_GARDEN.md) | [Simple Training Strategies and Model Scaling for Object Detection](https://arxiv.org/abs/2107.00057)|
|
64 |
+
|
65 |
+
#### Video Classification
|
66 |
+
|
67 |
+
| Model | Reference (Paper) |
|
68 |
+
|-------|-------------------|
|
69 |
+
| [Mobile Video Networks (MoViNets)](projects/movinet) | [MoViNets: Mobile Video Networks for Efficient Video Recognition](https://arxiv.org/abs/2103.11511) |
|
70 |
+
|
71 |
+
### [Natural Language Processing](nlp/README.md)
|
72 |
+
|
73 |
+
#### Pre-trained Language Model
|
74 |
+
|
75 |
+
| Model | Reference (Paper) |
|
76 |
+
|-------|-------------------|
|
77 |
+
| [ALBERT](nlp/MODEL_GARDEN.md#available-model-configs) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
|
78 |
+
| [BERT](nlp/MODEL_GARDEN.md#available-model-configs) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
|
79 |
+
| [ELECTRA](nlp/tasks/electra_task.py) | [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555) |
|
80 |
+
|
81 |
+
|
82 |
+
#### Neural Machine Translation
|
83 |
+
|
84 |
+
| Model | Reference (Paper) |
|
85 |
+
|-------|-------------------|
|
86 |
+
| [Transformer](nlp/MODEL_GARDEN.md#available-model-configs) | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
|
87 |
+
|
88 |
+
#### Natural Language Generation
|
89 |
+
|
90 |
+
| Model | Reference (Paper) |
|
91 |
+
|-------|-------------------|
|
92 |
+
| [NHNet (News Headline generation model)](projects/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
|
93 |
+
|
94 |
+
|
95 |
+
#### Knowledge Distillation
|
96 |
+
|
97 |
+
| Model | Reference (Paper) |
|
98 |
+
|-------|-------------------|
|
99 |
+
| [MobileBERT](projects/mobilebert) | [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) |
|
100 |
+
|
101 |
+
### Recommendation
|
102 |
+
|
103 |
+
Model | Reference (Paper)
|
104 |
+
-------------------------------- | -----------------
|
105 |
+
[DLRM](recommendation/ranking) | [Deep Learning Recommendation Model for Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091)
|
106 |
+
[DCN v2](recommendation/ranking) | [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535)
|
107 |
+
[NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031)
|
108 |
+
|
109 |
+
## How to get started with the official models
|
110 |
+
|
111 |
+
* The official models in the master branch are developed using
|
112 |
+
[master branch of TensorFlow 2](https://github.com/tensorflow/tensorflow/tree/master).
|
113 |
+
When you clone (the repository) or download (`pip` binary) master branch of
|
114 |
+
official models , master branch of TensorFlow gets downloaded as a
|
115 |
+
dependency. This is equivalent to the following.
|
116 |
+
|
117 |
+
```shell
|
118 |
+
pip3 install tf-models-nightly
|
119 |
+
pip3 install tensorflow-text-nightly # when model uses `nlp` packages
|
120 |
+
```
|
121 |
+
|
122 |
+
* Incase of stable versions, targeting a specific release, Tensorflow-models
|
123 |
+
repository version numbers match with the target TensorFlow release. For
|
124 |
+
example, [TensorFlow-models v2.8.x](https://github.com/tensorflow/models/releases/tag/v2.8.0)
|
125 |
+
is compatible with [TensorFlow v2.8.x](https://github.com/tensorflow/tensorflow/releases/tag/v2.8.0).
|
126 |
+
This is equivalent to the following:
|
127 |
+
|
128 |
+
```shell
|
129 |
+
pip3 install tf-models-official==2.8.0
|
130 |
+
pip3 install tensorflow-text==2.8.0 # when models in uses `nlp` packages
|
131 |
+
```
|
132 |
+
|
133 |
+
Starting from 2.9.x release, we release the modeling library as
|
134 |
+
`tensorflow_models` package and users can `import tensorflow_models` directly to
|
135 |
+
access to the exported symbols. If you are
|
136 |
+
using the latest nightly version or github code directly, please follow the
|
137 |
+
docstrings in the github.
|
138 |
+
|
139 |
+
Please follow the below steps before running models in this repository.
|
140 |
+
|
141 |
+
### Requirements
|
142 |
+
|
143 |
+
* The latest TensorFlow Model Garden release and the latest TensorFlow 2
|
144 |
+
* If you are on a version of TensorFlow earlier than 2.2, please
|
145 |
+
upgrade your TensorFlow to [the latest TensorFlow 2](https://www.tensorflow.org/install/).
|
146 |
+
* Python 3.7+
|
147 |
+
|
148 |
+
Our integration tests run with Python 3.7. Although Python 3.6 should work, we
|
149 |
+
don't recommend earlier versions.
|
150 |
+
|
151 |
+
### Installation
|
152 |
+
|
153 |
+
Please check [here](https://github.com/tensorflow/models#Installation) for the
|
154 |
+
instructions.
|
155 |
+
|
156 |
+
Available pypi packages:
|
157 |
+
|
158 |
+
* [tf-models-official](https://pypi.org/project/tf-models-official/)
|
159 |
+
* [tf-models-nightly](https://pypi.org/project/tf-models-nightly/): nightly
|
160 |
+
release with the latest changes.
|
161 |
+
* [tf-models-no-deps](https://pypi.org/project/tf-models-no-deps/): without
|
162 |
+
`tensorflow` and `tensorflow-text` in the `install_requires` list.
|
163 |
+
|
164 |
+
### Examples and Tutorials
|
165 |
+
|
166 |
+
Get started with TensorFlow Model Garden by exploring the provided examples and tutorials:
|
167 |
+
|
168 |
+
* [NLP](https://www.tensorflow.org/tfmodels/nlp)
|
169 |
+
* [Image classification](https://www.tensorflow.org/tfmodels/vision/image_classification)
|
170 |
+
* [Object detection](https://www.tensorflow.org/tfmodels/vision/object_detection)
|
171 |
+
* [Semantic Segmentation](https://www.tensorflow.org/tfmodels/vision/semantic_segmentation)
|
172 |
+
* [Instance Segmentation](https://www.tensorflow.org/tfmodels/vision/instance_segmentation)
|
173 |
+
|
174 |
+
|
175 |
+
## Contributions
|
176 |
+
|
177 |
+
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
|
Tensorflow/models/official/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
Tensorflow/models/official/common/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
Tensorflow/models/official/common/dataset_fn.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
16 |
+
#
|
17 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
18 |
+
# you may not use this file except in compliance with the License.
|
19 |
+
# You may obtain a copy of the License at
|
20 |
+
#
|
21 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
22 |
+
#
|
23 |
+
# Unless required by applicable law or agreed to in writing, software
|
24 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
25 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
26 |
+
# See the License for the specific language governing permissions and
|
27 |
+
# limitations under the License.
|
28 |
+
# ==============================================================================
|
29 |
+
"""Utility library for picking an appropriate dataset function."""
|
30 |
+
|
31 |
+
import functools
|
32 |
+
from typing import Any, Callable, Type, Union
|
33 |
+
|
34 |
+
import tensorflow as tf, tf_keras
|
35 |
+
|
36 |
+
PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]]
|
37 |
+
|
38 |
+
|
39 |
+
def pick_dataset_fn(file_type: str) -> PossibleDatasetType:
|
40 |
+
if file_type == 'tfrecord':
|
41 |
+
return tf.data.TFRecordDataset
|
42 |
+
if file_type == 'tfrecord_compressed':
|
43 |
+
return functools.partial(tf.data.TFRecordDataset, compression_type='GZIP')
|
44 |
+
raise ValueError('Unrecognized file_type: {}'.format(file_type))
|
Tensorflow/models/official/common/distribute_utils.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Helper functions for running models in a distributed setting."""
|
16 |
+
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
|
21 |
+
|
22 |
+
def _collective_communication(all_reduce_alg):
|
23 |
+
"""Return a CollectiveCommunication based on all_reduce_alg.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
all_reduce_alg: a string specifying which collective communication to pick,
|
27 |
+
or None.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
tf.distribute.experimental.CollectiveCommunication object
|
31 |
+
|
32 |
+
Raises:
|
33 |
+
ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
|
34 |
+
"""
|
35 |
+
collective_communication_options = {
|
36 |
+
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
|
37 |
+
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
|
38 |
+
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
|
39 |
+
}
|
40 |
+
if all_reduce_alg not in collective_communication_options:
|
41 |
+
raise ValueError(
|
42 |
+
"When used with `multi_worker_mirrored`, valid values for "
|
43 |
+
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
|
44 |
+
all_reduce_alg))
|
45 |
+
return collective_communication_options[all_reduce_alg]
|
46 |
+
|
47 |
+
|
48 |
+
def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
|
49 |
+
"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
all_reduce_alg: a string specifying which cross device op to pick, or None.
|
53 |
+
num_packs: an integer specifying number of packs for the cross device op.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tf.distribute.CrossDeviceOps object or None.
|
57 |
+
|
58 |
+
Raises:
|
59 |
+
ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
|
60 |
+
"""
|
61 |
+
if all_reduce_alg is None:
|
62 |
+
return None
|
63 |
+
mirrored_all_reduce_options = {
|
64 |
+
"nccl": tf.distribute.NcclAllReduce,
|
65 |
+
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
|
66 |
+
}
|
67 |
+
if all_reduce_alg not in mirrored_all_reduce_options:
|
68 |
+
raise ValueError(
|
69 |
+
"When used with `mirrored`, valid values for all_reduce_alg are "
|
70 |
+
"[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
|
71 |
+
all_reduce_alg))
|
72 |
+
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
|
73 |
+
return cross_device_ops_class(num_packs=num_packs)
|
74 |
+
|
75 |
+
|
76 |
+
def tpu_initialize(tpu_address):
|
77 |
+
"""Initializes TPU for TF 2.x training.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
tpu_address: string, bns address of master TPU worker.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
A TPUClusterResolver.
|
84 |
+
"""
|
85 |
+
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
|
86 |
+
tpu=tpu_address)
|
87 |
+
if tpu_address not in ("", "local"):
|
88 |
+
tf.config.experimental_connect_to_cluster(cluster_resolver)
|
89 |
+
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
|
90 |
+
return cluster_resolver
|
91 |
+
|
92 |
+
|
93 |
+
def get_distribution_strategy(distribution_strategy="mirrored",
|
94 |
+
num_gpus=0,
|
95 |
+
all_reduce_alg=None,
|
96 |
+
num_packs=1,
|
97 |
+
tpu_address=None,
|
98 |
+
**kwargs):
|
99 |
+
"""Return a Strategy for running the model.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
distribution_strategy: a string specifying which distribution strategy to
|
103 |
+
use. Accepted values are "off", "one_device", "mirrored",
|
104 |
+
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
|
105 |
+
insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
|
106 |
+
"off" means to use the default strategy which is obtained from
|
107 |
+
tf.distribute.get_strategy (for details on the default strategy, see
|
108 |
+
https://www.tensorflow.org/guide/distributed_training#default_strategy).
|
109 |
+
num_gpus: Number of GPUs to run this model.
|
110 |
+
all_reduce_alg: Optional. Specifies which algorithm to use when performing
|
111 |
+
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
|
112 |
+
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
|
113 |
+
"ring" and "nccl". If None, DistributionStrategy will choose based on
|
114 |
+
device topology.
|
115 |
+
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
|
116 |
+
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
|
117 |
+
tpu_address: Optional. String that represents TPU to connect to. Must not be
|
118 |
+
None if `distribution_strategy` is set to `tpu`.
|
119 |
+
**kwargs: Additional kwargs for internal usages.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
tf.distribute.Strategy object.
|
123 |
+
Raises:
|
124 |
+
ValueError: if `distribution_strategy` is "off" or "one_device" and
|
125 |
+
`num_gpus` is larger than 1; or `num_gpus` is negative or if
|
126 |
+
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
|
127 |
+
"""
|
128 |
+
del kwargs
|
129 |
+
if num_gpus < 0:
|
130 |
+
raise ValueError("`num_gpus` can not be negative.")
|
131 |
+
|
132 |
+
if not isinstance(distribution_strategy, str):
|
133 |
+
msg = ("distribution_strategy must be a string but got: %s." %
|
134 |
+
(distribution_strategy,))
|
135 |
+
if distribution_strategy == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
|
136 |
+
msg += (" If you meant to pass the string 'off', make sure you add "
|
137 |
+
"quotes around 'off' so that yaml interprets it as a string "
|
138 |
+
"instead of a bool.")
|
139 |
+
raise ValueError(msg)
|
140 |
+
|
141 |
+
distribution_strategy = distribution_strategy.lower()
|
142 |
+
if distribution_strategy == "off":
|
143 |
+
if num_gpus > 1:
|
144 |
+
raise ValueError(f"When {num_gpus} GPUs are specified, "
|
145 |
+
"distribution_strategy flag cannot be set to `off`.")
|
146 |
+
# Return the default distribution strategy.
|
147 |
+
return tf.distribute.get_strategy()
|
148 |
+
|
149 |
+
if distribution_strategy == "tpu":
|
150 |
+
# When tpu_address is an empty string, we communicate with local TPUs.
|
151 |
+
cluster_resolver = tpu_initialize(tpu_address)
|
152 |
+
return tf.distribute.TPUStrategy(cluster_resolver)
|
153 |
+
|
154 |
+
if distribution_strategy == "multi_worker_mirrored":
|
155 |
+
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
|
156 |
+
communication=_collective_communication(all_reduce_alg))
|
157 |
+
|
158 |
+
if distribution_strategy == "one_device":
|
159 |
+
if num_gpus == 0:
|
160 |
+
return tf.distribute.OneDeviceStrategy("device:CPU:0")
|
161 |
+
if num_gpus > 1:
|
162 |
+
raise ValueError("`OneDeviceStrategy` can not be used for more than "
|
163 |
+
"one device.")
|
164 |
+
return tf.distribute.OneDeviceStrategy("device:GPU:0")
|
165 |
+
|
166 |
+
if distribution_strategy == "mirrored":
|
167 |
+
if num_gpus == 0:
|
168 |
+
devices = ["device:CPU:0"]
|
169 |
+
else:
|
170 |
+
devices = ["device:GPU:%d" % i for i in range(num_gpus)]
|
171 |
+
return tf.distribute.MirroredStrategy(
|
172 |
+
devices=devices,
|
173 |
+
cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
|
174 |
+
|
175 |
+
if distribution_strategy == "parameter_server":
|
176 |
+
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
|
177 |
+
return tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
|
178 |
+
|
179 |
+
raise ValueError("Unrecognized Distribution Strategy: %r" %
|
180 |
+
distribution_strategy)
|
181 |
+
|
182 |
+
|
183 |
+
def configure_cluster(worker_hosts=None, task_index=-1):
|
184 |
+
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
worker_hosts: comma-separated list of worker ip:port pairs.
|
188 |
+
task_index: index of the worker.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
Number of workers in the cluster.
|
192 |
+
"""
|
193 |
+
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
|
194 |
+
if tf_config:
|
195 |
+
num_workers = (
|
196 |
+
len(tf_config["cluster"].get("chief", [])) +
|
197 |
+
len(tf_config["cluster"].get("worker", [])))
|
198 |
+
elif worker_hosts:
|
199 |
+
workers = worker_hosts.split(",")
|
200 |
+
num_workers = len(workers)
|
201 |
+
if num_workers > 1 and task_index < 0:
|
202 |
+
raise ValueError("Must specify task_index when number of workers > 1")
|
203 |
+
task_index = 0 if num_workers == 1 else task_index
|
204 |
+
os.environ["TF_CONFIG"] = json.dumps({
|
205 |
+
"cluster": {
|
206 |
+
"worker": workers
|
207 |
+
},
|
208 |
+
"task": {
|
209 |
+
"type": "worker",
|
210 |
+
"index": task_index
|
211 |
+
}
|
212 |
+
})
|
213 |
+
else:
|
214 |
+
num_workers = 1
|
215 |
+
return num_workers
|
216 |
+
|
217 |
+
|
218 |
+
def get_strategy_scope(strategy):
|
219 |
+
if strategy:
|
220 |
+
strategy_scope = strategy.scope()
|
221 |
+
else:
|
222 |
+
strategy_scope = DummyContextManager()
|
223 |
+
|
224 |
+
return strategy_scope
|
225 |
+
|
226 |
+
|
227 |
+
class DummyContextManager(object):
|
228 |
+
|
229 |
+
def __enter__(self):
|
230 |
+
pass
|
231 |
+
|
232 |
+
def __exit__(self, *args):
|
233 |
+
pass
|
Tensorflow/models/official/common/distribute_utils_test.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for distribution util functions."""
|
16 |
+
|
17 |
+
import sys
|
18 |
+
import tensorflow as tf, tf_keras
|
19 |
+
|
20 |
+
from official.common import distribute_utils
|
21 |
+
|
22 |
+
TPU_TEST = 'test_tpu' in sys.argv[0]
|
23 |
+
|
24 |
+
|
25 |
+
class DistributeUtilsTest(tf.test.TestCase):
|
26 |
+
"""Tests for distribute util functions."""
|
27 |
+
|
28 |
+
def test_invalid_args(self):
|
29 |
+
with self.assertRaisesRegex(ValueError, '`num_gpus` can not be negative.'):
|
30 |
+
_ = distribute_utils.get_distribution_strategy(num_gpus=-1)
|
31 |
+
|
32 |
+
with self.assertRaisesRegex(ValueError,
|
33 |
+
'.*If you meant to pass the string .*'):
|
34 |
+
_ = distribute_utils.get_distribution_strategy(
|
35 |
+
distribution_strategy=False, num_gpus=0)
|
36 |
+
with self.assertRaisesRegex(ValueError, 'When 2 GPUs are specified.*'):
|
37 |
+
_ = distribute_utils.get_distribution_strategy(
|
38 |
+
distribution_strategy='off', num_gpus=2)
|
39 |
+
with self.assertRaisesRegex(ValueError,
|
40 |
+
'`OneDeviceStrategy` can not be used.*'):
|
41 |
+
_ = distribute_utils.get_distribution_strategy(
|
42 |
+
distribution_strategy='one_device', num_gpus=2)
|
43 |
+
|
44 |
+
def test_one_device_strategy_cpu(self):
|
45 |
+
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
|
46 |
+
self.assertEquals(ds.num_replicas_in_sync, 1)
|
47 |
+
self.assertEquals(len(ds.extended.worker_devices), 1)
|
48 |
+
self.assertIn('CPU', ds.extended.worker_devices[0])
|
49 |
+
|
50 |
+
def test_one_device_strategy_gpu(self):
|
51 |
+
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=1)
|
52 |
+
self.assertEquals(ds.num_replicas_in_sync, 1)
|
53 |
+
self.assertEquals(len(ds.extended.worker_devices), 1)
|
54 |
+
self.assertIn('GPU', ds.extended.worker_devices[0])
|
55 |
+
|
56 |
+
def test_mirrored_strategy(self):
|
57 |
+
# CPU only.
|
58 |
+
_ = distribute_utils.get_distribution_strategy(num_gpus=0)
|
59 |
+
# 5 GPUs.
|
60 |
+
ds = distribute_utils.get_distribution_strategy(num_gpus=5)
|
61 |
+
self.assertEquals(ds.num_replicas_in_sync, 5)
|
62 |
+
self.assertEquals(len(ds.extended.worker_devices), 5)
|
63 |
+
for device in ds.extended.worker_devices:
|
64 |
+
self.assertIn('GPU', device)
|
65 |
+
|
66 |
+
_ = distribute_utils.get_distribution_strategy(
|
67 |
+
distribution_strategy='mirrored',
|
68 |
+
num_gpus=2,
|
69 |
+
all_reduce_alg='nccl',
|
70 |
+
num_packs=2)
|
71 |
+
with self.assertRaisesRegex(
|
72 |
+
ValueError,
|
73 |
+
'When used with `mirrored`, valid values for all_reduce_alg are.*'):
|
74 |
+
_ = distribute_utils.get_distribution_strategy(
|
75 |
+
distribution_strategy='mirrored',
|
76 |
+
num_gpus=2,
|
77 |
+
all_reduce_alg='dummy',
|
78 |
+
num_packs=2)
|
79 |
+
|
80 |
+
def test_mwms(self):
|
81 |
+
distribute_utils.configure_cluster(worker_hosts=None, task_index=-1)
|
82 |
+
ds = distribute_utils.get_distribution_strategy(
|
83 |
+
'multi_worker_mirrored', all_reduce_alg='nccl')
|
84 |
+
self.assertIsInstance(
|
85 |
+
ds, tf.distribute.experimental.MultiWorkerMirroredStrategy)
|
86 |
+
|
87 |
+
with self.assertRaisesRegex(
|
88 |
+
ValueError,
|
89 |
+
'When used with `multi_worker_mirrored`, valid values.*'):
|
90 |
+
_ = distribute_utils.get_distribution_strategy(
|
91 |
+
'multi_worker_mirrored', all_reduce_alg='dummy')
|
92 |
+
|
93 |
+
def test_no_strategy(self):
|
94 |
+
ds = distribute_utils.get_distribution_strategy('off')
|
95 |
+
self.assertIs(ds, tf.distribute.get_strategy())
|
96 |
+
|
97 |
+
def test_tpu_strategy(self):
|
98 |
+
if not TPU_TEST:
|
99 |
+
self.skipTest('Only Cloud TPU VM instances can have local TPUs.')
|
100 |
+
with self.assertRaises(ValueError):
|
101 |
+
_ = distribute_utils.get_distribution_strategy('tpu')
|
102 |
+
|
103 |
+
ds = distribute_utils.get_distribution_strategy('tpu', tpu_address='local')
|
104 |
+
self.assertIsInstance(
|
105 |
+
ds, tf.distribute.TPUStrategy)
|
106 |
+
|
107 |
+
def test_invalid_strategy(self):
|
108 |
+
with self.assertRaisesRegexp(
|
109 |
+
ValueError,
|
110 |
+
'distribution_strategy must be a string but got: False. If'):
|
111 |
+
distribute_utils.get_distribution_strategy(False)
|
112 |
+
with self.assertRaisesRegexp(
|
113 |
+
ValueError, 'distribution_strategy must be a string but got: 1'):
|
114 |
+
distribute_utils.get_distribution_strategy(1)
|
115 |
+
|
116 |
+
def test_get_strategy_scope(self):
|
117 |
+
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
|
118 |
+
with distribute_utils.get_strategy_scope(ds):
|
119 |
+
self.assertIs(tf.distribute.get_strategy(), ds)
|
120 |
+
with distribute_utils.get_strategy_scope(None):
|
121 |
+
self.assertIsNot(tf.distribute.get_strategy(), ds)
|
122 |
+
|
123 |
+
if __name__ == '__main__':
|
124 |
+
tf.test.main()
|
Tensorflow/models/official/common/flags.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""The central place to define flags."""
|
16 |
+
|
17 |
+
from absl import flags
|
18 |
+
|
19 |
+
|
20 |
+
def define_flags():
|
21 |
+
"""Defines flags.
|
22 |
+
|
23 |
+
All flags are defined as optional, but in practice most models use some of
|
24 |
+
these flags and so mark_flags_as_required() should be called after calling
|
25 |
+
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
|
26 |
+
For example:
|
27 |
+
|
28 |
+
```
|
29 |
+
from absl import flags
|
30 |
+
from official.common import flags as tfm_flags # pylint: disable=line-too-long
|
31 |
+
...
|
32 |
+
tfm_flags.define_flags()
|
33 |
+
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
|
34 |
+
```
|
35 |
+
|
36 |
+
The reason all flags are optional is because unit tests often do not set or
|
37 |
+
use any of the flags.
|
38 |
+
"""
|
39 |
+
flags.DEFINE_string(
|
40 |
+
'experiment', default=None, help=
|
41 |
+
'The experiment type registered, specifying an ExperimentConfig.')
|
42 |
+
|
43 |
+
flags.DEFINE_enum(
|
44 |
+
'mode',
|
45 |
+
default=None,
|
46 |
+
enum_values=[
|
47 |
+
'train', 'eval', 'train_and_eval', 'continuous_eval',
|
48 |
+
'continuous_train_and_eval', 'train_and_validate',
|
49 |
+
'train_and_post_eval'
|
50 |
+
],
|
51 |
+
help='Mode to run: `train`, `eval`, `train_and_eval`, '
|
52 |
+
'`continuous_eval`, `continuous_train_and_eval` and '
|
53 |
+
'`train_and_validate` (which is not implemented in '
|
54 |
+
'the open source version).')
|
55 |
+
|
56 |
+
flags.DEFINE_string(
|
57 |
+
'model_dir',
|
58 |
+
default=None,
|
59 |
+
help='The directory where the model and training/evaluation summaries'
|
60 |
+
'are stored.')
|
61 |
+
|
62 |
+
flags.DEFINE_multi_string(
|
63 |
+
'config_file',
|
64 |
+
default=None,
|
65 |
+
help='YAML/JSON files which specifies overrides. The override order '
|
66 |
+
'follows the order of args. Note that each file '
|
67 |
+
'can be used as an override template to override the default parameters '
|
68 |
+
'specified in Python. If the same parameter is specified in both '
|
69 |
+
'`--config_file` and `--params_override`, `config_file` will be used '
|
70 |
+
'first, followed by params_override.')
|
71 |
+
|
72 |
+
flags.DEFINE_string(
|
73 |
+
'params_override',
|
74 |
+
default=None,
|
75 |
+
help='a YAML/JSON string or a YAML file which specifies additional '
|
76 |
+
'overrides over the default parameters and those specified in '
|
77 |
+
'`--config_file`. Note that this is supposed to be used only to override '
|
78 |
+
'the model parameters, but not the parameters like TPU specific flags. '
|
79 |
+
'One canonical use case of `--config_file` and `--params_override` is '
|
80 |
+
'users first define a template config file using `--config_file`, then '
|
81 |
+
'use `--params_override` to adjust the minimal set of tuning parameters, '
|
82 |
+
'for example setting up different `train_batch_size`. The final override '
|
83 |
+
'order of parameters: default_model_params --> params from config_file '
|
84 |
+
'--> params in params_override. See also the help message of '
|
85 |
+
'`--config_file`.')
|
86 |
+
|
87 |
+
# The libraries rely on gin often make mistakes that include flags inside
|
88 |
+
# the library files which causes conflicts.
|
89 |
+
try:
|
90 |
+
flags.DEFINE_multi_string(
|
91 |
+
'gin_file', default=None, help='List of paths to the config files.')
|
92 |
+
except flags.DuplicateFlagError:
|
93 |
+
pass
|
94 |
+
|
95 |
+
try:
|
96 |
+
flags.DEFINE_multi_string(
|
97 |
+
'gin_params',
|
98 |
+
default=None,
|
99 |
+
help='Newline separated list of Gin parameter bindings.')
|
100 |
+
except flags.DuplicateFlagError:
|
101 |
+
pass
|
102 |
+
|
103 |
+
flags.DEFINE_string(
|
104 |
+
'tpu',
|
105 |
+
default=None,
|
106 |
+
help='The Cloud TPU to use for training. This should be either the name '
|
107 |
+
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
|
108 |
+
'url.')
|
109 |
+
|
110 |
+
flags.DEFINE_string(
|
111 |
+
'tf_data_service', default=None, help='The tf.data service address')
|
112 |
+
|
113 |
+
flags.DEFINE_string(
|
114 |
+
'tpu_platform', default=None, help='TPU platform type.')
|
Tensorflow/models/official/common/registry_imports.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""All necessary imports for registration."""
|
16 |
+
# pylint: disable=unused-import
|
17 |
+
from official import vision
|
18 |
+
from official.nlp import tasks
|
19 |
+
from official.nlp.configs import experiment_configs
|
20 |
+
from official.utils.testing import mock_task
|
Tensorflow/models/official/common/streamz_counters.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Global streamz counters."""
|
16 |
+
|
17 |
+
from tensorflow.python.eager import monitoring
|
18 |
+
|
19 |
+
|
20 |
+
progressive_policy_creation_counter = monitoring.Counter(
|
21 |
+
"/tensorflow/training/fast_training/progressive_policy_creation",
|
22 |
+
"Counter for the number of ProgressivePolicy creations.")
|
23 |
+
|
24 |
+
|
25 |
+
stack_vars_to_vars_call_counter = monitoring.Counter(
|
26 |
+
"/tensorflow/training/fast_training/tf_vars_to_vars",
|
27 |
+
"Counter for the number of low-level stacking API calls.")
|