Kasamuday commited on
Commit
f3507ef
1 Parent(s): 8c9b51d

Upload 4260 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +11 -0
  2. models/.github/ISSUE_TEMPLATE/00-official-bug-report-issue.md +59 -0
  3. models/.github/ISSUE_TEMPLATE/10-official-documentation-issue.md +20 -0
  4. models/.github/ISSUE_TEMPLATE/20-official-feature-request-issue.md +26 -0
  5. models/.github/ISSUE_TEMPLATE/30-research-bug-report-issue.md +58 -0
  6. models/.github/ISSUE_TEMPLATE/40-research-documentation-issue.md +20 -0
  7. models/.github/ISSUE_TEMPLATE/50-research-feature-request-issue.md +26 -0
  8. models/.github/ISSUE_TEMPLATE/60-questions-help-issue.md +14 -0
  9. models/.github/ISSUE_TEMPLATE/config.yml +1 -0
  10. models/.github/PULL_REQUEST_TEMPLATE.md +41 -0
  11. models/.github/README_TEMPLATE.md +124 -0
  12. models/.github/bot_config.yml +24 -0
  13. models/.github/scripts/pylint.sh +178 -0
  14. models/.github/workflows/ci.yml +32 -0
  15. models/.github/workflows/stale.yaml +67 -0
  16. models/.gitignore +98 -0
  17. models/AUTHORS +10 -0
  18. models/CODEOWNERS +29 -0
  19. models/CODE_OF_CONDUCT.md +79 -0
  20. models/CONTRIBUTING.md +10 -0
  21. models/ISSUES.md +24 -0
  22. models/LICENSE +212 -0
  23. models/README.md +130 -0
  24. models/SECURITY.md +251 -0
  25. models/community/README.md +60 -0
  26. models/docs/README.md +17 -0
  27. models/docs/index.md +140 -0
  28. models/docs/nlp/_guide_toc.yaml +9 -0
  29. models/docs/nlp/customize_encoder.ipynb +596 -0
  30. models/docs/nlp/decoding_api.ipynb +482 -0
  31. models/docs/nlp/fine_tune_bert.ipynb +1550 -0
  32. models/docs/nlp/index.ipynb +545 -0
  33. models/docs/nlp/load_lm_ckpts.ipynb +692 -0
  34. models/docs/orbit/index.ipynb +898 -0
  35. models/docs/vision/_toc.yaml +9 -0
  36. models/docs/vision/image_classification.ipynb +692 -0
  37. models/docs/vision/instance_segmentation.ipynb +1138 -0
  38. models/docs/vision/object_detection.ipynb +902 -0
  39. models/docs/vision/semantic_segmentation.ipynb +785 -0
  40. models/official/README-TPU.md +32 -0
  41. models/official/README.md +166 -0
  42. models/official/__init__.py +14 -0
  43. models/official/common/__init__.py +15 -0
  44. models/official/common/dataset_fn.py +44 -0
  45. models/official/common/distribute_utils.py +233 -0
  46. models/official/common/distribute_utils_test.py +124 -0
  47. models/official/common/flags.py +114 -0
  48. models/official/common/registry_imports.py +20 -0
  49. models/official/common/streamz_counters.py +27 -0
  50. models/official/core/__init__.py +31 -0
.gitattributes CHANGED
@@ -54,3 +54,14 @@ Tensorflow/workspace/models/my_ssd_mobnet/ckpt-5.data-00000-of-00001 filter=lfs
54
  Tensorflow/workspace/models/my_ssd_mobnet/ckpt-6.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
55
  Tensorflow/workspace/models/my_ssd_mobnet/ckpt-7.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
56
  Tensorflow/workspace/models/my_ssd_mobnet/ckpt-8.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
54
  Tensorflow/workspace/models/my_ssd_mobnet/ckpt-6.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
55
  Tensorflow/workspace/models/my_ssd_mobnet/ckpt-7.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
56
  Tensorflow/workspace/models/my_ssd_mobnet/ckpt-8.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
57
+ models/official/projects/waste_identification_ml/pre_processing/config/sample_images/ffdeb4cd-43ba-4ca0-a1e6-aa5824005f44.jpg filter=lfs diff=lfs merge=lfs -text
58
+ models/official/projects/waste_identification_ml/pre_processing/config/sample_images/image_2.png filter=lfs diff=lfs merge=lfs -text
59
+ models/official/projects/waste_identification_ml/pre_processing/config/sample_images/image_3.jpg filter=lfs diff=lfs merge=lfs -text
60
+ models/official/projects/waste_identification_ml/pre_processing/config/sample_images/image_4.png filter=lfs diff=lfs merge=lfs -text
61
+ models/research/deeplab/testing/pascal_voc_seg/val-00000-of-00001.tfrecord filter=lfs diff=lfs merge=lfs -text
62
+ models/research/dist/object_detection-0.1-py3.7.egg filter=lfs diff=lfs merge=lfs -text
63
+ models/research/dist/object_detection-0.1-py3.9.egg filter=lfs diff=lfs merge=lfs -text
64
+ models/research/lfads/synth_data/trained_itb/model-65000.meta filter=lfs diff=lfs merge=lfs -text
65
+ models/research/object_detection/dataset_tools/densepose/UV_symmetry_transforms.mat filter=lfs diff=lfs merge=lfs -text
66
+ models/research/object_detection/g3doc/img/kites_with_segment_overlay.png filter=lfs diff=lfs merge=lfs -text
67
+ models/research/object_detection/test_images/image2.jpg filter=lfs diff=lfs merge=lfs -text
models/.github/ISSUE_TEMPLATE/00-official-bug-report-issue.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "[Official Model] Bug Report"
3
+ about: Use this template for reporting a bug for the “official” directory
4
+ labels: type:bug,models:official
5
+
6
+ ---
7
+
8
+ # Prerequisites
9
+
10
+ Please answer the following questions for yourself before submitting an issue.
11
+
12
+ - [ ] I am using the latest TensorFlow Model Garden release and TensorFlow 2.
13
+ - [ ] I am reporting the issue to the correct repository. (Model Garden official or research directory)
14
+ - [ ] I checked to make sure that this issue has not been filed already.
15
+
16
+ ## 1. The entire URL of the file you are using
17
+
18
+ https://github.com/tensorflow/models/tree/master/official/...
19
+
20
+ ## 2. Describe the bug
21
+
22
+ A clear and concise description of what the bug is.
23
+
24
+ ## 3. Steps to reproduce
25
+
26
+ Steps to reproduce the behavior.
27
+
28
+ ## 4. Expected behavior
29
+
30
+ A clear and concise description of what you expected to happen.
31
+
32
+ ## 5. Additional context
33
+
34
+ Include any logs that would be helpful to diagnose the problem.
35
+
36
+ ## 6. System information
37
+
38
+ - OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
39
+ - Mobile device name if the issue happens on a mobile device:
40
+ - TensorFlow installed from (source or binary):
41
+ - TensorFlow version (use command below):
42
+ - Python version:
43
+ - Bazel version (if compiling from source):
44
+ - GCC/Compiler version (if compiling from source):
45
+ - CUDA/cuDNN version:
46
+ - GPU model and memory:
47
+
48
+ <!--
49
+ Collect system information using our environment capture script.
50
+ https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
51
+
52
+ You can also obtain the TensorFlow version with:
53
+
54
+ 1. TensorFlow 1.0
55
+ `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
56
+
57
+ 2. TensorFlow 2.0
58
+ `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
59
+ -->
models/.github/ISSUE_TEMPLATE/10-official-documentation-issue.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "[Official Model] Documentation Issue"
3
+ about: Use this template for reporting a documentation issue for the “official” directory
4
+ labels: type:docs,models:official
5
+
6
+ ---
7
+
8
+ # Prerequisites
9
+
10
+ Please answer the following question for yourself before submitting an issue.
11
+
12
+ - [ ] I checked to make sure that this issue has not been filed already.
13
+
14
+ ## 1. The entire URL of the documentation with the issue
15
+
16
+ https://github.com/tensorflow/models/tree/master/official/...
17
+
18
+ ## 2. Describe the issue
19
+
20
+ A clear and concise description of what needs to be changed.
models/.github/ISSUE_TEMPLATE/20-official-feature-request-issue.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "[Official Model] Feature request"
3
+ about: Use this template for raising a feature request for the “official” directory
4
+ labels: type:feature,models:official
5
+
6
+ ---
7
+
8
+ # Prerequisites
9
+
10
+ Please answer the following question for yourself before submitting an issue.
11
+
12
+ - [ ] I checked to make sure that this feature has not been requested already.
13
+
14
+ ## 1. The entire URL of the file you are using
15
+
16
+ https://github.com/tensorflow/models/tree/master/official/...
17
+
18
+ ## 2. Describe the feature you request
19
+
20
+ A clear and concise description of what you want to happen.
21
+
22
+ ## 3. Additional context
23
+
24
+ Add any other context about the feature request here.
25
+
26
+ ## 4. Are you willing to contribute it? (Yes or No)
models/.github/ISSUE_TEMPLATE/30-research-bug-report-issue.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "[Research Model] Bug Report"
3
+ about: Use this template for reporting a bug for the “research” directory
4
+ labels: type:bug,models:research
5
+
6
+ ---
7
+ # Prerequisites
8
+
9
+ Please answer the following questions for yourself before submitting an issue.
10
+
11
+ - [ ] I am using the latest TensorFlow Model Garden release and TensorFlow 2.
12
+ - [ ] I am reporting the issue to the correct repository. (Model Garden official or research directory)
13
+ - [ ] I checked to make sure that this issue has not already been filed.
14
+
15
+ ## 1. The entire URL of the file you are using
16
+
17
+ https://github.com/tensorflow/models/tree/master/research/...
18
+
19
+ ## 2. Describe the bug
20
+
21
+ A clear and concise description of what the bug is.
22
+
23
+ ## 3. Steps to reproduce
24
+
25
+ Steps to reproduce the behavior.
26
+
27
+ ## 4. Expected behavior
28
+
29
+ A clear and concise description of what you expected to happen.
30
+
31
+ ## 5. Additional context
32
+
33
+ Include any logs that would be helpful to diagnose the problem.
34
+
35
+ ## 6. System information
36
+
37
+ - OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
38
+ - Mobile device name if the issue happens on a mobile device:
39
+ - TensorFlow installed from (source or binary):
40
+ - TensorFlow version (use command below):
41
+ - Python version:
42
+ - Bazel version (if compiling from source):
43
+ - GCC/Compiler version (if compiling from source):
44
+ - CUDA/cuDNN version:
45
+ - GPU model and memory:
46
+
47
+ <!--
48
+ Collect system information using our environment capture script.
49
+ https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
50
+
51
+ You can also obtain the TensorFlow version with:
52
+
53
+ 1. TensorFlow 1.0
54
+ `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
55
+
56
+ 2. TensorFlow 2.0
57
+ `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
58
+ -->
models/.github/ISSUE_TEMPLATE/40-research-documentation-issue.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "[Research Model] Documentation Issue"
3
+ about: Use this template for reporting a documentation issue for the “research” directory
4
+ labels: type:docs,models:research
5
+
6
+ ---
7
+
8
+ # Prerequisites
9
+
10
+ Please answer the following question for yourself before submitting an issue.
11
+
12
+ - [ ] I checked to make sure that this issue has not been filed already.
13
+
14
+ ## 1. The entire URL of the documentation with the issue
15
+
16
+ https://github.com/tensorflow/models/tree/master/research/...
17
+
18
+ ## 2. Describe the issue
19
+
20
+ A clear and concise description of what needs to be changed.
models/.github/ISSUE_TEMPLATE/50-research-feature-request-issue.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "[Research Model] Feature Request"
3
+ about: Use this template for raising a feature request for the “research” directory
4
+ labels: type:feature,models:research
5
+
6
+ ---
7
+
8
+ # Prerequisites
9
+
10
+ Please answer the following question for yourself before submitting an issue.
11
+
12
+ - [ ] I checked to make sure that this feature has not been requested already.
13
+
14
+ ## 1. The entire URL of the file you are using
15
+
16
+ https://github.com/tensorflow/models/tree/master/research/...
17
+
18
+ ## 2. Describe the feature you request
19
+
20
+ A clear and concise description of what you want to happen.
21
+
22
+ ## 3. Additional context
23
+
24
+ Add any other context about the feature request here.
25
+
26
+ ## 4. Are you willing to contribute it? (Yes or No)
models/.github/ISSUE_TEMPLATE/60-questions-help-issue.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Questions and Help
3
+ about: Use this template for Questions and Help.
4
+ labels: type:support
5
+
6
+ ---
7
+ <!--
8
+ As per our GitHub Policy (https://github.com/tensorflow/models/blob/master/ISSUES.md), we only address code bugs, documentation issues, and feature requests on GitHub.
9
+
10
+ We will automatically close questions and help related issues.
11
+
12
+ Please go to Stack Overflow (http://stackoverflow.com/questions/tagged/tensorflow-model-garden) for questions and help.
13
+
14
+ -->
models/.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1 @@
 
 
1
+ blank_issues_enabled: false
models/.github/PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description
2
+
3
+ > :memo: Please include a summary of the change.
4
+ >
5
+ > * Please also include relevant motivation and context.
6
+ > * List any dependencies that are required for this change.
7
+
8
+ ## Type of change
9
+
10
+ For a new feature or function, please create an issue first to discuss it
11
+ with us before submitting a pull request.
12
+
13
+ Note: Please delete options that are not relevant.
14
+
15
+ - [ ] Bug fix (non-breaking change which fixes an issue)
16
+ - [ ] Documentation update
17
+ - [ ] TensorFlow 2 migration
18
+ - [ ] New feature (non-breaking change which adds functionality)
19
+ - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
20
+ - [ ] A new research paper code implementation
21
+ - [ ] Other (Specify)
22
+
23
+ ## Tests
24
+
25
+ > :memo: Please describe the tests that you ran to verify your changes.
26
+ >
27
+ > * Provide instructions so we can reproduce.
28
+ > * Please also list any relevant details for your test configuration.
29
+
30
+ **Test Configuration**:
31
+
32
+ ## Checklist
33
+
34
+ - [ ] I have signed the [Contributor License Agreement](https://github.com/tensorflow/models/wiki/Contributor-License-Agreements).
35
+ - [ ] I have read [guidelines for pull request](https://github.com/tensorflow/models/wiki/Submitting-a-pull-request).
36
+ - [ ] My code follows the [coding guidelines](https://github.com/tensorflow/models/wiki/Coding-guidelines).
37
+ - [ ] I have performed a self [code review](https://github.com/tensorflow/models/wiki/Code-review) of my own code.
38
+ - [ ] I have commented my code, particularly in hard-to-understand areas.
39
+ - [ ] I have made corresponding changes to the documentation.
40
+ - [ ] My changes generate no new warnings.
41
+ - [ ] I have added tests that prove my fix is effective or that my feature works.
models/.github/README_TEMPLATE.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ > :memo: A README.md template for releasing a paper code implementation to a GitHub repository.
2
+ >
3
+ > * Template version: 1.0.2020.170
4
+ > * Please modify sections depending on needs.
5
+
6
+ # Model name, Paper title, or Project Name
7
+
8
+ > :memo: Add a badge for the ArXiv identifier of your paper (arXiv:YYMM.NNNNN)
9
+
10
+ [![Paper](http://img.shields.io/badge/Paper-arXiv.YYMM.NNNNN-B3181B?logo=arXiv)](https://arxiv.org/abs/...)
11
+
12
+ This repository is the official or unofficial implementation of the following paper.
13
+
14
+ * Paper title: [Paper Title](https://arxiv.org/abs/YYMM.NNNNN)
15
+
16
+ ## Description
17
+
18
+ > :memo: Provide description of the model.
19
+ >
20
+ > * Provide brief information of the algorithms used.
21
+ > * Provide links for demos, blog posts, etc.
22
+
23
+ ## History
24
+
25
+ > :memo: Provide a changelog.
26
+
27
+ ## Authors or Maintainers
28
+
29
+ > :memo: Provide maintainer information.
30
+
31
+ * Full name ([@GitHub username](https://github.com/username))
32
+ * Full name ([@GitHub username](https://github.com/username))
33
+
34
+ ## Table of Contents
35
+
36
+ > :memo: Provide a table of contents to help readers navigate a lengthy README document.
37
+
38
+ ## Requirements
39
+
40
+ [![TensorFlow 2.1](https://img.shields.io/badge/TensorFlow-2.1-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0)
41
+ [![Python 3.6](https://img.shields.io/badge/Python-3.6-3776AB)](https://www.python.org/downloads/release/python-360/)
42
+
43
+ > :memo: Provide details of the software required.
44
+ >
45
+ > * Add a `requirements.txt` file to the root directory for installing the necessary dependencies.
46
+ > * Describe how to install requirements using pip.
47
+ > * Alternatively, create INSTALL.md.
48
+
49
+ To install requirements:
50
+
51
+ ```setup
52
+ pip install -r requirements.txt
53
+ ```
54
+
55
+ ## Results
56
+
57
+ [![TensorFlow Hub](https://img.shields.io/badge/TF%20Hub-Models-FF6F00?logo=tensorflow)](https://tfhub.dev/...)
58
+
59
+ > :memo: Provide a table with results. (e.g., accuracy, latency)
60
+ >
61
+ > * Provide links to the pre-trained models (checkpoint, SavedModel files).
62
+ > * Publish TensorFlow SavedModel files on TensorFlow Hub (tfhub.dev) if possible.
63
+ > * Add links to [TensorBoard.dev](https://tensorboard.dev/) for visualizing metrics.
64
+ >
65
+ > An example table for image classification results
66
+ >
67
+ > ### Image Classification
68
+ >
69
+ > | Model name | Download | Top 1 Accuracy | Top 5 Accuracy |
70
+ > |------------|----------|----------------|----------------|
71
+ > | Model name | [Checkpoint](https://drive.google.com/...), [SavedModel](https://tfhub.dev/...) | xx% | xx% |
72
+
73
+ ## Dataset
74
+
75
+ > :memo: Provide information of the dataset used.
76
+
77
+ ## Training
78
+
79
+ > :memo: Provide training information.
80
+ >
81
+ > * Provide details for preprocessing, hyperparameters, random seeds, and environment.
82
+ > * Provide a command line example for training.
83
+
84
+ Please run this command line for training.
85
+
86
+ ```shell
87
+ python3 ...
88
+ ```
89
+
90
+ ## Evaluation
91
+
92
+ > :memo: Provide an evaluation script with details of how to reproduce results.
93
+ >
94
+ > * Describe data preprocessing / postprocessing steps.
95
+ > * Provide a command line example for evaluation.
96
+
97
+ Please run this command line for evaluation.
98
+
99
+ ```shell
100
+ python3 ...
101
+ ```
102
+
103
+ ## References
104
+
105
+ > :memo: Provide links to references.
106
+
107
+ ## License
108
+
109
+ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
110
+
111
+ > :memo: Place your license text in a file named LICENSE in the root of the repository.
112
+ >
113
+ > * Include information about your license.
114
+ > * Reference: [Adding a license to a repository](https://help.github.com/en/github/building-a-strong-community/adding-a-license-to-a-repository)
115
+
116
+ This project is licensed under the terms of the **Apache License 2.0**.
117
+
118
+ ## Citation
119
+
120
+ > :memo: Make your repository citable.
121
+ >
122
+ > * Reference: [Making Your Code Citable](https://guides.github.com/activities/citable-code/)
123
+
124
+ If you want to cite this repository in your research paper, please use the following information.
models/.github/bot_config.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ #
16
+ # THIS IS A GENERATED DOCKERFILE.
17
+ #
18
+ # This file was assembled from multiple pieces, whose use is documented
19
+ # throughout. Please refer to the TensorFlow dockerfiles documentation
20
+ # for more information.
21
+
22
+ # A list of assignees
23
+ assignees:
24
+ - laxmareddyp
models/.github/scripts/pylint.sh ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ #
17
+ # Pylint wrapper extracted from main TensorFlow, sharing same exceptions.
18
+ # Specify --incremental to only check files touched since last commit on master,
19
+ # otherwise will recursively check current directory (full repo takes long!).
20
+
21
+ set -euo pipefail
22
+
23
+ # Download latest configs from main TensorFlow repo.
24
+ wget -q -O /tmp/pylintrc https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/tools/ci_build/pylintrc
25
+
26
+ SCRIPT_DIR=/tmp
27
+
28
+ num_cpus() {
29
+ # Get the number of CPUs
30
+ if [[ -f /proc/cpuinfo ]]; then
31
+ N_CPUS=$(grep -c ^processor /proc/cpuinfo)
32
+ else
33
+ # Fallback method
34
+ N_CPUS=`getconf _NPROCESSORS_ONLN`
35
+ fi
36
+ if [[ -z ${N_CPUS} ]]; then
37
+ die "ERROR: Unable to determine the number of CPUs"
38
+ fi
39
+
40
+ echo ${N_CPUS}
41
+ }
42
+
43
+ get_changed_files_in_last_non_merge_git_commit() {
44
+ git diff --name-only $(git merge-base master $(git branch --show-current))
45
+ }
46
+
47
+ # List Python files changed in the last non-merge git commit that still exist,
48
+ # i.e., not removed.
49
+ # Usage: get_py_files_to_check [--incremental]
50
+ get_py_files_to_check() {
51
+ if [[ "$1" == "--incremental" ]]; then
52
+ CHANGED_PY_FILES=$(get_changed_files_in_last_non_merge_git_commit | \
53
+ grep '.*\.py$')
54
+
55
+ # Do not include files removed in the last non-merge commit.
56
+ PY_FILES=""
57
+ for PY_FILE in ${CHANGED_PY_FILES}; do
58
+ if [[ -f "${PY_FILE}" ]]; then
59
+ PY_FILES="${PY_FILES} ${PY_FILE}"
60
+ fi
61
+ done
62
+
63
+ echo "${PY_FILES}"
64
+ else
65
+ find . -name '*.py'
66
+ fi
67
+ }
68
+
69
+ do_pylint() {
70
+ if [[ $# == 1 ]] && [[ "$1" == "--incremental" ]]; then
71
+ PYTHON_SRC_FILES=$(get_py_files_to_check --incremental)
72
+
73
+ if [[ -z "${PYTHON_SRC_FILES}" ]]; then
74
+ echo "do_pylint will NOT run due to --incremental flag and due to the "\
75
+ "absence of Python code changes in the last commit."
76
+ return 0
77
+ fi
78
+ elif [[ $# != 0 ]]; then
79
+ echo "Invalid syntax for invoking do_pylint"
80
+ echo "Usage: do_pylint [--incremental]"
81
+ return 1
82
+ else
83
+ PYTHON_SRC_FILES=$(get_py_files_to_check)
84
+ fi
85
+
86
+ # Something happened. TF no longer has Python code if this branch is taken
87
+ if [[ -z ${PYTHON_SRC_FILES} ]]; then
88
+ echo "do_pylint found no Python files to check. Returning."
89
+ return 0
90
+ fi
91
+
92
+ # Now that we know we have to do work, check if `pylint` is installed
93
+ PYLINT_BIN="python3.8 -m pylint"
94
+
95
+ echo ""
96
+ echo "check whether pylint is available or not."
97
+ echo ""
98
+ ${PYLINT_BIN} --version
99
+ if [[ $? -eq 0 ]]
100
+ then
101
+ echo ""
102
+ echo "pylint available, proceeding with pylint sanity check."
103
+ echo ""
104
+ else
105
+ echo ""
106
+ echo "pylint not available."
107
+ echo ""
108
+ return 1
109
+ fi
110
+
111
+ # Configure pylint using the following file
112
+ PYLINTRC_FILE="${SCRIPT_DIR}/pylintrc"
113
+
114
+ if [[ ! -f "${PYLINTRC_FILE}" ]]; then
115
+ die "ERROR: Cannot find pylint rc file at ${PYLINTRC_FILE}"
116
+ fi
117
+
118
+ # Run pylint in parallel, after some disk setup
119
+ NUM_SRC_FILES=$(echo ${PYTHON_SRC_FILES} | wc -w)
120
+ NUM_CPUS=$(num_cpus)
121
+
122
+ echo "Running pylint on ${NUM_SRC_FILES} files with ${NUM_CPUS} "\
123
+ "parallel jobs..."
124
+ echo ""
125
+
126
+ PYLINT_START_TIME=$(date +'%s')
127
+ OUTPUT_FILE="$(mktemp)_pylint_output.log"
128
+ ERRORS_FILE="$(mktemp)_pylint_errors.log"
129
+
130
+ rm -rf ${OUTPUT_FILE}
131
+ rm -rf ${ERRORS_FILE}
132
+
133
+ set +e
134
+ # When running, filter to only contain the error code lines. Removes module
135
+ # header, removes lines of context that show up from some lines.
136
+ # Also, don't redirect stderr as this would hide pylint fatal errors.
137
+ ${PYLINT_BIN} --rcfile="${PYLINTRC_FILE}" --output-format=parseable \
138
+ --jobs=${NUM_CPUS} ${PYTHON_SRC_FILES} | grep '\[[CEFW]' > ${OUTPUT_FILE}
139
+ PYLINT_END_TIME=$(date +'%s')
140
+
141
+ echo ""
142
+ echo "pylint took $((PYLINT_END_TIME - PYLINT_START_TIME)) s"
143
+ echo ""
144
+
145
+ # Report only what we care about
146
+ # Ref https://pylint.readthedocs.io/en/latest/technical_reference/features.html
147
+ # E: all errors
148
+ # W0311 bad-indentation
149
+ # W0312 mixed-indentation
150
+ # C0330 bad-continuation
151
+ # C0301 line-too-long
152
+ # C0326 bad-whitespace
153
+ # W0611 unused-import
154
+ # W0622 redefined-builtin
155
+ grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301|\[C0326|\[W0611|\[W0622)' ${OUTPUT_FILE} > ${ERRORS_FILE}
156
+
157
+ # Determine counts of errors
158
+ N_FORBID_ERRORS=$(wc -l ${ERRORS_FILE} | cut -d' ' -f1)
159
+ set -e
160
+
161
+ # Now, print the errors we should fix
162
+ echo ""
163
+ if [[ ${N_FORBID_ERRORS} != 0 ]]; then
164
+ echo "Found ${N_FORBID_ERRORS} pylint errors:"
165
+ cat ${ERRORS_FILE}
166
+ fi
167
+
168
+ echo ""
169
+ if [[ ${N_FORBID_ERRORS} != 0 ]]; then
170
+ echo "FAIL: Found ${N_FORBID_ERRORS} errors"
171
+ return 1
172
+ else
173
+ echo "PASS: Found no errors"
174
+ fi
175
+ }
176
+
177
+ do_pylint "$@"
178
+
models/.github/workflows/ci.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+ on: pull_request
3
+
4
+ permissions:
5
+ contents: read
6
+
7
+ jobs:
8
+ pylint:
9
+ runs-on: ubuntu-latest
10
+
11
+ steps:
12
+ - name: Set up Python 3.8
13
+ uses: actions/setup-python@v2
14
+ with:
15
+ python-version: 3.8
16
+
17
+ - name: Install pylint 2.4.4
18
+ run: |
19
+ python -m pip install --upgrade pip
20
+ pip install pylint==2.4.4
21
+
22
+ - name: Checkout code
23
+ uses: actions/checkout@v2
24
+ with:
25
+ ref: ${{ github.event.pull_request.head.sha }}
26
+ fetch-depth: 0
27
+
28
+ - name: Fetch master for diff
29
+ run: git fetch origin master:master
30
+
31
+ - name: Run pylint script
32
+ run: bash ./.github/scripts/pylint.sh --incremental
models/.github/workflows/stale.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # This workflow alerts and then closes the stale issues/PRs after specific time
17
+ # You can adjust the behavior by modifying this file.
18
+ # For more information, see:
19
+ # https://github.com/actions/stale
20
+
21
+ name: 'Close stale issues and PRs'
22
+ "on":
23
+ schedule:
24
+ - cron: "30 1 * * *"
25
+ permissions:
26
+ contents: read
27
+ issues: write
28
+ pull-requests: write
29
+
30
+ jobs:
31
+ stale:
32
+ runs-on: ubuntu-latest
33
+ steps:
34
+ - uses: 'actions/stale@v7'
35
+ with:
36
+ #Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale
37
+ exempt-issue-labels: 'override-stale'
38
+ #Comma separated list of labels that can be assigned to PRs to exclude them from being marked as stale
39
+ exempt-pr-labels: "override-stale"
40
+ #Limit the No. of API calls in one run default value is 30.
41
+ operations-per-run: 1000
42
+ #Prevent to remove stale label when PRs or issues are updated.
43
+ remove-stale-when-updated: false
44
+ # comment on issue if not active for more then 7 days.
45
+ stale-issue-message: 'This issue has been marked stale because it has no recent activity since 7 days. It will be closed if no further activity occurs. Thank you.'
46
+ # comment on PR if not active for more then 14 days.
47
+ stale-pr-message: 'This PR has been marked stale because it has no recent activity since 14 days. It will be closed if no further activity occurs. Thank you.'
48
+ # comment on issue if stale for more then 7 days.
49
+ close-issue-message: This issue was closed due to lack of activity after being marked stale for past 7 days.
50
+ # comment on PR if stale for more then 14 days.
51
+ close-pr-message: This PR was closed due to lack of activity after being marked stale for past 14 days.
52
+ # Number of days of inactivity before an Issue Request becomes stale
53
+ days-before-issue-stale: 7
54
+ # Number of days of inactivity before a stale Issue is closed
55
+ days-before-issue-close: 7
56
+ # reason for closed the issue default value is not_planned
57
+ close-issue-reason: completed
58
+ # Number of days of inactivity before a stale PR is closed
59
+ days-before-pr-close: 14
60
+ # Number of days of inactivity before an PR Request becomes stale
61
+ days-before-pr-stale: 14
62
+ # Check for label to stale or close the issue/PR
63
+ any-of-labels: 'stat:awaiting response'
64
+ # override stale to stalled for PR
65
+ stale-pr-label: 'stale'
66
+ # override stale to stalled for Issue
67
+ stale-issue-label: "stale"
models/.gitignore ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ env/
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+
27
+ # PyInstaller
28
+ # Usually these files are written by a python script from a template
29
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
30
+ *.manifest
31
+ *.spec
32
+
33
+ # Installer logs
34
+ pip-log.txt
35
+ pip-delete-this-directory.txt
36
+
37
+ # Unit test / coverage reports
38
+ htmlcov/
39
+ .tox/
40
+ .coverage
41
+ .coverage.*
42
+ .cache
43
+ nosetests.xml
44
+ coverage.xml
45
+ *,cover
46
+ .hypothesis/
47
+
48
+ # Translations
49
+ *.mo
50
+ *.pot
51
+
52
+ # Django stuff:
53
+ *.log
54
+ local_settings.py
55
+
56
+ # Flask stuff:
57
+ instance/
58
+ .webassets-cache
59
+
60
+ # Scrapy stuff:
61
+ .scrapy
62
+
63
+ # Sphinx documentation
64
+ docs/_build/
65
+
66
+ # PyBuilder
67
+ target/
68
+
69
+ # IPython Notebook
70
+ .ipynb_checkpoints
71
+
72
+ # pyenv
73
+ .python-version
74
+
75
+ # mypy
76
+ .mypy_cache
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # dotenv
82
+ .env
83
+
84
+ # virtualenv
85
+ venv/
86
+ ENV/
87
+
88
+ # Spyder project settings
89
+ .spyderproject
90
+
91
+ # Rope project settings
92
+ .ropeproject
93
+
94
+ # PyCharm
95
+ .idea/
96
+
97
+ # For mac
98
+ .DS_Store
models/AUTHORS ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is the official list of authors for copyright purposes.
2
+ # This file is distinct from the CONTRIBUTORS files.
3
+ # See the latter for an explanation.
4
+
5
+ # Names should be added to this file as:
6
+ # Name or Organization <email address>
7
+ # The email address is not required for organizations.
8
+
9
+ Google Inc.
10
+ David Dao <daviddao@broad.mit.edu>
models/CODEOWNERS ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * @tensorflow/tf-model-garden-team
2
+ /official/ @rachellj218 @saberkun
3
+ /official/nlp/ @saberkun @lehougoogle @rachellj218
4
+ /official/recommendation/ranking/ @gagika
5
+ /official/vision/ @yeqingli @arashwan @saberkun @rachellj218
6
+ /official/vision/projects/assemblenet/ @mryoo @yeqingli
7
+ /official/vision/projects/deepmac_maskrcnn/ @vighneshbirodkar
8
+ /official/vision/projects/movinet/ @hyperparticle @yuanliangzhe @yeqingli
9
+ /official/vision/projects/simclr/ @luotigerlsx @chentingpc @saxenasaurabh
10
+ /official/vision/projects/video_ssl/ @richardaecn @yeqingli
11
+ /research/adversarial_text/ @rsepassi @a-dai
12
+ /research/attention_ocr/ @xavigibert
13
+ /research/audioset/ @plakal @dpwe
14
+ /research/autoaugment/ @barretzoph
15
+ /research/cognitive_planning/ @s-gupta
16
+ /research/cvt_text/ @clarkkev @lmthang
17
+ /research/deep_speech/ @yhliang2018
18
+ /research/deeplab/ @aquariusjay @yknzhu
19
+ /research/delf/ @andrefaraujo
20
+ /research/efficient-hrl/ @ofirnachum
21
+ /research/lfads/ @jazcollins @sussillo
22
+ /research/lstm_object_detection/ @yinxiaoli @yongzhe2160
23
+ /research/marco/ @vincentvanhoucke
24
+ /research/object_detection/ @jch1 @tombstone @pkulzc
25
+ /research/pcl_rl/ @ofirnachum
26
+ /research/rebar/ @gjtucker
27
+ /research/seq_flow_lite/ @thunderfyc @karunreddy30
28
+ /research/slim/ @sguada @marksandler2
29
+ /research/vid2depth/ @rezama
models/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TensorFlow-models Code of Conduct
2
+
3
+ In the interest of fostering an open and welcoming environment, we as
4
+ contributors and maintainers pledge to make participation in our project and our
5
+ community a harassment-free experience for everyone, regardless of age, body
6
+ size, disability, ethnicity, gender identity and expression, level of
7
+ experience, nationality, personal appearance, race, religion, or sexual identity
8
+ and orientation.
9
+
10
+ ## Our Standards
11
+
12
+ Examples of behavior that contributes to creating a positive environment include:
13
+
14
+ * Using welcoming and inclusive language.
15
+ * Being respectful of differing viewpoints and experiences.
16
+ * Gracefully accepting constructive criticism.
17
+ * Focusing on what is best for the community.
18
+ * Showing empathy towards other community members.
19
+
20
+ Examples of unacceptable behavior by participants include:
21
+
22
+ * The use of sexualized language or imagery and unwelcome sexual attention or
23
+ advances.
24
+ * Trolling, insulting/derogatory comments, and personal or political attacks.
25
+ * Public or private harassment.
26
+ * Publishing others' private information, such as a physical or electronic
27
+ address, without explicit permission.
28
+ * Conduct which could reasonably be considered inappropriate for the forum in
29
+ which it occurs.
30
+
31
+ All TensorFlow-models forums and spaces are meant for professional interactions, and any behavior which could reasonably be considered inappropriate in a professional setting is unacceptable.
32
+
33
+ å
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
37
+
38
+ Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
39
+
40
+
41
+ ## Scope
42
+
43
+ This Code of Conduct applies to all content on tensorflow.org, TensorFlow-models GitHub organization, or any other official TensorFlow-models web presence allowing for community interactions, as well as at all official TensorFlow-models events, whether offline or online.
44
+
45
+ The Code of Conduct also applies within project spaces and in public spaces whenever an individual is representing TensorFlow-models or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed or de facto representative at an online or offline event.
46
+
47
+
48
+ ## Conflict Resolution
49
+
50
+ Conflicts in an open source project can take many forms, from someone having a bad day and using harsh and hurtful language in the issue queue, to more serious instances such as sexist/racist statements or threats of violence, and everything in between.
51
+
52
+ If the behavior is threatening or harassing, or for other reasons requires immediate escalation, please see below.
53
+
54
+ However, for the vast majority of issues, we aim to empower individuals to first resolve conflicts themselves, asking for help when needed, and only after that fails to escalate further. This approach gives people more control over the outcome of their dispute.
55
+
56
+ If you are experiencing or witnessing conflict, we ask you to use the following escalation strategy to address the conflict:
57
+
58
+ 1. Address the perceived conflict directly with those involved, preferably in a
59
+ real-time medium.
60
+ 2. If this fails, get a third party (e.g. a mutual friend, and/or someone with
61
+ background on the issue, but not involved in the conflict) to intercede.
62
+ 3. If you are still unable to resolve the conflict, and you believe it rises to
63
+ harassment or another code of conduct violation, report it.
64
+
65
+ ## Reporting Violations
66
+
67
+ Violations of the Code of Conduct can be reported to TensorFlow’s Project Stewards, Thea Lamkin (thealamkin@google.com) and Joana Carrasqueira (joanafilipa@google.com). The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report.
68
+
69
+ Violations of the Code of Conduct can occur in any setting, even those unrelated to the project. We will only consider complaints about conduct that has occurred within one year of the report.
70
+
71
+
72
+ ## Enforcement
73
+
74
+ If the Project Stewards receive a report alleging a violation of the Code of Conduct, the Project Stewards will notify the accused of the report, and provide them an opportunity to discuss the report before a sanction is issued. The Project Stewards will do their utmost to keep the reporter anonymous. If the act is ongoing (such as someone engaging in harassment), or involves a threat to anyone's safety (e.g. threats of violence), the Project Stewards may issue sanctions without notice.
75
+
76
+
77
+ ## Attribution
78
+
79
+ This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at https://contributor-covenant.org/version/1/4, and includes some aspects of the Geek Feminism Code of Conduct and the Drupal Code of Conduct.
models/CONTRIBUTING.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to contribute
2
+
3
+ ![Contributors](https://img.shields.io/github/contributors/tensorflow/models)
4
+
5
+ We encourage you to contribute to the TensorFlow Model Garden.
6
+
7
+ Please read our [guidelines](../../wiki/How-to-contribute) for details.
8
+
9
+ **NOTE**: Only [code owners](./CODEOWNERS) are allowed to merge a pull request.
10
+ Please contact the code owners of each model to merge your pull request.
models/ISSUES.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # If you open a GitHub issue, here is our policy.
2
+
3
+ * It must be a **bug**, a **feature request**, or a significant problem
4
+ with **documentation**.
5
+ * Please send a pull request instead for small documentation fixes.
6
+ * The required form must be filled out.
7
+ * The issue should be related to the repository it is created in.
8
+
9
+ General help and support should be sought on [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow-model-garden) or other non-GitHub channels.
10
+
11
+ [![](https://img.shields.io/stackexchange/stackoverflow/t/tensorflow-model-garden)](https://stackoverflow.com/questions/tagged/tensorflow-model-garden)
12
+
13
+ TensorFlow developers respond to issues.
14
+ We want to focus on work that benefits the whole community such as fixing bugs
15
+ and adding new features.
16
+ It helps us to address bugs and feature requests in a timely manner.
17
+
18
+ ---
19
+
20
+ Please understand that research models in the [research directory](https://github.com/tensorflow/models/tree/master/research)
21
+ included in this repository are experimental and research-style code.
22
+ They are not officially supported by the TensorFlow team.
23
+
24
+
models/LICENSE ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2022 Google LLC. All rights reserved.
2
+
3
+ All files in the following folders:
4
+ /community
5
+ /official
6
+ /orbit
7
+ /research
8
+ /tensorflow_models
9
+
10
+ Are licensed as follows:
11
+
12
+ Apache License
13
+ Version 2.0, January 2004
14
+ http://www.apache.org/licenses/
15
+
16
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
17
+
18
+ 1. Definitions.
19
+
20
+ "License" shall mean the terms and conditions for use, reproduction,
21
+ and distribution as defined by Sections 1 through 9 of this document.
22
+
23
+ "Licensor" shall mean the copyright owner or entity authorized by
24
+ the copyright owner that is granting the License.
25
+
26
+ "Legal Entity" shall mean the union of the acting entity and all
27
+ other entities that control, are controlled by, or are under common
28
+ control with that entity. For the purposes of this definition,
29
+ "control" means (i) the power, direct or indirect, to cause the
30
+ direction or management of such entity, whether by contract or
31
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
32
+ outstanding shares, or (iii) beneficial ownership of such entity.
33
+
34
+ "You" (or "Your") shall mean an individual or Legal Entity
35
+ exercising permissions granted by this License.
36
+
37
+ "Source" form shall mean the preferred form for making modifications,
38
+ including but not limited to software source code, documentation
39
+ source, and configuration files.
40
+
41
+ "Object" form shall mean any form resulting from mechanical
42
+ transformation or translation of a Source form, including but
43
+ not limited to compiled object code, generated documentation,
44
+ and conversions to other media types.
45
+
46
+ "Work" shall mean the work of authorship, whether in Source or
47
+ Object form, made available under the License, as indicated by a
48
+ copyright notice that is included in or attached to the work
49
+ (an example is provided in the Appendix below).
50
+
51
+ "Derivative Works" shall mean any work, whether in Source or Object
52
+ form, that is based on (or derived from) the Work and for which the
53
+ editorial revisions, annotations, elaborations, or other modifications
54
+ represent, as a whole, an original work of authorship. For the purposes
55
+ of this License, Derivative Works shall not include works that remain
56
+ separable from, or merely link (or bind by name) to the interfaces of,
57
+ the Work and Derivative Works thereof.
58
+
59
+ "Contribution" shall mean any work of authorship, including
60
+ the original version of the Work and any modifications or additions
61
+ to that Work or Derivative Works thereof, that is intentionally
62
+ submitted to Licensor for inclusion in the Work by the copyright owner
63
+ or by an individual or Legal Entity authorized to submit on behalf of
64
+ the copyright owner. For the purposes of this definition, "submitted"
65
+ means any form of electronic, verbal, or written communication sent
66
+ to the Licensor or its representatives, including but not limited to
67
+ communication on electronic mailing lists, source code control systems,
68
+ and issue tracking systems that are managed by, or on behalf of, the
69
+ Licensor for the purpose of discussing and improving the Work, but
70
+ excluding communication that is conspicuously marked or otherwise
71
+ designated in writing by the copyright owner as "Not a Contribution."
72
+
73
+ "Contributor" shall mean Licensor and any individual or Legal Entity
74
+ on behalf of whom a Contribution has been received by Licensor and
75
+ subsequently incorporated within the Work.
76
+
77
+ 2. Grant of Copyright License. Subject to the terms and conditions of
78
+ this License, each Contributor hereby grants to You a perpetual,
79
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
80
+ copyright license to reproduce, prepare Derivative Works of,
81
+ publicly display, publicly perform, sublicense, and distribute the
82
+ Work and such Derivative Works in Source or Object form.
83
+
84
+ 3. Grant of Patent License. Subject to the terms and conditions of
85
+ this License, each Contributor hereby grants to You a perpetual,
86
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
87
+ (except as stated in this section) patent license to make, have made,
88
+ use, offer to sell, sell, import, and otherwise transfer the Work,
89
+ where such license applies only to those patent claims licensable
90
+ by such Contributor that are necessarily infringed by their
91
+ Contribution(s) alone or by combination of their Contribution(s)
92
+ with the Work to which such Contribution(s) was submitted. If You
93
+ institute patent litigation against any entity (including a
94
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
95
+ or a Contribution incorporated within the Work constitutes direct
96
+ or contributory patent infringement, then any patent licenses
97
+ granted to You under this License for that Work shall terminate
98
+ as of the date such litigation is filed.
99
+
100
+ 4. Redistribution. You may reproduce and distribute copies of the
101
+ Work or Derivative Works thereof in any medium, with or without
102
+ modifications, and in Source or Object form, provided that You
103
+ meet the following conditions:
104
+
105
+ (a) You must give any other recipients of the Work or
106
+ Derivative Works a copy of this License; and
107
+
108
+ (b) You must cause any modified files to carry prominent notices
109
+ stating that You changed the files; and
110
+
111
+ (c) You must retain, in the Source form of any Derivative Works
112
+ that You distribute, all copyright, patent, trademark, and
113
+ attribution notices from the Source form of the Work,
114
+ excluding those notices that do not pertain to any part of
115
+ the Derivative Works; and
116
+
117
+ (d) If the Work includes a "NOTICE" text file as part of its
118
+ distribution, then any Derivative Works that You distribute must
119
+ include a readable copy of the attribution notices contained
120
+ within such NOTICE file, excluding those notices that do not
121
+ pertain to any part of the Derivative Works, in at least one
122
+ of the following places: within a NOTICE text file distributed
123
+ as part of the Derivative Works; within the Source form or
124
+ documentation, if provided along with the Derivative Works; or,
125
+ within a display generated by the Derivative Works, if and
126
+ wherever such third-party notices normally appear. The contents
127
+ of the NOTICE file are for informational purposes only and
128
+ do not modify the License. You may add Your own attribution
129
+ notices within Derivative Works that You distribute, alongside
130
+ or as an addendum to the NOTICE text from the Work, provided
131
+ that such additional attribution notices cannot be construed
132
+ as modifying the License.
133
+
134
+ You may add Your own copyright statement to Your modifications and
135
+ may provide additional or different license terms and conditions
136
+ for use, reproduction, or distribution of Your modifications, or
137
+ for any such Derivative Works as a whole, provided Your use,
138
+ reproduction, and distribution of the Work otherwise complies with
139
+ the conditions stated in this License.
140
+
141
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
142
+ any Contribution intentionally submitted for inclusion in the Work
143
+ by You to the Licensor shall be under the terms and conditions of
144
+ this License, without any additional terms or conditions.
145
+ Notwithstanding the above, nothing herein shall supersede or modify
146
+ the terms of any separate license agreement you may have executed
147
+ with Licensor regarding such Contributions.
148
+
149
+ 6. Trademarks. This License does not grant permission to use the trade
150
+ names, trademarks, service marks, or product names of the Licensor,
151
+ except as required for reasonable and customary use in describing the
152
+ origin of the Work and reproducing the content of the NOTICE file.
153
+
154
+ 7. Disclaimer of Warranty. Unless required by applicable law or
155
+ agreed to in writing, Licensor provides the Work (and each
156
+ Contributor provides its Contributions) on an "AS IS" BASIS,
157
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
158
+ implied, including, without limitation, any warranties or conditions
159
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
160
+ PARTICULAR PURPOSE. You are solely responsible for determining the
161
+ appropriateness of using or redistributing the Work and assume any
162
+ risks associated with Your exercise of permissions under this License.
163
+
164
+ 8. Limitation of Liability. In no event and under no legal theory,
165
+ whether in tort (including negligence), contract, or otherwise,
166
+ unless required by applicable law (such as deliberate and grossly
167
+ negligent acts) or agreed to in writing, shall any Contributor be
168
+ liable to You for damages, including any direct, indirect, special,
169
+ incidental, or consequential damages of any character arising as a
170
+ result of this License or out of the use or inability to use the
171
+ Work (including but not limited to damages for loss of goodwill,
172
+ work stoppage, computer failure or malfunction, or any and all
173
+ other commercial damages or losses), even if such Contributor
174
+ has been advised of the possibility of such damages.
175
+
176
+ 9. Accepting Warranty or Additional Liability. While redistributing
177
+ the Work or Derivative Works thereof, You may choose to offer,
178
+ and charge a fee for, acceptance of support, warranty, indemnity,
179
+ or other liability obligations and/or rights consistent with this
180
+ License. However, in accepting such obligations, You may act only
181
+ on Your own behalf and on Your sole responsibility, not on behalf
182
+ of any other Contributor, and only if You agree to indemnify,
183
+ defend, and hold each Contributor harmless for any liability
184
+ incurred by, or claims asserted against, such Contributor by reason
185
+ of your accepting any such warranty or additional liability.
186
+
187
+ END OF TERMS AND CONDITIONS
188
+
189
+ APPENDIX: How to apply the Apache License to your work.
190
+
191
+ To apply the Apache License to your work, attach the following
192
+ boilerplate notice, with the fields enclosed by brackets "[]"
193
+ replaced with your own identifying information. (Don't include
194
+ the brackets!) The text should be enclosed in the appropriate
195
+ comment syntax for the file format. We also recommend that a
196
+ file or class name and description of purpose be included on the
197
+ same "printed page" as the copyright notice for easier
198
+ identification within third-party archives.
199
+
200
+ Copyright 2016, The Authors.
201
+
202
+ Licensed under the Apache License, Version 2.0 (the "License");
203
+ you may not use this file except in compliance with the License.
204
+ You may obtain a copy of the License at
205
+
206
+ http://www.apache.org/licenses/LICENSE-2.0
207
+
208
+ Unless required by applicable law or agreed to in writing, software
209
+ distributed under the License is distributed on an "AS IS" BASIS,
210
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
211
+ See the License for the specific language governing permissions and
212
+ limitations under the License.
models/README.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <img src="https://storage.googleapis.com/tf_model_garden/tf_model_garden_logo.png">
3
+ </div>
4
+
5
+ [![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow)
6
+ [![tf-models-official PyPI](https://badge.fury.io/py/tf-models-official.svg)](https://badge.fury.io/py/tf-models-official)
7
+
8
+
9
+ # Welcome to the Model Garden for TensorFlow
10
+
11
+ The TensorFlow Model Garden is a repository with a number of different
12
+ implementations of state-of-the-art (SOTA) models and modeling solutions for
13
+ TensorFlow users. We aim to demonstrate the best practices for modeling so that
14
+ TensorFlow users can take full advantage of TensorFlow for their research and
15
+ product development.
16
+
17
+ To improve the transparency and reproducibility of our models, training logs on
18
+ [TensorBoard.dev](https://tensorboard.dev) are also provided for models to the
19
+ extent possible though not all models are suitable.
20
+
21
+ | Directory | Description |
22
+ |-----------|-------------|
23
+ | [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read<br /> For more details on the capabilities, check the guide on the [Model-garden](https://www.tensorflow.org/tfmodels)|
24
+ | [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers |
25
+ | [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
26
+ | [orbit](orbit) | • A flexible and lightweight library that users can easily use or fork when writing customized training loop code in TensorFlow 2.x. It seamlessly integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU). |
27
+
28
+ ## Installation
29
+
30
+ To install the current release of tensorflow-models, please follow any one of the methods described below.
31
+
32
+ #### Method 1: Install the TensorFlow Model Garden pip package
33
+
34
+ <details>
35
+
36
+ **tf-models-official** is the stable Model Garden package. Please check out the [releases](https://github.com/tensorflow/models/releases) to see what are available modules.
37
+
38
+ pip3 will install all models and dependencies automatically.
39
+
40
+ ```shell
41
+ pip3 install tf-models-official
42
+ ```
43
+
44
+ Please check out our examples:
45
+ - [basic library import](https://github.com/tensorflow/models/blob/master/tensorflow_models/tensorflow_models_pypi.ipynb)
46
+ - [nlp model building](https://github.com/tensorflow/models/blob/master/docs/nlp/index.ipynb)
47
+ to learn how to use a PIP package.
48
+
49
+ Note that **tf-models-official** may not include the latest changes in the master branch of this
50
+ github repo. To include latest changes, you may install **tf-models-nightly**,
51
+ which is the nightly Model Garden package created daily automatically.
52
+
53
+ ```shell
54
+ pip3 install tf-models-nightly
55
+ ```
56
+
57
+ </details>
58
+
59
+
60
+ #### Method 2: Clone the source
61
+
62
+ <details>
63
+
64
+ 1. Clone the GitHub repository:
65
+
66
+ ```shell
67
+ git clone https://github.com/tensorflow/models.git
68
+ ```
69
+
70
+ 2. Add the top-level ***/models*** folder to the Python path.
71
+
72
+ ```shell
73
+ export PYTHONPATH=$PYTHONPATH:/path/to/models
74
+ ```
75
+
76
+ If you are using in a Windows environment, you may need to use the following command with PowerShell:
77
+ ```shell
78
+ $env:PYTHONPATH += ":\path\to\models"
79
+ ```
80
+
81
+ If you are using a Colab notebook, please set the Python path with os.environ.
82
+
83
+ ```python
84
+ import os
85
+ os.environ['PYTHONPATH'] += ":/path/to/models"
86
+ ```
87
+
88
+ 3. Install other dependencies
89
+
90
+ ```shell
91
+ pip3 install --user -r models/official/requirements.txt
92
+ ```
93
+
94
+ Finally, if you are using nlp packages, please also install
95
+ **tensorflow-text-nightly**:
96
+
97
+ ```shell
98
+ pip3 install tensorflow-text-nightly
99
+ ```
100
+
101
+ </details>
102
+
103
+
104
+ ## Announcements
105
+
106
+ Please check [this page](https://github.com/tensorflow/models/wiki/Announcements) for recent announcements.
107
+
108
+ ## Contributions
109
+
110
+ [![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
111
+
112
+ If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
113
+
114
+ ## License
115
+
116
+ [Apache License 2.0](LICENSE)
117
+
118
+ ## Citing TensorFlow Model Garden
119
+
120
+ If you use TensorFlow Model Garden in your research, please cite this repository.
121
+
122
+ ```
123
+ @misc{tensorflowmodelgarden2020,
124
+ author = {Hongkun Yu, Chen Chen, Xianzhi Du, Yeqing Li, Abdullah Rashwan, Le Hou, Pengchong Jin, Fan Yang,
125
+ Frederick Liu, Jaeyoun Kim, and Jing Li},
126
+ title = {{TensorFlow Model Garden}},
127
+ howpublished = {\url{https://github.com/tensorflow/models}},
128
+ year = {2020}
129
+ }
130
+ ```
models/SECURITY.md ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Using TensorFlow Securely
2
+
3
+ This document discusses how to safely deal with untrusted programs (models or
4
+ model parameters), and input data. Below, we also provide guidelines on how to
5
+ report vulnerabilities in TensorFlow.
6
+
7
+ ## TensorFlow models are programs
8
+
9
+ TensorFlow's runtime system interprets and executes programs. What machine
10
+ learning practitioners term
11
+ [**models**](https://developers.google.com/machine-learning/glossary/#model) are
12
+ expressed as programs that TensorFlow executes. TensorFlow programs are encoded
13
+ as computation
14
+ [**graphs**](https://developers.google.com/machine-learning/glossary/#graph).
15
+ The model's parameters are often stored separately in **checkpoints**.
16
+
17
+ At runtime, TensorFlow executes the computation graph using the parameters
18
+ provided. Note that the behavior of the computation graph may change
19
+ depending on the parameters provided. TensorFlow itself is not a sandbox. When
20
+ executing the computation graph, TensorFlow may read and write files, send and
21
+ receive data over the network, and even spawn additional processes. All these
22
+ tasks are performed with the permissions of the TensorFlow process. Allowing
23
+ for this flexibility makes for a powerful machine learning platform,
24
+ but it has implications for security.
25
+
26
+ The computation graph may also accept **inputs**. Those inputs are the
27
+ data you supply to TensorFlow to train a model, or to use a model to run
28
+ inference on the data.
29
+
30
+ **TensorFlow models are programs, and need to be treated as such from a security
31
+ perspective.**
32
+
33
+ ## Running untrusted models
34
+
35
+ As a general rule: **Always** execute untrusted models inside a sandbox (e.g.,
36
+ [nsjail](https://github.com/google/nsjail)).
37
+
38
+ There are several ways in which a model could become untrusted. Obviously, if an
39
+ untrusted party supplies TensorFlow kernels, arbitrary code may be executed.
40
+ The same is true if the untrusted party provides Python code, such as the
41
+ Python code that generates TensorFlow graphs.
42
+
43
+ Even if the untrusted party only supplies the serialized computation
44
+ graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the
45
+ set of computation primitives available to TensorFlow is powerful enough that
46
+ you should assume that the TensorFlow process effectively executes arbitrary
47
+ code. One common solution is to allow only a few safe Ops. While this is
48
+ possible in theory, we still recommend you sandbox the execution.
49
+
50
+ It depends on the computation graph whether a user provided checkpoint is safe.
51
+ It is easily possible to create computation graphs in which malicious
52
+ checkpoints can trigger unsafe behavior. For example, consider a graph that
53
+ contains a `tf.cond` depending on the value of a `tf.Variable`. One branch of
54
+ the `tf.cond` is harmless, but the other is unsafe. Since the `tf.Variable` is
55
+ stored in the checkpoint, whoever provides the checkpoint now has the ability to
56
+ trigger unsafe behavior, even though the graph is not under their control.
57
+
58
+ In other words, graphs can contain vulnerabilities of their own. To allow users
59
+ to provide checkpoints to a model you run on their behalf (e.g., in order to
60
+ compare model quality for a fixed model architecture), you must carefully audit
61
+ your model, and we recommend you run the TensorFlow process in a sandbox.
62
+
63
+ ## Accepting untrusted Inputs
64
+
65
+ It is possible to write models that are secure in a sense that they can safely
66
+ process untrusted inputs assuming there are no bugs. There are two main reasons
67
+ to not rely on this: First, it is easy to write models which must not be exposed
68
+ to untrusted inputs, and second, there are bugs in any software system of
69
+ sufficient complexity. Letting users control inputs could allow them to trigger
70
+ bugs either in TensorFlow or in dependent libraries.
71
+
72
+ In general, it is good practice to isolate parts of any system which is exposed
73
+ to untrusted (e.g., user-provided) inputs in a sandbox.
74
+
75
+ A useful analogy to how any TensorFlow graph is executed is any interpreted
76
+ programming language, such as Python. While it is possible to write secure
77
+ Python code which can be exposed to user supplied inputs (by, e.g., carefully
78
+ quoting and sanitizing input strings, size-checking input blobs, etc.), it is
79
+ very easy to write Python programs which are insecure. Even secure Python code
80
+ could be rendered insecure by a bug in the Python interpreter, or in a bug in a
81
+ Python library used (e.g.,
82
+ [this one](https://www.cvedetails.com/cve/CVE-2017-12852/)).
83
+
84
+ ## Running a TensorFlow server
85
+
86
+ TensorFlow is a platform for distributed computing, and as such there is a
87
+ TensorFlow server (`tf.train.Server`). **The TensorFlow server is meant for
88
+ internal communication only. It is not built for use in an untrusted network.**
89
+
90
+ For performance reasons, the default TensorFlow server does not include any
91
+ authorization protocol and sends messages unencrypted. It accepts connections
92
+ from anywhere, and executes the graphs it is sent without performing any checks.
93
+ Therefore, if you run a `tf.train.Server` in your network, anybody with
94
+ access to the network can execute what you should consider arbitrary code with
95
+ the privileges of the process running the `tf.train.Server`.
96
+
97
+ When running distributed TensorFlow, you must isolate the network in which the
98
+ cluster lives. Cloud providers provide instructions for setting up isolated
99
+ networks, which are sometimes branded as "virtual private cloud." Refer to the
100
+ instructions for
101
+ [GCP](https://cloud.google.com/compute/docs/networks-and-firewalls) and
102
+ [AWS](https://aws.amazon.com/vpc/)) for details.
103
+
104
+ Note that `tf.train.Server` is different from the server created by
105
+ `tensorflow/serving` (the default binary for which is called `ModelServer`).
106
+ By default, `ModelServer` also has no built-in mechanism for authentication.
107
+ Connecting it to an untrusted network allows anyone on this network to run the
108
+ graphs known to the `ModelServer`. This means that an attacker may run
109
+ graphs using untrusted inputs as described above, but they would not be able to
110
+ execute arbitrary graphs. It is possible to safely expose a `ModelServer`
111
+ directly to an untrusted network, **but only if the graphs it is configured to
112
+ use have been carefully audited to be safe**.
113
+
114
+ Similar to best practices for other servers, we recommend running any
115
+ `ModelServer` with appropriate privileges (i.e., using a separate user with
116
+ reduced permissions). In the spirit of defense in depth, we recommend
117
+ authenticating requests to any TensorFlow server connected to an untrusted
118
+ network, as well as sandboxing the server to minimize the adverse effects of
119
+ any breach.
120
+
121
+ ## Vulnerabilities in TensorFlow
122
+
123
+ TensorFlow is a large and complex system. It also depends on a large set of
124
+ third party libraries (e.g., `numpy`, `libjpeg-turbo`, PNG parsers, `protobuf`).
125
+ It is possible that TensorFlow or its dependent libraries contain
126
+ vulnerabilities that would allow triggering unexpected or dangerous behavior
127
+ with specially crafted inputs.
128
+
129
+ ### What is a vulnerability?
130
+
131
+ Given TensorFlow's flexibility, it is possible to specify computation graphs
132
+ which exhibit unexpected or unwanted behavior. The fact that TensorFlow models
133
+ can perform arbitrary computations means that they may read and write files,
134
+ communicate via the network, produce deadlocks and infinite loops, or run out
135
+ of memory. It is only when these behaviors are outside the specifications of the
136
+ operations involved that such behavior is a vulnerability.
137
+
138
+ A `FileWriter` writing a file is not unexpected behavior and therefore is not a
139
+ vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution
140
+ **is** a vulnerability.
141
+
142
+ This is more subtle from a system perspective. For example, it is easy to cause
143
+ a TensorFlow process to try to allocate more memory than available by specifying
144
+ a computation graph containing an ill-considered `tf.tile` operation. TensorFlow
145
+ should exit cleanly in this case (it would raise an exception in Python, or
146
+ return an error `Status` in C++). However, if the surrounding system is not
147
+ expecting the possibility, such behavior could be used in a denial of service
148
+ attack (or worse). Because TensorFlow behaves correctly, this is not a
149
+ vulnerability in TensorFlow (although it would be a vulnerability of this
150
+ hypothetical system).
151
+
152
+ As a general rule, it is incorrect behavior for TensorFlow to access memory it
153
+ does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
154
+ such behaviors constitute a vulnerability.
155
+
156
+ One of the most critical parts of any system is input handling. If malicious
157
+ input can trigger side effects or incorrect behavior, this is a bug, and likely
158
+ a vulnerability.
159
+
160
+ ### Reporting vulnerabilities
161
+
162
+ Please email reports about any security related issues you find to
163
+ `security@tensorflow.org`. This mail is delivered to a small security team. Your
164
+ email will be acknowledged within one business day, and you'll receive a more
165
+ detailed response to your email within 7 days indicating the next steps in
166
+ handling your report. For critical problems, you may encrypt your report (see
167
+ below).
168
+
169
+ Please use a descriptive subject line for your report email. After the initial
170
+ reply to your report, the security team will endeavor to keep you informed of
171
+ the progress being made towards a fix and announcement.
172
+
173
+ In addition, please include the following information along with your report:
174
+
175
+ * Your name and affiliation (if any).
176
+ * A description of the technical details of the vulnerabilities. It is very
177
+ important to let us know how we can reproduce your findings.
178
+ * An explanation who can exploit this vulnerability, and what they gain when
179
+ doing so -- write an attack scenario. This will help us evaluate your report
180
+ quickly, especially if the issue is complex.
181
+ * Whether this vulnerability public or known to third parties. If it is, please
182
+ provide details.
183
+
184
+ If you believe that an existing (public) issue is security-related, please send
185
+ an email to `security@tensorflow.org`. The email should include the issue ID and
186
+ a short description of why it should be handled according to this security
187
+ policy.
188
+
189
+ Once an issue is reported, TensorFlow uses the following disclosure process:
190
+
191
+ * When a report is received, we confirm the issue and determine its severity.
192
+ * If we know of specific third-party services or software based on TensorFlow
193
+ that require mitigation before publication, those projects will be notified.
194
+ * An advisory is prepared (but not published) which details the problem and
195
+ steps for mitigation.
196
+ * The vulnerability is fixed and potential workarounds are identified.
197
+ * Wherever possible, the fix is also prepared for the branches corresponding to
198
+ all releases of TensorFlow at most one year old. We will attempt to commit
199
+ these fixes as soon as possible, and as close together as possible.
200
+ * Patch releases are published for all fixed released versions, a
201
+ notification is sent to discuss@tensorflow.org, and the advisory is published.
202
+
203
+ Note that we mostly do patch releases for security reasons and each version of
204
+ TensorFlow is supported for only 1 year after the release.
205
+
206
+ Past security advisories are listed below. We credit reporters for identifying
207
+ security issues, although we keep your name confidential if you request it.
208
+
209
+ #### Encryption key for `security@tensorflow.org`
210
+
211
+ If your disclosure is extremely sensitive, you may choose to encrypt your
212
+ report using the key below. Please only use this for critical security
213
+ reports.
214
+
215
+ ```
216
+ -----BEGIN PGP PUBLIC KEY BLOCK-----
217
+
218
+ mQENBFpqdzwBCADTeAHLNEe9Vm77AxhmGP+CdjlY84O6DouOCDSq00zFYdIU/7aI
219
+ LjYwhEmDEvLnRCYeFGdIHVtW9YrVktqYE9HXVQC7nULU6U6cvkQbwHCdrjaDaylP
220
+ aJUXkNrrxibhx9YYdy465CfusAaZ0aM+T9DpcZg98SmsSml/HAiiY4mbg/yNVdPs
221
+ SEp/Ui4zdIBNNs6at2gGZrd4qWhdM0MqGJlehqdeUKRICE/mdedXwsWLM8AfEA0e
222
+ OeTVhZ+EtYCypiF4fVl/NsqJ/zhBJpCx/1FBI1Uf/lu2TE4eOS1FgmIqb2j4T+jY
223
+ e+4C8kGB405PAC0n50YpOrOs6k7fiQDjYmbNABEBAAG0LVRlbnNvckZsb3cgU2Vj
224
+ dXJpdHkgPHNlY3VyaXR5QHRlbnNvcmZsb3cub3JnPokBTgQTAQgAOBYhBEkvXzHm
225
+ gOJBnwP4Wxnef3wVoM2yBQJaanc8AhsDBQsJCAcCBhUKCQgLAgQWAgMBAh4BAheA
226
+ AAoJEBnef3wVoM2yNlkIAICqetv33MD9W6mPAXH3eon+KJoeHQHYOuwWfYkUF6CC
227
+ o+X2dlPqBSqMG3bFuTrrcwjr9w1V8HkNuzzOJvCm1CJVKaxMzPuXhBq5+DeT67+a
228
+ T/wK1L2R1bF0gs7Pp40W3np8iAFEh8sgqtxXvLGJLGDZ1Lnfdprg3HciqaVAiTum
229
+ HBFwszszZZ1wAnKJs5KVteFN7GSSng3qBcj0E0ql2nPGEqCVh+6RG/TU5C8gEsEf
230
+ 3DX768M4okmFDKTzLNBm+l08kkBFt+P43rNK8dyC4PXk7yJa93SmS/dlK6DZ16Yw
231
+ 2FS1StiZSVqygTW59rM5XNwdhKVXy2mf/RtNSr84gSi5AQ0EWmp3PAEIALInfBLR
232
+ N6fAUGPFj+K3za3PeD0fWDijlC9f4Ety/icwWPkOBdYVBn0atzI21thPRbfuUxfe
233
+ zr76xNNrtRRlbDSAChA1J5T86EflowcQor8dNC6fS+oHFCGeUjfEAm16P6mGTo0p
234
+ osdG2XnnTHOOEFbEUeWOwR/zT0QRaGGknoy2pc4doWcJptqJIdTl1K8xyBieik/b
235
+ nSoClqQdZJa4XA3H9G+F4NmoZGEguC5GGb2P9NHYAJ3MLHBHywZip8g9oojIwda+
236
+ OCLL4UPEZ89cl0EyhXM0nIAmGn3Chdjfu3ebF0SeuToGN8E1goUs3qSE77ZdzIsR
237
+ BzZSDFrgmZH+uP0AEQEAAYkBNgQYAQgAIBYhBEkvXzHmgOJBnwP4Wxnef3wVoM2y
238
+ BQJaanc8AhsMAAoJEBnef3wVoM2yX4wIALcYZbQhSEzCsTl56UHofze6C3QuFQIH
239
+ J4MIKrkTfwiHlCujv7GASGU2Vtis5YEyOoMidUVLlwnebE388MmaJYRm0fhYq6lP
240
+ A3vnOCcczy1tbo846bRdv012zdUA+wY+mOITdOoUjAhYulUR0kiA2UdLSfYzbWwy
241
+ 7Obq96Jb/cPRxk8jKUu2rqC/KDrkFDtAtjdIHh6nbbQhFuaRuWntISZgpIJxd8Bt
242
+ Gwi0imUVd9m9wZGuTbDGi6YTNk0GPpX5OMF5hjtM/objzTihSw9UN+65Y/oSQM81
243
+ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
244
+ =CDME
245
+ -----END PGP PUBLIC KEY BLOCK-----
246
+ ```
247
+
248
+ ### Known Vulnerabilities
249
+
250
+ At this time there are no known vulnerability with TensorFlow-models. For a list of known vulnerabilities and security advisories for TensorFlow,
251
+ [click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).
models/community/README.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![Logo](https://storage.googleapis.com/tf_model_garden/tf_model_garden_logo.png)
2
+
3
+ # TensorFlow Community Models
4
+
5
+ This repository provides a curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2.
6
+
7
+ **Note**: Contributing companies or individuals are responsible for maintaining their repositories.
8
+
9
+ ## Computer Vision
10
+
11
+ ### Image Recognition
12
+
13
+ | Model | Paper | Features | Maintainer |
14
+ |-------|-------|----------|------------|
15
+ | [DenseNet 169](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/densenet169) | [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993) | • FP32 Inference | [Intel](https://github.com/IntelAI) |
16
+ | [Inception V3](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv3) | [Rethinking the Inception Architecture<br/>for Computer Vision](https://arxiv.org/pdf/1512.00567.pdf) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
17
+ | [Inception V4](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv4) | [Inception-v4, Inception-ResNet and the Impact<br/>of Residual Connections on Learning](https://arxiv.org/pdf/1602.07261) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
18
+ | [MobileNet V1](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/mobilenet_v1) | [MobileNets: Efficient Convolutional Neural Networks<br/>for Mobile Vision Applications](https://arxiv.org/pdf/1704.04861) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
19
+ | [ResNet 101](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet101) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
20
+ | [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
21
+ | [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
22
+ | EfficientNet [v1](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Classification/ConvNets/efficientnet_v1) [v2](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Classification/ConvNets/efficientnet_v2) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/pdf/1905.11946.pdf) | • Automatic mixed precision<br/>• Horovod Multi-GPU training (NCCL)<br/>• Multi-node training on a Pyxis/Enroot Slurm cluster<br/>• XLA | [NVIDIA](https://github.com/NVIDIA) |
23
+
24
+ ### Object Detection
25
+
26
+ | Model | Paper | Features | Maintainer |
27
+ |-------|-------|----------|------------|
28
+ | [R-FCN](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/rfcn) | [R-FCN: Object Detection<br/>via Region-based Fully Convolutional Networks](https://arxiv.org/pdf/1605.06409) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
29
+ | [SSD-MobileNet](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-mobilenet) | [MobileNets: Efficient Convolutional Neural Networks<br/>for Mobile Vision Applications](https://arxiv.org/pdf/1704.04861) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
30
+ | [SSD-ResNet34](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-resnet34) | [SSD: Single Shot MultiBox Detector](https://arxiv.org/pdf/1512.02325) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
31
+
32
+ ### Segmentation
33
+
34
+ | Model | Paper | Features | Maintainer |
35
+ |-------|-------|----------|------------|
36
+ | [Mask R-CNN](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/MaskRCNN) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) |
37
+ | [U-Net Medical Image Segmentation](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/UNet_Medical) | [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) |
38
+
39
+ ## Natural Language Processing
40
+
41
+ | Model | Paper | Features | Maintainer |
42
+ |-------|-------|----------|------------|
43
+ | [BERT](https://github.com/IntelAI/models/tree/master/benchmarks/language_modeling/tensorflow/bert_large) | [BERT: Pre-training of Deep Bidirectional Transformers<br/>for Language Understanding](https://arxiv.org/pdf/1810.04805) | • FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
44
+ | [BERT](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/LanguageModeling/BERT) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805) | • Horovod Multi-GPU<br/>• Multi-node with Horovod and Pyxis/Enroot Slurm cluster<br/>• XLA<br/>• Automatic mixed precision<br/>• LAMB | [NVIDIA](https://github.com/NVIDIA) |
45
+ | [ELECTRA](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/LanguageModeling/ELECTRA) | [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://openreview.net/forum?id=r1xMH1BtvB) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• Multi-node training on a Pyxis/Enroot Slurm cluster | [NVIDIA](https://github.com/NVIDIA) |
46
+ | [GNMT](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/mlperf_gnmt) | [Google’s Neural Machine Translation System:<br/>Bridging the Gap between Human and Machine Translation](https://arxiv.org/pdf/1609.08144) | • FP32 Inference | [Intel](https://github.com/IntelAI) |
47
+ | [Transformer-LT (Official)](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/transformer_lt_official) | [Attention Is All You Need](https://arxiv.org/pdf/1706.03762) | • FP32 Inference | [Intel](https://github.com/IntelAI) |
48
+ | [Transformer-LT (MLPerf)](https://github.com/IntelAI/models/tree/master/benchmarks/language_translation/tensorflow/transformer_mlperf) | [Attention Is All You Need](https://arxiv.org/pdf/1706.03762) | • FP32 Training | [Intel](https://github.com/IntelAI) |
49
+
50
+ ## Recommendation Systems
51
+
52
+ | Model | Paper | Features | Maintainer |
53
+ |-------|-------|----------|------------|
54
+ | [Wide & Deep](https://github.com/IntelAI/models/tree/master/benchmarks/recommendation/tensorflow/wide_deep_large_ds) | [Wide & Deep Learning for Recommender Systems](https://arxiv.org/pdf/1606.07792) | • FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
55
+ | [Wide & Deep](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Recommendation/WideAndDeep) | [Wide & Deep Learning for Recommender Systems](https://arxiv.org/pdf/1606.07792) | • Automatic mixed precision<br/>• Multi-GPU training support with Horovod<br/>• XLA | [NVIDIA](https://github.com/NVIDIA) |
56
+ | [DLRM](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Recommendation/DLRM) | [Deep Learning Recommendation Model for Personalization and Recommendation Systems](https://arxiv.org/pdf/1906.00091.pdf) | • Automatic Mixed Precision<br/>• Hybrid-parallel multiGPU training using Horovod all2all<br/>• Multinode training for Pyxis/Enroot Slurm clusters<br/>• XLA<br/>• Criteo dataset preprocessing with Spark on GPU | [NVIDIA](https://github.com/NVIDIA) |
57
+
58
+ ## Contributions
59
+
60
+ If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
models/docs/README.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Public docs for TensorFlow Models
2
+
3
+ This directory contains the top-level public documentation for
4
+ [TensorFlow Models](https://github.com/tensorflow/models)
5
+
6
+ This directory is mirrored to https://tensorflow.org/tfmodels, and is mainly
7
+ concerned with documenting the tools provided in the `tensorflow_models` pip
8
+ package (including `orbit`).
9
+
10
+ Api-reference pages are
11
+ [available on the site](https://www.tensorflow.org/api_docs/more).
12
+
13
+ The
14
+ [Official Models](https://github.com/tensorflow/models/blob/master/official/projects)
15
+ and [Research Models](https://github.com/tensorflow/models/blob/master/research)
16
+ directories are not described in detail here, refer to the individual project
17
+ directories for more information.
models/docs/index.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Garden overview
2
+
3
+ The TensorFlow Model Garden provides implementations of many state-of-the-art
4
+ machine learning (ML) models for vision and natural language processing (NLP),
5
+ as well as workflow tools to let you quickly configure and run those models on
6
+ standard datasets. Whether you are looking to benchmark performance for a
7
+ well-known model, verify the results of recently released research, or extend
8
+ existing models, the Model Garden can help you drive your ML research and
9
+ applications forward.
10
+
11
+ The Model Garden includes the following resources for machine learning
12
+ developers:
13
+
14
+ - [**Official models**](#official) for vision and NLP, maintained by Google
15
+ engineers
16
+ - [**Research models**](#research) published as part of ML research papers
17
+ - [**Training experiment framework**](#training_framework) for fast,
18
+ declarative training configuration of official models
19
+ - [**Specialized ML operations**](#ops) for vision and natural language
20
+ processing (NLP)
21
+ - [**Model training loop**](#orbit) management with Orbit
22
+
23
+ These resources are built to be used with the TensorFlow Core framework and
24
+ integrate with your existing TensorFlow development projects. Model
25
+ Garden resources are also provided under an [open
26
+ source](https://github.com/tensorflow/models/blob/master/LICENSE) license, so
27
+ you can freely extend and distribute the models and tools.
28
+
29
+ Practical ML models are computationally intensive to train and run, and may
30
+ require accelerators such as Graphical Processing Units (GPUs) and Tensor
31
+ Processing Units (TPUs). Most of the models in Model Garden were trained on
32
+ large datasets using TPUs. However, you can also train and run these models on
33
+ GPU and CPU processors.
34
+
35
+ ## Model Garden models
36
+
37
+ The machine learning models in the Model Garden include full code so you can
38
+ test, train, or re-train them for research and experimentation. The Model Garden
39
+ includes two primary categories of models: *official models* and *research
40
+ models*.
41
+
42
+ ### Official models {:#official}
43
+
44
+ The [Official Models](https://github.com/tensorflow/models/tree/master/official)
45
+ repository is a collection of state-of-the-art models, with a focus on
46
+ vision and natural language processing (NLP).
47
+ These models are implemented using current TensorFlow 2.x high-level
48
+ APIs. Model libraries in this repository are optimized for fast performance and
49
+ actively maintained by Google engineers. The official models include additional
50
+ metadata you can use to quickly configure experiments using the Model Garden
51
+ [training experiment framework](#training_framework).
52
+
53
+ ### Research models {:#research}
54
+
55
+ The [Research Models](https://github.com/tensorflow/models/tree/master/research)
56
+ repository is a collection of models published as code resources for research
57
+ papers. These models are implemented using both TensorFlow 1.x and 2.x. Model
58
+ libraries in the research folder are supported by the code owners and the
59
+ research community.
60
+
61
+ ## Training experiment framework {:#training_framework}
62
+
63
+ The Model Garden training experiment framework lets you quickly assemble and run
64
+ training experiments using its official models and standard datasets. The
65
+ training framework uses additional metadata included with the Model Garden's
66
+ official models to allow you to configure models quickly using a declarative
67
+ programming model. You can define a training experiment using Python commands in
68
+ the
69
+ [TensorFlow Model library](https://www.tensorflow.org/api_docs/python/tfm/core)
70
+ or configure training using a YAML configuration file, like this
71
+ [example](https://github.com/tensorflow/models/blob/master/official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml).
72
+
73
+ The training framework uses
74
+ [`tfm.core.base_trainer.ExperimentConfig`](https://www.tensorflow.org/api_docs/python/tfm/core/base_trainer/ExperimentConfig)
75
+ as the configuration object, which contains the following top-level
76
+ configuration objects:
77
+
78
+ - [`runtime`](https://www.tensorflow.org/api_docs/python/tfm/core/base_task/RuntimeConfig):
79
+ Defines the processing hardware, distribution strategy, and other
80
+ performance optimizations
81
+ - [`task`](https://www.tensorflow.org/api_docs/python/tfm/core/config_definitions/TaskConfig):
82
+ Defines the model, training data, losses, and initialization
83
+ - [`trainer`](https://www.tensorflow.org/api_docs/python/tfm/core/base_trainer/TrainerConfig):
84
+ Defines the optimizer, training loops, evaluation loops, summaries, and
85
+ checkpoints
86
+
87
+ For a complete example using the Model Garden training experiment framework, see
88
+ the [Image classification with Model Garden](vision/image_classification.ipynb)
89
+ tutorial. For information on the training experiment framework, check out the
90
+ [TensorFlow Models API documentation](https://tensorflow.org/api_docs/python/tfm/core).
91
+ If you are looking for a solution to manage training loops for your model
92
+ training experiments, check out [Orbit](#orbit).
93
+
94
+ ## Specialized ML operations {:#ops}
95
+
96
+ The Model Garden contains many vision and NLP operations specifically designed
97
+ to execute state-of-the-art models that run efficiently on GPUs and TPUs. Review
98
+ the TensorFlow Models Vision library API docs for a list of specialized
99
+ [vision operations](https://www.tensorflow.org/api_docs/python/tfm/vision).
100
+ Review the TensorFlow Models NLP Library API docs for a list of
101
+ [NLP operations](https://www.tensorflow.org/api_docs/python/tfm/nlp). These
102
+ libraries also include additional utility functions used for vision and NLP data
103
+ processing, training, and model execution.
104
+
105
+ ## Training loops with Orbit {:#orbit}
106
+
107
+ There are two default options for training TensorFlow models:
108
+
109
+ * Use the high-level Keras
110
+ [Model.fit](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit)
111
+ function. If your model and training procedure fit the assumptions of Keras'
112
+ `Model.fit` (incremental gradient descent on batches of data) method this can
113
+ be very convenient.
114
+ * Write a custom training loop
115
+ [with keras](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch),
116
+ or [without](https://www.tensorflow.org/guide/core/logistic_regression_core).
117
+ You can write a custom training loop with low-level TensorFlow methods such as
118
+ `tf.GradientTape` or `tf.function`. However, this approach requires a lot of
119
+ boilerplate code, and doesn't do anything to simplify distributed training.
120
+
121
+ Orbit tries to provide a third option in between these two extremes.
122
+
123
+ Orbit is a flexible, lightweight library designed to make it easier to
124
+ write custom training loops in TensorFlow 2.x, and works well with the Model
125
+ Garden [training experiment framework](#training_framework). Orbit handles
126
+ common model training tasks such as saving checkpoints, running model
127
+ evaluations, and setting up summary writing. It seamlessly integrates with
128
+ `tf.distribute` and supports running on different device types, including CPU,
129
+ GPU, and TPU hardware. The Orbit tool is also [open
130
+ source](https://github.com/tensorflow/models/blob/master/orbit/LICENSE), so you
131
+ can extend and adapt to your model training needs.
132
+
133
+ The Orbit guide is available [here](orbit/index.ipynb).
134
+
135
+ Note: You can customize how the Keras API executes training. Mainly you must
136
+ override the `Model.train_step` method or use `keras.callbacks` like
137
+ `callbacks.ModelCheckpoint` or `callbacks.TensorBoard`. For more information
138
+ about modifying the behavior of `train_step`, check out the
139
+ [Customize what happens in Model.fit](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit)
140
+ page.
models/docs/nlp/_guide_toc.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ toc:
2
+ - heading: TensorFlow Models - NLP
3
+ style: divider
4
+ - title: "Overview"
5
+ path: /tfmodels/nlp
6
+ - title: "Customize a transformer encoder"
7
+ path: /tfmodels/nlp/customize_encoder
8
+ - title: "Load LM checkpoints"
9
+ path: /tfmodels/nlp/load_lm_ckpts
models/docs/nlp/customize_encoder.ipynb ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "Bp8t2AI8i7uP"
7
+ },
8
+ "source": [
9
+ "##### Copyright 2022 The TensorFlow Authors."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "cellView": "form",
17
+ "id": "rxPj2Lsni9O4"
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
22
+ "# you may not use this file except in compliance with the License.\n",
23
+ "# You may obtain a copy of the License at\n",
24
+ "#\n",
25
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
26
+ "#\n",
27
+ "# Unless required by applicable law or agreed to in writing, software\n",
28
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
29
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
30
+ "# See the License for the specific language governing permissions and\n",
31
+ "# limitations under the License."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {
37
+ "id": "6xS-9i5DrRvO"
38
+ },
39
+ "source": [
40
+ "# Customizing a Transformer Encoder"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "metadata": {
46
+ "id": "Mwb9uw1cDXsa"
47
+ },
48
+ "source": [
49
+ "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
50
+ " <td>\n",
51
+ " <a target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/customize_encoder\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
52
+ " </td>\n",
53
+ " <td>\n",
54
+ " <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
55
+ " </td>\n",
56
+ " <td>\n",
57
+ " <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
58
+ " </td>\n",
59
+ " <td>\n",
60
+ " <a href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
61
+ " </td>\n",
62
+ "</table>"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "metadata": {
68
+ "id": "iLrcV4IyrcGX"
69
+ },
70
+ "source": [
71
+ "## Learning objectives\n",
72
+ "\n",
73
+ "The [TensorFlow Models NLP library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling) is a collection of tools for building and training modern high performance natural language models.\n",
74
+ "\n",
75
+ "The `tfm.nlp.networks.EncoderScaffold` is the core of this library, and lots of new network architectures are proposed to improve the encoder. In this Colab notebook, we will learn how to customize the encoder to employ new network architectures."
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "markdown",
80
+ "metadata": {
81
+ "id": "YYxdyoWgsl8t"
82
+ },
83
+ "source": [
84
+ "## Install and import"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "markdown",
89
+ "metadata": {
90
+ "id": "fEJSFutUsn_h"
91
+ },
92
+ "source": [
93
+ "### Install the TensorFlow Model Garden pip package\n",
94
+ "\n",
95
+ "* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
96
+ "which is the nightly Model Garden package created daily automatically.\n",
97
+ "* `pip` will install all models and dependencies automatically."
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "metadata": {
104
+ "id": "mfHI5JyuJ1y9"
105
+ },
106
+ "outputs": [],
107
+ "source": [
108
+ "!pip install -q opencv-python"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "metadata": {
115
+ "id": "thsKZDjhswhR"
116
+ },
117
+ "outputs": [],
118
+ "source": [
119
+ "!pip install -q tf-models-official"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "markdown",
124
+ "metadata": {
125
+ "id": "hpf7JPCVsqtv"
126
+ },
127
+ "source": [
128
+ "### Import Tensorflow and other libraries"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "metadata": {
135
+ "id": "my4dp-RMssQe"
136
+ },
137
+ "outputs": [],
138
+ "source": [
139
+ "import numpy as np\n",
140
+ "import tensorflow as tf\n",
141
+ "\n",
142
+ "import tensorflow_models as tfm\n",
143
+ "nlp = tfm.nlp"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "markdown",
148
+ "metadata": {
149
+ "id": "vjDmVsFfs85n"
150
+ },
151
+ "source": [
152
+ "## Canonical BERT encoder\n",
153
+ "\n",
154
+ "Before learning how to customize the encoder, let's firstly create a canonical BERT enoder and use it to instantiate a `bert_classifier.BertClassifier` for classification task."
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "metadata": {
161
+ "id": "Oav8sbgstWc-"
162
+ },
163
+ "outputs": [],
164
+ "source": [
165
+ "cfg = {\n",
166
+ " \"vocab_size\": 100,\n",
167
+ " \"hidden_size\": 32,\n",
168
+ " \"num_layers\": 3,\n",
169
+ " \"num_attention_heads\": 4,\n",
170
+ " \"intermediate_size\": 64,\n",
171
+ " \"activation\": tfm.utils.activations.gelu,\n",
172
+ " \"dropout_rate\": 0.1,\n",
173
+ " \"attention_dropout_rate\": 0.1,\n",
174
+ " \"max_sequence_length\": 16,\n",
175
+ " \"type_vocab_size\": 2,\n",
176
+ " \"initializer\": tf.keras.initializers.TruncatedNormal(stddev=0.02),\n",
177
+ "}\n",
178
+ "bert_encoder = nlp.networks.BertEncoder(**cfg)\n",
179
+ "\n",
180
+ "def build_classifier(bert_encoder):\n",
181
+ " return nlp.models.BertClassifier(bert_encoder, num_classes=2)\n",
182
+ "\n",
183
+ "canonical_classifier_model = build_classifier(bert_encoder)"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "metadata": {
189
+ "id": "Qe2UWI6_tsHo"
190
+ },
191
+ "source": [
192
+ "`canonical_classifier_model` can be trained using the training data. For details about how to train the model, please see the [Fine tuning bert](https://www.tensorflow.org/text/tutorials/fine_tune_bert) notebook. We skip the code that trains the model here.\n",
193
+ "\n",
194
+ "After training, we can apply the model to do prediction.\n"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "metadata": {
201
+ "id": "csED2d-Yt5h6"
202
+ },
203
+ "outputs": [],
204
+ "source": [
205
+ "def predict(model):\n",
206
+ " batch_size = 3\n",
207
+ " np.random.seed(0)\n",
208
+ " word_ids = np.random.randint(\n",
209
+ " cfg[\"vocab_size\"], size=(batch_size, cfg[\"max_sequence_length\"]))\n",
210
+ " mask = np.random.randint(2, size=(batch_size, cfg[\"max_sequence_length\"]))\n",
211
+ " type_ids = np.random.randint(\n",
212
+ " cfg[\"type_vocab_size\"], size=(batch_size, cfg[\"max_sequence_length\"]))\n",
213
+ " print(model([word_ids, mask, type_ids], training=False))\n",
214
+ "\n",
215
+ "predict(canonical_classifier_model)"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "metadata": {
221
+ "id": "PzKStEK9t_Pb"
222
+ },
223
+ "source": [
224
+ "## Customize BERT encoder\n",
225
+ "\n",
226
+ "One BERT encoder consists of an embedding network and multiple transformer blocks, and each transformer block contains an attention layer and a feedforward layer."
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "markdown",
231
+ "metadata": {
232
+ "id": "rmwQfhj6fmKz"
233
+ },
234
+ "source": [
235
+ "We provide easy ways to customize each of those components via (1)\n",
236
+ "[EncoderScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py) and (2) [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py)."
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "markdown",
241
+ "metadata": {
242
+ "id": "xsMgEVHAui11"
243
+ },
244
+ "source": [
245
+ "### Use EncoderScaffold\n",
246
+ "\n",
247
+ "`networks.EncoderScaffold` allows users to provide a custom embedding subnetwork\n",
248
+ " (which will replace the standard embedding logic) and/or a custom hidden layer class (which will replace the `Transformer` instantiation in the encoder)."
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "markdown",
253
+ "metadata": {
254
+ "id": "-JBabpa2AOz8"
255
+ },
256
+ "source": [
257
+ "#### Without Customization\n",
258
+ "\n",
259
+ "Without any customization, `networks.EncoderScaffold` behaves the same the canonical `networks.BertEncoder`.\n",
260
+ "\n",
261
+ "As shown in the following example, `networks.EncoderScaffold` can load `networks.BertEncoder`'s weights and output the same values:"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": null,
267
+ "metadata": {
268
+ "id": "ktNzKuVByZQf"
269
+ },
270
+ "outputs": [],
271
+ "source": [
272
+ "default_hidden_cfg = dict(\n",
273
+ " num_attention_heads=cfg[\"num_attention_heads\"],\n",
274
+ " intermediate_size=cfg[\"intermediate_size\"],\n",
275
+ " intermediate_activation=cfg[\"activation\"],\n",
276
+ " dropout_rate=cfg[\"dropout_rate\"],\n",
277
+ " attention_dropout_rate=cfg[\"attention_dropout_rate\"],\n",
278
+ " kernel_initializer=cfg[\"initializer\"],\n",
279
+ ")\n",
280
+ "default_embedding_cfg = dict(\n",
281
+ " vocab_size=cfg[\"vocab_size\"],\n",
282
+ " type_vocab_size=cfg[\"type_vocab_size\"],\n",
283
+ " hidden_size=cfg[\"hidden_size\"],\n",
284
+ " initializer=cfg[\"initializer\"],\n",
285
+ " dropout_rate=cfg[\"dropout_rate\"],\n",
286
+ " max_seq_length=cfg[\"max_sequence_length\"]\n",
287
+ ")\n",
288
+ "default_kwargs = dict(\n",
289
+ " hidden_cfg=default_hidden_cfg,\n",
290
+ " embedding_cfg=default_embedding_cfg,\n",
291
+ " num_hidden_instances=cfg[\"num_layers\"],\n",
292
+ " pooled_output_dim=cfg[\"hidden_size\"],\n",
293
+ " return_all_layer_outputs=True,\n",
294
+ " pooler_layer_initializer=cfg[\"initializer\"],\n",
295
+ ")\n",
296
+ "\n",
297
+ "encoder_scaffold = nlp.networks.EncoderScaffold(**default_kwargs)\n",
298
+ "classifier_model_from_encoder_scaffold = build_classifier(encoder_scaffold)\n",
299
+ "classifier_model_from_encoder_scaffold.set_weights(\n",
300
+ " canonical_classifier_model.get_weights())\n",
301
+ "predict(classifier_model_from_encoder_scaffold)"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "markdown",
306
+ "metadata": {
307
+ "id": "sMaUmLyIuwcs"
308
+ },
309
+ "source": [
310
+ "#### Customize Embedding\n",
311
+ "\n",
312
+ "Next, we show how to use a customized embedding network.\n",
313
+ "\n",
314
+ "We first build an embedding network that would replace the default network. This one will have 2 inputs (`mask` and `word_ids`) instead of 3, and won't use positional embeddings."
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "metadata": {
321
+ "id": "LTinnaG6vcsw"
322
+ },
323
+ "outputs": [],
324
+ "source": [
325
+ "word_ids = tf.keras.layers.Input(\n",
326
+ " shape=(cfg['max_sequence_length'],), dtype=tf.int32, name=\"input_word_ids\")\n",
327
+ "mask = tf.keras.layers.Input(\n",
328
+ " shape=(cfg['max_sequence_length'],), dtype=tf.int32, name=\"input_mask\")\n",
329
+ "embedding_layer = nlp.layers.OnDeviceEmbedding(\n",
330
+ " vocab_size=cfg['vocab_size'],\n",
331
+ " embedding_width=cfg['hidden_size'],\n",
332
+ " initializer=cfg[\"initializer\"],\n",
333
+ " name=\"word_embeddings\")\n",
334
+ "word_embeddings = embedding_layer(word_ids)\n",
335
+ "attention_mask = nlp.layers.SelfAttentionMask()([word_embeddings, mask])\n",
336
+ "new_embedding_network = tf.keras.Model([word_ids, mask],\n",
337
+ " [word_embeddings, attention_mask])"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "markdown",
342
+ "metadata": {
343
+ "id": "HN7_yu-6O3qI"
344
+ },
345
+ "source": [
346
+ "Inspecting `new_embedding_network`, we can see it takes two inputs:\n",
347
+ "`input_word_ids` and `input_mask`."
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": null,
353
+ "metadata": {
354
+ "id": "fO9zKFE4OpHp"
355
+ },
356
+ "outputs": [],
357
+ "source": [
358
+ "tf.keras.utils.plot_model(new_embedding_network, show_shapes=True, dpi=48)"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "markdown",
363
+ "metadata": {
364
+ "id": "9cOaGQHLv12W"
365
+ },
366
+ "source": [
367
+ "We can then build a new encoder using the above `new_embedding_network`."
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": null,
373
+ "metadata": {
374
+ "id": "mtFDMNf2vIl9"
375
+ },
376
+ "outputs": [],
377
+ "source": [
378
+ "kwargs = dict(default_kwargs)\n",
379
+ "\n",
380
+ "# Use new embedding network.\n",
381
+ "kwargs['embedding_cls'] = new_embedding_network\n",
382
+ "kwargs['embedding_data'] = embedding_layer.embeddings\n",
383
+ "\n",
384
+ "encoder_with_customized_embedding = nlp.networks.EncoderScaffold(**kwargs)\n",
385
+ "classifier_model = build_classifier(encoder_with_customized_embedding)\n",
386
+ "# ... Train the model ...\n",
387
+ "print(classifier_model.inputs)\n",
388
+ "\n",
389
+ "# Assert that there are only two inputs.\n",
390
+ "assert len(classifier_model.inputs) == 2"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "markdown",
395
+ "metadata": {
396
+ "id": "Z73ZQDtmwg9K"
397
+ },
398
+ "source": [
399
+ "#### Customized Transformer\n",
400
+ "\n",
401
+ "Users can also override the `hidden_cls` argument in `networks.EncoderScaffold`'s constructor employ a customized Transformer layer.\n",
402
+ "\n",
403
+ "See [the source of `nlp.layers.ReZeroTransformer`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/rezero_transformer.py) for how to implement a customized Transformer layer.\n",
404
+ "\n",
405
+ "The following is an example of using `nlp.layers.ReZeroTransformer`:\n"
406
+ ]
407
+ },
408
+ {
409
+ "cell_type": "code",
410
+ "execution_count": null,
411
+ "metadata": {
412
+ "id": "uAIarLZgw6pA"
413
+ },
414
+ "outputs": [],
415
+ "source": [
416
+ "kwargs = dict(default_kwargs)\n",
417
+ "\n",
418
+ "# Use ReZeroTransformer.\n",
419
+ "kwargs['hidden_cls'] = nlp.layers.ReZeroTransformer\n",
420
+ "\n",
421
+ "encoder_with_rezero_transformer = nlp.networks.EncoderScaffold(**kwargs)\n",
422
+ "classifier_model = build_classifier(encoder_with_rezero_transformer)\n",
423
+ "# ... Train the model ...\n",
424
+ "predict(classifier_model)\n",
425
+ "\n",
426
+ "# Assert that the variable `rezero_alpha` from ReZeroTransformer exists.\n",
427
+ "assert 'rezero_alpha' in ''.join([x.name for x in classifier_model.trainable_weights])"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "markdown",
432
+ "metadata": {
433
+ "id": "6PMHFdvnxvR0"
434
+ },
435
+ "source": [
436
+ "### Use `nlp.layers.TransformerScaffold`\n",
437
+ "\n",
438
+ "The above method of customizing the model requires rewriting the whole `nlp.layers.Transformer` layer, while sometimes you may only want to customize either attention layer or feedforward block. In this case, `nlp.layers.TransformerScaffold` can be used.\n"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "markdown",
443
+ "metadata": {
444
+ "id": "D6FejlgwyAy_"
445
+ },
446
+ "source": [
447
+ "#### Customize Attention Layer\n",
448
+ "\n",
449
+ "User can also override the `attention_cls` argument in `layers.TransformerScaffold`'s constructor to employ a customized Attention layer.\n",
450
+ "\n",
451
+ "See [the source of `nlp.layers.TalkingHeadsAttention`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py) for how to implement a customized `Attention` layer.\n",
452
+ "\n",
453
+ "Following is an example of using `nlp.layers.TalkingHeadsAttention`:"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "code",
458
+ "execution_count": null,
459
+ "metadata": {
460
+ "id": "nFrSMrZuyNeQ"
461
+ },
462
+ "outputs": [],
463
+ "source": [
464
+ "# Use TalkingHeadsAttention\n",
465
+ "hidden_cfg = dict(default_hidden_cfg)\n",
466
+ "hidden_cfg['attention_cls'] = nlp.layers.TalkingHeadsAttention\n",
467
+ "\n",
468
+ "kwargs = dict(default_kwargs)\n",
469
+ "kwargs['hidden_cls'] = nlp.layers.TransformerScaffold\n",
470
+ "kwargs['hidden_cfg'] = hidden_cfg\n",
471
+ "\n",
472
+ "encoder = nlp.networks.EncoderScaffold(**kwargs)\n",
473
+ "classifier_model = build_classifier(encoder)\n",
474
+ "# ... Train the model ...\n",
475
+ "predict(classifier_model)\n",
476
+ "\n",
477
+ "# Assert that the variable `pre_softmax_weight` from TalkingHeadsAttention exists.\n",
478
+ "assert 'pre_softmax_weight' in ''.join([x.name for x in classifier_model.trainable_weights])"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "metadata": {
485
+ "id": "tKkZ8spzYmpc"
486
+ },
487
+ "outputs": [],
488
+ "source": [
489
+ "tf.keras.utils.plot_model(encoder_with_rezero_transformer, show_shapes=True, dpi=48)"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "markdown",
494
+ "metadata": {
495
+ "id": "kuEJcTyByVvI"
496
+ },
497
+ "source": [
498
+ "#### Customize Feedforward Layer\n",
499
+ "\n",
500
+ "Similiarly, one could also customize the feedforward layer.\n",
501
+ "\n",
502
+ "See [the source of `nlp.layers.GatedFeedforward`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/gated_feedforward.py) for how to implement a customized feedforward layer.\n",
503
+ "\n",
504
+ "Following is an example of using `nlp.layers.GatedFeedforward`:"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "metadata": {
511
+ "id": "XAbKy_l4y_-i"
512
+ },
513
+ "outputs": [],
514
+ "source": [
515
+ "# Use GatedFeedforward\n",
516
+ "hidden_cfg = dict(default_hidden_cfg)\n",
517
+ "hidden_cfg['feedforward_cls'] = nlp.layers.GatedFeedforward\n",
518
+ "\n",
519
+ "kwargs = dict(default_kwargs)\n",
520
+ "kwargs['hidden_cls'] = nlp.layers.TransformerScaffold\n",
521
+ "kwargs['hidden_cfg'] = hidden_cfg\n",
522
+ "\n",
523
+ "encoder_with_gated_feedforward = nlp.networks.EncoderScaffold(**kwargs)\n",
524
+ "classifier_model = build_classifier(encoder_with_gated_feedforward)\n",
525
+ "# ... Train the model ...\n",
526
+ "predict(classifier_model)\n",
527
+ "\n",
528
+ "# Assert that the variable `gate` from GatedFeedforward exists.\n",
529
+ "assert 'gate' in ''.join([x.name for x in classifier_model.trainable_weights])"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "markdown",
534
+ "metadata": {
535
+ "id": "a_8NWUhkzeAq"
536
+ },
537
+ "source": [
538
+ "### Build a new Encoder\n",
539
+ "\n",
540
+ "Finally, you could also build a new encoder using building blocks in the modeling library.\n",
541
+ "\n",
542
+ "See [the source for `nlp.networks.AlbertEncoder`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/albert_encoder.py) as an example of how to do this. \n",
543
+ "\n",
544
+ "Here is an example using `nlp.networks.AlbertEncoder`:\n"
545
+ ]
546
+ },
547
+ {
548
+ "cell_type": "code",
549
+ "execution_count": null,
550
+ "metadata": {
551
+ "id": "xsiA3RzUzmUM"
552
+ },
553
+ "outputs": [],
554
+ "source": [
555
+ "albert_encoder = nlp.networks.AlbertEncoder(**cfg)\n",
556
+ "classifier_model = build_classifier(albert_encoder)\n",
557
+ "# ... Train the model ...\n",
558
+ "predict(classifier_model)"
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "markdown",
563
+ "metadata": {
564
+ "id": "MeidDfhlHKSO"
565
+ },
566
+ "source": [
567
+ "Inspecting the `albert_encoder`, we see it stacks the same `Transformer` layer multiple times (note the loop-back on the \"Transformer\" block below.."
568
+ ]
569
+ },
570
+ {
571
+ "cell_type": "code",
572
+ "execution_count": null,
573
+ "metadata": {
574
+ "id": "Uv_juT22HERW"
575
+ },
576
+ "outputs": [],
577
+ "source": [
578
+ "tf.keras.utils.plot_model(albert_encoder, show_shapes=True, dpi=48)"
579
+ ]
580
+ }
581
+ ],
582
+ "metadata": {
583
+ "colab": {
584
+ "collapsed_sections": [],
585
+ "name": "customize_encoder.ipynb",
586
+ "provenance": [],
587
+ "toc_visible": true
588
+ },
589
+ "kernelspec": {
590
+ "display_name": "Python 3",
591
+ "name": "python3"
592
+ }
593
+ },
594
+ "nbformat": 4,
595
+ "nbformat_minor": 0
596
+ }
models/docs/nlp/decoding_api.ipynb ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "vXLA5InzXydn"
7
+ },
8
+ "source": [
9
+ "##### Copyright 2021 The TensorFlow Authors."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "cellView": "form",
17
+ "id": "RuRlpLL-X0R_"
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
22
+ "# you may not use this file except in compliance with the License.\n",
23
+ "# You may obtain a copy of the License at\n",
24
+ "#\n",
25
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
26
+ "#\n",
27
+ "# Unless required by applicable law or agreed to in writing, software\n",
28
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
29
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
30
+ "# See the License for the specific language governing permissions and\n",
31
+ "# limitations under the License."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {
37
+ "id": "2X-XaMSVcLua"
38
+ },
39
+ "source": [
40
+ "# Decoding API"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "metadata": {
46
+ "id": "hYEwGTeCXnnX"
47
+ },
48
+ "source": [
49
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
50
+ " \u003ctd\u003e\n",
51
+ " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/decoding_api\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
52
+ " \u003c/td\u003e\n",
53
+ " \u003ctd\u003e\n",
54
+ " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/decoding_api.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
55
+ " \u003c/td\u003e\n",
56
+ " \u003ctd\u003e\n",
57
+ " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/decoding_api.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
58
+ " \u003c/td\u003e\n",
59
+ " \u003ctd\u003e\n",
60
+ " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/decoding_api.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
61
+ " \u003c/td\u003e\n",
62
+ "\u003c/table\u003e"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "metadata": {
68
+ "id": "fsACVQpVSifi"
69
+ },
70
+ "source": [
71
+ "### Install the TensorFlow Model Garden pip package\n",
72
+ "\n",
73
+ "* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
74
+ "which is the nightly Model Garden package created daily automatically.\n",
75
+ "* pip will install all models and dependencies automatically."
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {
82
+ "id": "G4BhAu01HZcM"
83
+ },
84
+ "outputs": [],
85
+ "source": [
86
+ "!pip uninstall -y opencv-python"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "metadata": {
93
+ "id": "2j-xhrsVQOQT"
94
+ },
95
+ "outputs": [],
96
+ "source": [
97
+ "!pip install tf-models-official"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "metadata": {
104
+ "id": "BjP7zwxmskpY"
105
+ },
106
+ "outputs": [],
107
+ "source": [
108
+ "import os\n",
109
+ "\n",
110
+ "import numpy as np\n",
111
+ "import matplotlib.pyplot as plt\n",
112
+ "\n",
113
+ "import tensorflow as tf\n",
114
+ "\n",
115
+ "from tensorflow_models import nlp"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {
122
+ "id": "T92ccAzlnGqh"
123
+ },
124
+ "outputs": [],
125
+ "source": [
126
+ "def length_norm(length, dtype):\n",
127
+ " \"\"\"Return length normalization factor.\"\"\"\n",
128
+ " return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "markdown",
133
+ "metadata": {
134
+ "id": "0AWgyo-IQ5sP"
135
+ },
136
+ "source": [
137
+ "## Overview\n",
138
+ "\n",
139
+ "This API provides an interface to experiment with different decoding strategies used for auto-regressive models.\n",
140
+ "\n",
141
+ "1. The following sampling strategies are provided in sampling_module.py, which inherits from the base Decoding class:\n",
142
+ " * [top_p](https://arxiv.org/abs/1904.09751) : [github](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/sampling_module.py#L65) \n",
143
+ "\n",
144
+ " This implementation chooses the most probable logits with cumulative probabilities up to top_p.\n",
145
+ "\n",
146
+ " * [top_k](https://arxiv.org/pdf/1805.04833.pdf) : [github](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/sampling_module.py#L48)\n",
147
+ "\n",
148
+ " At each timestep, this implementation samples from top-k logits based on their probability distribution\n",
149
+ "\n",
150
+ " * Greedy : [github](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/sampling_module.py#L26)\n",
151
+ "\n",
152
+ " This implementation returns the top logits based on probabilities.\n",
153
+ "\n",
154
+ "2. Beam search is provided in beam_search.py. [github](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/beam_search.py)\n",
155
+ "\n",
156
+ " This implementation reduces the risk of missing hidden high probability logits by keeping the most likely num_beams of logits at each time step and eventually choosing the logits that has the overall highest probability."
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "markdown",
161
+ "metadata": {
162
+ "id": "MfOj7oaBRQnS"
163
+ },
164
+ "source": [
165
+ "## Initialize Sampling Module in TF-NLP.\n",
166
+ "\n",
167
+ "\n",
168
+ "\u003e **symbols_to_logits_fn** : This is a closure implemented by the users of the API. The input to this closure will be \n",
169
+ "```\n",
170
+ "Args:\n",
171
+ " 1] ids [batch_size, .. (index + 1 or 1 if padded_decode is True)],\n",
172
+ " 2] index [scalar] : current decoded step,\n",
173
+ " 3] cache [nested dictionary of tensors].\n",
174
+ "Returns:\n",
175
+ " 1] tensor for next-step logits [batch_size, vocab]\n",
176
+ " 2] the updated_cache [nested dictionary of tensors].\n",
177
+ "```\n",
178
+ "This closure calls the model to predict the logits for the 'index+1' step. The cache is used for faster decoding.\n",
179
+ "Here is a [reference](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/beam_search_test.py#L88) implementation for the above closure.\n",
180
+ "\n",
181
+ "\n",
182
+ "\u003e **length_normalization_fn** : Closure for returning length normalization parameter.\n",
183
+ "```\n",
184
+ "Args: \n",
185
+ " 1] length : scalar for decoded step index.\n",
186
+ " 2] dtype : data-type of output tensor\n",
187
+ "Returns:\n",
188
+ " 1] value of length normalization factor.\n",
189
+ "Example :\n",
190
+ " def _length_norm(length, dtype):\n",
191
+ " return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)\n",
192
+ "```\n",
193
+ "\n",
194
+ "\u003e **vocab_size** : Output vocabulary size.\n",
195
+ "\n",
196
+ "\u003e **max_decode_length** : Scalar for total number of decoding steps.\n",
197
+ "\n",
198
+ "\u003e **eos_id** : Decoding will stop if all output decoded ids in the batch have this ID.\n",
199
+ "\n",
200
+ "\u003e **padded_decode** : Set this to True if running on TPU. Tensors are padded to max_decoding_length if this is True.\n",
201
+ "\n",
202
+ "\u003e **top_k** : top_k is enabled if this value is \u003e 1.\n",
203
+ "\n",
204
+ "\u003e **top_p** : top_p is enabled if this value is \u003e 0 and \u003c 1.0\n",
205
+ "\n",
206
+ "\u003e **sampling_temperature** : This is used to re-estimate the softmax output. Temperature skews the distribution towards high-probability tokens and lowers the mass in the tail distribution. Value has to be positive. Low temperature is equivalent to greedy and makes the distribution sharper, while high temperature makes it flatter.\n",
207
+ "\n",
208
+ "\u003e **enable_greedy** : By default, this is true and greedy decoding is enabled.\n"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "markdown",
213
+ "metadata": {
214
+ "id": "lV1RRp6ihnGX"
215
+ },
216
+ "source": [
217
+ "## Initialize the Model Hyper-parameters"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": null,
223
+ "metadata": {
224
+ "id": "eTsGp2gaKLdE"
225
+ },
226
+ "outputs": [],
227
+ "source": [
228
+ "params = {\n",
229
+ " 'num_heads': 2,\n",
230
+ " 'num_layers': 2,\n",
231
+ " 'batch_size': 2,\n",
232
+ " 'n_dims': 256,\n",
233
+ " 'max_decode_length': 4}"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "markdown",
238
+ "metadata": {
239
+ "id": "CYXkoplAij01"
240
+ },
241
+ "source": [
242
+ "## Initialize cache. "
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "markdown",
247
+ "metadata": {
248
+ "id": "UGvmd0_dRFYI"
249
+ },
250
+ "source": [
251
+ "In auto-regressive architectures like Transformer based [Encoder-Decoder](https://arxiv.org/abs/1706.03762) models, \n",
252
+ "Cache is used for fast sequential decoding.\n",
253
+ "It is a nested dictionary storing pre-computed hidden-states (key and values in the self-attention blocks and the cross-attention blocks) for every layer."
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "metadata": {
260
+ "id": "D6kfZOOKgkm1"
261
+ },
262
+ "outputs": [],
263
+ "source": [
264
+ "cache = {\n",
265
+ " 'layer_%d' % layer: {\n",
266
+ " 'k': tf.zeros(\n",
267
+ " shape=[params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims'] // params['num_heads']],\n",
268
+ " dtype=tf.float32),\n",
269
+ " 'v': tf.zeros(\n",
270
+ " shape=[params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims'] // params['num_heads']],\n",
271
+ " dtype=tf.float32)\n",
272
+ " } for layer in range(params['num_layers'])\n",
273
+ " }\n",
274
+ "print(\"cache value shape for layer 1 :\", cache['layer_1']['k'].shape)"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "markdown",
279
+ "metadata": {
280
+ "id": "syl7I5nURPgW"
281
+ },
282
+ "source": [
283
+ "### Create model_fn\n",
284
+ " In practice, this will be replaced by an actual model implementation such as [here](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer.py#L236)\n",
285
+ "```\n",
286
+ "Args:\n",
287
+ "i : Step that is being decoded.\n",
288
+ "Returns:\n",
289
+ " logit probabilities of size [batch_size, 1, vocab_size]\n",
290
+ "```\n"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": null,
296
+ "metadata": {
297
+ "id": "AhzSkRisRdB6"
298
+ },
299
+ "outputs": [],
300
+ "source": [
301
+ "probabilities = tf.constant([[[0.3, 0.4, 0.3], [0.3, 0.3, 0.4],\n",
302
+ " [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],\n",
303
+ " [[0.2, 0.5, 0.3], [0.2, 0.7, 0.1],\n",
304
+ " [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]])\n",
305
+ "def model_fn(i):\n",
306
+ " return probabilities[:, i, :]"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "metadata": {
313
+ "id": "FAJ4CpbfVdjr"
314
+ },
315
+ "outputs": [],
316
+ "source": [
317
+ "def _symbols_to_logits_fn():\n",
318
+ " \"\"\"Calculates logits of the next tokens.\"\"\"\n",
319
+ " def symbols_to_logits_fn(ids, i, temp_cache):\n",
320
+ " del ids\n",
321
+ " logits = tf.cast(tf.math.log(model_fn(i)), tf.float32)\n",
322
+ " return logits, temp_cache\n",
323
+ " return symbols_to_logits_fn"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "metadata": {
329
+ "id": "R_tV3jyWVL47"
330
+ },
331
+ "source": [
332
+ "## Greedy \n",
333
+ "Greedy decoding selects the token id with the highest probability as its next id: $id_t = argmax_{w}P(id | id_{1:t-1})$ at each timestep $t$. The following sketch shows greedy decoding. "
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "metadata": {
340
+ "id": "aGt9idSkVQEJ"
341
+ },
342
+ "outputs": [],
343
+ "source": [
344
+ "greedy_obj = sampling_module.SamplingModule(\n",
345
+ " length_normalization_fn=None,\n",
346
+ " dtype=tf.float32,\n",
347
+ " symbols_to_logits_fn=_symbols_to_logits_fn(),\n",
348
+ " vocab_size=3,\n",
349
+ " max_decode_length=params['max_decode_length'],\n",
350
+ " eos_id=10,\n",
351
+ " padded_decode=False)\n",
352
+ "ids, _ = greedy_obj.generate(\n",
353
+ " initial_ids=tf.constant([9, 1]), initial_cache=cache)\n",
354
+ "print(\"Greedy Decoded Ids:\", ids)"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "markdown",
359
+ "metadata": {
360
+ "id": "s4pTTsQXVz5O"
361
+ },
362
+ "source": [
363
+ "## top_k sampling\n",
364
+ "In *Top-K* sampling, the *K* most likely next token ids are filtered and the probability mass is redistributed among only those *K* ids. "
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": null,
370
+ "metadata": {
371
+ "id": "pCLWIn6GV5_G"
372
+ },
373
+ "outputs": [],
374
+ "source": [
375
+ "top_k_obj = sampling_module.SamplingModule(\n",
376
+ " length_normalization_fn=length_norm,\n",
377
+ " dtype=tf.float32,\n",
378
+ " symbols_to_logits_fn=_symbols_to_logits_fn(),\n",
379
+ " vocab_size=3,\n",
380
+ " max_decode_length=params['max_decode_length'],\n",
381
+ " eos_id=10,\n",
382
+ " sample_temperature=tf.constant(1.0),\n",
383
+ " top_k=tf.constant(3),\n",
384
+ " padded_decode=False,\n",
385
+ " enable_greedy=False)\n",
386
+ "ids, _ = top_k_obj.generate(\n",
387
+ " initial_ids=tf.constant([9, 1]), initial_cache=cache)\n",
388
+ "print(\"top-k sampled Ids:\", ids)"
389
+ ]
390
+ },
391
+ {
392
+ "cell_type": "markdown",
393
+ "metadata": {
394
+ "id": "Jp3G-eE_WI4Y"
395
+ },
396
+ "source": [
397
+ "## top_p sampling\n",
398
+ "Instead of sampling only from the most likely *K* token ids, in *Top-p* sampling chooses from the smallest possible set of ids whose cumulative probability exceeds the probability *p*."
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": null,
404
+ "metadata": {
405
+ "id": "rEGdIWcuWILO"
406
+ },
407
+ "outputs": [],
408
+ "source": [
409
+ "top_p_obj = sampling_module.SamplingModule(\n",
410
+ " length_normalization_fn=length_norm,\n",
411
+ " dtype=tf.float32,\n",
412
+ " symbols_to_logits_fn=_symbols_to_logits_fn(),\n",
413
+ " vocab_size=3,\n",
414
+ " max_decode_length=params['max_decode_length'],\n",
415
+ " eos_id=10,\n",
416
+ " sample_temperature=tf.constant(1.0),\n",
417
+ " top_p=tf.constant(0.9),\n",
418
+ " padded_decode=False,\n",
419
+ " enable_greedy=False)\n",
420
+ "ids, _ = top_p_obj.generate(\n",
421
+ " initial_ids=tf.constant([9, 1]), initial_cache=cache)\n",
422
+ "print(\"top-p sampled Ids:\", ids)"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "markdown",
427
+ "metadata": {
428
+ "id": "2hcuyJ2VWjDz"
429
+ },
430
+ "source": [
431
+ "## Beam search decoding\n",
432
+ "Beam search reduces the risk of missing hidden high probability token ids by keeping the most likely num_beams of hypotheses at each time step and eventually choosing the hypothesis that has the overall highest probability. "
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": null,
438
+ "metadata": {
439
+ "id": "cJ3WzvSrWmSA"
440
+ },
441
+ "outputs": [],
442
+ "source": [
443
+ "beam_size = 2\n",
444
+ "params['batch_size'] = 1\n",
445
+ "beam_cache = {\n",
446
+ " 'layer_%d' % layer: {\n",
447
+ " 'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32),\n",
448
+ " 'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32)\n",
449
+ " } for layer in range(params['num_layers'])\n",
450
+ " }\n",
451
+ "print(\"cache key shape for layer 1 :\", beam_cache['layer_1']['k'].shape)\n",
452
+ "ids, _ = beam_search.sequence_beam_search(\n",
453
+ " symbols_to_logits_fn=_symbols_to_logits_fn(),\n",
454
+ " initial_ids=tf.constant([9], tf.int32),\n",
455
+ " initial_cache=beam_cache,\n",
456
+ " vocab_size=3,\n",
457
+ " beam_size=beam_size,\n",
458
+ " alpha=0.6,\n",
459
+ " max_decode_length=params['max_decode_length'],\n",
460
+ " eos_id=10,\n",
461
+ " padded_decode=False,\n",
462
+ " dtype=tf.float32)\n",
463
+ "print(\"Beam search ids:\", ids)"
464
+ ]
465
+ }
466
+ ],
467
+ "metadata": {
468
+ "accelerator": "GPU",
469
+ "colab": {
470
+ "collapsed_sections": [],
471
+ "name": "decoding_api_in_tf_nlp.ipynb",
472
+ "provenance": [],
473
+ "toc_visible": true
474
+ },
475
+ "kernelspec": {
476
+ "display_name": "Python 3",
477
+ "name": "python3"
478
+ }
479
+ },
480
+ "nbformat": 4,
481
+ "nbformat_minor": 0
482
+ }
models/docs/nlp/fine_tune_bert.ipynb ADDED
@@ -0,0 +1,1550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "vXLA5InzXydn"
7
+ },
8
+ "source": [
9
+ "##### Copyright 2019 The TensorFlow Authors."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "cellView": "form",
17
+ "id": "RuRlpLL-X0R_"
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
22
+ "# you may not use this file except in compliance with the License.\n",
23
+ "# You may obtain a copy of the License at\n",
24
+ "#\n",
25
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
26
+ "#\n",
27
+ "# Unless required by applicable law or agreed to in writing, software\n",
28
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
29
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
30
+ "# See the License for the specific language governing permissions and\n",
31
+ "# limitations under the License."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {
37
+ "id": "1mLJmVotXs64"
38
+ },
39
+ "source": [
40
+ "# Fine-tuning a BERT model"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "metadata": {
46
+ "id": "hYEwGTeCXnnX"
47
+ },
48
+ "source": [
49
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
50
+ " \u003ctd\u003e\n",
51
+ " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/fine_tune_bert\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
52
+ " \u003c/td\u003e\n",
53
+ " \u003ctd\u003e\n",
54
+ " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
55
+ " \u003c/td\u003e\n",
56
+ " \u003ctd\u003e\n",
57
+ " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
58
+ " \u003c/td\u003e\n",
59
+ " \u003ctd\u003e\n",
60
+ " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
61
+ " \u003c/td\u003e\n",
62
+ " \u003ctd\u003e\n",
63
+ " \u003ca href=\"https://tfhub.dev/google/collections/bert\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub model\u003c/a\u003e\n",
64
+ " \u003c/td\u003e\n",
65
+ "\u003c/table\u003e"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "metadata": {
71
+ "id": "YN2ACivEPxgD"
72
+ },
73
+ "source": [
74
+ "This tutorial demonstrates how to fine-tune a [Bidirectional Encoder Representations from Transformers (BERT)](https://arxiv.org/abs/1810.04805) (Devlin et al., 2018) model using [TensorFlow Model Garden](https://github.com/tensorflow/models).\n",
75
+ "\n",
76
+ "You can also find the pre-trained BERT model used in this tutorial on [TensorFlow Hub (TF Hub)](https://tensorflow.org/hub). For concrete examples of how to use the models from TF Hub, refer to the [Solve Glue tasks using BERT](https://www.tensorflow.org/text/tutorials/bert_glue) tutorial. If you're just trying to fine-tune a model, the TF Hub tutorial is a good starting point.\n",
77
+ "\n",
78
+ "On the other hand, if you're interested in deeper customization, follow this tutorial. It shows how to do a lot of things manually, so you can learn how you can customize the workflow from data preprocessing to training, exporting and saving the model."
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "metadata": {
84
+ "id": "s2d9S2CSSO1z"
85
+ },
86
+ "source": [
87
+ "## Setup"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "metadata": {
93
+ "id": "69de3375e32a"
94
+ },
95
+ "source": [
96
+ "### Install pip packages"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "markdown",
101
+ "metadata": {
102
+ "id": "fsACVQpVSifi"
103
+ },
104
+ "source": [
105
+ "Start by installing the TensorFlow Text and Model Garden pip packages.\n",
106
+ "\n",
107
+ "* `tf-models-official` is the TensorFlow Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` GitHub repo. To include the latest changes, you may install `tf-models-nightly`, which is the nightly Model Garden package created daily automatically.\n",
108
+ "* pip will install all models and dependencies automatically."
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "metadata": {
115
+ "id": "sE6XUxLOf1s-"
116
+ },
117
+ "outputs": [],
118
+ "source": [
119
+ "!pip install -q opencv-python"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "metadata": {
126
+ "id": "yic2y7_o-BCC"
127
+ },
128
+ "outputs": [],
129
+ "source": [
130
+ "!pip install -q -U \"tensorflow-text==2.11.*\""
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {
137
+ "id": "NvNr2svBM-p3"
138
+ },
139
+ "outputs": [],
140
+ "source": [
141
+ "!pip install -q tf-models-official"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "markdown",
146
+ "metadata": {
147
+ "id": "U-7qPCjWUAyy"
148
+ },
149
+ "source": [
150
+ "### Import libraries"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "metadata": {
157
+ "id": "lXsXev5MNr20"
158
+ },
159
+ "outputs": [],
160
+ "source": [
161
+ "import os\n",
162
+ "\n",
163
+ "import numpy as np\n",
164
+ "import matplotlib.pyplot as plt\n",
165
+ "\n",
166
+ "import tensorflow as tf\n",
167
+ "import tensorflow_models as tfm\n",
168
+ "import tensorflow_hub as hub\n",
169
+ "import tensorflow_datasets as tfds\n",
170
+ "tfds.disable_progress_bar()"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "markdown",
175
+ "metadata": {
176
+ "id": "mbanlzTvJBsz"
177
+ },
178
+ "source": [
179
+ "### Resources"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "markdown",
184
+ "metadata": {
185
+ "id": "PpW0x8TpR8DT"
186
+ },
187
+ "source": [
188
+ "The following directory contains the BERT model's configuration, vocabulary, and a pre-trained checkpoint used in this tutorial:"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "metadata": {
195
+ "id": "vzRHOLciR8eq"
196
+ },
197
+ "outputs": [],
198
+ "source": [
199
+ "gs_folder_bert = \"gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12\"\n",
200
+ "tf.io.gfile.listdir(gs_folder_bert)"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "markdown",
205
+ "metadata": {
206
+ "id": "Qv6abtRvH4xO"
207
+ },
208
+ "source": [
209
+ "## Load and preprocess the dataset\n",
210
+ "\n",
211
+ "This example uses the GLUE (General Language Understanding Evaluation) MRPC (Microsoft Research Paraphrase Corpus) [dataset from TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc).\n",
212
+ "\n",
213
+ "This dataset is not set up such that it can be directly fed into the BERT model. The following section handles the necessary preprocessing."
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "markdown",
218
+ "metadata": {
219
+ "id": "28DvUhC1YUiB"
220
+ },
221
+ "source": [
222
+ "### Get the dataset from TensorFlow Datasets\n",
223
+ "\n",
224
+ "The GLUE MRPC (Dolan and Brockett, 2005) dataset is a corpus of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent. It has the following attributes:\n",
225
+ "\n",
226
+ "* Number of labels: 2\n",
227
+ "* Size of training dataset: 3668\n",
228
+ "* Size of evaluation dataset: 408\n",
229
+ "* Maximum sequence length of training and evaluation dataset: 128\n",
230
+ "\n",
231
+ "Begin by loading the MRPC dataset from TFDS:"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": null,
237
+ "metadata": {
238
+ "id": "Ijikx5OsH9AT"
239
+ },
240
+ "outputs": [],
241
+ "source": [
242
+ "batch_size=32\n",
243
+ "glue, info = tfds.load('glue/mrpc',\n",
244
+ " with_info=True,\n",
245
+ " batch_size=32)"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "metadata": {
252
+ "id": "QcMTJU4N7VX-"
253
+ },
254
+ "outputs": [],
255
+ "source": [
256
+ "glue"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "markdown",
261
+ "metadata": {
262
+ "id": "ZgBg2r2nYT-K"
263
+ },
264
+ "source": [
265
+ "The `info` object describes the dataset and its features:"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "metadata": {
272
+ "id": "IQrHxv7W7jH5"
273
+ },
274
+ "outputs": [],
275
+ "source": [
276
+ "info.features"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "markdown",
281
+ "metadata": {
282
+ "id": "vhsVWYNxazz5"
283
+ },
284
+ "source": [
285
+ "The two classes are:"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": null,
291
+ "metadata": {
292
+ "id": "n0gfc_VTayfQ"
293
+ },
294
+ "outputs": [],
295
+ "source": [
296
+ "info.features['label'].names"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "markdown",
301
+ "metadata": {
302
+ "id": "38zJcap6xkbC"
303
+ },
304
+ "source": [
305
+ "Here is one example from the training set:"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": null,
311
+ "metadata": {
312
+ "id": "xON_i6SkwApW"
313
+ },
314
+ "outputs": [],
315
+ "source": [
316
+ "example_batch = next(iter(glue['train']))\n",
317
+ "\n",
318
+ "for key, value in example_batch.items():\n",
319
+ " print(f\"{key:9s}: {value[0].numpy()}\")"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "markdown",
324
+ "metadata": {
325
+ "id": "R9vEWgKA4SxV"
326
+ },
327
+ "source": [
328
+ "### Preprocess the data\n",
329
+ "\n",
330
+ "The keys `\"sentence1\"` and `\"sentence2\"` in the GLUE MRPC dataset contain two input sentences for each example.\n",
331
+ "\n",
332
+ "Because the BERT model from the Model Garden doesn't take raw text as input, two things need to happen first:\n",
333
+ "\n",
334
+ "1. The text needs to be _tokenized_ (split into word pieces) and converted to _indices_.\n",
335
+ "2. Then, the _indices_ need to be packed into the format that the model expects."
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "markdown",
340
+ "metadata": {
341
+ "id": "9fbTyfJpNr7x"
342
+ },
343
+ "source": [
344
+ "#### The BERT tokenizer"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "markdown",
349
+ "metadata": {
350
+ "id": "wqeN54S61ZKQ"
351
+ },
352
+ "source": [
353
+ "To fine tune a pre-trained language model from the Model Garden, such as BERT, you need to make sure that you're using exactly the same tokenization, vocabulary, and index mapping as used during training.\n",
354
+ "\n",
355
+ "The following code rebuilds the tokenizer that was used by the base model using the Model Garden's `tfm.nlp.layers.FastWordpieceBertTokenizer` layer:"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "metadata": {
362
+ "id": "-DK4q5wEBmlB"
363
+ },
364
+ "outputs": [],
365
+ "source": [
366
+ "tokenizer = tfm.nlp.layers.FastWordpieceBertTokenizer(\n",
367
+ " vocab_file=os.path.join(gs_folder_bert, \"vocab.txt\"),\n",
368
+ " lower_case=True)"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "markdown",
373
+ "metadata": {
374
+ "id": "zYHDSquU2lDU"
375
+ },
376
+ "source": [
377
+ "Let's tokenize a test sentence:"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "execution_count": null,
383
+ "metadata": {
384
+ "id": "L_OfOYPg853R"
385
+ },
386
+ "outputs": [],
387
+ "source": [
388
+ "tokens = tokenizer(tf.constant([\"Hello TensorFlow!\"]))\n",
389
+ "tokens"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "markdown",
394
+ "metadata": {
395
+ "id": "MfjaaMYy5Gt8"
396
+ },
397
+ "source": [
398
+ "Learn more about the tokenization process in the [Subword tokenization](https://www.tensorflow.org/text/guide/subwords_tokenizer) and [Tokenizing with TensorFlow Text](https://www.tensorflow.org/text/guide/tokenizers) guides."
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "markdown",
403
+ "metadata": {
404
+ "id": "wd1b09OO5GJl"
405
+ },
406
+ "source": [
407
+ "#### Pack the inputs"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "markdown",
412
+ "metadata": {
413
+ "id": "62UTWLQd9-LB"
414
+ },
415
+ "source": [
416
+ "TensorFlow Model Garden's BERT model doesn't just take the tokenized strings as input. It also expects these to be packed into a particular format. `tfm.nlp.layers.BertPackInputs` layer can handle the conversion from _a list of tokenized sentences_ to the input format expected by the Model Garden's BERT model.\n",
417
+ "\n",
418
+ "`tfm.nlp.layers.BertPackInputs` packs the two input sentences (per example in the MRCP dataset) concatenated together. This input is expected to start with a `[CLS]` \"This is a classification problem\" token, and each sentence should end with a `[SEP]` \"Separator\" token.\n",
419
+ "\n",
420
+ "Therefore, the `tfm.nlp.layers.BertPackInputs` layer's constructor takes the `tokenizer`'s special tokens as an argument. It also needs to know the indices of the tokenizer's special tokens."
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": null,
426
+ "metadata": {
427
+ "id": "5iroDlrFDRcF"
428
+ },
429
+ "outputs": [],
430
+ "source": [
431
+ "special = tokenizer.get_special_tokens_dict()\n",
432
+ "special"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": null,
438
+ "metadata": {
439
+ "id": "b71HarkuG92H"
440
+ },
441
+ "outputs": [],
442
+ "source": [
443
+ "max_seq_length = 128\n",
444
+ "\n",
445
+ "packer = tfm.nlp.layers.BertPackInputs(\n",
446
+ " seq_length=max_seq_length,\n",
447
+ " special_tokens_dict = tokenizer.get_special_tokens_dict())"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "markdown",
452
+ "metadata": {
453
+ "id": "CZlSZbYd6liN"
454
+ },
455
+ "source": [
456
+ "The `packer` takes a list of tokenized sentences as input. For example:"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": null,
462
+ "metadata": {
463
+ "id": "27dU_VkJHc9S"
464
+ },
465
+ "outputs": [],
466
+ "source": [
467
+ "sentences1 = [\"hello tensorflow\"]\n",
468
+ "tok1 = tokenizer(sentences1)\n",
469
+ "tok1"
470
+ ]
471
+ },
472
+ {
473
+ "cell_type": "code",
474
+ "execution_count": null,
475
+ "metadata": {
476
+ "id": "LURHmNOSHnWN"
477
+ },
478
+ "outputs": [],
479
+ "source": [
480
+ "sentences2 = [\"goodbye tensorflow\"]\n",
481
+ "tok2 = tokenizer(sentences2)\n",
482
+ "tok2"
483
+ ]
484
+ },
485
+ {
486
+ "cell_type": "markdown",
487
+ "metadata": {
488
+ "id": "r8bvB8gI8BqP"
489
+ },
490
+ "source": [
491
+ "Then, it returns a dictionary containing three outputs:\n",
492
+ "\n",
493
+ "- `input_word_ids`: The tokenized sentences packed together.\n",
494
+ "- `input_mask`: The mask indicating which locations are valid in the other outputs.\n",
495
+ "- `input_type_ids`: Indicating which sentence each token belongs to."
496
+ ]
497
+ },
498
+ {
499
+ "cell_type": "code",
500
+ "execution_count": null,
501
+ "metadata": {
502
+ "id": "YsIDTOMJHrUQ"
503
+ },
504
+ "outputs": [],
505
+ "source": [
506
+ "packed = packer([tok1, tok2])\n",
507
+ "\n",
508
+ "for key, tensor in packed.items():\n",
509
+ " print(f\"{key:15s}: {tensor[:, :12]}\")"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "markdown",
514
+ "metadata": {
515
+ "id": "red4tRcq74Qc"
516
+ },
517
+ "source": [
518
+ "#### Put it all together\n",
519
+ "\n",
520
+ "Combine these two parts into a `keras.layers.Layer` that can be attached to your model:"
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "execution_count": null,
526
+ "metadata": {
527
+ "id": "9Qtz-tv-6nz6"
528
+ },
529
+ "outputs": [],
530
+ "source": [
531
+ "class BertInputProcessor(tf.keras.layers.Layer):\n",
532
+ " def __init__(self, tokenizer, packer):\n",
533
+ " super().__init__()\n",
534
+ " self.tokenizer = tokenizer\n",
535
+ " self.packer = packer\n",
536
+ "\n",
537
+ " def call(self, inputs):\n",
538
+ " tok1 = self.tokenizer(inputs['sentence1'])\n",
539
+ " tok2 = self.tokenizer(inputs['sentence2'])\n",
540
+ "\n",
541
+ " packed = self.packer([tok1, tok2])\n",
542
+ "\n",
543
+ " if 'label' in inputs:\n",
544
+ " return packed, inputs['label']\n",
545
+ " else:\n",
546
+ " return packed"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "markdown",
551
+ "metadata": {
552
+ "id": "rdy9wp499btU"
553
+ },
554
+ "source": [
555
+ "But for now just apply it to the dataset using `Dataset.map`, since the dataset you loaded from TFDS is a `tf.data.Dataset` object:"
556
+ ]
557
+ },
558
+ {
559
+ "cell_type": "code",
560
+ "execution_count": null,
561
+ "metadata": {
562
+ "id": "qmyh76AL7VAs"
563
+ },
564
+ "outputs": [],
565
+ "source": [
566
+ "bert_inputs_processor = BertInputProcessor(tokenizer, packer)"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "code",
571
+ "execution_count": null,
572
+ "metadata": {
573
+ "id": "B8SSCtDe9MCk"
574
+ },
575
+ "outputs": [],
576
+ "source": [
577
+ "glue_train = glue['train'].map(bert_inputs_processor).prefetch(1)"
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "markdown",
582
+ "metadata": {
583
+ "id": "KXpiDosO9rkY"
584
+ },
585
+ "source": [
586
+ "Here is an example batch from the processed dataset:"
587
+ ]
588
+ },
589
+ {
590
+ "cell_type": "code",
591
+ "execution_count": null,
592
+ "metadata": {
593
+ "id": "ffNvDE6t9rP-"
594
+ },
595
+ "outputs": [],
596
+ "source": [
597
+ "example_inputs, example_labels = next(iter(glue_train))"
598
+ ]
599
+ },
600
+ {
601
+ "cell_type": "code",
602
+ "execution_count": null,
603
+ "metadata": {
604
+ "id": "5sxtTuUi-bXt"
605
+ },
606
+ "outputs": [],
607
+ "source": [
608
+ "example_inputs"
609
+ ]
610
+ },
611
+ {
612
+ "cell_type": "code",
613
+ "execution_count": null,
614
+ "metadata": {
615
+ "id": "wP4z_-9a-dFk"
616
+ },
617
+ "outputs": [],
618
+ "source": [
619
+ "example_labels"
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "code",
624
+ "execution_count": null,
625
+ "metadata": {
626
+ "id": "jyjTdGpFhO_1"
627
+ },
628
+ "outputs": [],
629
+ "source": [
630
+ "for key, value in example_inputs.items():\n",
631
+ " print(f'{key:15s} shape: {value.shape}')\n",
632
+ "\n",
633
+ "print(f'{\"labels\":15s} shape: {example_labels.shape}')"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "markdown",
638
+ "metadata": {
639
+ "id": "mkGHN_FK-50U"
640
+ },
641
+ "source": [
642
+ "The `input_word_ids` contain the token IDs:"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "code",
647
+ "execution_count": null,
648
+ "metadata": {
649
+ "id": "eGL1_ktWLcgF"
650
+ },
651
+ "outputs": [],
652
+ "source": [
653
+ "plt.pcolormesh(example_inputs['input_word_ids'])"
654
+ ]
655
+ },
656
+ {
657
+ "cell_type": "markdown",
658
+ "metadata": {
659
+ "id": "ulNZ4U96-8JZ"
660
+ },
661
+ "source": [
662
+ "The mask allows the model to cleanly differentiate between the content and the padding. The mask has the same shape as the `input_word_ids`, and contains a `1` anywhere the `input_word_ids` is not padding."
663
+ ]
664
+ },
665
+ {
666
+ "cell_type": "code",
667
+ "execution_count": null,
668
+ "metadata": {
669
+ "id": "zB7mW7DGK3rW"
670
+ },
671
+ "outputs": [],
672
+ "source": [
673
+ "plt.pcolormesh(example_inputs['input_mask'])"
674
+ ]
675
+ },
676
+ {
677
+ "cell_type": "markdown",
678
+ "metadata": {
679
+ "id": "rxLenwAvCkBf"
680
+ },
681
+ "source": [
682
+ "The \"input type\" also has the same shape, but inside the non-padded region, contains a `0` or a `1` indicating which sentence the token is a part of."
683
+ ]
684
+ },
685
+ {
686
+ "cell_type": "code",
687
+ "execution_count": null,
688
+ "metadata": {
689
+ "id": "2CetH_5C9P2m"
690
+ },
691
+ "outputs": [],
692
+ "source": [
693
+ "plt.pcolormesh(example_inputs['input_type_ids'])"
694
+ ]
695
+ },
696
+ {
697
+ "cell_type": "markdown",
698
+ "metadata": {
699
+ "id": "pxHHeyei_sb9"
700
+ },
701
+ "source": [
702
+ "Apply the same preprocessing to the validation and test subsets of the GLUE MRPC dataset:"
703
+ ]
704
+ },
705
+ {
706
+ "cell_type": "code",
707
+ "execution_count": null,
708
+ "metadata": {
709
+ "id": "yuLKxf6zHxw-"
710
+ },
711
+ "outputs": [],
712
+ "source": [
713
+ "glue_validation = glue['validation'].map(bert_inputs_processor).prefetch(1)\n",
714
+ "glue_test = glue['test'].map(bert_inputs_processor).prefetch(1)"
715
+ ]
716
+ },
717
+ {
718
+ "cell_type": "markdown",
719
+ "metadata": {
720
+ "id": "FSwymsbkbLDA"
721
+ },
722
+ "source": [
723
+ "## Build, train and export the model"
724
+ ]
725
+ },
726
+ {
727
+ "cell_type": "markdown",
728
+ "metadata": {
729
+ "id": "bxxO3pJCEM9p"
730
+ },
731
+ "source": [
732
+ "Now that you have formatted the data as expected, you can start working on building and training the model."
733
+ ]
734
+ },
735
+ {
736
+ "cell_type": "markdown",
737
+ "metadata": {
738
+ "id": "Efrj3Cn1kLAp"
739
+ },
740
+ "source": [
741
+ "### Build the model\n"
742
+ ]
743
+ },
744
+ {
745
+ "cell_type": "markdown",
746
+ "metadata": {
747
+ "id": "xxpOY5r2Ayq6"
748
+ },
749
+ "source": [
750
+ "The first step is to download the configuration file—`config_dict`—for the pre-trained BERT model:\n"
751
+ ]
752
+ },
753
+ {
754
+ "cell_type": "code",
755
+ "execution_count": null,
756
+ "metadata": {
757
+ "id": "v7ap0BONSJuz"
758
+ },
759
+ "outputs": [],
760
+ "source": [
761
+ "import json\n",
762
+ "\n",
763
+ "bert_config_file = os.path.join(gs_folder_bert, \"bert_config.json\")\n",
764
+ "config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())\n",
765
+ "config_dict"
766
+ ]
767
+ },
768
+ {
769
+ "cell_type": "code",
770
+ "execution_count": null,
771
+ "metadata": {
772
+ "id": "pKaEaKJSX85J"
773
+ },
774
+ "outputs": [],
775
+ "source": [
776
+ "encoder_config = tfm.nlp.encoders.EncoderConfig({\n",
777
+ " 'type':'bert',\n",
778
+ " 'bert': config_dict\n",
779
+ "})"
780
+ ]
781
+ },
782
+ {
783
+ "cell_type": "code",
784
+ "execution_count": null,
785
+ "metadata": {
786
+ "id": "LbgzWukNSqOS"
787
+ },
788
+ "outputs": [],
789
+ "source": [
790
+ "bert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)\n",
791
+ "bert_encoder"
792
+ ]
793
+ },
794
+ {
795
+ "cell_type": "markdown",
796
+ "metadata": {
797
+ "id": "96ldxDSwkVkj"
798
+ },
799
+ "source": [
800
+ "The configuration file defines the core BERT model from the Model Garden, which is a Keras model that predicts the outputs of `num_classes` from the inputs with maximum sequence length `max_seq_length`."
801
+ ]
802
+ },
803
+ {
804
+ "cell_type": "code",
805
+ "execution_count": null,
806
+ "metadata": {
807
+ "id": "cH682__U0FBv"
808
+ },
809
+ "outputs": [],
810
+ "source": [
811
+ "bert_classifier = tfm.nlp.models.BertClassifier(network=bert_encoder, num_classes=2)"
812
+ ]
813
+ },
814
+ {
815
+ "cell_type": "markdown",
816
+ "metadata": {
817
+ "id": "sFmVG4SKZAw8"
818
+ },
819
+ "source": [
820
+ "Run it on a test batch of data 10 examples from the training set. The output is the logits for the two classes:"
821
+ ]
822
+ },
823
+ {
824
+ "cell_type": "code",
825
+ "execution_count": null,
826
+ "metadata": {
827
+ "id": "VTjgPbp4ZDKo"
828
+ },
829
+ "outputs": [],
830
+ "source": [
831
+ "bert_classifier(\n",
832
+ " example_inputs, training=True).numpy()[:10]"
833
+ ]
834
+ },
835
+ {
836
+ "cell_type": "markdown",
837
+ "metadata": {
838
+ "id": "Q0NTdwZsQK8n"
839
+ },
840
+ "source": [
841
+ "The `TransformerEncoder` in the center of the classifier above **is** the `bert_encoder`.\n",
842
+ "\n",
843
+ "If you inspect the encoder, notice the stack of `Transformer` layers connected to those same three inputs:"
844
+ ]
845
+ },
846
+ {
847
+ "cell_type": "code",
848
+ "execution_count": null,
849
+ "metadata": {
850
+ "id": "8L__-erBwLIQ"
851
+ },
852
+ "outputs": [],
853
+ "source": [
854
+ "tf.keras.utils.plot_model(bert_encoder, show_shapes=True, dpi=48)"
855
+ ]
856
+ },
857
+ {
858
+ "cell_type": "markdown",
859
+ "metadata": {
860
+ "id": "mKAvkQc3heSy"
861
+ },
862
+ "source": [
863
+ "### Restore the encoder weights\n",
864
+ "\n",
865
+ "When built, the encoder is randomly initialized. Restore the encoder's weights from the checkpoint:"
866
+ ]
867
+ },
868
+ {
869
+ "cell_type": "code",
870
+ "execution_count": null,
871
+ "metadata": {
872
+ "id": "97Ll2Gichd_Y"
873
+ },
874
+ "outputs": [],
875
+ "source": [
876
+ "checkpoint = tf.train.Checkpoint(encoder=bert_encoder)\n",
877
+ "checkpoint.read(\n",
878
+ " os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
879
+ ]
880
+ },
881
+ {
882
+ "cell_type": "markdown",
883
+ "metadata": {
884
+ "id": "2oHOql35k3Dd"
885
+ },
886
+ "source": [
887
+ "Note: The pretrained `TransformerEncoder` is also available on [TensorFlow Hub](https://tensorflow.org/hub). Go to the [TF Hub appendix](#hub_bert) for details."
888
+ ]
889
+ },
890
+ {
891
+ "cell_type": "markdown",
892
+ "metadata": {
893
+ "id": "115caFLMk-_l"
894
+ },
895
+ "source": [
896
+ "### Set up the optimizer\n",
897
+ "\n",
898
+ "BERT typically uses the Adam optimizer with weight decay—[AdamW](https://arxiv.org/abs/1711.05101) (`tf.keras.optimizers.experimental.AdamW`).\n",
899
+ "It also employs a learning rate schedule that first warms up from 0 and then decays to 0:"
900
+ ]
901
+ },
902
+ {
903
+ "cell_type": "code",
904
+ "execution_count": null,
905
+ "metadata": {
906
+ "id": "c0jBycPDtkxR"
907
+ },
908
+ "outputs": [],
909
+ "source": [
910
+ "# Set up epochs and steps\n",
911
+ "epochs = 5\n",
912
+ "batch_size = 32\n",
913
+ "eval_batch_size = 32\n",
914
+ "\n",
915
+ "train_data_size = info.splits['train'].num_examples\n",
916
+ "steps_per_epoch = int(train_data_size / batch_size)\n",
917
+ "num_train_steps = steps_per_epoch * epochs\n",
918
+ "warmup_steps = int(0.1 * num_train_steps)\n",
919
+ "initial_learning_rate=2e-5"
920
+ ]
921
+ },
922
+ {
923
+ "cell_type": "markdown",
924
+ "metadata": {
925
+ "id": "GFankgHK0Rvh"
926
+ },
927
+ "source": [
928
+ "Linear decay from `initial_learning_rate` to zero over `num_train_steps`."
929
+ ]
930
+ },
931
+ {
932
+ "cell_type": "code",
933
+ "execution_count": null,
934
+ "metadata": {
935
+ "id": "qWSyT8P2j4mV"
936
+ },
937
+ "outputs": [],
938
+ "source": [
939
+ "linear_decay = tf.keras.optimizers.schedules.PolynomialDecay(\n",
940
+ " initial_learning_rate=initial_learning_rate,\n",
941
+ " end_learning_rate=0,\n",
942
+ " decay_steps=num_train_steps)"
943
+ ]
944
+ },
945
+ {
946
+ "cell_type": "markdown",
947
+ "metadata": {
948
+ "id": "anZPZPAP0Y3n"
949
+ },
950
+ "source": [
951
+ "Warmup to that value over `warmup_steps`:"
952
+ ]
953
+ },
954
+ {
955
+ "cell_type": "code",
956
+ "execution_count": null,
957
+ "metadata": {
958
+ "id": "z_AsVCiRkoN1"
959
+ },
960
+ "outputs": [],
961
+ "source": [
962
+ "warmup_schedule = tfm.optimization.lr_schedule.LinearWarmup(\n",
963
+ " warmup_learning_rate = 0,\n",
964
+ " after_warmup_lr_sched = linear_decay,\n",
965
+ " warmup_steps = warmup_steps\n",
966
+ ")"
967
+ ]
968
+ },
969
+ {
970
+ "cell_type": "markdown",
971
+ "metadata": {
972
+ "id": "arfbaK6t0kH_"
973
+ },
974
+ "source": [
975
+ "The overall schedule looks like this:"
976
+ ]
977
+ },
978
+ {
979
+ "cell_type": "code",
980
+ "execution_count": null,
981
+ "metadata": {
982
+ "id": "rYZGunhqbGUZ"
983
+ },
984
+ "outputs": [],
985
+ "source": [
986
+ "x = tf.linspace(0, num_train_steps, 1001)\n",
987
+ "y = [warmup_schedule(xi) for xi in x]\n",
988
+ "plt.plot(x,y)\n",
989
+ "plt.xlabel('Train step')\n",
990
+ "plt.ylabel('Learning rate')"
991
+ ]
992
+ },
993
+ {
994
+ "cell_type": "markdown",
995
+ "metadata": {
996
+ "id": "bjsmG_fm0opn"
997
+ },
998
+ "source": [
999
+ "Use `tf.keras.optimizers.experimental.AdamW` to instantiate the optimizer with that schedule:"
1000
+ ]
1001
+ },
1002
+ {
1003
+ "cell_type": "code",
1004
+ "execution_count": null,
1005
+ "metadata": {
1006
+ "id": "R8pTNuKIw1dA"
1007
+ },
1008
+ "outputs": [],
1009
+ "source": [
1010
+ "optimizer = tf.keras.optimizers.experimental.Adam(\n",
1011
+ " learning_rate = warmup_schedule)"
1012
+ ]
1013
+ },
1014
+ {
1015
+ "cell_type": "markdown",
1016
+ "metadata": {
1017
+ "id": "78FEUOOEkoP0"
1018
+ },
1019
+ "source": [
1020
+ "### Train the model"
1021
+ ]
1022
+ },
1023
+ {
1024
+ "cell_type": "markdown",
1025
+ "metadata": {
1026
+ "id": "OTNcA0O0nSq9"
1027
+ },
1028
+ "source": [
1029
+ "Set the metric as accuracy and the loss as sparse categorical cross-entropy. Then, compile and train the BERT classifier:"
1030
+ ]
1031
+ },
1032
+ {
1033
+ "cell_type": "code",
1034
+ "execution_count": null,
1035
+ "metadata": {
1036
+ "id": "d5FeL0b6j7ky"
1037
+ },
1038
+ "outputs": [],
1039
+ "source": [
1040
+ "metrics = [tf.keras.metrics.SparseCategoricalAccuracy('accuracy', dtype=tf.float32)]\n",
1041
+ "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
1042
+ "\n",
1043
+ "bert_classifier.compile(\n",
1044
+ " optimizer=optimizer,\n",
1045
+ " loss=loss,\n",
1046
+ " metrics=metrics)"
1047
+ ]
1048
+ },
1049
+ {
1050
+ "cell_type": "code",
1051
+ "execution_count": null,
1052
+ "metadata": {
1053
+ "id": "CsrylctIj_Xy"
1054
+ },
1055
+ "outputs": [],
1056
+ "source": [
1057
+ "bert_classifier.evaluate(glue_validation)"
1058
+ ]
1059
+ },
1060
+ {
1061
+ "cell_type": "code",
1062
+ "execution_count": null,
1063
+ "metadata": {
1064
+ "id": "hgPPc2oNmcVZ"
1065
+ },
1066
+ "outputs": [],
1067
+ "source": [
1068
+ "bert_classifier.fit(\n",
1069
+ " glue_train,\n",
1070
+ " validation_data=(glue_validation),\n",
1071
+ " batch_size=32,\n",
1072
+ " epochs=epochs)"
1073
+ ]
1074
+ },
1075
+ {
1076
+ "cell_type": "markdown",
1077
+ "metadata": {
1078
+ "id": "IFtKFWbNKb0u"
1079
+ },
1080
+ "source": [
1081
+ "Now run the fine-tuned model on a custom example to see that it works.\n",
1082
+ "\n",
1083
+ "Start by encoding some sentence pairs:"
1084
+ ]
1085
+ },
1086
+ {
1087
+ "cell_type": "code",
1088
+ "execution_count": null,
1089
+ "metadata": {
1090
+ "id": "S1sdW6lLWaEi"
1091
+ },
1092
+ "outputs": [],
1093
+ "source": [
1094
+ "my_examples = {\n",
1095
+ " 'sentence1':[\n",
1096
+ " 'The rain in Spain falls mainly on the plain.',\n",
1097
+ " 'Look I fine tuned BERT.'],\n",
1098
+ " 'sentence2':[\n",
1099
+ " 'It mostly rains on the flat lands of Spain.',\n",
1100
+ " 'Is it working? This does not match.']\n",
1101
+ " }"
1102
+ ]
1103
+ },
1104
+ {
1105
+ "cell_type": "markdown",
1106
+ "metadata": {
1107
+ "id": "7ynJibkBRTJF"
1108
+ },
1109
+ "source": [
1110
+ "The model should report class `1` \"match\" for the first example and class `0` \"no-match\" for the second:"
1111
+ ]
1112
+ },
1113
+ {
1114
+ "cell_type": "code",
1115
+ "execution_count": null,
1116
+ "metadata": {
1117
+ "id": "umo0ttrgRYIM"
1118
+ },
1119
+ "outputs": [],
1120
+ "source": [
1121
+ "ex_packed = bert_inputs_processor(my_examples)\n",
1122
+ "my_logits = bert_classifier(ex_packed, training=False)\n",
1123
+ "\n",
1124
+ "result_cls_ids = tf.argmax(my_logits)\n",
1125
+ "result_cls_ids"
1126
+ ]
1127
+ },
1128
+ {
1129
+ "cell_type": "code",
1130
+ "execution_count": null,
1131
+ "metadata": {
1132
+ "id": "HNdmOEHKT7e8"
1133
+ },
1134
+ "outputs": [],
1135
+ "source": [
1136
+ "tf.gather(tf.constant(info.features['label'].names), result_cls_ids)"
1137
+ ]
1138
+ },
1139
+ {
1140
+ "cell_type": "markdown",
1141
+ "metadata": {
1142
+ "id": "fVo_AnT0l26j"
1143
+ },
1144
+ "source": [
1145
+ "### Export the model\n",
1146
+ "\n",
1147
+ "Often the goal of training a model is to _use_ it for something outside of the Python process that created it. You can do this by exporting the model using `tf.saved_model`. (Learn more in the [Using the SavedModel format](https://www.tensorflow.org/guide/saved_model) guide and the [Save and load a model using a distribution strategy](https://www.tensorflow.org/tutorials/distribute/save_and_load) tutorial.)\n",
1148
+ "\n",
1149
+ "First, build a wrapper class to export the model. This wrapper does two things:\n",
1150
+ "\n",
1151
+ "- First it packages `bert_inputs_processor` and `bert_classifier` together into a single `tf.Module`, so you can export all the functionalities.\n",
1152
+ "- Second it defines a `tf.function` that implements the end-to-end execution of the model.\n",
1153
+ "\n",
1154
+ "Setting the `input_signature` argument of `tf.function` lets you define a fixed signature for the `tf.function`. This can be less surprising than the default automatic retracing behavior."
1155
+ ]
1156
+ },
1157
+ {
1158
+ "cell_type": "code",
1159
+ "execution_count": null,
1160
+ "metadata": {
1161
+ "id": "78h83mlt9wpY"
1162
+ },
1163
+ "outputs": [],
1164
+ "source": [
1165
+ "class ExportModel(tf.Module):\n",
1166
+ " def __init__(self, input_processor, classifier):\n",
1167
+ " self.input_processor = input_processor\n",
1168
+ " self.classifier = classifier\n",
1169
+ "\n",
1170
+ " @tf.function(input_signature=[{\n",
1171
+ " 'sentence1': tf.TensorSpec(shape=[None], dtype=tf.string),\n",
1172
+ " 'sentence2': tf.TensorSpec(shape=[None], dtype=tf.string)}])\n",
1173
+ " def __call__(self, inputs):\n",
1174
+ " packed = self.input_processor(inputs)\n",
1175
+ " logits = self.classifier(packed, training=False)\n",
1176
+ " result_cls_ids = tf.argmax(logits)\n",
1177
+ " return {\n",
1178
+ " 'logits': logits,\n",
1179
+ " 'class_id': result_cls_ids,\n",
1180
+ " 'class': tf.gather(\n",
1181
+ " tf.constant(info.features['label'].names),\n",
1182
+ " result_cls_ids)\n",
1183
+ " }"
1184
+ ]
1185
+ },
1186
+ {
1187
+ "cell_type": "markdown",
1188
+ "metadata": {
1189
+ "id": "qnxysGUfIgFQ"
1190
+ },
1191
+ "source": [
1192
+ "Create an instance of this export-model and save it:"
1193
+ ]
1194
+ },
1195
+ {
1196
+ "cell_type": "code",
1197
+ "execution_count": null,
1198
+ "metadata": {
1199
+ "id": "TmHW9DEFUZ0X"
1200
+ },
1201
+ "outputs": [],
1202
+ "source": [
1203
+ "export_model = ExportModel(bert_inputs_processor, bert_classifier)"
1204
+ ]
1205
+ },
1206
+ {
1207
+ "cell_type": "code",
1208
+ "execution_count": null,
1209
+ "metadata": {
1210
+ "id": "Nl5x6nElZqkP"
1211
+ },
1212
+ "outputs": [],
1213
+ "source": [
1214
+ "import tempfile\n",
1215
+ "export_dir=tempfile.mkdtemp(suffix='_saved_model')\n",
1216
+ "tf.saved_model.save(export_model, export_dir=export_dir,\n",
1217
+ " signatures={'serving_default': export_model.__call__})"
1218
+ ]
1219
+ },
1220
+ {
1221
+ "cell_type": "markdown",
1222
+ "metadata": {
1223
+ "id": "Pd8B5dy-ImDJ"
1224
+ },
1225
+ "source": [
1226
+ "Reload the model and compare the results to the original:"
1227
+ ]
1228
+ },
1229
+ {
1230
+ "cell_type": "code",
1231
+ "execution_count": null,
1232
+ "metadata": {
1233
+ "id": "9cAhHySVXHD5"
1234
+ },
1235
+ "outputs": [],
1236
+ "source": [
1237
+ "original_logits = export_model(my_examples)['logits']"
1238
+ ]
1239
+ },
1240
+ {
1241
+ "cell_type": "code",
1242
+ "execution_count": null,
1243
+ "metadata": {
1244
+ "id": "H9cAcYwfW2fy"
1245
+ },
1246
+ "outputs": [],
1247
+ "source": [
1248
+ "reloaded = tf.saved_model.load(export_dir)\n",
1249
+ "reloaded_logits = reloaded(my_examples)['logits']"
1250
+ ]
1251
+ },
1252
+ {
1253
+ "cell_type": "code",
1254
+ "execution_count": null,
1255
+ "metadata": {
1256
+ "id": "y_ACvKPsVUXC"
1257
+ },
1258
+ "outputs": [],
1259
+ "source": [
1260
+ "# The results are identical:\n",
1261
+ "print(original_logits.numpy())\n",
1262
+ "print()\n",
1263
+ "print(reloaded_logits.numpy())"
1264
+ ]
1265
+ },
1266
+ {
1267
+ "cell_type": "code",
1268
+ "execution_count": null,
1269
+ "metadata": {
1270
+ "id": "lBlPP20dXPFR"
1271
+ },
1272
+ "outputs": [],
1273
+ "source": [
1274
+ "print(np.mean(abs(original_logits - reloaded_logits)))"
1275
+ ]
1276
+ },
1277
+ {
1278
+ "cell_type": "markdown",
1279
+ "metadata": {
1280
+ "id": "CPsg7dZwfBM2"
1281
+ },
1282
+ "source": [
1283
+ "Congratulations! You've used `tensorflow_models` to build a BERT-classifier, train it, and export for later use."
1284
+ ]
1285
+ },
1286
+ {
1287
+ "cell_type": "markdown",
1288
+ "metadata": {
1289
+ "id": "eQceYqRFT_Eg"
1290
+ },
1291
+ "source": [
1292
+ "## Optional: BERT on TF Hub"
1293
+ ]
1294
+ },
1295
+ {
1296
+ "cell_type": "markdown",
1297
+ "metadata": {
1298
+ "id": "QbklKt-w_CiI"
1299
+ },
1300
+ "source": [
1301
+ "\u003ca id=\"hub_bert\"\u003e\u003c/a\u003e\n",
1302
+ "\n",
1303
+ "\n",
1304
+ "You can get the BERT model off the shelf from [TF Hub](https://tfhub.dev/). There are [many versions available along with their input preprocessors](https://tfhub.dev/google/collections/bert/1).\n",
1305
+ "\n",
1306
+ "This example uses [a small version of BERT from TF Hub](https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2) that was pre-trained using the English Wikipedia and BooksCorpus datasets, similar to the [original implementation](https://arxiv.org/abs/1908.08962) (Turc et al., 2019).\n",
1307
+ "\n",
1308
+ "Start by importing TF Hub:"
1309
+ ]
1310
+ },
1311
+ {
1312
+ "cell_type": "code",
1313
+ "execution_count": null,
1314
+ "metadata": {
1315
+ "id": "GDWrHm0BGpbX"
1316
+ },
1317
+ "outputs": [],
1318
+ "source": [
1319
+ "import tensorflow_hub as hub"
1320
+ ]
1321
+ },
1322
+ {
1323
+ "cell_type": "markdown",
1324
+ "metadata": {
1325
+ "id": "f02f38f83ac4"
1326
+ },
1327
+ "source": [
1328
+ "Select the input preprocessor and the model from TF Hub and wrap them as `hub.KerasLayer` layers:"
1329
+ ]
1330
+ },
1331
+ {
1332
+ "cell_type": "code",
1333
+ "execution_count": null,
1334
+ "metadata": {
1335
+ "id": "lo6479At4sP1"
1336
+ },
1337
+ "outputs": [],
1338
+ "source": [
1339
+ "# Always make sure you use the right preprocessor.\n",
1340
+ "hub_preprocessor = hub.KerasLayer(\n",
1341
+ " \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\")\n",
1342
+ "\n",
1343
+ "# This is a really small BERT.\n",
1344
+ "hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2\",\n",
1345
+ " trainable=True)\n",
1346
+ "\n",
1347
+ "print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
1348
+ ]
1349
+ },
1350
+ {
1351
+ "cell_type": "markdown",
1352
+ "metadata": {
1353
+ "id": "iTzF574wivQv"
1354
+ },
1355
+ "source": [
1356
+ "Test run the preprocessor on a batch of data:"
1357
+ ]
1358
+ },
1359
+ {
1360
+ "cell_type": "code",
1361
+ "execution_count": null,
1362
+ "metadata": {
1363
+ "id": "GOASSKR5R3-N"
1364
+ },
1365
+ "outputs": [],
1366
+ "source": [
1367
+ "hub_inputs = hub_preprocessor(['Hello TensorFlow!'])\n",
1368
+ "{key: value[0, :10].numpy() for key, value in hub_inputs.items()} "
1369
+ ]
1370
+ },
1371
+ {
1372
+ "cell_type": "code",
1373
+ "execution_count": null,
1374
+ "metadata": {
1375
+ "id": "XEcYrCR45Uwo"
1376
+ },
1377
+ "outputs": [],
1378
+ "source": [
1379
+ "result = hub_encoder(\n",
1380
+ " inputs=hub_inputs,\n",
1381
+ " training=False,\n",
1382
+ ")\n",
1383
+ "\n",
1384
+ "print(\"Pooled output shape:\", result['pooled_output'].shape)\n",
1385
+ "print(\"Sequence output shape:\", result['sequence_output'].shape)"
1386
+ ]
1387
+ },
1388
+ {
1389
+ "cell_type": "markdown",
1390
+ "metadata": {
1391
+ "id": "cjojn8SmLSRI"
1392
+ },
1393
+ "source": [
1394
+ "At this point it would be simple to add a classification head yourself.\n",
1395
+ "\n",
1396
+ "The Model Garden `tfm.nlp.models.BertClassifier` class can also build a classifier onto the TF Hub encoder:"
1397
+ ]
1398
+ },
1399
+ {
1400
+ "cell_type": "code",
1401
+ "execution_count": null,
1402
+ "metadata": {
1403
+ "id": "9nTDaApyLR70"
1404
+ },
1405
+ "outputs": [],
1406
+ "source": [
1407
+ "hub_classifier = tfm.nlp.models.BertClassifier(\n",
1408
+ " bert_encoder,\n",
1409
+ " num_classes=2,\n",
1410
+ " dropout_rate=0.1,\n",
1411
+ " initializer=tf.keras.initializers.TruncatedNormal(\n",
1412
+ " stddev=0.02))"
1413
+ ]
1414
+ },
1415
+ {
1416
+ "cell_type": "markdown",
1417
+ "metadata": {
1418
+ "id": "xMJX3wV0_v7I"
1419
+ },
1420
+ "source": [
1421
+ "The one downside to loading this model from TF Hub is that the structure of internal Keras layers is not restored. This makes it more difficult to inspect or modify the model.\n",
1422
+ "\n",
1423
+ "The BERT encoder model—`hub_classifier`—is now a single layer."
1424
+ ]
1425
+ },
1426
+ {
1427
+ "cell_type": "markdown",
1428
+ "metadata": {
1429
+ "id": "u_IqwXjRV1vd"
1430
+ },
1431
+ "source": [
1432
+ "For concrete examples of this approach, refer to [Solve Glue tasks using BERT](https://www.tensorflow.org/text/tutorials/bert_glue)."
1433
+ ]
1434
+ },
1435
+ {
1436
+ "cell_type": "markdown",
1437
+ "metadata": {
1438
+ "id": "ji3tdLz101km"
1439
+ },
1440
+ "source": [
1441
+ "## Optional: Optimizer `config`s\n",
1442
+ "\n",
1443
+ "The `tensorflow_models` package defines serializable `config` classes that describe how to build the live objects. Earlier in this tutorial, you built the optimizer manually.\n",
1444
+ "\n",
1445
+ "The configuration below describes an (almost) identical optimizer built by the `optimizer_factory.OptimizerFactory`:"
1446
+ ]
1447
+ },
1448
+ {
1449
+ "cell_type": "code",
1450
+ "execution_count": null,
1451
+ "metadata": {
1452
+ "id": "Fdb9C1ontnH_"
1453
+ },
1454
+ "outputs": [],
1455
+ "source": [
1456
+ "optimization_config = tfm.optimization.OptimizationConfig(\n",
1457
+ " optimizer=tfm.optimization.OptimizerConfig(\n",
1458
+ " type = \"adam\"),\n",
1459
+ " learning_rate = tfm.optimization.LrConfig(\n",
1460
+ " type='polynomial',\n",
1461
+ " polynomial=tfm.optimization.PolynomialLrConfig(\n",
1462
+ " initial_learning_rate=2e-5,\n",
1463
+ " end_learning_rate=0.0,\n",
1464
+ " decay_steps=num_train_steps)),\n",
1465
+ " warmup = tfm.optimization.WarmupConfig(\n",
1466
+ " type='linear',\n",
1467
+ " linear=tfm.optimization.LinearWarmupConfig(warmup_steps=warmup_steps)\n",
1468
+ " ))\n",
1469
+ "\n",
1470
+ "\n",
1471
+ "fac = tfm.optimization.optimizer_factory.OptimizerFactory(optimization_config)\n",
1472
+ "lr = fac.build_learning_rate()\n",
1473
+ "optimizer = fac.build_optimizer(lr=lr)"
1474
+ ]
1475
+ },
1476
+ {
1477
+ "cell_type": "code",
1478
+ "execution_count": null,
1479
+ "metadata": {
1480
+ "id": "Rp7R1hBfv5HG"
1481
+ },
1482
+ "outputs": [],
1483
+ "source": [
1484
+ "x = tf.linspace(0, num_train_steps, 1001).numpy()\n",
1485
+ "y = [lr(xi) for xi in x]\n",
1486
+ "plt.plot(x,y)\n",
1487
+ "plt.xlabel('Train step')\n",
1488
+ "plt.ylabel('Learning rate')"
1489
+ ]
1490
+ },
1491
+ {
1492
+ "cell_type": "markdown",
1493
+ "metadata": {
1494
+ "id": "ywn5miD_dnuh"
1495
+ },
1496
+ "source": [
1497
+ "The advantage to using `config` objects is that they don't contain any complicated TensorFlow objects, and can be easily serialized to JSON, and rebuilt. Here's the JSON for the above `tfm.optimization.OptimizationConfig`:"
1498
+ ]
1499
+ },
1500
+ {
1501
+ "cell_type": "code",
1502
+ "execution_count": null,
1503
+ "metadata": {
1504
+ "id": "zo5RV5lud81Y"
1505
+ },
1506
+ "outputs": [],
1507
+ "source": [
1508
+ "optimization_config = optimization_config.as_dict()\n",
1509
+ "optimization_config"
1510
+ ]
1511
+ },
1512
+ {
1513
+ "cell_type": "markdown",
1514
+ "metadata": {
1515
+ "id": "Z6qPXPEhekkd"
1516
+ },
1517
+ "source": [
1518
+ "The `tfm.optimization.optimizer_factory.OptimizerFactory` can just as easily build the optimizer from the JSON dictionary:"
1519
+ ]
1520
+ },
1521
+ {
1522
+ "cell_type": "code",
1523
+ "execution_count": null,
1524
+ "metadata": {
1525
+ "id": "p-bYrvfMYsxp"
1526
+ },
1527
+ "outputs": [],
1528
+ "source": [
1529
+ "fac = tfm.optimization.optimizer_factory.OptimizerFactory(\n",
1530
+ " tfm.optimization.OptimizationConfig(optimization_config))\n",
1531
+ "lr = fac.build_learning_rate()\n",
1532
+ "optimizer = fac.build_optimizer(lr=lr)"
1533
+ ]
1534
+ }
1535
+ ],
1536
+ "metadata": {
1537
+ "accelerator": "GPU",
1538
+ "colab": {
1539
+ "name": "fine_tune_bert.ipynb",
1540
+ "private_outputs": true,
1541
+ "toc_visible": true
1542
+ },
1543
+ "kernelspec": {
1544
+ "display_name": "Python 3",
1545
+ "name": "python3"
1546
+ }
1547
+ },
1548
+ "nbformat": 4,
1549
+ "nbformat_minor": 0
1550
+ }
models/docs/nlp/index.ipynb ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "80xnUmoI7fBX"
7
+ },
8
+ "source": [
9
+ "##### Copyright 2020 The TensorFlow Authors."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "cellView": "form",
17
+ "id": "8nvTnfs6Q692"
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
22
+ "# you may not use this file except in compliance with the License.\n",
23
+ "# You may obtain a copy of the License at\n",
24
+ "#\n",
25
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
26
+ "#\n",
27
+ "# Unless required by applicable law or agreed to in writing, software\n",
28
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
29
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
30
+ "# See the License for the specific language governing permissions and\n",
31
+ "# limitations under the License."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {
37
+ "id": "WmfcMK5P5C1G"
38
+ },
39
+ "source": [
40
+ "# Introduction to the TensorFlow Models NLP library"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "metadata": {
46
+ "id": "cH-oJ8R6AHMK"
47
+ },
48
+ "source": [
49
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
50
+ " \u003ctd\u003e\n",
51
+ " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
52
+ " \u003c/td\u003e\n",
53
+ " \u003ctd\u003e\n",
54
+ " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
55
+ " \u003c/td\u003e\n",
56
+ " \u003ctd\u003e\n",
57
+ " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
58
+ " \u003c/td\u003e\n",
59
+ " \u003ctd\u003e\n",
60
+ " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
61
+ " \u003c/td\u003e\n",
62
+ "\u003c/table\u003e"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "metadata": {
68
+ "id": "0H_EFIhq4-MJ"
69
+ },
70
+ "source": [
71
+ "## Learning objectives\n",
72
+ "\n",
73
+ "In this Colab notebook, you will learn how to build transformer-based models for common NLP tasks including pretraining, span labelling and classification using the building blocks from [NLP modeling library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling)."
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "markdown",
78
+ "metadata": {
79
+ "id": "2N97-dps_nUk"
80
+ },
81
+ "source": [
82
+ "## Install and import"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "metadata": {
88
+ "id": "459ygAVl_rg0"
89
+ },
90
+ "source": [
91
+ "### Install the TensorFlow Model Garden pip package\n",
92
+ "\n",
93
+ "* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
94
+ "which is the nightly Model Garden package created daily automatically.\n",
95
+ "* `pip` will install all models and dependencies automatically."
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {
102
+ "id": "Y-qGkdh6_sZc"
103
+ },
104
+ "outputs": [],
105
+ "source": [
106
+ "!pip install tf-models-official"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "markdown",
111
+ "metadata": {
112
+ "id": "e4huSSwyAG_5"
113
+ },
114
+ "source": [
115
+ "### Import Tensorflow and other libraries"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {
122
+ "id": "jqYXqtjBAJd9"
123
+ },
124
+ "outputs": [],
125
+ "source": [
126
+ "import numpy as np\n",
127
+ "import tensorflow as tf\n",
128
+ "\n",
129
+ "from tensorflow_models import nlp"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "markdown",
134
+ "metadata": {
135
+ "id": "djBQWjvy-60Y"
136
+ },
137
+ "source": [
138
+ "## BERT pretraining model\n",
139
+ "\n",
140
+ "BERT ([Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)) introduced the method of pre-training language representations on a large text corpus and then using that model for downstream NLP tasks.\n",
141
+ "\n",
142
+ "In this section, we will learn how to build a model to pretrain BERT on the masked language modeling task and next sentence prediction task. For simplicity, we only show the minimum example and use dummy data."
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "markdown",
147
+ "metadata": {
148
+ "id": "MKuHVlsCHmiq"
149
+ },
150
+ "source": [
151
+ "### Build a `BertPretrainer` model wrapping `BertEncoder`\n",
152
+ "\n",
153
+ "The `nlp.networks.BertEncoder` class implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers (`nlp.layers.TransformerEncoderBlock`), but not the masked language model or classification task networks.\n",
154
+ "\n",
155
+ "The `nlp.models.BertPretrainer` class allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives."
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "metadata": {
162
+ "id": "EXkcXz-9BwB3"
163
+ },
164
+ "outputs": [],
165
+ "source": [
166
+ "# Build a small transformer network.\n",
167
+ "vocab_size = 100\n",
168
+ "network = nlp.networks.BertEncoder(\n",
169
+ " vocab_size=vocab_size, \n",
170
+ " # The number of TransformerEncoderBlock layers\n",
171
+ " num_layers=3)"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "metadata": {
177
+ "id": "0NH5irV5KTMS"
178
+ },
179
+ "source": [
180
+ "Inspecting the encoder, we see it contains few embedding layers, stacked `nlp.layers.TransformerEncoderBlock` layers and are connected to three input layers:\n",
181
+ "\n",
182
+ "`input_word_ids`, `input_type_ids` and `input_mask`.\n"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {
189
+ "id": "lZNoZkBrIoff"
190
+ },
191
+ "outputs": [],
192
+ "source": [
193
+ "tf.keras.utils.plot_model(network, show_shapes=True, expand_nested=True, dpi=48)"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": null,
199
+ "metadata": {
200
+ "id": "o7eFOZXiIl-b"
201
+ },
202
+ "outputs": [],
203
+ "source": [
204
+ "# Create a BERT pretrainer with the created network.\n",
205
+ "num_token_predictions = 8\n",
206
+ "bert_pretrainer = nlp.models.BertPretrainer(\n",
207
+ " network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "markdown",
212
+ "metadata": {
213
+ "id": "d5h5HT7gNHx_"
214
+ },
215
+ "source": [
216
+ "Inspecting the `bert_pretrainer`, we see it wraps the `encoder` with additional `MaskedLM` and `nlp.layers.ClassificationHead` heads."
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": null,
222
+ "metadata": {
223
+ "id": "2tcNfm03IBF7"
224
+ },
225
+ "outputs": [],
226
+ "source": [
227
+ "tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, expand_nested=True, dpi=48)"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "metadata": {
234
+ "id": "F2oHrXGUIS0M"
235
+ },
236
+ "outputs": [],
237
+ "source": [
238
+ "# We can feed some dummy data to get masked language model and sentence output.\n",
239
+ "sequence_length = 16\n",
240
+ "batch_size = 2\n",
241
+ "\n",
242
+ "word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
243
+ "mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
244
+ "type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
245
+ "masked_lm_positions_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
246
+ "\n",
247
+ "outputs = bert_pretrainer(\n",
248
+ " [word_id_data, mask_data, type_id_data, masked_lm_positions_data])\n",
249
+ "lm_output = outputs[\"masked_lm\"]\n",
250
+ "sentence_output = outputs[\"classification\"]\n",
251
+ "print(f'lm_output: shape={lm_output.shape}, dtype={lm_output.dtype!r}')\n",
252
+ "print(f'sentence_output: shape={sentence_output.shape}, dtype={sentence_output.dtype!r}')"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "markdown",
257
+ "metadata": {
258
+ "id": "bnx3UCHniCS5"
259
+ },
260
+ "source": [
261
+ "### Compute loss\n",
262
+ "Next, we can use `lm_output` and `sentence_output` to compute `loss`."
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": null,
268
+ "metadata": {
269
+ "id": "k30H4Q86f52x"
270
+ },
271
+ "outputs": [],
272
+ "source": [
273
+ "masked_lm_ids_data = np.random.randint(vocab_size, size=(batch_size, num_token_predictions))\n",
274
+ "masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
275
+ "next_sentence_labels_data = np.random.randint(2, size=(batch_size))\n",
276
+ "\n",
277
+ "mlm_loss = nlp.losses.weighted_sparse_categorical_crossentropy_loss(\n",
278
+ " labels=masked_lm_ids_data,\n",
279
+ " predictions=lm_output,\n",
280
+ " weights=masked_lm_weights_data)\n",
281
+ "sentence_loss = nlp.losses.weighted_sparse_categorical_crossentropy_loss(\n",
282
+ " labels=next_sentence_labels_data,\n",
283
+ " predictions=sentence_output)\n",
284
+ "loss = mlm_loss + sentence_loss\n",
285
+ "\n",
286
+ "print(loss)"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "markdown",
291
+ "metadata": {
292
+ "id": "wrmSs8GjHxVw"
293
+ },
294
+ "source": [
295
+ "With the loss, you can optimize the model.\n",
296
+ "After training, we can save the weights of TransformerEncoder for the downstream fine-tuning tasks. Please see [run_pretraining.py](https://github.com/tensorflow/models/blob/master/official/legacy/bert/run_pretraining.py) for the full example.\n"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "markdown",
301
+ "metadata": {
302
+ "id": "k8cQVFvBCV4s"
303
+ },
304
+ "source": [
305
+ "## Span labeling model\n",
306
+ "\n",
307
+ "Span labeling is the task to assign labels to a span of the text, for example, label a span of text as the answer of a given question.\n",
308
+ "\n",
309
+ "In this section, we will learn how to build a span labeling model. Again, we use dummy data for simplicity."
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "markdown",
314
+ "metadata": {
315
+ "id": "xrLLEWpfknUW"
316
+ },
317
+ "source": [
318
+ "### Build a BertSpanLabeler wrapping BertEncoder\n",
319
+ "\n",
320
+ "The `nlp.models.BertSpanLabeler` class implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n",
321
+ "\n",
322
+ "Note that `nlp.models.BertSpanLabeler` wraps a `nlp.networks.BertEncoder`, the weights of which can be restored from the above pretraining model.\n"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": null,
328
+ "metadata": {
329
+ "id": "B941M4iUCejO"
330
+ },
331
+ "outputs": [],
332
+ "source": [
333
+ "network = nlp.networks.BertEncoder(\n",
334
+ " vocab_size=vocab_size, num_layers=2)\n",
335
+ "\n",
336
+ "# Create a BERT trainer with the created network.\n",
337
+ "bert_span_labeler = nlp.models.BertSpanLabeler(network)"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "markdown",
342
+ "metadata": {
343
+ "id": "QpB9pgj4PpMg"
344
+ },
345
+ "source": [
346
+ "Inspecting the `bert_span_labeler`, we see it wraps the encoder with additional `SpanLabeling` that outputs `start_position` and `end_position`."
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": null,
352
+ "metadata": {
353
+ "id": "RbqRNJCLJu4H"
354
+ },
355
+ "outputs": [],
356
+ "source": [
357
+ "tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, expand_nested=True, dpi=48)"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": null,
363
+ "metadata": {
364
+ "id": "fUf1vRxZJwio"
365
+ },
366
+ "outputs": [],
367
+ "source": [
368
+ "# Create a set of 2-dimensional data tensors to feed into the model.\n",
369
+ "word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
370
+ "mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
371
+ "type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
372
+ "\n",
373
+ "# Feed the data to the model.\n",
374
+ "start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n",
375
+ "\n",
376
+ "print(f'start_logits: shape={start_logits.shape}, dtype={start_logits.dtype!r}')\n",
377
+ "print(f'end_logits: shape={end_logits.shape}, dtype={end_logits.dtype!r}')"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "markdown",
382
+ "metadata": {
383
+ "id": "WqhgQaN1lt-G"
384
+ },
385
+ "source": [
386
+ "### Compute loss\n",
387
+ "With `start_logits` and `end_logits`, we can compute loss:"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": null,
393
+ "metadata": {
394
+ "id": "waqs6azNl3Nn"
395
+ },
396
+ "outputs": [],
397
+ "source": [
398
+ "start_positions = np.random.randint(sequence_length, size=(batch_size))\n",
399
+ "end_positions = np.random.randint(sequence_length, size=(batch_size))\n",
400
+ "\n",
401
+ "start_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
402
+ " start_positions, start_logits, from_logits=True)\n",
403
+ "end_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
404
+ " end_positions, end_logits, from_logits=True)\n",
405
+ "\n",
406
+ "total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2\n",
407
+ "print(total_loss)"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "markdown",
412
+ "metadata": {
413
+ "id": "Zdf03YtZmd_d"
414
+ },
415
+ "source": [
416
+ "With the `loss`, you can optimize the model. Please see [run_squad.py](https://github.com/tensorflow/models/blob/master/official/legacy/bert/run_squad.py) for the full example."
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "markdown",
421
+ "metadata": {
422
+ "id": "0A1XnGSTChg9"
423
+ },
424
+ "source": [
425
+ "## Classification model\n",
426
+ "\n",
427
+ "In the last section, we show how to build a text classification model.\n"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "markdown",
432
+ "metadata": {
433
+ "id": "MSK8OpZgnQa9"
434
+ },
435
+ "source": [
436
+ "### Build a BertClassifier model wrapping BertEncoder\n",
437
+ "\n",
438
+ "`nlp.models.BertClassifier` implements a [CLS] token classification model containing a single classification head."
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": null,
444
+ "metadata": {
445
+ "id": "cXXCsffkCphk"
446
+ },
447
+ "outputs": [],
448
+ "source": [
449
+ "network = nlp.networks.BertEncoder(\n",
450
+ " vocab_size=vocab_size, num_layers=2)\n",
451
+ "\n",
452
+ "# Create a BERT trainer with the created network.\n",
453
+ "num_classes = 2\n",
454
+ "bert_classifier = nlp.models.BertClassifier(\n",
455
+ " network, num_classes=num_classes)"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "markdown",
460
+ "metadata": {
461
+ "id": "8tZKueKYP4bB"
462
+ },
463
+ "source": [
464
+ "Inspecting the `bert_classifier`, we see it wraps the `encoder` with additional `Classification` head."
465
+ ]
466
+ },
467
+ {
468
+ "cell_type": "code",
469
+ "execution_count": null,
470
+ "metadata": {
471
+ "id": "snlutm9ZJgEZ"
472
+ },
473
+ "outputs": [],
474
+ "source": [
475
+ "tf.keras.utils.plot_model(bert_classifier, show_shapes=True, expand_nested=True, dpi=48)"
476
+ ]
477
+ },
478
+ {
479
+ "cell_type": "code",
480
+ "execution_count": null,
481
+ "metadata": {
482
+ "id": "yyHPHsqBJkCz"
483
+ },
484
+ "outputs": [],
485
+ "source": [
486
+ "# Create a set of 2-dimensional data tensors to feed into the model.\n",
487
+ "word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
488
+ "mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
489
+ "type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
490
+ "\n",
491
+ "# Feed the data to the model.\n",
492
+ "logits = bert_classifier([word_id_data, mask_data, type_id_data])\n",
493
+ "print(f'logits: shape={logits.shape}, dtype={logits.dtype!r}')"
494
+ ]
495
+ },
496
+ {
497
+ "cell_type": "markdown",
498
+ "metadata": {
499
+ "id": "w--a2mg4nzKm"
500
+ },
501
+ "source": [
502
+ "### Compute loss\n",
503
+ "\n",
504
+ "With `logits`, we can compute `loss`:"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "metadata": {
511
+ "id": "9X0S1DoFn_5Q"
512
+ },
513
+ "outputs": [],
514
+ "source": [
515
+ "labels = np.random.randint(num_classes, size=(batch_size))\n",
516
+ "\n",
517
+ "loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
518
+ " labels, logits, from_logits=True)\n",
519
+ "print(loss)"
520
+ ]
521
+ },
522
+ {
523
+ "cell_type": "markdown",
524
+ "metadata": {
525
+ "id": "mzBqOylZo3og"
526
+ },
527
+ "source": [
528
+ "With the `loss`, you can optimize the model. Please see the [Fine tune_bert](https://www.tensorflow.org/text/tutorials/fine_tune_bert) notebook or the [model training documentation](https://github.com/tensorflow/models/blob/master/official/nlp/docs/train.md) for the full example."
529
+ ]
530
+ }
531
+ ],
532
+ "metadata": {
533
+ "colab": {
534
+ "name": "nlp_modeling_library_intro.ipynb",
535
+ "provenance": [],
536
+ "toc_visible": true
537
+ },
538
+ "kernelspec": {
539
+ "display_name": "Python 3",
540
+ "name": "python3"
541
+ }
542
+ },
543
+ "nbformat": 4,
544
+ "nbformat_minor": 0
545
+ }
models/docs/nlp/load_lm_ckpts.ipynb ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "30155835fc9f"
7
+ },
8
+ "source": [
9
+ "##### Copyright 2022 The TensorFlow Authors."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "cellView": "form",
17
+ "id": "906e07f6e562"
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
22
+ "# you may not use this file except in compliance with the License.\n",
23
+ "# You may obtain a copy of the License at\n",
24
+ "#\n",
25
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
26
+ "#\n",
27
+ "# Unless required by applicable law or agreed to in writing, software\n",
28
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
29
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
30
+ "# See the License for the specific language governing permissions and\n",
31
+ "# limitations under the License."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {
37
+ "id": "5hrbPTziJK15"
38
+ },
39
+ "source": [
40
+ "# Load LM Checkpoints using Model Garden"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "metadata": {
46
+ "id": "-PYqCW1II75I"
47
+ },
48
+ "source": [
49
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
50
+ " \u003ctd\u003e\n",
51
+ " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/load_lm_ckpts\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
52
+ " \u003c/td\u003e\n",
53
+ " \u003ctd\u003e\n",
54
+ " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/load_lm_ckpts.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
55
+ " \u003c/td\u003e\n",
56
+ " \u003ctd\u003e\n",
57
+ " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/load_lm_ckpts.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
58
+ " \u003c/td\u003e\n",
59
+ " \u003ctd\u003e\n",
60
+ " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/load_lm_ckpts.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
61
+ " \u003c/td\u003e\n",
62
+ "\u003c/table\u003e"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "metadata": {
68
+ "id": "yyyk1KMlJdWd"
69
+ },
70
+ "source": [
71
+ "This tutorial demonstrates how to load BERT, ALBERT and ELECTRA pretrained checkpoints and use them for downstream tasks.\n",
72
+ "\n",
73
+ "[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development."
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "markdown",
78
+ "metadata": {
79
+ "id": "uEG4RYHolQij"
80
+ },
81
+ "source": [
82
+ "## Install TF Model Garden package"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {
89
+ "id": "kPfC1NJZnJq1"
90
+ },
91
+ "outputs": [],
92
+ "source": [
93
+ "!pip install -U -q \"tf-models-official\""
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "markdown",
98
+ "metadata": {
99
+ "id": "Op9R3zy3lUk8"
100
+ },
101
+ "source": [
102
+ "## Import necessary libraries"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {
109
+ "id": "6_y4Rfq23wK-"
110
+ },
111
+ "outputs": [],
112
+ "source": [
113
+ "import os\n",
114
+ "import yaml\n",
115
+ "import json\n",
116
+ "\n",
117
+ "import tensorflow as tf"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {
124
+ "id": "xjgv3gllzbYQ"
125
+ },
126
+ "outputs": [],
127
+ "source": [
128
+ "import tensorflow_models as tfm\n",
129
+ "\n",
130
+ "from official.core import exp_factory"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "metadata": {
136
+ "id": "J-t2mo6VQNfY"
137
+ },
138
+ "source": [
139
+ "## Load BERT model pretrained checkpoints"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "markdown",
144
+ "metadata": {
145
+ "id": "hdBsFnI20LDE"
146
+ },
147
+ "source": [
148
+ "### Select required BERT model"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "metadata": {
155
+ "id": "apn3VgxUlr5G"
156
+ },
157
+ "outputs": [],
158
+ "source": [
159
+ "# @title Download Checkpoint of the Selected Model { display-mode: \"form\", run: \"auto\" }\n",
160
+ "model_display_name = 'BERT-base cased English' # @param ['BERT-base uncased English','BERT-base cased English','BERT-large uncased English', 'BERT-large cased English', 'BERT-large, Uncased (Whole Word Masking)', 'BERT-large, Cased (Whole Word Masking)', 'BERT-base MultiLingual','BERT-base Chinese']\n",
161
+ "\n",
162
+ "if model_display_name == 'BERT-base uncased English':\n",
163
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-12_H-768_A-12.tar.gz\"\n",
164
+ " !tar -xvf \"uncased_L-12_H-768_A-12.tar.gz\"\n",
165
+ "elif model_display_name == 'BERT-base cased English':\n",
166
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-12_H-768_A-12.tar.gz\"\n",
167
+ " !tar -xvf \"cased_L-12_H-768_A-12.tar.gz\"\n",
168
+ "elif model_display_name == \"BERT-large uncased English\":\n",
169
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-24_H-1024_A-16.tar.gz\"\n",
170
+ " !tar -xvf \"uncased_L-24_H-1024_A-16.tar.gz\"\n",
171
+ "elif model_display_name == \"BERT-large cased English\":\n",
172
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-24_H-1024_A-16.tar.gz\"\n",
173
+ " !tar -xvf \"cased_L-24_H-1024_A-16.tar.gz\"\n",
174
+ "elif model_display_name == \"BERT-large, Uncased (Whole Word Masking)\":\n",
175
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/wwm_uncased_L-24_H-1024_A-16.tar.gz\"\n",
176
+ " !tar -xvf \"wwm_uncased_L-24_H-1024_A-16.tar.gz\"\n",
177
+ "elif model_display_name == \"BERT-large, Cased (Whole Word Masking)\":\n",
178
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/wwm_cased_L-24_H-1024_A-16.tar.gz\"\n",
179
+ " !tar -xvf \"wwm_cased_L-24_H-1024_A-16.tar.gz\"\n",
180
+ "elif model_display_name == \"BERT-base MultiLingual\":\n",
181
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/multi_cased_L-12_H-768_A-12.tar.gz\"\n",
182
+ " !tar -xvf \"multi_cased_L-12_H-768_A-12.tar.gz\"\n",
183
+ "elif model_display_name == \"BERT-base Chinese\":\n",
184
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/chinese_L-12_H-768_A-12.tar.gz\"\n",
185
+ " !tar -xvf \"chinese_L-12_H-768_A-12.tar.gz\""
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {
192
+ "id": "jzxyziRuaC95"
193
+ },
194
+ "outputs": [],
195
+ "source": [
196
+ "# Lookup table of the directory name corresponding to each model checkpoint\n",
197
+ "folder_bert_dict = {\n",
198
+ " 'BERT-base uncased English': 'uncased_L-12_H-768_A-12',\n",
199
+ " 'BERT-base cased English': 'cased_L-12_H-768_A-12',\n",
200
+ " 'BERT-large uncased English': 'uncased_L-24_H-1024_A-16',\n",
201
+ " 'BERT-large cased English': 'cased_L-24_H-1024_A-16',\n",
202
+ " 'BERT-large, Uncased (Whole Word Masking)': 'wwm_uncased_L-24_H-1024_A-16',\n",
203
+ " 'BERT-large, Cased (Whole Word Masking)': 'wwm_cased_L-24_H-1024_A-16',\n",
204
+ " 'BERT-base MultiLingual': 'multi_cased_L-12_H-768_A-1',\n",
205
+ " 'BERT-base Chinese': 'chinese_L-12_H-768_A-12'\n",
206
+ "}\n",
207
+ "\n",
208
+ "folder_bert = folder_bert_dict.get(model_display_name)\n",
209
+ "folder_bert"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "markdown",
214
+ "metadata": {
215
+ "id": "q1WrYswpZPlc"
216
+ },
217
+ "source": [
218
+ "### Construct BERT Model Using the New `params.yaml`\n",
219
+ "\n",
220
+ "params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here."
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": null,
226
+ "metadata": {
227
+ "id": "quu1s8Hi2szo"
228
+ },
229
+ "outputs": [],
230
+ "source": [
231
+ "config_file = os.path.join(folder_bert, \"params.yaml\")\n",
232
+ "config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())\n",
233
+ "config_dict"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": null,
239
+ "metadata": {
240
+ "id": "3t8o0iG9v8ac"
241
+ },
242
+ "outputs": [],
243
+ "source": [
244
+ "# Method 1: pass encoder config dict into EncoderConfig\n",
245
+ "encoder_config = tfm.nlp.encoders.EncoderConfig(config_dict[\"task\"][\"model\"][\"encoder\"])\n",
246
+ "encoder_config.get().as_dict()"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": null,
252
+ "metadata": {
253
+ "id": "2I5PetB6wPvb"
254
+ },
255
+ "outputs": [],
256
+ "source": [
257
+ "# Method 2: use override_params_dict function to override default Encoder params\n",
258
+ "encoder_config = tfm.nlp.encoders.EncoderConfig()\n",
259
+ "tfm.hyperparams.override_params_dict(encoder_config, config_dict[\"task\"][\"model\"][\"encoder\"], is_strict=True)\n",
260
+ "encoder_config.get().as_dict()"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "markdown",
265
+ "metadata": {
266
+ "id": "5yHiG_9oS3Uw"
267
+ },
268
+ "source": [
269
+ "### Construct BERT Model Using the Old `bert_config.json`"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "metadata": {
276
+ "id": "WEyaqLcW3nne"
277
+ },
278
+ "outputs": [],
279
+ "source": [
280
+ "bert_config_file = os.path.join(folder_bert, \"bert_config.json\")\n",
281
+ "config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())\n",
282
+ "config_dict"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": null,
288
+ "metadata": {
289
+ "id": "xSIcaW9tdrl4"
290
+ },
291
+ "outputs": [],
292
+ "source": [
293
+ "encoder_config = tfm.nlp.encoders.EncoderConfig({\n",
294
+ " 'type':'bert',\n",
295
+ " 'bert': config_dict\n",
296
+ "})\n",
297
+ "\n",
298
+ "encoder_config.get().as_dict()"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "markdown",
303
+ "metadata": {
304
+ "id": "yZznAP--TDLe"
305
+ },
306
+ "source": [
307
+ "### Construct a classifier with `encoder_config`\n",
308
+ "\n",
309
+ "Here, we construct a new BERT Classifier with 2 classes and plot its model architecture. A BERT Classifier consists of a BERT encoder using the selected encoder config, a Dropout layer and a MLP classification head."
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": null,
315
+ "metadata": {
316
+ "id": "Ny962I8nqs4n"
317
+ },
318
+ "outputs": [],
319
+ "source": [
320
+ "bert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)\n",
321
+ "bert_classifier = tfm.nlp.models.BertClassifier(network=bert_encoder, num_classes=2)\n",
322
+ "\n",
323
+ "tf.keras.utils.plot_model(bert_classifier)"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "metadata": {
329
+ "id": "IStKfxXkTJMu"
330
+ },
331
+ "source": [
332
+ "### Load Pretrained Weights into the BERT Classifier\n",
333
+ "\n",
334
+ "The provided pretrained checkpoint only contains weights for the BERT Encoder within the BERT Classifier. Weights for the Classification Head is still randomly initialized."
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "code",
339
+ "execution_count": null,
340
+ "metadata": {
341
+ "id": "G9_XCBpOEo4y"
342
+ },
343
+ "outputs": [],
344
+ "source": [
345
+ "checkpoint = tf.train.Checkpoint(encoder=bert_encoder)\n",
346
+ "checkpoint.read(\n",
347
+ " os.path.join(folder_bert, 'bert_model.ckpt')).expect_partial().assert_existing_objects_matched()"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "markdown",
352
+ "metadata": {
353
+ "id": "E6Hu1FFgQWUU"
354
+ },
355
+ "source": [
356
+ "## Load ALBERT model pretrained checkpoints"
357
+ ]
358
+ },
359
+ {
360
+ "cell_type": "code",
361
+ "execution_count": null,
362
+ "metadata": {
363
+ "id": "TWUtFeWxQn0V"
364
+ },
365
+ "outputs": [],
366
+ "source": [
367
+ "# @title Download Checkpoint of the Selected Model { display-mode: \"form\", run: \"auto\" }\n",
368
+ "albert_model_display_name = 'ALBERT-xxlarge English' # @param ['ALBERT-base English', 'ALBERT-large English', 'ALBERT-xlarge English', 'ALBERT-xxlarge English']\n",
369
+ "\n",
370
+ "if albert_model_display_name == 'ALBERT-base English':\n",
371
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_base.tar.gz\"\n",
372
+ " !tar -xvf \"albert_base.tar.gz\"\n",
373
+ "elif albert_model_display_name == 'ALBERT-large English':\n",
374
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_large.tar.gz\"\n",
375
+ " !tar -xvf \"albert_large.tar.gz\"\n",
376
+ "elif albert_model_display_name == \"ALBERT-xlarge English\":\n",
377
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_xlarge.tar.gz\"\n",
378
+ " !tar -xvf \"albert_xlarge.tar.gz\"\n",
379
+ "elif albert_model_display_name == \"ALBERT-xxlarge English\":\n",
380
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_xxlarge.tar.gz\"\n",
381
+ " !tar -xvf \"albert_xxlarge.tar.gz\""
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "code",
386
+ "execution_count": null,
387
+ "metadata": {
388
+ "id": "5lZDWD7zUAAO"
389
+ },
390
+ "outputs": [],
391
+ "source": [
392
+ "# Lookup table of the directory name corresponding to each model checkpoint\n",
393
+ "folder_albert_dict = {\n",
394
+ " 'ALBERT-base English': 'albert_base',\n",
395
+ " 'ALBERT-large English': 'albert_large',\n",
396
+ " 'ALBERT-xlarge English': 'albert_xlarge',\n",
397
+ " 'ALBERT-xxlarge English': 'albert_xxlarge'\n",
398
+ "}\n",
399
+ "\n",
400
+ "folder_albert = folder_albert_dict.get(albert_model_display_name)\n",
401
+ "folder_albert"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "markdown",
406
+ "metadata": {
407
+ "id": "ftXwmObdU2fS"
408
+ },
409
+ "source": [
410
+ "### Construct ALBERT Model Using the New `params.yaml`\n",
411
+ "\n",
412
+ "params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here."
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": null,
418
+ "metadata": {
419
+ "id": "VXn20q2oU1UJ"
420
+ },
421
+ "outputs": [],
422
+ "source": [
423
+ "config_file = os.path.join(folder_albert, \"params.yaml\")\n",
424
+ "config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())\n",
425
+ "config_dict"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": null,
431
+ "metadata": {
432
+ "id": "Uo_TSMSvWOX_"
433
+ },
434
+ "outputs": [],
435
+ "source": [
436
+ "# Method 1: pass encoder config dict into EncoderConfig\n",
437
+ "encoder_config = tfm.nlp.encoders.EncoderConfig(config_dict[\"task\"][\"model\"][\"encoder\"])\n",
438
+ "encoder_config.get().as_dict()"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": null,
444
+ "metadata": {
445
+ "id": "u7oJe93uWcy0"
446
+ },
447
+ "outputs": [],
448
+ "source": [
449
+ "# Method 2: use override_params_dict function to override default Encoder params\n",
450
+ "encoder_config = tfm.nlp.encoders.EncoderConfig()\n",
451
+ "tfm.hyperparams.override_params_dict(encoder_config, config_dict[\"task\"][\"model\"][\"encoder\"], is_strict=True)\n",
452
+ "encoder_config.get().as_dict()"
453
+ ]
454
+ },
455
+ {
456
+ "cell_type": "markdown",
457
+ "metadata": {
458
+ "id": "abpQFw80Wx6c"
459
+ },
460
+ "source": [
461
+ "### Construct ALBERT Model Using the Old `albert_config.json`"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "code",
466
+ "execution_count": null,
467
+ "metadata": {
468
+ "id": "Xb99qms6WuPa"
469
+ },
470
+ "outputs": [],
471
+ "source": [
472
+ "albert_config_file = os.path.join(folder_albert, \"albert_config.json\")\n",
473
+ "config_dict = json.loads(tf.io.gfile.GFile(albert_config_file).read())\n",
474
+ "config_dict"
475
+ ]
476
+ },
477
+ {
478
+ "cell_type": "code",
479
+ "execution_count": null,
480
+ "metadata": {
481
+ "id": "mCW0RJHcEtVV"
482
+ },
483
+ "outputs": [],
484
+ "source": [
485
+ "encoder_config = tfm.nlp.encoders.EncoderConfig({\n",
486
+ " 'type':'albert',\n",
487
+ " 'albert': config_dict\n",
488
+ "})\n",
489
+ "\n",
490
+ "encoder_config.get().as_dict()"
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "markdown",
495
+ "metadata": {
496
+ "id": "EIAMaOxdZw5u"
497
+ },
498
+ "source": [
499
+ "### Construct a Classifier with `encoder_config`\n",
500
+ "\n",
501
+ "Here, we construct a new BERT Classifier with 2 classes and plot its model architecture. A BERT Classifier consists of a BERT encoder using the selected encoder config, a Dropout layer and a MLP classification head."
502
+ ]
503
+ },
504
+ {
505
+ "cell_type": "code",
506
+ "execution_count": null,
507
+ "metadata": {
508
+ "id": "xTkUisEEFEey"
509
+ },
510
+ "outputs": [],
511
+ "source": [
512
+ "albert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)\n",
513
+ "albert_classifier = tfm.nlp.models.BertClassifier(network=albert_encoder, num_classes=2)\n",
514
+ "\n",
515
+ "tf.keras.utils.plot_model(albert_classifier)"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "markdown",
520
+ "metadata": {
521
+ "id": "m6EG_7CaZ2rI"
522
+ },
523
+ "source": [
524
+ "### Load Pretrained Weights into the Classifier\n",
525
+ "\n",
526
+ "The provided pretrained checkpoint only contains weights for the ALBERT Encoder within the ALBERT Classifier. Weights for the Classification Head is still randomly initialized."
527
+ ]
528
+ },
529
+ {
530
+ "cell_type": "code",
531
+ "execution_count": null,
532
+ "metadata": {
533
+ "id": "7dOG3agXZ9Dx"
534
+ },
535
+ "outputs": [],
536
+ "source": [
537
+ "checkpoint = tf.train.Checkpoint(encoder=albert_encoder)\n",
538
+ "checkpoint.read(\n",
539
+ " os.path.join(folder_albert, 'bert_model.ckpt')).expect_partial().assert_existing_objects_matched()"
540
+ ]
541
+ },
542
+ {
543
+ "cell_type": "markdown",
544
+ "metadata": {
545
+ "id": "6xsbeS-EcCqu"
546
+ },
547
+ "source": [
548
+ "## Load ELECTRA model pretrained checkpoints"
549
+ ]
550
+ },
551
+ {
552
+ "cell_type": "code",
553
+ "execution_count": null,
554
+ "metadata": {
555
+ "id": "VpwIrAR4cIBF"
556
+ },
557
+ "outputs": [],
558
+ "source": [
559
+ "# @title Download Checkpoint of the Selected Model { display-mode: \"form\", run: \"auto\" }\n",
560
+ "electra_model_display_name = 'ELECTRA-small English' # @param ['ELECTRA-small English', 'ELECTRA-base English']\n",
561
+ "\n",
562
+ "if electra_model_display_name == 'ELECTRA-small English':\n",
563
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/electra/small.tar.gz\"\n",
564
+ " !tar -xvf \"small.tar.gz\"\n",
565
+ "elif electra_model_display_name == 'ELECTRA-base English':\n",
566
+ " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/electra/base.tar.gz\"\n",
567
+ " !tar -xvf \"base.tar.gz\""
568
+ ]
569
+ },
570
+ {
571
+ "cell_type": "code",
572
+ "execution_count": null,
573
+ "metadata": {
574
+ "id": "fy4FmsNOhlNa"
575
+ },
576
+ "outputs": [],
577
+ "source": [
578
+ "# Lookup table of the directory name corresponding to each model checkpoint\n",
579
+ "folder_electra_dict = {\n",
580
+ " 'ELECTRA-small English': 'small',\n",
581
+ " 'ELECTRA-base English': 'base'\n",
582
+ "}\n",
583
+ "\n",
584
+ "folder_electra = folder_electra_dict.get(electra_model_display_name)\n",
585
+ "folder_electra"
586
+ ]
587
+ },
588
+ {
589
+ "cell_type": "markdown",
590
+ "metadata": {
591
+ "id": "rgAcf-Fl3RTG"
592
+ },
593
+ "source": [
594
+ "### Construct BERT Model Using the `params.yaml`\n",
595
+ "\n",
596
+ "params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here."
597
+ ]
598
+ },
599
+ {
600
+ "cell_type": "code",
601
+ "execution_count": null,
602
+ "metadata": {
603
+ "id": "ZNBg5xzqh0Gr"
604
+ },
605
+ "outputs": [],
606
+ "source": [
607
+ "config_file = os.path.join(folder_electra, \"params.yaml\")\n",
608
+ "config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())\n",
609
+ "config_dict"
610
+ ]
611
+ },
612
+ {
613
+ "cell_type": "code",
614
+ "execution_count": null,
615
+ "metadata": {
616
+ "id": "i-yX-KgJyduv"
617
+ },
618
+ "outputs": [],
619
+ "source": [
620
+ "disc_encoder_config = tfm.nlp.encoders.EncoderConfig(\n",
621
+ " config_dict['model']['discriminator_encoder']\n",
622
+ ")\n",
623
+ "\n",
624
+ "disc_encoder_config.get().as_dict()"
625
+ ]
626
+ },
627
+ {
628
+ "cell_type": "markdown",
629
+ "metadata": {
630
+ "id": "1AdrMkH73VYz"
631
+ },
632
+ "source": [
633
+ "### Construct a Classifier with `encoder_config`\n",
634
+ "\n",
635
+ "Here, we construct a Classifier with 2 classes and plot its model architecture. A Classifier consists of a ELECTRA discriminator encoder using the selected encoder config, a Dropout layer and a MLP classification head.\n",
636
+ "\n",
637
+ "**Note**: The generator is discarded and the discriminator is used for downstream tasks"
638
+ ]
639
+ },
640
+ {
641
+ "cell_type": "code",
642
+ "execution_count": null,
643
+ "metadata": {
644
+ "id": "98Pt-SxszAvN"
645
+ },
646
+ "outputs": [],
647
+ "source": [
648
+ "disc_encoder = tfm.nlp.encoders.build_encoder(disc_encoder_config)\n",
649
+ "elctra_dic_classifier = tfm.nlp.models.BertClassifier(network=disc_encoder, num_classes=2)\n",
650
+ "tf.keras.utils.plot_model(elctra_dic_classifier)"
651
+ ]
652
+ },
653
+ {
654
+ "cell_type": "markdown",
655
+ "metadata": {
656
+ "id": "aWQ2FKj64X5U"
657
+ },
658
+ "source": [
659
+ "### Load Pretrained Weights into the Classifier\n",
660
+ "\n",
661
+ "The provided pretrained checkpoint contains weights for the entire ELECTRA model. We are only loading its discriminator (conveninently named as `encoder`) wights within the Classifier. Weights for the Classification Head is still randomly initialized."
662
+ ]
663
+ },
664
+ {
665
+ "cell_type": "code",
666
+ "execution_count": null,
667
+ "metadata": {
668
+ "id": "99pznFJszQfV"
669
+ },
670
+ "outputs": [],
671
+ "source": [
672
+ "checkpoint = tf.train.Checkpoint(encoder=disc_encoder)\n",
673
+ "checkpoint.read(\n",
674
+ " tf.train.latest_checkpoint(os.path.join(folder_electra))\n",
675
+ " ).expect_partial().assert_existing_objects_matched()"
676
+ ]
677
+ }
678
+ ],
679
+ "metadata": {
680
+ "colab": {
681
+ "name": "load_lm_ckpts.ipynb",
682
+ "provenance": [],
683
+ "toc_visible": true
684
+ },
685
+ "kernelspec": {
686
+ "display_name": "Python 3",
687
+ "name": "python3"
688
+ }
689
+ },
690
+ "nbformat": 4,
691
+ "nbformat_minor": 0
692
+ }
models/docs/orbit/index.ipynb ADDED
@@ -0,0 +1,898 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "Tce3stUlHN0L"
7
+ },
8
+ "source": [
9
+ "##### Copyright 2020 The TensorFlow Authors."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "cellView": "form",
17
+ "id": "tuOe1ymfHZPu"
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
22
+ "# you may not use this file except in compliance with the License.\n",
23
+ "# You may obtain a copy of the License at\n",
24
+ "#\n",
25
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
26
+ "#\n",
27
+ "# Unless required by applicable law or agreed to in writing, software\n",
28
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
29
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
30
+ "# See the License for the specific language governing permissions and\n",
31
+ "# limitations under the License."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {
37
+ "id": "qFdPvlXBOdUN"
38
+ },
39
+ "source": [
40
+ "# Training with Orbit"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "metadata": {
46
+ "id": "MfBg1C5NB3X0"
47
+ },
48
+ "source": [
49
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
50
+ " \u003ctd\u003e\n",
51
+ " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/orbit\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
52
+ " \u003c/td\u003e\n",
53
+ " \u003ctd\u003e\n",
54
+ " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
55
+ " \u003c/td\u003e\n",
56
+ " \u003ctd\u003e\n",
57
+ " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
58
+ " \u003c/td\u003e\n",
59
+ " \u003ctd\u003e\n",
60
+ " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
61
+ " \u003c/td\u003e\n",
62
+ "\n",
63
+ "\u003c/table\u003e"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "metadata": {
69
+ "id": "456h0idS2Xcq"
70
+ },
71
+ "source": [
72
+ "This example will work through fine-tuning a BERT model using the [Orbit](https://www.tensorflow.org/api_docs/python/orbit) training library.\n",
73
+ "\n",
74
+ "Orbit is a flexible, lightweight library designed to make it easy to write [custom training loops](https://www.tensorflow.org/tutorials/distribute/custom_training) in TensorFlow. Orbit handles common model training tasks such as saving checkpoints, running model evaluations, and setting up summary writing, while giving users full control over implementing the inner training loop. It integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU).\n",
75
+ "\n",
76
+ "Most examples on [tensorflow.org](https://www.tensorflow.org/) use custom training loops or [model.fit()](https://www.tensorflow.org/api_docs/python/tf/keras/Model) from Keras. Orbit is a good alternative to `model.fit` if your model is complex and your training loop requires more flexibility, control, or customization. Also, using Orbit can simplify the code when there are many different model architectures that all use the same custom training loop.\n",
77
+ "\n",
78
+ "This tutorial focuses on setting up and using Orbit, rather than details about BERT, model construction, and data processing. For more in-depth tutorials on these topics, refer to the following tutorials:\n",
79
+ "\n",
80
+ "* [Fine tune BERT](https://www.tensorflow.org/text/tutorials/fine_tune_bert) - which goes into detail on these sub-topics.\n",
81
+ "* [Fine tune BERT for GLUE on TPU](https://www.tensorflow.org/text/tutorials/bert_glue) - which generalizes the code to run any BERT configuration on any [GLUE](https://www.tensorflow.org/datasets/catalog/glue) sub-task, and runs on TPU."
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "markdown",
86
+ "metadata": {
87
+ "id": "TJ4m3khW3p_W"
88
+ },
89
+ "source": [
90
+ "## Install the TensorFlow Models package\n",
91
+ "\n",
92
+ "Install and import the necessary packages, then configure all the objects necessary for training a model.\n",
93
+ "\n"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "metadata": {
100
+ "id": "FZlj0U8Aq9Gt"
101
+ },
102
+ "outputs": [],
103
+ "source": [
104
+ "!pip install -q opencv-python\n",
105
+ "!pip install tensorflow>=2.9.0 tf-models-official"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {
111
+ "id": "MEJkRrmapr16"
112
+ },
113
+ "source": [
114
+ "The `tf-models-official` package contains both the `orbit` and `tensorflow_models` modules."
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "metadata": {
121
+ "id": "dUVPW84Zucuq"
122
+ },
123
+ "outputs": [],
124
+ "source": [
125
+ "import tensorflow_models as tfm\n",
126
+ "import orbit"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "markdown",
131
+ "metadata": {
132
+ "id": "18Icocf3lwYD"
133
+ },
134
+ "source": [
135
+ "## Setup for training\n",
136
+ "\n",
137
+ "This tutorial does not focus on configuring the environment, building the model and optimizer, and loading data. All these techniques are covered in more detail in the [Fine tune BERT](https://www.tensorflow.org/text/tutorials/fine_tune_bert) and [Fine tune BERT with GLUE](https://www.tensorflow.org/text/tutorials/bert_glue) tutorials.\n",
138
+ "\n",
139
+ "To view how the training is set up for this tutorial, expand the rest of this section.\n",
140
+ "\n",
141
+ " \u003c!-- \u003cdiv class=\"tfo-display-only-on-site\"\u003e\u003cdevsite-expandable\u003e\n",
142
+ " \u003cbutton type=\"button\" class=\"button-red button expand-control\"\u003eExpand Section\u003c/button\u003e --\u003e"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "markdown",
147
+ "metadata": {
148
+ "id": "Ljy0z-i3okCS"
149
+ },
150
+ "source": [
151
+ "### Import the necessary packages\n",
152
+ "\n",
153
+ "Import the BERT model and dataset building library from [Tensorflow Model Garden](https://github.com/tensorflow/models)."
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "metadata": {
160
+ "id": "gCBo6wxA2b5n"
161
+ },
162
+ "outputs": [],
163
+ "source": [
164
+ "import glob\n",
165
+ "import os\n",
166
+ "import pathlib\n",
167
+ "import tempfile\n",
168
+ "import time\n",
169
+ "\n",
170
+ "import numpy as np\n",
171
+ "\n",
172
+ "import tensorflow as tf"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "metadata": {
179
+ "id": "PG1kwhnvq3VC"
180
+ },
181
+ "outputs": [],
182
+ "source": [
183
+ "from official.nlp.data import sentence_prediction_dataloader\n",
184
+ "from official.nlp import optimization"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "markdown",
189
+ "metadata": {
190
+ "id": "PsbhUV_p3wxN"
191
+ },
192
+ "source": [
193
+ "### Configure the distribution strategy\n",
194
+ "\n",
195
+ "While `tf.distribute` won't help the model's runtime if you're running on a single machine or GPU, it's necessary for TPUs. Setting up a distribution strategy allows you to use the same code regardless of the configuration."
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": null,
201
+ "metadata": {
202
+ "id": "PG702dqstXIk"
203
+ },
204
+ "outputs": [],
205
+ "source": [
206
+ "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
207
+ "\n",
208
+ "if 'GPU' in ''.join(logical_device_names):\n",
209
+ " strategy = tf.distribute.MirroredStrategy()\n",
210
+ "elif 'TPU' in ''.join(logical_device_names):\n",
211
+ " resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n",
212
+ " tf.config.experimental_connect_to_cluster(resolver)\n",
213
+ " tf.tpu.experimental.initialize_tpu_system(resolver)\n",
214
+ " strategy = tf.distribute.TPUStrategy(resolver)\n",
215
+ "else:\n",
216
+ " strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "markdown",
221
+ "metadata": {
222
+ "id": "eaQgM98deAMu"
223
+ },
224
+ "source": [
225
+ "For more information about the TPU setup, refer to the [TPU guide](https://www.tensorflow.org/guide/tpu)."
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "markdown",
230
+ "metadata": {
231
+ "id": "7aOxMLLV32Zm"
232
+ },
233
+ "source": [
234
+ "### Create a model and an optimizer"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": null,
240
+ "metadata": {
241
+ "id": "YRdWzOfK3_56"
242
+ },
243
+ "outputs": [],
244
+ "source": [
245
+ "max_seq_length = 128\n",
246
+ "learning_rate = 3e-5\n",
247
+ "num_train_epochs = 3\n",
248
+ "train_batch_size = 32\n",
249
+ "eval_batch_size = 64\n",
250
+ "\n",
251
+ "train_data_size = 3668\n",
252
+ "steps_per_epoch = int(train_data_size / train_batch_size)\n",
253
+ "\n",
254
+ "train_steps = steps_per_epoch * num_train_epochs\n",
255
+ "warmup_steps = int(train_steps * 0.1)\n",
256
+ "\n",
257
+ "print(\"train batch size: \", train_batch_size)\n",
258
+ "print(\"train epochs: \", num_train_epochs)\n",
259
+ "print(\"steps_per_epoch: \", steps_per_epoch)"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": null,
265
+ "metadata": {
266
+ "id": "BVw3886Ysse6"
267
+ },
268
+ "outputs": [],
269
+ "source": [
270
+ "model_dir = pathlib.Path(tempfile.mkdtemp())\n",
271
+ "print(model_dir)"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "markdown",
276
+ "metadata": {
277
+ "id": "mu9cV7ew-cVe"
278
+ },
279
+ "source": [
280
+ "\n",
281
+ "Create a BERT Classifier model and a simple optimizer. They must be created inside `strategy.scope` so that the variables can be distributed. "
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "code",
286
+ "execution_count": null,
287
+ "metadata": {
288
+ "id": "gmwtX0cp-mj5"
289
+ },
290
+ "outputs": [],
291
+ "source": [
292
+ "with strategy.scope():\n",
293
+ " encoder_network = tfm.nlp.encoders.build_encoder(\n",
294
+ " tfm.nlp.encoders.EncoderConfig(type=\"bert\"))\n",
295
+ " classifier_model = tfm.nlp.models.BertClassifier(\n",
296
+ " network=encoder_network, num_classes=2)\n",
297
+ "\n",
298
+ " optimizer = optimization.create_optimizer(\n",
299
+ " init_lr=3e-5,\n",
300
+ " num_train_steps=steps_per_epoch * num_train_epochs,\n",
301
+ " num_warmup_steps=warmup_steps,\n",
302
+ " end_lr=0.0,\n",
303
+ " optimizer_type='adamw')"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": null,
309
+ "metadata": {
310
+ "id": "jwJSfewG5jVV"
311
+ },
312
+ "outputs": [],
313
+ "source": [
314
+ "tf.keras.utils.plot_model(classifier_model)"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "markdown",
319
+ "metadata": {
320
+ "id": "IQy5pYgAf8Ft"
321
+ },
322
+ "source": [
323
+ "### Initialize from a Checkpoint"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "code",
328
+ "execution_count": null,
329
+ "metadata": {
330
+ "id": "6CE14GEybgRR"
331
+ },
332
+ "outputs": [],
333
+ "source": [
334
+ "bert_dir = 'gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12/'\n",
335
+ "tf.io.gfile.listdir(bert_dir)"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "execution_count": null,
341
+ "metadata": {
342
+ "id": "x7fwxz9xidKt"
343
+ },
344
+ "outputs": [],
345
+ "source": [
346
+ "bert_checkpoint = bert_dir + 'bert_model.ckpt'"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": null,
352
+ "metadata": {
353
+ "id": "q7EfwVCRe7N_"
354
+ },
355
+ "outputs": [],
356
+ "source": [
357
+ "def init_from_ckpt_fn():\n",
358
+ " init_checkpoint = tf.train.Checkpoint(**classifier_model.checkpoint_items)\n",
359
+ " with strategy.scope():\n",
360
+ " (init_checkpoint\n",
361
+ " .read(bert_checkpoint)\n",
362
+ " .expect_partial()\n",
363
+ " .assert_existing_objects_matched())"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": null,
369
+ "metadata": {
370
+ "id": "M0LUMlsde-2f"
371
+ },
372
+ "outputs": [],
373
+ "source": [
374
+ "with strategy.scope():\n",
375
+ " init_from_ckpt_fn()"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "markdown",
380
+ "metadata": {
381
+ "id": "gAuns4vN_IYV"
382
+ },
383
+ "source": [
384
+ "\n",
385
+ "To use Orbit, create a `tf.train.CheckpointManager` object."
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": null,
391
+ "metadata": {
392
+ "id": "i7NwM1Jq_MX7"
393
+ },
394
+ "outputs": [],
395
+ "source": [
396
+ "checkpoint = tf.train.Checkpoint(model=classifier_model, optimizer=optimizer)\n",
397
+ "checkpoint_manager = tf.train.CheckpointManager(\n",
398
+ " checkpoint,\n",
399
+ " directory=model_dir,\n",
400
+ " max_to_keep=5,\n",
401
+ " step_counter=optimizer.iterations,\n",
402
+ " checkpoint_interval=steps_per_epoch,\n",
403
+ " init_fn=init_from_ckpt_fn)"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "markdown",
408
+ "metadata": {
409
+ "id": "nzeiAFhcCOAo"
410
+ },
411
+ "source": [
412
+ "### Create distributed datasets\n",
413
+ "\n",
414
+ "As a shortcut for this tutorial, the [GLUE/MPRC dataset](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc) has been converted to a pair of [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) files containing serialized `tf.train.Example` protos.\n",
415
+ "\n",
416
+ "The data was converted using [this script](https://github.com/tensorflow/models/blob/r2.9.0/official/nlp/data/create_finetuning_data.py).\n",
417
+ "\n"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": null,
423
+ "metadata": {
424
+ "id": "ZVfbiT1dCnDk"
425
+ },
426
+ "outputs": [],
427
+ "source": [
428
+ "train_data_path = \"gs://download.tensorflow.org/data/model_garden_colab/mrpc_train.tf_record\"\n",
429
+ "eval_data_path = \"gs://download.tensorflow.org/data/model_garden_colab/mrpc_eval.tf_record\"\n",
430
+ "\n",
431
+ "def _dataset_fn(input_file_pattern, \n",
432
+ " global_batch_size, \n",
433
+ " is_training, \n",
434
+ " input_context=None):\n",
435
+ " data_config = sentence_prediction_dataloader.SentencePredictionDataConfig(\n",
436
+ " input_path=input_file_pattern,\n",
437
+ " seq_length=max_seq_length,\n",
438
+ " global_batch_size=global_batch_size,\n",
439
+ " is_training=is_training)\n",
440
+ " return sentence_prediction_dataloader.SentencePredictionDataLoader(\n",
441
+ " data_config).load(input_context=input_context)\n",
442
+ "\n",
443
+ "train_dataset = orbit.utils.make_distributed_dataset(\n",
444
+ " strategy, _dataset_fn, input_file_pattern=train_data_path,\n",
445
+ " global_batch_size=train_batch_size, is_training=True)\n",
446
+ "eval_dataset = orbit.utils.make_distributed_dataset(\n",
447
+ " strategy, _dataset_fn, input_file_pattern=eval_data_path,\n",
448
+ " global_batch_size=eval_batch_size, is_training=False)\n",
449
+ "\n"
450
+ ]
451
+ },
452
+ {
453
+ "cell_type": "markdown",
454
+ "metadata": {
455
+ "id": "dPgiDBQCjsXW"
456
+ },
457
+ "source": [
458
+ "### Create a loss function\n"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": null,
464
+ "metadata": {
465
+ "id": "7MCUmmo2jvXl"
466
+ },
467
+ "outputs": [],
468
+ "source": [
469
+ "def loss_fn(labels, logits):\n",
470
+ " \"\"\"Classification loss.\"\"\"\n",
471
+ " labels = tf.squeeze(labels)\n",
472
+ " log_probs = tf.nn.log_softmax(logits, axis=-1)\n",
473
+ " one_hot_labels = tf.one_hot(\n",
474
+ " tf.cast(labels, dtype=tf.int32), depth=2, dtype=tf.float32)\n",
475
+ " per_example_loss = -tf.reduce_sum(\n",
476
+ " tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)\n",
477
+ " return tf.reduce_mean(per_example_loss)"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "markdown",
482
+ "metadata": {
483
+ "id": "ohlO-8FQkwsr"
484
+ },
485
+ "source": [
486
+ " \u003c/devsite-expandable\u003e\u003c/div\u003e"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "markdown",
491
+ "metadata": {
492
+ "id": "ymhbvPaEJ96T"
493
+ },
494
+ "source": [
495
+ "## Controllers, Trainers and Evaluators\n",
496
+ "\n",
497
+ "When using Orbit, the `orbit.Controller` class drives the training. The Controller handles the details of distribution strategies, step counting, TensorBoard summaries, and checkpointing.\n",
498
+ "\n",
499
+ "To implement the training and evaluation, pass a `trainer` and `evaluator`, which are subclass instances of `orbit.AbstractTrainer` and `orbit.AbstractEvaluator`. Keeping with Orbit's light-weight design, these two classes have a minimal interface.\n",
500
+ "\n",
501
+ "The Controller drives training and evaluation by calling `trainer.train(num_steps)` and `evaluator.evaluate(num_steps)`. These `train` and `evaluate` methods return a dictionary of results for logging.\n",
502
+ "\n"
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "markdown",
507
+ "metadata": {
508
+ "id": "a6sU2vBeyXtu"
509
+ },
510
+ "source": [
511
+ "Training is broken into chunks of length `num_steps`. This is set by the Controller's [`steps_per_loop`](https://tensorflow.org/api_docs/python/orbit/Controller#args) argument. With the trainer and evaluator abstract base classes, the meaning of `num_steps` is entirely determined by the implementer.\n",
512
+ "\n",
513
+ "Some common examples include:\n",
514
+ "\n",
515
+ "* Having the chunks represent dataset-epoch boundaries, like the default keras setup. \n",
516
+ "* Using it to more efficiently dispatch a number of training steps to an accelerator with a single `tf.function` call (like the `steps_per_execution` argument to `Model.compile`). \n",
517
+ "* Subdividing into smaller chunks as needed.\n"
518
+ ]
519
+ },
520
+ {
521
+ "cell_type": "markdown",
522
+ "metadata": {
523
+ "id": "p4mXGIRJsf1j"
524
+ },
525
+ "source": [
526
+ "### StandardTrainer and StandardEvaluator\n",
527
+ "\n",
528
+ "Orbit provides two additional classes, `orbit.StandardTrainer` and `orbit.StandardEvaluator`, to give more structure around the training and evaluation loops.\n",
529
+ "\n",
530
+ "With StandardTrainer, you only need to set `train_loop_begin`, `train_step`, and `train_loop_end`. The base class handles the loops, dataset logic, and `tf.function` (according to the options set by their `orbit.StandardTrainerOptions`). This is simpler than `orbit.AbstractTrainer`, which requires you to handle the entire loop. StandardEvaluator has a similar structure and simplification to StandardTrainer.\n",
531
+ "\n",
532
+ "This is effectively an implementation of the `steps_per_execution` approach used by Keras."
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "markdown",
537
+ "metadata": {
538
+ "id": "-hvZ8PvohmR5"
539
+ },
540
+ "source": [
541
+ "Contrast this with Keras, where training is divided both into epochs (a single pass over the dataset) and `steps_per_execution`(set within [`Model.compile`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile). In Keras, metric averages are typically accumulated over an epoch, and reported \u0026 reset between epochs. For efficiency, `steps_per_execution` only controls the number of training steps made per call.\n",
542
+ "\n",
543
+ "In this simple case, `steps_per_loop` (within `StandardTrainer`) will handle both the metric resets and the number of steps per call. \n"
544
+ ]
545
+ },
546
+ {
547
+ "cell_type": "markdown",
548
+ "metadata": {
549
+ "id": "NoDFN1L-1jIu"
550
+ },
551
+ "source": [
552
+ "The minimal setup when using these base classes is to implement the methods as follows:\n",
553
+ "\n",
554
+ "1. `StandardTrainer.train_loop_begin` - Reset your training metrics.\n",
555
+ "2. `StandardTrainer.train_step` - Apply a single gradient update.\n",
556
+ "3. `StandardTrainer.train_loop_end` - Report your training metrics.\n",
557
+ "\n",
558
+ "and\n",
559
+ "\n",
560
+ "4. `StandardEvaluator.eval_begin` - Reset your evaluation metrics.\n",
561
+ "5. `StandardEvaluator.eval_step` - Run a single evaluation setep.\n",
562
+ "6. `StandardEvaluator.eval_reduce` - This is not necessary in this simple setup.\n",
563
+ "7. `StandardEvaluator.eval_end` - Report your evaluation metrics.\n",
564
+ "\n",
565
+ "Depending on the settings, the base class may wrap the `train_step` and `eval_step` code in `tf.function` or `tf.while_loop`, which has some limitations compared to standard python."
566
+ ]
567
+ },
568
+ {
569
+ "cell_type": "markdown",
570
+ "metadata": {
571
+ "id": "3KPA0NDZt2JD"
572
+ },
573
+ "source": [
574
+ "### Define the trainer class"
575
+ ]
576
+ },
577
+ {
578
+ "cell_type": "markdown",
579
+ "metadata": {
580
+ "id": "6LDPsvJwfuPR"
581
+ },
582
+ "source": [
583
+ "In this section you'll create a subclass of `orbit.StandardTrainer` for this task. \n",
584
+ "\n",
585
+ "Note: To better explain the `BertClassifierTrainer` class, this section defines each method as a stand-alone function and assembles them into a class at the end.\n",
586
+ "\n",
587
+ "The trainer needs access to the training data, model, optimizer, and distribution strategy. Pass these as arguments to the initializer.\n",
588
+ "\n",
589
+ "Define a single training metric, `training_loss`, using `tf.keras.metrics.Mean`. "
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": null,
595
+ "metadata": {
596
+ "id": "6DQYZN5ax-MG"
597
+ },
598
+ "outputs": [],
599
+ "source": [
600
+ "def trainer_init(self,\n",
601
+ " train_dataset,\n",
602
+ " model,\n",
603
+ " optimizer,\n",
604
+ " strategy):\n",
605
+ " self.strategy = strategy\n",
606
+ " with self.strategy.scope():\n",
607
+ " self.model = model\n",
608
+ " self.optimizer = optimizer\n",
609
+ " self.global_step = self.optimizer.iterations\n",
610
+ " \n",
611
+ "\n",
612
+ " self.train_loss = tf.keras.metrics.Mean(\n",
613
+ " 'training_loss', dtype=tf.float32)\n",
614
+ " orbit.StandardTrainer.__init__(self, train_dataset)\n"
615
+ ]
616
+ },
617
+ {
618
+ "cell_type": "markdown",
619
+ "metadata": {
620
+ "id": "QOwHD7U5hVue"
621
+ },
622
+ "source": [
623
+ "Before starting a run of the training loop, the `train_loop_begin` method will reset the `train_loss` metric."
624
+ ]
625
+ },
626
+ {
627
+ "cell_type": "code",
628
+ "execution_count": null,
629
+ "metadata": {
630
+ "id": "AkpcHqXShWL0"
631
+ },
632
+ "outputs": [],
633
+ "source": [
634
+ "def train_loop_begin(self):\n",
635
+ " self.train_loss.reset_states()"
636
+ ]
637
+ },
638
+ {
639
+ "cell_type": "markdown",
640
+ "metadata": {
641
+ "id": "UjtFOFyxn2BB"
642
+ },
643
+ "source": [
644
+ "The `train_step` is a straight-forward loss-calculation and gradient update that is run by the distribution strategy. This is accomplished by defining the gradient step as a nested function (`step_fn`).\n",
645
+ "\n",
646
+ "The method receives `tf.distribute.DistributedIterator` to handle the [distributed input](https://www.tensorflow.org/tutorials/distribute/input). The method uses `Strategy.run` to execute `step_fn` and feeds it from the distributed iterator.\n",
647
+ "\n"
648
+ ]
649
+ },
650
+ {
651
+ "cell_type": "code",
652
+ "execution_count": null,
653
+ "metadata": {
654
+ "id": "QuPwNnT5I-GP"
655
+ },
656
+ "outputs": [],
657
+ "source": [
658
+ "def train_step(self, iterator):\n",
659
+ "\n",
660
+ " def step_fn(inputs):\n",
661
+ " labels = inputs.pop(\"label_ids\")\n",
662
+ " with tf.GradientTape() as tape:\n",
663
+ " model_outputs = self.model(inputs, training=True)\n",
664
+ " # Raw loss is used for reporting in metrics/logs.\n",
665
+ " raw_loss = loss_fn(labels, model_outputs)\n",
666
+ " # Scales down the loss for gradients to be invariant from replicas.\n",
667
+ " loss = raw_loss / self.strategy.num_replicas_in_sync\n",
668
+ "\n",
669
+ " grads = tape.gradient(loss, self.model.trainable_variables)\n",
670
+ " optimizer.apply_gradients(zip(grads, self.model.trainable_variables))\n",
671
+ " # For reporting, the metric takes the mean of losses.\n",
672
+ " self.train_loss.update_state(raw_loss)\n",
673
+ "\n",
674
+ " self.strategy.run(step_fn, args=(next(iterator),))"
675
+ ]
676
+ },
677
+ {
678
+ "cell_type": "markdown",
679
+ "metadata": {
680
+ "id": "VmQNwx5QpyDt"
681
+ },
682
+ "source": [
683
+ "The `orbit.StandardTrainer` handles the `@tf.function` and loops.\n",
684
+ "\n",
685
+ "After running through `num_steps` of training, `StandardTrainer` calls `train_loop_end`. The function returns the metric results:"
686
+ ]
687
+ },
688
+ {
689
+ "cell_type": "code",
690
+ "execution_count": null,
691
+ "metadata": {
692
+ "id": "GqCyVk1zzGod"
693
+ },
694
+ "outputs": [],
695
+ "source": [
696
+ "def train_loop_end(self):\n",
697
+ " return {\n",
698
+ " self.train_loss.name: self.train_loss.result(),\n",
699
+ " }"
700
+ ]
701
+ },
702
+ {
703
+ "cell_type": "markdown",
704
+ "metadata": {
705
+ "id": "xvmLONl80KUv"
706
+ },
707
+ "source": [
708
+ "Build a subclass of `orbit.StandardTrainer` with those methods."
709
+ ]
710
+ },
711
+ {
712
+ "cell_type": "code",
713
+ "execution_count": null,
714
+ "metadata": {
715
+ "id": "oRoL7VE6xt1G"
716
+ },
717
+ "outputs": [],
718
+ "source": [
719
+ "class BertClassifierTrainer(orbit.StandardTrainer):\n",
720
+ " __init__ = trainer_init\n",
721
+ " train_loop_begin = train_loop_begin\n",
722
+ " train_step = train_step\n",
723
+ " train_loop_end = train_loop_end"
724
+ ]
725
+ },
726
+ {
727
+ "cell_type": "markdown",
728
+ "metadata": {
729
+ "id": "yjG4QAWj1B00"
730
+ },
731
+ "source": [
732
+ "### Define the evaluator class\n",
733
+ "\n",
734
+ "Note: Like the previous section, this section defines each method as a stand-alone function and assembles them into a `BertClassifierEvaluator` class at the end.\n",
735
+ "\n",
736
+ "The evaluator is even simpler for this task. It needs access to the evaluation dataset, the model, and the strategy. After saving references to those objects, the constructor just needs to create the metrics."
737
+ ]
738
+ },
739
+ {
740
+ "cell_type": "code",
741
+ "execution_count": null,
742
+ "metadata": {
743
+ "id": "cvX7seCY1CWj"
744
+ },
745
+ "outputs": [],
746
+ "source": [
747
+ "def evaluator_init(self,\n",
748
+ " eval_dataset,\n",
749
+ " model,\n",
750
+ " strategy):\n",
751
+ " self.strategy = strategy\n",
752
+ " with self.strategy.scope():\n",
753
+ " self.model = model\n",
754
+ " \n",
755
+ " self.eval_loss = tf.keras.metrics.Mean(\n",
756
+ " 'evaluation_loss', dtype=tf.float32)\n",
757
+ " self.eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
758
+ " name='accuracy', dtype=tf.float32)\n",
759
+ " orbit.StandardEvaluator.__init__(self, eval_dataset)"
760
+ ]
761
+ },
762
+ {
763
+ "cell_type": "markdown",
764
+ "metadata": {
765
+ "id": "0r-z-XK7ybyX"
766
+ },
767
+ "source": [
768
+ "Similar to the trainer, the `eval_begin` and `eval_end` methods just need to reset the metrics before the loop and then report the results after the loop."
769
+ ]
770
+ },
771
+ {
772
+ "cell_type": "code",
773
+ "execution_count": null,
774
+ "metadata": {
775
+ "id": "7VVb0Tg6yZjI"
776
+ },
777
+ "outputs": [],
778
+ "source": [
779
+ "def eval_begin(self):\n",
780
+ " self.eval_accuracy.reset_states()\n",
781
+ " self.eval_loss.reset_states()\n",
782
+ "\n",
783
+ "def eval_end(self):\n",
784
+ " return {\n",
785
+ " self.eval_accuracy.name: self.eval_accuracy.result(),\n",
786
+ " self.eval_loss.name: self.eval_loss.result(),\n",
787
+ " }"
788
+ ]
789
+ },
790
+ {
791
+ "cell_type": "markdown",
792
+ "metadata": {
793
+ "id": "iDOZcQvttdmZ"
794
+ },
795
+ "source": [
796
+ "The `eval_step` method works like `train_step`. The inner `step_fn` defines the actual work of calculating the loss \u0026 accuracy and updating the metrics. The outer `eval_step` receives `tf.distribute.DistributedIterator` as input, and uses `Strategy.run` to launch the distributed execution to `step_fn`, feeding it from the distributed iterator."
797
+ ]
798
+ },
799
+ {
800
+ "cell_type": "code",
801
+ "execution_count": null,
802
+ "metadata": {
803
+ "id": "JLJnYuuGJjvd"
804
+ },
805
+ "outputs": [],
806
+ "source": [
807
+ "def eval_step(self, iterator):\n",
808
+ "\n",
809
+ " def step_fn(inputs):\n",
810
+ " labels = inputs.pop(\"label_ids\")\n",
811
+ " model_outputs = self.model(inputs, training=True)\n",
812
+ " loss = loss_fn(labels, model_outputs)\n",
813
+ " self.eval_loss.update_state(loss)\n",
814
+ " self.eval_accuracy.update_state(labels, model_outputs)\n",
815
+ "\n",
816
+ " self.strategy.run(step_fn, args=(next(iterator),))"
817
+ ]
818
+ },
819
+ {
820
+ "cell_type": "markdown",
821
+ "metadata": {
822
+ "id": "Gt3hh0V30QcP"
823
+ },
824
+ "source": [
825
+ "Build a subclass of `orbit.StandardEvaluator` with those methods."
826
+ ]
827
+ },
828
+ {
829
+ "cell_type": "code",
830
+ "execution_count": null,
831
+ "metadata": {
832
+ "id": "3zqyLxfNyCgA"
833
+ },
834
+ "outputs": [],
835
+ "source": [
836
+ "class BertClassifierEvaluator(orbit.StandardEvaluator):\n",
837
+ " __init__ = evaluator_init\n",
838
+ " eval_begin = eval_begin\n",
839
+ " eval_end = eval_end\n",
840
+ " eval_step = eval_step"
841
+ ]
842
+ },
843
+ {
844
+ "cell_type": "markdown",
845
+ "metadata": {
846
+ "id": "aK9gEja9qPOc"
847
+ },
848
+ "source": [
849
+ "### End-to-end training and evaluation\n",
850
+ "\n",
851
+ "To run the training and evaluation, simply create the trainer, evaluator, and `orbit.Controller` instances. Then call the `Controller.train_and_evaluate` method."
852
+ ]
853
+ },
854
+ {
855
+ "cell_type": "code",
856
+ "execution_count": null,
857
+ "metadata": {
858
+ "id": "PqQetxyXqRA9"
859
+ },
860
+ "outputs": [],
861
+ "source": [
862
+ "trainer = BertClassifierTrainer(\n",
863
+ " train_dataset, classifier_model, optimizer, strategy)\n",
864
+ "\n",
865
+ "evaluator = BertClassifierEvaluator(\n",
866
+ " eval_dataset, classifier_model, strategy)\n",
867
+ "\n",
868
+ "controller = orbit.Controller(\n",
869
+ " trainer=trainer,\n",
870
+ " evaluator=evaluator,\n",
871
+ " global_step=trainer.global_step,\n",
872
+ " steps_per_loop=20,\n",
873
+ " checkpoint_manager=checkpoint_manager)\n",
874
+ "\n",
875
+ "result = controller.train_and_evaluate(\n",
876
+ " train_steps=steps_per_epoch * num_train_epochs,\n",
877
+ " eval_steps=-1,\n",
878
+ " eval_interval=steps_per_epoch)"
879
+ ]
880
+ }
881
+ ],
882
+ "metadata": {
883
+ "colab": {
884
+ "collapsed_sections": [
885
+ "Tce3stUlHN0L"
886
+ ],
887
+ "name": "Orbit Tutorial.ipynb",
888
+ "provenance": [],
889
+ "toc_visible": true
890
+ },
891
+ "kernelspec": {
892
+ "display_name": "Python 3",
893
+ "name": "python3"
894
+ }
895
+ },
896
+ "nbformat": 4,
897
+ "nbformat_minor": 0
898
+ }
models/docs/vision/_toc.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ toc:
2
+ - title: "Example: Image classification"
3
+ path: /tfmodels/vision/image_classification
4
+ - title: "Example: Object Detection"
5
+ path: /tfmodels/vision/object_detection
6
+ - title: "Example: Semantic Segmentation"
7
+ path: /tfmodels/vision/semantic_segmentation
8
+ - title: "Example: Instance Segmentation"
9
+ path: /tfmodels/vision/instance_segmentation
models/docs/vision/image_classification.ipynb ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "Tce3stUlHN0L"
7
+ },
8
+ "source": [
9
+ "##### Copyright 2020 The TensorFlow Authors."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "cellView": "form",
17
+ "id": "tuOe1ymfHZPu"
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
22
+ "# you may not use this file except in compliance with the License.\n",
23
+ "# You may obtain a copy of the License at\n",
24
+ "#\n",
25
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
26
+ "#\n",
27
+ "# Unless required by applicable law or agreed to in writing, software\n",
28
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
29
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
30
+ "# See the License for the specific language governing permissions and\n",
31
+ "# limitations under the License."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {
37
+ "id": "qFdPvlXBOdUN"
38
+ },
39
+ "source": [
40
+ "# Image classification with Model Garden"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "metadata": {
46
+ "id": "MfBg1C5NB3X0"
47
+ },
48
+ "source": [
49
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
50
+ " \u003ctd\u003e\n",
51
+ " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/image_classification\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
52
+ " \u003c/td\u003e\n",
53
+ " \u003ctd\u003e\n",
54
+ " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
55
+ " \u003c/td\u003e\n",
56
+ " \u003ctd\u003e\n",
57
+ " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
58
+ " \u003c/td\u003e\n",
59
+ " \u003ctd\u003e\n",
60
+ " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
61
+ " \u003c/td\u003e\n",
62
+ "\u003c/table\u003e"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "metadata": {
68
+ "id": "Ta_nFXaVAqLD"
69
+ },
70
+ "source": [
71
+ "This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow [Model Garden](https://github.com/tensorflow/models) package (`tensorflow-models`) to classify images in the [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.\n",
72
+ "\n",
73
+ "Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
74
+ "\n",
75
+ "This tutorial uses a [ResNet](https://arxiv.org/pdf/1512.03385.pdf) model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.\n",
76
+ "\n",
77
+ "This tutorial demonstrates how to:\n",
78
+ "1. Use models from the TensorFlow Models package.\n",
79
+ "2. Fine-tune a pre-built ResNet for image classification.\n",
80
+ "3. Export the tuned ResNet model."
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "markdown",
85
+ "metadata": {
86
+ "id": "G2FlaQcEPOER"
87
+ },
88
+ "source": [
89
+ "## Setup\n",
90
+ "\n",
91
+ "Install and import the necessary modules."
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "metadata": {
98
+ "id": "XvWfdCrvrV5W"
99
+ },
100
+ "outputs": [],
101
+ "source": [
102
+ "!pip install -U -q \"tf-models-official\""
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "markdown",
107
+ "metadata": {
108
+ "id": "CKYMTPjOE400"
109
+ },
110
+ "source": [
111
+ "Import TensorFlow, TensorFlow Datasets, and a few helper libraries."
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {
118
+ "id": "Wlon1uoIowmZ"
119
+ },
120
+ "outputs": [],
121
+ "source": [
122
+ "import pprint\n",
123
+ "import tempfile\n",
124
+ "\n",
125
+ "from IPython import display\n",
126
+ "import matplotlib.pyplot as plt\n",
127
+ "\n",
128
+ "import tensorflow as tf\n",
129
+ "import tensorflow_datasets as tfds"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "markdown",
134
+ "metadata": {
135
+ "id": "AVTs0jDd1b24"
136
+ },
137
+ "source": [
138
+ "The `tensorflow_models` package contains the ResNet vision model, and the `official.vision.serving` model contains the function to save and export the tuned model."
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {
145
+ "id": "NHT1iiIiBzlC"
146
+ },
147
+ "outputs": [],
148
+ "source": [
149
+ "import tensorflow_models as tfm\n",
150
+ "\n",
151
+ "# These are not in the tfm public API for v2.9. They will be available in v2.10\n",
152
+ "from official.vision.serving import export_saved_model_lib\n",
153
+ "import official.core.train_lib"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "markdown",
158
+ "metadata": {
159
+ "id": "aKv3wdqkQ8FU"
160
+ },
161
+ "source": [
162
+ "## Configure the ResNet-18 model for the Cifar-10 dataset"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "markdown",
167
+ "metadata": {
168
+ "id": "5iN8mHEJjKYE"
169
+ },
170
+ "source": [
171
+ "The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.\n",
172
+ "\n",
173
+ "In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
174
+ "\n",
175
+ "Use the `resnet_imagenet` factory configuration, as defined by `tfm.vision.configs.image_classification.image_classification_imagenet`. The configuration is set up to train ResNet to converge on [ImageNet](https://www.image-net.org/)."
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "metadata": {
182
+ "id": "1M77f88Dj2Td"
183
+ },
184
+ "outputs": [],
185
+ "source": [
186
+ "exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n",
187
+ "tfds_name = 'cifar10'\n",
188
+ "ds,ds_info = tfds.load(\n",
189
+ "tfds_name,\n",
190
+ "with_info=True)\n",
191
+ "ds_info"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "markdown",
196
+ "metadata": {
197
+ "id": "U6PVwXA-j3E7"
198
+ },
199
+ "source": [
200
+ "Adjust the model and dataset configurations so that it works with Cifar-10 (`cifar10`)."
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "metadata": {
207
+ "id": "YWI7faVStQaV"
208
+ },
209
+ "outputs": [],
210
+ "source": [
211
+ "# Configure model\n",
212
+ "exp_config.task.model.num_classes = 10\n",
213
+ "exp_config.task.model.input_size = list(ds_info.features[\"image\"].shape)\n",
214
+ "exp_config.task.model.backbone.resnet.model_id = 18\n",
215
+ "\n",
216
+ "# Configure training and testing data\n",
217
+ "batch_size = 128\n",
218
+ "\n",
219
+ "exp_config.task.train_data.input_path = ''\n",
220
+ "exp_config.task.train_data.tfds_name = tfds_name\n",
221
+ "exp_config.task.train_data.tfds_split = 'train'\n",
222
+ "exp_config.task.train_data.global_batch_size = batch_size\n",
223
+ "\n",
224
+ "exp_config.task.validation_data.input_path = ''\n",
225
+ "exp_config.task.validation_data.tfds_name = tfds_name\n",
226
+ "exp_config.task.validation_data.tfds_split = 'test'\n",
227
+ "exp_config.task.validation_data.global_batch_size = batch_size\n"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "markdown",
232
+ "metadata": {
233
+ "id": "DE3ggKzzTD56"
234
+ },
235
+ "source": [
236
+ "Adjust the trainer configuration."
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "metadata": {
243
+ "id": "inE_-4UGkLud"
244
+ },
245
+ "outputs": [],
246
+ "source": [
247
+ "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
248
+ "\n",
249
+ "if 'GPU' in ''.join(logical_device_names):\n",
250
+ " print('This may be broken in Colab.')\n",
251
+ " device = 'GPU'\n",
252
+ "elif 'TPU' in ''.join(logical_device_names):\n",
253
+ " print('This may be broken in Colab.')\n",
254
+ " device = 'TPU'\n",
255
+ "else:\n",
256
+ " print('Running on CPU is slow, so only train for a few steps.')\n",
257
+ " device = 'CPU'\n",
258
+ "\n",
259
+ "if device=='CPU':\n",
260
+ " train_steps = 20\n",
261
+ " exp_config.trainer.steps_per_loop = 5\n",
262
+ "else:\n",
263
+ " train_steps=5000\n",
264
+ " exp_config.trainer.steps_per_loop = 100\n",
265
+ "\n",
266
+ "exp_config.trainer.summary_interval = 100\n",
267
+ "exp_config.trainer.checkpoint_interval = train_steps\n",
268
+ "exp_config.trainer.validation_interval = 1000\n",
269
+ "exp_config.trainer.validation_steps = ds_info.splits['test'].num_examples // batch_size\n",
270
+ "exp_config.trainer.train_steps = train_steps\n",
271
+ "exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
272
+ "exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
273
+ "exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
274
+ "exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "markdown",
279
+ "metadata": {
280
+ "id": "5mTcDnBiTOYD"
281
+ },
282
+ "source": [
283
+ "Print the modified configuration."
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": null,
289
+ "metadata": {
290
+ "id": "tuVfxSBCTK-y"
291
+ },
292
+ "outputs": [],
293
+ "source": [
294
+ "pprint.pprint(exp_config.as_dict())\n",
295
+ "\n",
296
+ "display.Javascript(\"google.colab.output.setIframeHeight('300px');\")"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "markdown",
301
+ "metadata": {
302
+ "id": "w7_X0UHaRF2m"
303
+ },
304
+ "source": [
305
+ "Set up the distribution strategy."
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": null,
311
+ "metadata": {
312
+ "id": "ykL14FIbTaSt"
313
+ },
314
+ "outputs": [],
315
+ "source": [
316
+ "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
317
+ "\n",
318
+ "if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
319
+ " tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
320
+ "\n",
321
+ "if 'GPU' in ''.join(logical_device_names):\n",
322
+ " distribution_strategy = tf.distribute.MirroredStrategy()\n",
323
+ "elif 'TPU' in ''.join(logical_device_names):\n",
324
+ " tf.tpu.experimental.initialize_tpu_system()\n",
325
+ " tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
326
+ " distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
327
+ "else:\n",
328
+ " print('Warning: this will be really slow.')\n",
329
+ " distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "markdown",
334
+ "metadata": {
335
+ "id": "W4k5YH5pTjaK"
336
+ },
337
+ "source": [
338
+ "Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
339
+ "\n",
340
+ "The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": null,
346
+ "metadata": {
347
+ "id": "6MgYSH0PtUaW"
348
+ },
349
+ "outputs": [],
350
+ "source": [
351
+ "with distribution_strategy.scope():\n",
352
+ " model_dir = tempfile.mkdtemp()\n",
353
+ " task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)\n",
354
+ "\n",
355
+ "# tf.keras.utils.plot_model(task.build_model(), show_shapes=True)"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "metadata": {
362
+ "id": "IFXEZYdzBKoX"
363
+ },
364
+ "outputs": [],
365
+ "source": [
366
+ "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
367
+ " print()\n",
368
+ " print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
369
+ " print(f'labels.shape: {str(labels.shape):16} labels.dtype: {labels.dtype!r}')"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "markdown",
374
+ "metadata": {
375
+ "id": "yrwxnGDaRU0U"
376
+ },
377
+ "source": [
378
+ "## Visualize the training data"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "markdown",
383
+ "metadata": {
384
+ "id": "683c255c6c52"
385
+ },
386
+ "source": [
387
+ "The dataloader applies a z-score normalization using \n",
388
+ "`preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`, so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range."
389
+ ]
390
+ },
391
+ {
392
+ "cell_type": "code",
393
+ "execution_count": null,
394
+ "metadata": {
395
+ "id": "PdmOz2EC0Nx2"
396
+ },
397
+ "outputs": [],
398
+ "source": [
399
+ "plt.hist(images.numpy().flatten());"
400
+ ]
401
+ },
402
+ {
403
+ "cell_type": "markdown",
404
+ "metadata": {
405
+ "id": "7a8582ebde7b"
406
+ },
407
+ "source": [
408
+ "Use `ds_info` (which is an instance of `tfds.core.DatasetInfo`) to lookup the text descriptions of each class ID."
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": null,
414
+ "metadata": {
415
+ "id": "Wq4Wq_CuDG3Q"
416
+ },
417
+ "outputs": [],
418
+ "source": [
419
+ "label_info = ds_info.features['label']\n",
420
+ "label_info.int2str(1)"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "markdown",
425
+ "metadata": {
426
+ "id": "8c652a6fdbcf"
427
+ },
428
+ "source": [
429
+ "Visualize a batch of the data."
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": null,
435
+ "metadata": {
436
+ "id": "ZKfTxytf1l0d"
437
+ },
438
+ "outputs": [],
439
+ "source": [
440
+ "def show_batch(images, labels, predictions=None):\n",
441
+ " plt.figure(figsize=(10, 10))\n",
442
+ " min = images.numpy().min()\n",
443
+ " max = images.numpy().max()\n",
444
+ " delta = max - min\n",
445
+ "\n",
446
+ " for i in range(12):\n",
447
+ " plt.subplot(6, 6, i + 1)\n",
448
+ " plt.imshow((images[i]-min) / delta)\n",
449
+ " if predictions is None:\n",
450
+ " plt.title(label_info.int2str(labels[i]))\n",
451
+ " else:\n",
452
+ " if labels[i] == predictions[i]:\n",
453
+ " color = 'g'\n",
454
+ " else:\n",
455
+ " color = 'r'\n",
456
+ " plt.title(label_info.int2str(predictions[i]), color=color)\n",
457
+ " plt.axis(\"off\")"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "code",
462
+ "execution_count": null,
463
+ "metadata": {
464
+ "id": "xkA5h_RBtYYU"
465
+ },
466
+ "outputs": [],
467
+ "source": [
468
+ "plt.figure(figsize=(10, 10))\n",
469
+ "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
470
+ " show_batch(images, labels)"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "markdown",
475
+ "metadata": {
476
+ "id": "v_A9VnL2RbXP"
477
+ },
478
+ "source": [
479
+ "## Visualize the testing data"
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "markdown",
484
+ "metadata": {
485
+ "id": "AXovuumW_I2z"
486
+ },
487
+ "source": [
488
+ "Visualize a batch of images from the validation dataset."
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": null,
494
+ "metadata": {
495
+ "id": "Ma-_Eb-nte9A"
496
+ },
497
+ "outputs": [],
498
+ "source": [
499
+ "plt.figure(figsize=(10, 10));\n",
500
+ "for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):\n",
501
+ " show_batch(images, labels)"
502
+ ]
503
+ },
504
+ {
505
+ "cell_type": "markdown",
506
+ "metadata": {
507
+ "id": "ihKJt2FHRi2N"
508
+ },
509
+ "source": [
510
+ "## Train and evaluate"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": null,
516
+ "metadata": {
517
+ "id": "0AFMNvYxtjXx"
518
+ },
519
+ "outputs": [],
520
+ "source": [
521
+ "model, eval_logs = tfm.core.train_lib.run_experiment(\n",
522
+ " distribution_strategy=distribution_strategy,\n",
523
+ " task=task,\n",
524
+ " mode='train_and_eval',\n",
525
+ " params=exp_config,\n",
526
+ " model_dir=model_dir,\n",
527
+ " run_post_eval=True)"
528
+ ]
529
+ },
530
+ {
531
+ "cell_type": "code",
532
+ "execution_count": null,
533
+ "metadata": {
534
+ "id": "gCcHMQYhozmA"
535
+ },
536
+ "outputs": [],
537
+ "source": [
538
+ "# tf.keras.utils.plot_model(model, show_shapes=True)"
539
+ ]
540
+ },
541
+ {
542
+ "cell_type": "markdown",
543
+ "metadata": {
544
+ "id": "L7nVfxlBA8Gb"
545
+ },
546
+ "source": [
547
+ "Print the `accuracy`, `top_5_accuracy`, and `validation_loss` evaluation metrics."
548
+ ]
549
+ },
550
+ {
551
+ "cell_type": "code",
552
+ "execution_count": null,
553
+ "metadata": {
554
+ "id": "0124f938a1b9"
555
+ },
556
+ "outputs": [],
557
+ "source": [
558
+ "for key, value in eval_logs.items():\n",
559
+ " if isinstance(value, tf.Tensor):\n",
560
+ " value = value.numpy()\n",
561
+ " print(f'{key:20}: {value:.3f}')"
562
+ ]
563
+ },
564
+ {
565
+ "cell_type": "markdown",
566
+ "metadata": {
567
+ "id": "TDys5bZ1zsml"
568
+ },
569
+ "source": [
570
+ "Run a batch of the processed training data through the model, and view the results"
571
+ ]
572
+ },
573
+ {
574
+ "cell_type": "code",
575
+ "execution_count": null,
576
+ "metadata": {
577
+ "id": "GhI7zR-Uz1JT"
578
+ },
579
+ "outputs": [],
580
+ "source": [
581
+ "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
582
+ " predictions = model.predict(images)\n",
583
+ " predictions = tf.argmax(predictions, axis=-1)\n",
584
+ "\n",
585
+ "show_batch(images, labels, tf.cast(predictions, tf.int32))\n",
586
+ "\n",
587
+ "if device=='CPU':\n",
588
+ " plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')"
589
+ ]
590
+ },
591
+ {
592
+ "cell_type": "markdown",
593
+ "metadata": {
594
+ "id": "fkE9locGTBgt"
595
+ },
596
+ "source": [
597
+ "## Export a SavedModel"
598
+ ]
599
+ },
600
+ {
601
+ "cell_type": "markdown",
602
+ "metadata": {
603
+ "id": "9669d08c91af"
604
+ },
605
+ "source": [
606
+ "The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results.\n"
607
+ ]
608
+ },
609
+ {
610
+ "cell_type": "code",
611
+ "execution_count": null,
612
+ "metadata": {
613
+ "id": "AQCFa7BvtmDg"
614
+ },
615
+ "outputs": [],
616
+ "source": [
617
+ "# Saving and exporting the trained model\n",
618
+ "export_saved_model_lib.export_inference_graph(\n",
619
+ " input_type='image_tensor',\n",
620
+ " batch_size=1,\n",
621
+ " input_image_size=[32, 32],\n",
622
+ " params=exp_config,\n",
623
+ " checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
624
+ " export_dir='./export/')"
625
+ ]
626
+ },
627
+ {
628
+ "cell_type": "markdown",
629
+ "metadata": {
630
+ "id": "vVr6DxNqTyLZ"
631
+ },
632
+ "source": [
633
+ "Test the exported model."
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "code",
638
+ "execution_count": null,
639
+ "metadata": {
640
+ "id": "gP7nOvrftsB0"
641
+ },
642
+ "outputs": [],
643
+ "source": [
644
+ "# Importing SavedModel\n",
645
+ "imported = tf.saved_model.load('./export/')\n",
646
+ "model_fn = imported.signatures['serving_default']"
647
+ ]
648
+ },
649
+ {
650
+ "cell_type": "markdown",
651
+ "metadata": {
652
+ "id": "GiOp2WVIUNUZ"
653
+ },
654
+ "source": [
655
+ "Visualize the predictions."
656
+ ]
657
+ },
658
+ {
659
+ "cell_type": "code",
660
+ "execution_count": null,
661
+ "metadata": {
662
+ "id": "BTRMrZQAN4mk"
663
+ },
664
+ "outputs": [],
665
+ "source": [
666
+ "plt.figure(figsize=(10, 10))\n",
667
+ "for data in tfds.load('cifar10', split='test').batch(12).take(1):\n",
668
+ " predictions = []\n",
669
+ " for image in data['image']:\n",
670
+ " index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]\n",
671
+ " predictions.append(index)\n",
672
+ " show_batch(data['image'], data['label'], predictions)\n",
673
+ "\n",
674
+ " if device=='CPU':\n",
675
+ " plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')"
676
+ ]
677
+ }
678
+ ],
679
+ "metadata": {
680
+ "colab": {
681
+ "name": "classification_with_model_garden.ipynb",
682
+ "provenance": [],
683
+ "toc_visible": true
684
+ },
685
+ "kernelspec": {
686
+ "display_name": "Python 3",
687
+ "name": "python3"
688
+ }
689
+ },
690
+ "nbformat": 4,
691
+ "nbformat_minor": 0
692
+ }
models/docs/vision/instance_segmentation.ipynb ADDED
@@ -0,0 +1,1138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "eCes7jVU8r08"
7
+ },
8
+ "source": [
9
+ "##### Copyright 2023 The TensorFlow Authors."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "id": "pc1j3ZVF8mmG"
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
21
+ "# you may not use this file except in compliance with the License.\n",
22
+ "# You may obtain a copy of the License at\n",
23
+ "#\n",
24
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
25
+ "#\n",
26
+ "# Unless required by applicable law or agreed to in writing, software\n",
27
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
28
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
29
+ "# See the License for the specific language governing permissions and\n",
30
+ "# limitations under the License."
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "markdown",
35
+ "metadata": {
36
+ "id": "SUUX9CnCYI9Y"
37
+ },
38
+ "source": [
39
+ "# Instance Segmentation with Model Garden\n",
40
+ "\n",
41
+ "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
42
+ " <td>\n",
43
+ " <a target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/instance_segmentation\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
44
+ " </td>\n",
45
+ " <td>\n",
46
+ " <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/instance_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
47
+ " </td>\n",
48
+ " <td>\n",
49
+ " <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/instance_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
50
+ " </td>\n",
51
+ " <td>\n",
52
+ " <a href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/instance_segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
53
+ " </td>\n",
54
+ "</table>"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "metadata": {
60
+ "id": "UjP7bQUdTeFr"
61
+ },
62
+ "source": [
63
+ "This tutorial fine-tunes a [Mask R-CNN](https://arxiv.org/abs/1703.06870) with [Mobilenet V2](https://arxiv.org/abs/1801.04381) as backbone model from the [TensorFlow Model Garden](https://pypi.org/project/tf-models-official/) package (tensorflow-models).\n",
64
+ "\n",
65
+ "\n",
66
+ "[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
67
+ "\n",
68
+ "This tutorial demonstrates how to:\n",
69
+ "\n",
70
+ "1. Use models from the TensorFlow Models package.\n",
71
+ "2. Train/Fine-tune a pre-built Mask R-CNN with mobilenet as backbone for Object Detection and Instance Segmentation\n",
72
+ "3. Export the trained/tuned Mask R-CNN model"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "metadata": {
78
+ "id": "RDp6Kk1Baoi4"
79
+ },
80
+ "source": [
81
+ "## Install Necessary Dependencies"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "metadata": {
88
+ "id": "hcl98qUOxlL8"
89
+ },
90
+ "outputs": [],
91
+ "source": [
92
+ "!pip install -U -q \"tf-models-official\"\n",
93
+ "!pip install -U -q remotezip tqdm opencv-python einops"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "markdown",
98
+ "metadata": {
99
+ "id": "5-gCe_YTapey"
100
+ },
101
+ "source": [
102
+ "## Import required libraries"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {
109
+ "id": "Qa9552Ukgf3d"
110
+ },
111
+ "outputs": [],
112
+ "source": [
113
+ "import os\n",
114
+ "import io\n",
115
+ "import json\n",
116
+ "import tqdm\n",
117
+ "import shutil\n",
118
+ "import pprint\n",
119
+ "import pathlib\n",
120
+ "import tempfile\n",
121
+ "import requests\n",
122
+ "import collections\n",
123
+ "import matplotlib\n",
124
+ "import numpy as np\n",
125
+ "import tensorflow as tf\n",
126
+ "import matplotlib.pyplot as plt\n",
127
+ "\n",
128
+ "from PIL import Image\n",
129
+ "from six import BytesIO\n",
130
+ "from etils import epath\n",
131
+ "from IPython import display\n",
132
+ "from urllib.request import urlopen"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "metadata": {
139
+ "id": "tSCMIDRDP2fV"
140
+ },
141
+ "outputs": [],
142
+ "source": [
143
+ "import orbit\n",
144
+ "import tensorflow as tf\n",
145
+ "import tensorflow_models as tfm\n",
146
+ "import tensorflow_datasets as tfds\n",
147
+ "\n",
148
+ "from official.core import exp_factory\n",
149
+ "from official.core import config_definitions as cfg\n",
150
+ "from official.vision.data import tfrecord_lib\n",
151
+ "from official.vision.serving import export_saved_model_lib\n",
152
+ "from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder\n",
153
+ "from official.vision.utils.object_detection import visualization_utils\n",
154
+ "from official.vision.ops.preprocess_ops import normalize_image, resize_and_crop_image\n",
155
+ "from official.vision.data.create_coco_tf_record import coco_annotations_to_lists\n",
156
+ "\n",
157
+ "pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
158
+ "print(tf.__version__) # Check the version of tensorflow used\n",
159
+ "\n",
160
+ "%matplotlib inline"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "metadata": {
166
+ "id": "GIrXW8sp2bKa"
167
+ },
168
+ "source": [
169
+ "## Download subset of lvis dataset\n",
170
+ "\n",
171
+ "[LVIS](https://www.tensorflow.org/datasets/catalog/lvis): A dataset for large vocabulary instance segmentation.\n",
172
+ "\n",
173
+ "Note: LVIS uses the COCO 2017 train, validation, and test image sets. \n",
174
+ "If you have already downloaded the COCO images, you only need to download \n",
175
+ "the LVIS annotations. LVIS val set contains images from COCO 2017 train in \n",
176
+ "addition to the COCO 2017 val split."
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {
183
+ "cellView": "form",
184
+ "id": "F_A9_cS310jf"
185
+ },
186
+ "outputs": [],
187
+ "source": [
188
+ "# @title Download annotation files\n",
189
+ "\n",
190
+ "!wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip\n",
191
+ "!unzip -q lvis_v1_train.json.zip\n",
192
+ "!rm lvis_v1_train.json.zip\n",
193
+ "\n",
194
+ "!wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_val.json.zip\n",
195
+ "!unzip -q lvis_v1_val.json.zip\n",
196
+ "!rm lvis_v1_val.json.zip\n",
197
+ "\n",
198
+ "!wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_image_info_test_dev.json.zip\n",
199
+ "!unzip -q lvis_v1_image_info_test_dev.json.zip\n",
200
+ "!rm lvis_v1_image_info_test_dev.json.zip"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "metadata": {
207
+ "cellView": "form",
208
+ "id": "kB-C5Svj11S0"
209
+ },
210
+ "outputs": [],
211
+ "source": [
212
+ "# @title Lvis annotation parsing\n",
213
+ "\n",
214
+ "# Annotations with invalid bounding boxes. Will not be used.\n",
215
+ "_INVALID_ANNOTATIONS = [\n",
216
+ " # Train split.\n",
217
+ " 662101,\n",
218
+ " 81217,\n",
219
+ " 462924,\n",
220
+ " 227817,\n",
221
+ " 29381,\n",
222
+ " 601484,\n",
223
+ " 412185,\n",
224
+ " 504667,\n",
225
+ " 572573,\n",
226
+ " 91937,\n",
227
+ " 239022,\n",
228
+ " 181534,\n",
229
+ " 101685,\n",
230
+ " # Validation split.\n",
231
+ " 36668,\n",
232
+ " 57541,\n",
233
+ " 33126,\n",
234
+ " 10932,\n",
235
+ "]\n",
236
+ "\n",
237
+ "def get_category_map(annotation_path, num_classes):\n",
238
+ " with epath.Path(annotation_path).open() as f:\n",
239
+ " data = json.load(f)\n",
240
+ "\n",
241
+ " category_map = {id+1: {'id': cat_dict['id'],\n",
242
+ " 'name': cat_dict['name']}\n",
243
+ " for id, cat_dict in enumerate(data['categories'][:num_classes])}\n",
244
+ " return category_map\n",
245
+ "\n",
246
+ "class LvisAnnotation:\n",
247
+ " \"\"\"LVIS annotation helper class.\n",
248
+ " The format of the annations is explained on\n",
249
+ " https://www.lvisdataset.org/dataset.\n",
250
+ " \"\"\"\n",
251
+ "\n",
252
+ " def __init__(self, annotation_path):\n",
253
+ " with epath.Path(annotation_path).open() as f:\n",
254
+ " data = json.load(f)\n",
255
+ " self._data = data\n",
256
+ "\n",
257
+ " img_id2annotations = collections.defaultdict(list)\n",
258
+ " for a in self._data.get('annotations', []):\n",
259
+ " if a['category_id'] in category_ids:\n",
260
+ " img_id2annotations[a['image_id']].append(a)\n",
261
+ " self._img_id2annotations = {\n",
262
+ " k: list(sorted(v, key=lambda a: a['id']))\n",
263
+ " for k, v in img_id2annotations.items()\n",
264
+ " }\n",
265
+ "\n",
266
+ " @property\n",
267
+ " def categories(self):\n",
268
+ " \"\"\"Return the category dicts, as sorted in the file.\"\"\"\n",
269
+ " return self._data['categories']\n",
270
+ "\n",
271
+ " @property\n",
272
+ " def images(self):\n",
273
+ " \"\"\"Return the image dicts, as sorted in the file.\"\"\"\n",
274
+ " sub_images = []\n",
275
+ " for image_info in self._data['images']:\n",
276
+ " if image_info['id'] in self._img_id2annotations:\n",
277
+ " sub_images.append(image_info)\n",
278
+ " return sub_images\n",
279
+ "\n",
280
+ " def get_annotations(self, img_id):\n",
281
+ " \"\"\"Return all annotations associated with the image id string.\"\"\"\n",
282
+ " # Some images don't have any annotations. Return empty list instead.\n",
283
+ " return self._img_id2annotations.get(img_id, [])\n",
284
+ "\n",
285
+ "def _generate_tf_records(prefix, images_zip, annotation_file, num_shards=5):\n",
286
+ " \"\"\"Generate TFRecords.\"\"\"\n",
287
+ "\n",
288
+ " lvis_annotation = LvisAnnotation(annotation_file)\n",
289
+ "\n",
290
+ " def _process_example(prefix, image_info, id_to_name_map):\n",
291
+ " # Search image dirs.\n",
292
+ " filename = pathlib.Path(image_info['coco_url']).name\n",
293
+ " image = tf.io.read_file(os.path.join(IMGS_DIR, filename))\n",
294
+ " instances = lvis_annotation.get_annotations(img_id=image_info['id'])\n",
295
+ " instances = [x for x in instances if x['id'] not in _INVALID_ANNOTATIONS]\n",
296
+ " # print([x['category_id'] for x in instances])\n",
297
+ " is_crowd = {'iscrowd': 0}\n",
298
+ " instances = [dict(x, **is_crowd) for x in instances]\n",
299
+ " neg_category_ids = image_info.get('neg_category_ids', [])\n",
300
+ " not_exhaustive_category_ids = image_info.get(\n",
301
+ " 'not_exhaustive_category_ids', []\n",
302
+ " )\n",
303
+ " data, _ = coco_annotations_to_lists(instances,\n",
304
+ " id_to_name_map,\n",
305
+ " image_info['height'],\n",
306
+ " image_info['width'],\n",
307
+ " include_masks=True)\n",
308
+ " # data['category_id'] = [id-1 for id in data['category_id']]\n",
309
+ " keys_to_features = {\n",
310
+ " 'image/encoded':\n",
311
+ " tfrecord_lib.convert_to_feature(image.numpy()),\n",
312
+ " 'image/filename':\n",
313
+ " tfrecord_lib.convert_to_feature(filename.encode('utf8')),\n",
314
+ " 'image/format':\n",
315
+ " tfrecord_lib.convert_to_feature('jpg'.encode('utf8')),\n",
316
+ " 'image/height':\n",
317
+ " tfrecord_lib.convert_to_feature(image_info['height']),\n",
318
+ " 'image/width':\n",
319
+ " tfrecord_lib.convert_to_feature(image_info['width']),\n",
320
+ " 'image/source_id':\n",
321
+ " tfrecord_lib.convert_to_feature(str(image_info['id']).encode('utf8')),\n",
322
+ " 'image/object/bbox/xmin':\n",
323
+ " tfrecord_lib.convert_to_feature(data['xmin']),\n",
324
+ " 'image/object/bbox/xmax':\n",
325
+ " tfrecord_lib.convert_to_feature(data['xmax']),\n",
326
+ " 'image/object/bbox/ymin':\n",
327
+ " tfrecord_lib.convert_to_feature(data['ymin']),\n",
328
+ " 'image/object/bbox/ymax':\n",
329
+ " tfrecord_lib.convert_to_feature(data['ymax']),\n",
330
+ " 'image/object/class/text':\n",
331
+ " tfrecord_lib.convert_to_feature(data['category_names']),\n",
332
+ " 'image/object/class/label':\n",
333
+ " tfrecord_lib.convert_to_feature(data['category_id']),\n",
334
+ " 'image/object/is_crowd':\n",
335
+ " tfrecord_lib.convert_to_feature(data['is_crowd']),\n",
336
+ " 'image/object/area':\n",
337
+ " tfrecord_lib.convert_to_feature(data['area'], 'float_list'),\n",
338
+ " 'image/object/mask':\n",
339
+ " tfrecord_lib.convert_to_feature(data['encoded_mask_png'])\n",
340
+ " }\n",
341
+ " # print(keys_to_features['image/object/class/label'])\n",
342
+ " example = tf.train.Example(\n",
343
+ " features=tf.train.Features(feature=keys_to_features))\n",
344
+ " return example\n",
345
+ "\n",
346
+ "\n",
347
+ "\n",
348
+ " # file_names = [f\"{prefix}/{pathlib.Path(image_info['coco_url']).name}\"\n",
349
+ " # for image_info in lvis_annotation.images]\n",
350
+ " # _extract_images(images_zip, file_names)\n",
351
+ " writers = [\n",
352
+ " tf.io.TFRecordWriter(\n",
353
+ " tf_records_dir + prefix +'-%05d-of-%05d.tfrecord' % (i, num_shards))\n",
354
+ " for i in range(num_shards)\n",
355
+ " ]\n",
356
+ " id_to_name_map = {cat_dict['id']: cat_dict['name']\n",
357
+ " for cat_dict in lvis_annotation.categories[:NUM_CLASSES]}\n",
358
+ " # print(id_to_name_map)\n",
359
+ " for idx, image_info in enumerate(tqdm.tqdm(lvis_annotation.images)):\n",
360
+ " img_data = requests.get(image_info['coco_url'], stream=True).content\n",
361
+ " img_name = image_info['coco_url'].split('/')[-1]\n",
362
+ " with open(os.path.join(IMGS_DIR, img_name), 'wb') as handler:\n",
363
+ " handler.write(img_data)\n",
364
+ " tf_example = _process_example(prefix, image_info, id_to_name_map)\n",
365
+ " writers[idx % num_shards].write(tf_example.SerializeToString())\n",
366
+ "\n",
367
+ " del lvis_annotation"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": null,
373
+ "metadata": {
374
+ "id": "5u2dwjIT2HZu"
375
+ },
376
+ "outputs": [],
377
+ "source": [
378
+ "_URLS = {\n",
379
+ " 'train_images': 'http://images.cocodataset.org/zips/train2017.zip',\n",
380
+ " 'validation_images': 'http://images.cocodataset.org/zips/val2017.zip',\n",
381
+ " 'test_images': 'http://images.cocodataset.org/zips/test2017.zip',\n",
382
+ "}\n",
383
+ "\n",
384
+ "train_prefix = 'train'\n",
385
+ "valid_prefix = 'val'\n",
386
+ "\n",
387
+ "train_annotation_path = './lvis_v1_train.json'\n",
388
+ "valid_annotation_path = './lvis_v1_val.json'\n",
389
+ "\n",
390
+ "IMGS_DIR = './lvis_sub_dataset/'\n",
391
+ "tf_records_dir = './lvis_tfrecords/'\n",
392
+ "\n",
393
+ "\n",
394
+ "if not os.path.exists(IMGS_DIR):\n",
395
+ " os.mkdir(IMGS_DIR)\n",
396
+ "\n",
397
+ "if not os.path.exists(tf_records_dir):\n",
398
+ " os.mkdir(tf_records_dir)\n",
399
+ "\n",
400
+ "\n",
401
+ "\n",
402
+ "NUM_CLASSES = 3\n",
403
+ "category_index = get_category_map(valid_annotation_path, NUM_CLASSES)\n",
404
+ "category_ids = list(category_index.keys())"
405
+ ]
406
+ },
407
+ {
408
+ "cell_type": "code",
409
+ "execution_count": null,
410
+ "metadata": {
411
+ "id": "KBgl5fG42LpD"
412
+ },
413
+ "outputs": [],
414
+ "source": [
415
+ "# Below helper function are taken from github tensorflow dataset lvis\n",
416
+ "# https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/datasets/lvis/lvis_dataset_builder.py\n",
417
+ "_generate_tf_records(train_prefix,\n",
418
+ " _URLS['train_images'],\n",
419
+ " train_annotation_path)"
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "code",
424
+ "execution_count": null,
425
+ "metadata": {
426
+ "id": "89O59u_H2NIJ"
427
+ },
428
+ "outputs": [],
429
+ "source": [
430
+ "_generate_tf_records(valid_prefix,\n",
431
+ " _URLS['validation_images'],\n",
432
+ " valid_annotation_path)"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "markdown",
437
+ "metadata": {
438
+ "id": "EREyevfIY4rz"
439
+ },
440
+ "source": [
441
+ "## Configure the MaskRCNN Resnet FPN COCO model for custom dataset"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": null,
447
+ "metadata": {
448
+ "id": "5yGLLvXlPInP"
449
+ },
450
+ "outputs": [],
451
+ "source": [
452
+ "train_data_input_path = './lvis_tfrecords/train*'\n",
453
+ "valid_data_input_path = './lvis_tfrecords/val*'\n",
454
+ "test_data_input_path = './lvis_tfrecords/test*'\n",
455
+ "model_dir = './trained_model/'\n",
456
+ "export_dir ='./exported_model/'"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": null,
462
+ "metadata": {
463
+ "id": "ms3wRQKAIORe"
464
+ },
465
+ "outputs": [],
466
+ "source": [
467
+ "if not os.path.exists(model_dir):\n",
468
+ " os.mkdir(model_dir)"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "markdown",
473
+ "metadata": {
474
+ "id": "EXA5NmvDblYP"
475
+ },
476
+ "source": [
477
+ "In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
478
+ "\n",
479
+ "\n",
480
+ "Use the `retinanet_mobilenet_coco` experiment configuration, as defined by `tfm.vision.configs.maskrcnn.maskrcnn_mobilenet_coco`.\n",
481
+ "\n",
482
+ "Please find all the registered experiements [here](https://www.tensorflow.org/api_docs/python/tfm/core/exp_factory/get_exp_config)\n",
483
+ "\n",
484
+ "The configuration defines an experiment to train a Mask R-CNN model with mobilenet as backbone and FPN as decoder. Default Congiguration is trained on [COCO](https://cocodataset.org/) train2017 and evaluated on [COCO](https://cocodataset.org/) val2017.\n",
485
+ "\n",
486
+ "There are also other alternative experiments available such as\n",
487
+ "`maskrcnn_resnetfpn_coco`,\n",
488
+ "`maskrcnn_spinenet_coco` and more. One can switch to them by changing the experiment name argument to the `get_exp_config` function."
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": null,
494
+ "metadata": {
495
+ "id": "Zi2F1qGgPWOH"
496
+ },
497
+ "outputs": [],
498
+ "source": [
499
+ "exp_config = exp_factory.get_exp_config('maskrcnn_mobilenet_coco')"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "execution_count": null,
505
+ "metadata": {
506
+ "id": "zo-EaCdmn5j-"
507
+ },
508
+ "outputs": [],
509
+ "source": [
510
+ "model_ckpt_path = './model_ckpt/'\n",
511
+ "if not os.path.exists(model_ckpt_path):\n",
512
+ " os.mkdir(model_ckpt_path)\n",
513
+ "\n",
514
+ "!gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.data-00000-of-00001 './model_ckpt/'\n",
515
+ "!gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.index './model_ckpt/'"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "markdown",
520
+ "metadata": {
521
+ "id": "ymnwJYaFgHs2"
522
+ },
523
+ "source": [
524
+ "### Adjust the model and dataset configurations so that it works with custom dataset."
525
+ ]
526
+ },
527
+ {
528
+ "cell_type": "code",
529
+ "execution_count": null,
530
+ "metadata": {
531
+ "id": "zyn9ieZyUbEJ"
532
+ },
533
+ "outputs": [],
534
+ "source": [
535
+ "BATCH_SIZE = 8\n",
536
+ "HEIGHT, WIDTH = 256, 256\n",
537
+ "IMG_SHAPE = [HEIGHT, WIDTH, 3]\n",
538
+ "\n",
539
+ "\n",
540
+ "# Backbone Config\n",
541
+ "exp_config.task.annotation_file = None\n",
542
+ "exp_config.task.freeze_backbone = True\n",
543
+ "exp_config.task.init_checkpoint = \"./model_ckpt/ckpt-180648\"\n",
544
+ "exp_config.task.init_checkpoint_modules = \"backbone\"\n",
545
+ "\n",
546
+ "# Model Config\n",
547
+ "exp_config.task.model.num_classes = NUM_CLASSES + 1\n",
548
+ "exp_config.task.model.input_size = IMG_SHAPE\n",
549
+ "\n",
550
+ "# Training Data Config\n",
551
+ "exp_config.task.train_data.input_path = train_data_input_path\n",
552
+ "exp_config.task.train_data.dtype = 'float32'\n",
553
+ "exp_config.task.train_data.global_batch_size = BATCH_SIZE\n",
554
+ "exp_config.task.train_data.shuffle_buffer_size = 64\n",
555
+ "exp_config.task.train_data.parser.aug_scale_max = 1.0\n",
556
+ "exp_config.task.train_data.parser.aug_scale_min = 1.0\n",
557
+ "\n",
558
+ "# Validation Data Config\n",
559
+ "exp_config.task.validation_data.input_path = valid_data_input_path\n",
560
+ "exp_config.task.validation_data.dtype = 'float32'\n",
561
+ "exp_config.task.validation_data.global_batch_size = BATCH_SIZE"
562
+ ]
563
+ },
564
+ {
565
+ "cell_type": "markdown",
566
+ "metadata": {
567
+ "id": "0409ReANgKzF"
568
+ },
569
+ "source": [
570
+ "### Adjust the trainer configuration."
571
+ ]
572
+ },
573
+ {
574
+ "cell_type": "code",
575
+ "execution_count": null,
576
+ "metadata": {
577
+ "id": "ne8t5AHRUd9g"
578
+ },
579
+ "outputs": [],
580
+ "source": [
581
+ "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
582
+ "\n",
583
+ "if 'GPU' in ''.join(logical_device_names):\n",
584
+ " print('This may be broken in Colab.')\n",
585
+ " device = 'GPU'\n",
586
+ "elif 'TPU' in ''.join(logical_device_names):\n",
587
+ " print('This may be broken in Colab.')\n",
588
+ " device = 'TPU'\n",
589
+ "else:\n",
590
+ " print('Running on CPU is slow, so only train for a few steps.')\n",
591
+ " device = 'CPU'\n",
592
+ "\n",
593
+ "\n",
594
+ "train_steps = 2000\n",
595
+ "exp_config.trainer.steps_per_loop = 200 # steps_per_loop = num_of_training_examples // train_batch_size\n",
596
+ "\n",
597
+ "exp_config.trainer.summary_interval = 200\n",
598
+ "exp_config.trainer.checkpoint_interval = 200\n",
599
+ "exp_config.trainer.validation_interval = 200\n",
600
+ "exp_config.trainer.validation_steps = 200 # validation_steps = num_of_validation_examples // eval_batch_size\n",
601
+ "exp_config.trainer.train_steps = train_steps\n",
602
+ "exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 200\n",
603
+ "exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
604
+ "exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
605
+ "exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.07\n",
606
+ "exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05"
607
+ ]
608
+ },
609
+ {
610
+ "cell_type": "markdown",
611
+ "metadata": {
612
+ "id": "k3I4X-bWgNm0"
613
+ },
614
+ "source": [
615
+ "### Print the modified configuration."
616
+ ]
617
+ },
618
+ {
619
+ "cell_type": "code",
620
+ "execution_count": null,
621
+ "metadata": {
622
+ "id": "IsmxXNlyWBAK"
623
+ },
624
+ "outputs": [],
625
+ "source": [
626
+ "pp.pprint(exp_config.as_dict())\n",
627
+ "display.Javascript(\"google.colab.output.setIframeHeight('500px');\")"
628
+ ]
629
+ },
630
+ {
631
+ "cell_type": "markdown",
632
+ "metadata": {
633
+ "id": "jxarWEHDgQSk"
634
+ },
635
+ "source": [
636
+ "### Set up the distribution strategy."
637
+ ]
638
+ },
639
+ {
640
+ "cell_type": "code",
641
+ "execution_count": null,
642
+ "metadata": {
643
+ "id": "4JxhiGNwQRv2"
644
+ },
645
+ "outputs": [],
646
+ "source": [
647
+ "# Setting up the Strategy\n",
648
+ "if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
649
+ " tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
650
+ "\n",
651
+ "if 'GPU' in ''.join(logical_device_names):\n",
652
+ " distribution_strategy = tf.distribute.MirroredStrategy()\n",
653
+ "elif 'TPU' in ''.join(logical_device_names):\n",
654
+ " tf.tpu.experimental.initialize_tpu_system()\n",
655
+ " tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
656
+ " distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
657
+ "else:\n",
658
+ " print('Warning: this will be really slow.')\n",
659
+ " distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n",
660
+ "\n",
661
+ "print(\"Done\")"
662
+ ]
663
+ },
664
+ {
665
+ "cell_type": "markdown",
666
+ "metadata": {
667
+ "id": "QqZU9f1ugS_A"
668
+ },
669
+ "source": [
670
+ "## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
671
+ "\n",
672
+ "The `Task` object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
673
+ ]
674
+ },
675
+ {
676
+ "cell_type": "code",
677
+ "execution_count": null,
678
+ "metadata": {
679
+ "id": "N5R-7KzORB1n"
680
+ },
681
+ "outputs": [],
682
+ "source": [
683
+ "with distribution_strategy.scope():\n",
684
+ " task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "markdown",
689
+ "metadata": {
690
+ "id": "Fmpz2R_cglIv"
691
+ },
692
+ "source": [
693
+ "## Visualize a batch of the data."
694
+ ]
695
+ },
696
+ {
697
+ "cell_type": "code",
698
+ "execution_count": null,
699
+ "metadata": {
700
+ "id": "O82f_7A8gfnY"
701
+ },
702
+ "outputs": [],
703
+ "source": [
704
+ "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
705
+ " print()\n",
706
+ " print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
707
+ " print(f'labels.keys: {labels.keys()}')"
708
+ ]
709
+ },
710
+ {
711
+ "cell_type": "markdown",
712
+ "metadata": {
713
+ "id": "dLcSHWjqgl66"
714
+ },
715
+ "source": [
716
+ "### Create Category Index Dictionary to map the labels to coressponding label names"
717
+ ]
718
+ },
719
+ {
720
+ "cell_type": "code",
721
+ "execution_count": null,
722
+ "metadata": {
723
+ "id": "ajF85r_6R9d9"
724
+ },
725
+ "outputs": [],
726
+ "source": [
727
+ "tf_ex_decoder = TfExampleDecoder(include_mask=True)"
728
+ ]
729
+ },
730
+ {
731
+ "cell_type": "markdown",
732
+ "metadata": {
733
+ "id": "gRdveeYVgr7B"
734
+ },
735
+ "source": [
736
+ "### Helper Function for Visualizing the results from TFRecords\n",
737
+ "Use `visualize_boxes_and_labels_on_image_array` from `visualization_utils` to draw boudning boxes on the image."
738
+ ]
739
+ },
740
+ {
741
+ "cell_type": "code",
742
+ "execution_count": null,
743
+ "metadata": {
744
+ "id": "uWEuOs8QStrz"
745
+ },
746
+ "outputs": [],
747
+ "source": [
748
+ "def show_batch(raw_records):\n",
749
+ " plt.figure(figsize=(20, 20))\n",
750
+ " use_normalized_coordinates=True\n",
751
+ " min_score_thresh = 0.30\n",
752
+ " for i, serialized_example in enumerate(raw_records):\n",
753
+ " plt.subplot(1, 3, i + 1)\n",
754
+ " decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
755
+ " image = decoded_tensors['image'].numpy().astype('uint8')\n",
756
+ " scores = np.ones(shape=(len(decoded_tensors['groundtruth_boxes'])))\n",
757
+ " # print(decoded_tensors['groundtruth_instance_masks'].numpy().shape)\n",
758
+ " # print(decoded_tensors.keys())\n",
759
+ " visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
760
+ " image,\n",
761
+ " decoded_tensors['groundtruth_boxes'].numpy(),\n",
762
+ " decoded_tensors['groundtruth_classes'].numpy().astype('int'),\n",
763
+ " scores,\n",
764
+ " category_index=category_index,\n",
765
+ " use_normalized_coordinates=use_normalized_coordinates,\n",
766
+ " min_score_thresh=min_score_thresh,\n",
767
+ " instance_masks=decoded_tensors['groundtruth_instance_masks'].numpy().astype('uint8'),\n",
768
+ " line_thickness=4)\n",
769
+ "\n",
770
+ " plt.imshow(image)\n",
771
+ " plt.axis(\"off\")\n",
772
+ " plt.title(f\"Image-{i+1}\")\n",
773
+ " plt.show()"
774
+ ]
775
+ },
776
+ {
777
+ "cell_type": "markdown",
778
+ "metadata": {
779
+ "id": "FergQ2P5gv_j"
780
+ },
781
+ "source": [
782
+ "### Visualization of Train Data\n",
783
+ "\n",
784
+ "The bounding box detection has three components\n",
785
+ " 1. Class label of the object detected.\n",
786
+ " 2. Percentage of match between predicted and ground truth bounding boxes.\n",
787
+ " 3. Instance Segmentation Mask\n",
788
+ "\n",
789
+ "**Note**: The reason of everything is 100% is because we are visualising the groundtruth"
790
+ ]
791
+ },
792
+ {
793
+ "cell_type": "code",
794
+ "execution_count": null,
795
+ "metadata": {
796
+ "id": "lN0zdBwxU5Z5"
797
+ },
798
+ "outputs": [],
799
+ "source": [
800
+ "buffer_size = 100\n",
801
+ "num_of_examples = 3\n",
802
+ "\n",
803
+ "train_tfrecords = tf.io.gfile.glob(exp_config.task.train_data.input_path)\n",
804
+ "raw_records = tf.data.TFRecordDataset(train_tfrecords).shuffle(buffer_size=buffer_size).take(num_of_examples)\n",
805
+ "show_batch(raw_records)"
806
+ ]
807
+ },
808
+ {
809
+ "cell_type": "markdown",
810
+ "metadata": {
811
+ "id": "nn7IZSs5hQLg"
812
+ },
813
+ "source": [
814
+ "## Train and evaluate\n",
815
+ "\n",
816
+ "We follow the COCO challenge tradition to evaluate the accuracy of object detection based on mAP(mean Average Precision). Please check [here](https://cocodataset.org/#detection-eval) for detail explanation of how evaluation metrics for detection task is done.\n",
817
+ "\n",
818
+ "**IoU**: is defined as the area of the intersection divided by the area of the union of a predicted bounding box and ground truth bounding box."
819
+ ]
820
+ },
821
+ {
822
+ "cell_type": "code",
823
+ "execution_count": null,
824
+ "metadata": {
825
+ "id": "UTuIs4kFZGv_"
826
+ },
827
+ "outputs": [],
828
+ "source": [
829
+ "model, eval_logs = tfm.core.train_lib.run_experiment(\n",
830
+ " distribution_strategy=distribution_strategy,\n",
831
+ " task=task,\n",
832
+ " mode='train_and_eval',\n",
833
+ " params=exp_config,\n",
834
+ " model_dir=model_dir,\n",
835
+ " run_post_eval=True)"
836
+ ]
837
+ },
838
+ {
839
+ "cell_type": "markdown",
840
+ "metadata": {
841
+ "id": "rfpH4QHkh1gI"
842
+ },
843
+ "source": [
844
+ "## Load logs in tensorboard"
845
+ ]
846
+ },
847
+ {
848
+ "cell_type": "code",
849
+ "execution_count": null,
850
+ "metadata": {
851
+ "id": "wcdOvg6eNP6R"
852
+ },
853
+ "outputs": [],
854
+ "source": [
855
+ "%load_ext tensorboard\n",
856
+ "%tensorboard --logdir \"./trained_model\""
857
+ ]
858
+ },
859
+ {
860
+ "cell_type": "markdown",
861
+ "metadata": {
862
+ "id": "hAo9lozJh2cV"
863
+ },
864
+ "source": [
865
+ "## Saving and exporting the trained model\n",
866
+ "\n",
867
+ "The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results."
868
+ ]
869
+ },
870
+ {
871
+ "cell_type": "code",
872
+ "execution_count": null,
873
+ "metadata": {
874
+ "id": "iZG1vPbTQqFh"
875
+ },
876
+ "outputs": [],
877
+ "source": [
878
+ "export_saved_model_lib.export_inference_graph(\n",
879
+ " input_type='image_tensor',\n",
880
+ " batch_size=1,\n",
881
+ " input_image_size=[HEIGHT, WIDTH],\n",
882
+ " params=exp_config,\n",
883
+ " checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
884
+ " export_dir=export_dir)"
885
+ ]
886
+ },
887
+ {
888
+ "cell_type": "markdown",
889
+ "metadata": {
890
+ "id": "OHIfMeVXh7vJ"
891
+ },
892
+ "source": [
893
+ "## Inference from Trained Model"
894
+ ]
895
+ },
896
+ {
897
+ "cell_type": "code",
898
+ "execution_count": null,
899
+ "metadata": {
900
+ "id": "uaXyzMvXROTd"
901
+ },
902
+ "outputs": [],
903
+ "source": [
904
+ "def load_image_into_numpy_array(path):\n",
905
+ " \"\"\"Load an image from file into a numpy array.\n",
906
+ "\n",
907
+ " Puts image into numpy array to feed into tensorflow graph.\n",
908
+ " Note that by convention we put it into a numpy array with shape\n",
909
+ " (height, width, channels), where channels=3 for RGB.\n",
910
+ "\n",
911
+ " Args:\n",
912
+ " path: the file path to the image\n",
913
+ "\n",
914
+ " Returns:\n",
915
+ " uint8 numpy array with shape (img_height, img_width, 3)\n",
916
+ " \"\"\"\n",
917
+ " image = None\n",
918
+ " if(path.startswith('http')):\n",
919
+ " response = urlopen(path)\n",
920
+ " image_data = response.read()\n",
921
+ " image_data = BytesIO(image_data)\n",
922
+ " image = Image.open(image_data)\n",
923
+ " else:\n",
924
+ " image_data = tf.io.gfile.GFile(path, 'rb').read()\n",
925
+ " image = Image.open(BytesIO(image_data))\n",
926
+ "\n",
927
+ " (im_width, im_height) = image.size\n",
928
+ " return np.array(image.getdata()).reshape(\n",
929
+ " (1, im_height, im_width, 3)).astype(np.uint8)\n",
930
+ "\n",
931
+ "\n",
932
+ "\n",
933
+ "def build_inputs_for_object_detection(image, input_image_size):\n",
934
+ " \"\"\"Builds Object Detection model inputs for serving.\"\"\"\n",
935
+ " image, _ = resize_and_crop_image(\n",
936
+ " image,\n",
937
+ " input_image_size,\n",
938
+ " padded_size=input_image_size,\n",
939
+ " aug_scale_min=1.0,\n",
940
+ " aug_scale_max=1.0)\n",
941
+ " return image"
942
+ ]
943
+ },
944
+ {
945
+ "cell_type": "markdown",
946
+ "metadata": {
947
+ "id": "ZDI9zv_4h-7-"
948
+ },
949
+ "source": [
950
+ "## Visualize test data"
951
+ ]
952
+ },
953
+ {
954
+ "cell_type": "code",
955
+ "execution_count": null,
956
+ "metadata": {
957
+ "id": "rdyIri-1RThk"
958
+ },
959
+ "outputs": [],
960
+ "source": [
961
+ "num_of_examples = 3\n",
962
+ "\n",
963
+ "test_tfrecords = tf.io.gfile.glob('./lvis_tfrecords/val*')\n",
964
+ "test_ds = tf.data.TFRecordDataset(test_tfrecords).take(num_of_examples)\n",
965
+ "show_batch(test_ds)"
966
+ ]
967
+ },
968
+ {
969
+ "cell_type": "markdown",
970
+ "metadata": {
971
+ "id": "KkMZm4DtiAHO"
972
+ },
973
+ "source": [
974
+ "## Importing SavedModel"
975
+ ]
976
+ },
977
+ {
978
+ "cell_type": "code",
979
+ "execution_count": null,
980
+ "metadata": {
981
+ "id": "rDozz4NXRZ7p"
982
+ },
983
+ "outputs": [],
984
+ "source": [
985
+ "imported = tf.saved_model.load(export_dir)\n",
986
+ "model_fn = imported.signatures['serving_default']"
987
+ ]
988
+ },
989
+ {
990
+ "cell_type": "markdown",
991
+ "metadata": {
992
+ "id": "DUxk4-AjLAcO"
993
+ },
994
+ "source": [
995
+ "## Visualize predictions"
996
+ ]
997
+ },
998
+ {
999
+ "cell_type": "code",
1000
+ "execution_count": null,
1001
+ "metadata": {
1002
+ "id": "Gez57T5ShYnM"
1003
+ },
1004
+ "outputs": [],
1005
+ "source": [
1006
+ "def reframe_image_corners_relative_to_boxes(boxes):\n",
1007
+ " \"\"\"Reframe the image corners ([0, 0, 1, 1]) to be relative to boxes.\n",
1008
+ " The local coordinate frame of each box is assumed to be relative to\n",
1009
+ " its own for corners.\n",
1010
+ " Args:\n",
1011
+ " boxes: A float tensor of [num_boxes, 4] of (ymin, xmin, ymax, xmax)\n",
1012
+ " coordinates in relative coordinate space of each bounding box.\n",
1013
+ " Returns:\n",
1014
+ " reframed_boxes: Reframes boxes with same shape as input.\n",
1015
+ " \"\"\"\n",
1016
+ " ymin, xmin, ymax, xmax = (boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3])\n",
1017
+ "\n",
1018
+ " height = tf.maximum(ymax - ymin, 1e-4)\n",
1019
+ " width = tf.maximum(xmax - xmin, 1e-4)\n",
1020
+ "\n",
1021
+ " ymin_out = (0 - ymin) / height\n",
1022
+ " xmin_out = (0 - xmin) / width\n",
1023
+ " ymax_out = (1 - ymin) / height\n",
1024
+ " xmax_out = (1 - xmin) / width\n",
1025
+ " return tf.stack([ymin_out, xmin_out, ymax_out, xmax_out], axis=1)\n",
1026
+ "\n",
1027
+ "def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,\n",
1028
+ " image_width, resize_method='bilinear'):\n",
1029
+ " \"\"\"Transforms the box masks back to full image masks.\n",
1030
+ " Embeds masks in bounding boxes of larger masks whose shapes correspond to\n",
1031
+ " image shape.\n",
1032
+ " Args:\n",
1033
+ " box_masks: A tensor of size [num_masks, mask_height, mask_width].\n",
1034
+ " boxes: A tf.float32 tensor of size [num_masks, 4] containing the box\n",
1035
+ " corners. Row i contains [ymin, xmin, ymax, xmax] of the box\n",
1036
+ " corresponding to mask i. Note that the box corners are in\n",
1037
+ " normalized coordinates.\n",
1038
+ " image_height: Image height. The output mask will have the same height as\n",
1039
+ " the image height.\n",
1040
+ " image_width: Image width. The output mask will have the same width as the\n",
1041
+ " image width.\n",
1042
+ " resize_method: The resize method, either 'bilinear' or 'nearest'. Note that\n",
1043
+ " 'bilinear' is only respected if box_masks is a float.\n",
1044
+ " Returns:\n",
1045
+ " A tensor of size [num_masks, image_height, image_width] with the same dtype\n",
1046
+ " as `box_masks`.\n",
1047
+ " \"\"\"\n",
1048
+ " resize_method = 'nearest' if box_masks.dtype == tf.uint8 else resize_method\n",
1049
+ " # TODO(rathodv): Make this a public function.\n",
1050
+ " def reframe_box_masks_to_image_masks_default():\n",
1051
+ " \"\"\"The default function when there are more than 0 box masks.\"\"\"\n",
1052
+ "\n",
1053
+ " num_boxes = tf.shape(box_masks)[0]\n",
1054
+ " box_masks_expanded = tf.expand_dims(box_masks, axis=3)\n",
1055
+ "\n",
1056
+ " resized_crops = tf.image.crop_and_resize(\n",
1057
+ " image=box_masks_expanded,\n",
1058
+ " boxes=reframe_image_corners_relative_to_boxes(boxes),\n",
1059
+ " box_indices=tf.range(num_boxes),\n",
1060
+ " crop_size=[image_height, image_width],\n",
1061
+ " method=resize_method,\n",
1062
+ " extrapolation_value=0)\n",
1063
+ " return tf.cast(resized_crops, box_masks.dtype)\n",
1064
+ "\n",
1065
+ " image_masks = tf.cond(\n",
1066
+ " tf.shape(box_masks)[0] > 0,\n",
1067
+ " reframe_box_masks_to_image_masks_default,\n",
1068
+ " lambda: tf.zeros([0, image_height, image_width, 1], box_masks.dtype))\n",
1069
+ " return tf.squeeze(image_masks, axis=3)"
1070
+ ]
1071
+ },
1072
+ {
1073
+ "cell_type": "code",
1074
+ "execution_count": null,
1075
+ "metadata": {
1076
+ "id": "6EIRAlXcSQaA"
1077
+ },
1078
+ "outputs": [],
1079
+ "source": [
1080
+ "input_image_size = (HEIGHT, WIDTH)\n",
1081
+ "plt.figure(figsize=(20, 20))\n",
1082
+ "min_score_thresh = 0.40 # Change minimum score for threshold to see all bounding boxes confidences\n",
1083
+ "\n",
1084
+ "for i, serialized_example in enumerate(test_ds):\n",
1085
+ " plt.subplot(1, 3, i+1)\n",
1086
+ " decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
1087
+ " image = build_inputs_for_object_detection(decoded_tensors['image'], input_image_size)\n",
1088
+ " image = tf.expand_dims(image, axis=0)\n",
1089
+ " image = tf.cast(image, dtype = tf.uint8)\n",
1090
+ " image_np = image[0].numpy()\n",
1091
+ " result = model_fn(image)\n",
1092
+ " # Visualize detection and masks\n",
1093
+ " if 'detection_masks' in result:\n",
1094
+ " # we need to convert np.arrays to tensors\n",
1095
+ " detection_masks = tf.convert_to_tensor(result['detection_masks'][0])\n",
1096
+ " detection_boxes = tf.convert_to_tensor(result['detection_boxes'][0])\n",
1097
+ " detection_masks_reframed = reframe_box_masks_to_image_masks(\n",
1098
+ " detection_masks, detection_boxes/256.0,\n",
1099
+ " image_np.shape[0], image_np.shape[1])\n",
1100
+ " detection_masks_reframed = tf.cast(\n",
1101
+ " detection_masks_reframed > min_score_thresh,\n",
1102
+ " np.uint8)\n",
1103
+ "\n",
1104
+ " result['detection_masks_reframed'] = detection_masks_reframed.numpy()\n",
1105
+ " visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
1106
+ " image_np,\n",
1107
+ " result['detection_boxes'][0].numpy(),\n",
1108
+ " (result['detection_classes'][0] + 0).numpy().astype(int),\n",
1109
+ " result['detection_scores'][0].numpy(),\n",
1110
+ " category_index=category_index,\n",
1111
+ " use_normalized_coordinates=False,\n",
1112
+ " max_boxes_to_draw=200,\n",
1113
+ " min_score_thresh=min_score_thresh,\n",
1114
+ " instance_masks=result.get('detection_masks_reframed', None),\n",
1115
+ " line_thickness=4)\n",
1116
+ "\n",
1117
+ " plt.imshow(image_np)\n",
1118
+ " plt.axis(\"off\")\n",
1119
+ "\n",
1120
+ "plt.show()"
1121
+ ]
1122
+ }
1123
+ ],
1124
+ "metadata": {
1125
+ "accelerator": "GPU",
1126
+ "colab": {
1127
+ "name": "instance_segmentation.ipynb",
1128
+ "provenance": [],
1129
+ "toc_visible": true
1130
+ },
1131
+ "kernelspec": {
1132
+ "display_name": "Python 3",
1133
+ "name": "python3"
1134
+ }
1135
+ },
1136
+ "nbformat": 4,
1137
+ "nbformat_minor": 0
1138
+ }
models/docs/vision/object_detection.ipynb ADDED
@@ -0,0 +1,902 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "Cayt5nCXb3WG"
7
+ },
8
+ "source": [
9
+ "##### Copyright 2022 The TensorFlow Authors."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "id": "DYL3CXHRb9-f"
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
21
+ "# you may not use this file except in compliance with the License.\n",
22
+ "# You may obtain a copy of the License at\n",
23
+ "#\n",
24
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
25
+ "#\n",
26
+ "# Unless required by applicable law or agreed to in writing, software\n",
27
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
28
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
29
+ "# See the License for the specific language governing permissions and\n",
30
+ "# limitations under the License."
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "markdown",
35
+ "metadata": {
36
+ "id": "VYDmsvURYZjz"
37
+ },
38
+ "source": [
39
+ "# Object detection with Model Garden\n",
40
+ "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
41
+ " <td>\n",
42
+ " <a target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/object_detection\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
43
+ " </td>\n",
44
+ " <td>\n",
45
+ " <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/object_detection.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
46
+ " </td>\n",
47
+ " <td>\n",
48
+ " <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/object_detection.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
49
+ " </td>\n",
50
+ " <td>\n",
51
+ " <a href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/object_detection.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
52
+ " </td>\n",
53
+ "</table>"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "markdown",
58
+ "metadata": {
59
+ "id": "69aQq_PXcUvL"
60
+ },
61
+ "source": [
62
+ "This tutorial fine-tunes a [RetinaNet](https://arxiv.org/abs/1708.02002) with ResNet-50 as backbone model from the [TensorFlow Model Garden](https://pypi.org/project/tf-models-official/) package (tensorflow-models) to detect three different Blood Cells in [BCCD](https://public.roboflow.com/object-detection/bccd) dataset. The RetinaNet is pretrained on [COCO](https://cocodataset.org/) train2017 and evaluated on [COCO](https://cocodataset.org/) val2017\n",
63
+ "\n",
64
+ "[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
65
+ "\n",
66
+ "This tutorial demonstrates how to:\n",
67
+ "\n",
68
+ "1. Use models from the Tensorflow Model Garden(TFM) package.\n",
69
+ "2. Fine-tune a pre-trained RetinanNet with ResNet-50 as backbone for object detection.\n",
70
+ "3. Export the tuned RetinaNet model"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "metadata": {
76
+ "id": "IeSHlZyUZl6f"
77
+ },
78
+ "source": [
79
+ "## Install necessary dependencies"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {
86
+ "id": "Pip0LHj3ZqgL"
87
+ },
88
+ "outputs": [],
89
+ "source": [
90
+ "!pip install -U -q \"tf-models-official\""
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "metadata": {
96
+ "id": "H3kS7Y0sZsIj"
97
+ },
98
+ "source": [
99
+ "## Import required libraries"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {
106
+ "id": "hFdVelJ2YbQz"
107
+ },
108
+ "outputs": [],
109
+ "source": [
110
+ "import os\n",
111
+ "import io\n",
112
+ "import pprint\n",
113
+ "import tempfile\n",
114
+ "import matplotlib\n",
115
+ "import numpy as np\n",
116
+ "import tensorflow as tf\n",
117
+ "import matplotlib.pyplot as plt\n",
118
+ "\n",
119
+ "from PIL import Image\n",
120
+ "from six import BytesIO\n",
121
+ "from IPython import display\n",
122
+ "from urllib.request import urlopen"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "metadata": {
128
+ "id": "TF77J-iMZn_u"
129
+ },
130
+ "source": [
131
+ "## Import required libraries from tensorflow models"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {
138
+ "id": "iT27_SOTY1Dz"
139
+ },
140
+ "outputs": [],
141
+ "source": [
142
+ "import orbit\n",
143
+ "import tensorflow_models as tfm\n",
144
+ "\n",
145
+ "from official.core import exp_factory\n",
146
+ "from official.core import config_definitions as cfg\n",
147
+ "from official.vision.serving import export_saved_model_lib\n",
148
+ "from official.vision.ops.preprocess_ops import normalize_image\n",
149
+ "from official.vision.ops.preprocess_ops import resize_and_crop_image\n",
150
+ "from official.vision.utils.object_detection import visualization_utils\n",
151
+ "from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder\n",
152
+ "\n",
153
+ "pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
154
+ "print(tf.__version__) # Check the version of tensorflow used\n",
155
+ "\n",
156
+ "%matplotlib inline"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "markdown",
161
+ "metadata": {
162
+ "id": "WGbMG8cpZyKa"
163
+ },
164
+ "source": [
165
+ "## Custom dataset preparation for object detection\n",
166
+ "\n",
167
+ "Models in official repository(of model-garden) requires data in a TFRecords format.\n",
168
+ "\n",
169
+ "\n",
170
+ "Please check [this resource](https://www.tensorflow.org/tutorials/load_data/tfrecord) to learn more about TFRecords data format.\n"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "markdown",
175
+ "metadata": {
176
+ "id": "Uq5hcbJ8Z4th"
177
+ },
178
+ "source": [
179
+ "### Upload your custom data in drive or local disk of the notebook and unzip the data"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "metadata": {
186
+ "id": "rDixpoqoY3Za"
187
+ },
188
+ "outputs": [],
189
+ "source": [
190
+ "!curl -L 'https://public.roboflow.com/ds/ZpYLqHeT0W?key=ZXfZLRnhsc' > './BCCD.v1-bccd.coco.zip'\n",
191
+ "!unzip -q -o './BCCD.v1-bccd.coco.zip' -d './BCC.v1-bccd.coco/'\n",
192
+ "!rm './BCCD.v1-bccd.coco.zip'"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "metadata": {
198
+ "id": "GI1h9UChZ8cC"
199
+ },
200
+ "source": [
201
+ "### CLI command to convert data(train data)."
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {
208
+ "id": "x_8cmB82Y65O"
209
+ },
210
+ "outputs": [],
211
+ "source": [
212
+ "%%bash\n",
213
+ "\n",
214
+ "TRAIN_DATA_DIR='./BCC.v1-bccd.coco/train'\n",
215
+ "TRAIN_ANNOTATION_FILE_DIR='./BCC.v1-bccd.coco/train/_annotations.coco.json'\n",
216
+ "OUTPUT_TFRECORD_TRAIN='./bccd_coco_tfrecords/train'\n",
217
+ "\n",
218
+ "# Need to provide\n",
219
+ " # 1. image_dir: where images are present\n",
220
+ " # 2. object_annotations_file: where annotations are listed in json format\n",
221
+ " # 3. output_file_prefix: where to write output convered TFRecords files\n",
222
+ "python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
223
+ " --image_dir=${TRAIN_DATA_DIR} \\\n",
224
+ " --object_annotations_file=${TRAIN_ANNOTATION_FILE_DIR} \\\n",
225
+ " --output_file_prefix=$OUTPUT_TFRECORD_TRAIN \\\n",
226
+ " --num_shards=1"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "markdown",
231
+ "metadata": {
232
+ "id": "VuwZpwUoaAKU"
233
+ },
234
+ "source": [
235
+ "### CLI command to convert data(validation data)."
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": null,
241
+ "metadata": {
242
+ "id": "q8mQ8prGY8kh"
243
+ },
244
+ "outputs": [],
245
+ "source": [
246
+ "%%bash\n",
247
+ "\n",
248
+ "VALID_DATA_DIR='./BCC.v1-bccd.coco/valid'\n",
249
+ "VALID_ANNOTATION_FILE_DIR='./BCC.v1-bccd.coco/valid/_annotations.coco.json'\n",
250
+ "OUTPUT_TFRECORD_VALID='./bccd_coco_tfrecords/valid'\n",
251
+ "\n",
252
+ "python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
253
+ " --image_dir=$VALID_DATA_DIR \\\n",
254
+ " --object_annotations_file=$VALID_ANNOTATION_FILE_DIR \\\n",
255
+ " --output_file_prefix=$OUTPUT_TFRECORD_VALID \\\n",
256
+ " --num_shards=1"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "markdown",
261
+ "metadata": {
262
+ "id": "BYGxNNAXaCW6"
263
+ },
264
+ "source": [
265
+ "### CLI command to convert data(test data)."
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "metadata": {
272
+ "id": "-K8qlfstY-Ua"
273
+ },
274
+ "outputs": [],
275
+ "source": [
276
+ "%%bash\n",
277
+ "\n",
278
+ "TEST_DATA_DIR='./BCC.v1-bccd.coco/test'\n",
279
+ "TEST_ANNOTATION_FILE_DIR='./BCC.v1-bccd.coco/test/_annotations.coco.json'\n",
280
+ "OUTPUT_TFRECORD_TEST='./bccd_coco_tfrecords/test'\n",
281
+ "\n",
282
+ "python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
283
+ " --image_dir=$TEST_DATA_DIR \\\n",
284
+ " --object_annotations_file=$TEST_ANNOTATION_FILE_DIR \\\n",
285
+ " --output_file_prefix=$OUTPUT_TFRECORD_TEST \\\n",
286
+ " --num_shards=1"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "markdown",
291
+ "metadata": {
292
+ "id": "cW7hQEJTaEtj"
293
+ },
294
+ "source": [
295
+ "## Configure the Retinanet Resnet FPN COCO model for custom dataset.\n",
296
+ "\n",
297
+ "Dataset used for fine tuning the checkpoint is Blood Cells Detection (BCCD)."
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": null,
303
+ "metadata": {
304
+ "id": "PMGEl7iXZAAF"
305
+ },
306
+ "outputs": [],
307
+ "source": [
308
+ "train_data_input_path = './bccd_coco_tfrecords/train-00000-of-00001.tfrecord'\n",
309
+ "valid_data_input_path = './bccd_coco_tfrecords/valid-00000-of-00001.tfrecord'\n",
310
+ "test_data_input_path = './bccd_coco_tfrecords/test-00000-of-00001.tfrecord'\n",
311
+ "model_dir = './trained_model/'\n",
312
+ "export_dir ='./exported_model/'"
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "markdown",
317
+ "metadata": {
318
+ "id": "2DJpKvdeaHF3"
319
+ },
320
+ "source": [
321
+ "In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
322
+ "\n",
323
+ "\n",
324
+ "Use the `retinanet_resnetfpn_coco` experiment configuration, as defined by `tfm.vision.configs.retinanet.retinanet_resnetfpn_coco`.\n",
325
+ "\n",
326
+ "The configuration defines an experiment to train a RetinanNet with Resnet-50 as backbone, FPN as decoder. Default Configuration is trained on [COCO](https://cocodataset.org/) train2017 and evaluated on [COCO](https://cocodataset.org/) val2017.\n",
327
+ "\n",
328
+ "There are also other alternative experiments available such as\n",
329
+ "`retinanet_resnetfpn_coco`, `retinanet_spinenet_coco`, `fasterrcnn_resnetfpn_coco` and more. One can switch to them by changing the experiment name argument to the `get_exp_config` function.\n",
330
+ "\n",
331
+ "We are going to fine tune the Resnet-50 backbone checkpoint which is already present in the default configuration."
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": null,
337
+ "metadata": {
338
+ "id": "Ie1ObPH9ZBpa"
339
+ },
340
+ "outputs": [],
341
+ "source": [
342
+ "exp_config = exp_factory.get_exp_config('retinanet_resnetfpn_coco')"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "markdown",
347
+ "metadata": {
348
+ "id": "LFhjFkw-alba"
349
+ },
350
+ "source": [
351
+ "### Adjust the model and dataset configurations so that it works with custom dataset(in this case `BCCD`)."
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "execution_count": null,
357
+ "metadata": {
358
+ "id": "ej7j6dvIZDQA"
359
+ },
360
+ "outputs": [],
361
+ "source": [
362
+ "batch_size = 8\n",
363
+ "num_classes = 3\n",
364
+ "\n",
365
+ "HEIGHT, WIDTH = 256, 256\n",
366
+ "IMG_SIZE = [HEIGHT, WIDTH, 3]\n",
367
+ "\n",
368
+ "# Backbone config.\n",
369
+ "exp_config.task.freeze_backbone = False\n",
370
+ "exp_config.task.annotation_file = ''\n",
371
+ "\n",
372
+ "# Model config.\n",
373
+ "exp_config.task.model.input_size = IMG_SIZE\n",
374
+ "exp_config.task.model.num_classes = num_classes + 1\n",
375
+ "exp_config.task.model.detection_generator.tflite_post_processing.max_classes_per_detection = exp_config.task.model.num_classes\n",
376
+ "\n",
377
+ "# Training data config.\n",
378
+ "exp_config.task.train_data.input_path = train_data_input_path\n",
379
+ "exp_config.task.train_data.dtype = 'float32'\n",
380
+ "exp_config.task.train_data.global_batch_size = batch_size\n",
381
+ "exp_config.task.train_data.parser.aug_scale_max = 1.0\n",
382
+ "exp_config.task.train_data.parser.aug_scale_min = 1.0\n",
383
+ "\n",
384
+ "# Validation data config.\n",
385
+ "exp_config.task.validation_data.input_path = valid_data_input_path\n",
386
+ "exp_config.task.validation_data.dtype = 'float32'\n",
387
+ "exp_config.task.validation_data.global_batch_size = batch_size"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "markdown",
392
+ "metadata": {
393
+ "id": "ROVc1rayaqI1"
394
+ },
395
+ "source": [
396
+ "### Adjust the trainer configuration."
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": null,
402
+ "metadata": {
403
+ "id": "BZsCVBafZFIE"
404
+ },
405
+ "outputs": [],
406
+ "source": [
407
+ "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
408
+ "\n",
409
+ "if 'GPU' in ''.join(logical_device_names):\n",
410
+ " print('This may be broken in Colab.')\n",
411
+ " device = 'GPU'\n",
412
+ "elif 'TPU' in ''.join(logical_device_names):\n",
413
+ " print('This may be broken in Colab.')\n",
414
+ " device = 'TPU'\n",
415
+ "else:\n",
416
+ " print('Running on CPU is slow, so only train for a few steps.')\n",
417
+ " device = 'CPU'\n",
418
+ "\n",
419
+ "\n",
420
+ "train_steps = 1000\n",
421
+ "exp_config.trainer.steps_per_loop = 100 # steps_per_loop = num_of_training_examples // train_batch_size\n",
422
+ "\n",
423
+ "exp_config.trainer.summary_interval = 100\n",
424
+ "exp_config.trainer.checkpoint_interval = 100\n",
425
+ "exp_config.trainer.validation_interval = 100\n",
426
+ "exp_config.trainer.validation_steps = 100 # validation_steps = num_of_validation_examples // eval_batch_size\n",
427
+ "exp_config.trainer.train_steps = train_steps\n",
428
+ "exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100\n",
429
+ "exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
430
+ "exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
431
+ "exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
432
+ "exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "markdown",
437
+ "metadata": {
438
+ "id": "XS6cJfs2atgI"
439
+ },
440
+ "source": [
441
+ "### Print the modified configuration."
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": null,
447
+ "metadata": {
448
+ "id": "IvfJlqI7ZIcD"
449
+ },
450
+ "outputs": [],
451
+ "source": [
452
+ "pp.pprint(exp_config.as_dict())\n",
453
+ "display.Javascript('google.colab.output.setIframeHeight(\"500px\");')"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "markdown",
458
+ "metadata": {
459
+ "id": "6o5mbpRBawbs"
460
+ },
461
+ "source": [
462
+ "### Set up the distribution strategy."
463
+ ]
464
+ },
465
+ {
466
+ "cell_type": "code",
467
+ "execution_count": null,
468
+ "metadata": {
469
+ "id": "2NvY8QHOZKGr"
470
+ },
471
+ "outputs": [],
472
+ "source": [
473
+ "if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
474
+ " tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
475
+ "\n",
476
+ "if 'GPU' in ''.join(logical_device_names):\n",
477
+ " distribution_strategy = tf.distribute.MirroredStrategy()\n",
478
+ "elif 'TPU' in ''.join(logical_device_names):\n",
479
+ " tf.tpu.experimental.initialize_tpu_system()\n",
480
+ " tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
481
+ " distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
482
+ "else:\n",
483
+ " print('Warning: this will be really slow.')\n",
484
+ " distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n",
485
+ "\n",
486
+ "print('Done')"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "markdown",
491
+ "metadata": {
492
+ "id": "4wPtJgoOa33v"
493
+ },
494
+ "source": [
495
+ "## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
496
+ "\n",
497
+ "The `Task` object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "code",
502
+ "execution_count": null,
503
+ "metadata": {
504
+ "id": "Ns9LAsiXZLuX"
505
+ },
506
+ "outputs": [],
507
+ "source": [
508
+ "with distribution_strategy.scope():\n",
509
+ " task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "markdown",
514
+ "metadata": {
515
+ "id": "vTKbQxDkbArE"
516
+ },
517
+ "source": [
518
+ "## Visualize a batch of the data."
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "code",
523
+ "execution_count": null,
524
+ "metadata": {
525
+ "id": "3RIlbhj0ZNvt"
526
+ },
527
+ "outputs": [],
528
+ "source": [
529
+ "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
530
+ " print()\n",
531
+ " print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
532
+ " print(f'labels.keys: {labels.keys()}')"
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "markdown",
537
+ "metadata": {
538
+ "id": "m-QW7DoKbD8z"
539
+ },
540
+ "source": [
541
+ "### Create category index dictionary to map the labels to coressponding label names."
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": null,
547
+ "metadata": {
548
+ "id": "MN0sSthbZR-s"
549
+ },
550
+ "outputs": [],
551
+ "source": [
552
+ "category_index={\n",
553
+ " 1: {\n",
554
+ " 'id': 1,\n",
555
+ " 'name': 'Platelets'\n",
556
+ " },\n",
557
+ " 2: {\n",
558
+ " 'id': 2,\n",
559
+ " 'name': 'RBC'\n",
560
+ " },\n",
561
+ " 3: {\n",
562
+ " 'id': 3,\n",
563
+ " 'name': 'WBC'\n",
564
+ " }\n",
565
+ "}\n",
566
+ "tf_ex_decoder = TfExampleDecoder()"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "markdown",
571
+ "metadata": {
572
+ "id": "AcbmD1pRbGcS"
573
+ },
574
+ "source": [
575
+ "### Helper function for visualizing the results from TFRecords.\n",
576
+ "Use `visualize_boxes_and_labels_on_image_array` from `visualization_utils` to draw boudning boxes on the image."
577
+ ]
578
+ },
579
+ {
580
+ "cell_type": "code",
581
+ "execution_count": null,
582
+ "metadata": {
583
+ "id": "wWBeomMMZThI"
584
+ },
585
+ "outputs": [],
586
+ "source": [
587
+ "def show_batch(raw_records, num_of_examples):\n",
588
+ " plt.figure(figsize=(20, 20))\n",
589
+ " use_normalized_coordinates=True\n",
590
+ " min_score_thresh = 0.30\n",
591
+ " for i, serialized_example in enumerate(raw_records):\n",
592
+ " plt.subplot(1, 3, i + 1)\n",
593
+ " decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
594
+ " image = decoded_tensors['image'].numpy().astype('uint8')\n",
595
+ " scores = np.ones(shape=(len(decoded_tensors['groundtruth_boxes'])))\n",
596
+ " visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
597
+ " image,\n",
598
+ " decoded_tensors['groundtruth_boxes'].numpy(),\n",
599
+ " decoded_tensors['groundtruth_classes'].numpy().astype('int'),\n",
600
+ " scores,\n",
601
+ " category_index=category_index,\n",
602
+ " use_normalized_coordinates=use_normalized_coordinates,\n",
603
+ " max_boxes_to_draw=200,\n",
604
+ " min_score_thresh=min_score_thresh,\n",
605
+ " agnostic_mode=False,\n",
606
+ " instance_masks=None,\n",
607
+ " line_thickness=4)\n",
608
+ "\n",
609
+ " plt.imshow(image)\n",
610
+ " plt.axis('off')\n",
611
+ " plt.title(f'Image-{i+1}')\n",
612
+ " plt.show()"
613
+ ]
614
+ },
615
+ {
616
+ "cell_type": "markdown",
617
+ "metadata": {
618
+ "id": "R3EgriELbJly"
619
+ },
620
+ "source": [
621
+ "### Visualization of train data\n",
622
+ "\n",
623
+ "The bounding box detection has two components\n",
624
+ " 1. Class label of the object detected (e.g.RBC)\n",
625
+ " 2. Percentage of match between predicted and ground truth bounding boxes.\n",
626
+ "\n",
627
+ "**Note**: The reason of everything is 100% is because we are visualising the groundtruth."
628
+ ]
629
+ },
630
+ {
631
+ "cell_type": "code",
632
+ "execution_count": null,
633
+ "metadata": {
634
+ "id": "hdrsciGIZVNO"
635
+ },
636
+ "outputs": [],
637
+ "source": [
638
+ "buffer_size = 20\n",
639
+ "num_of_examples = 3\n",
640
+ "\n",
641
+ "raw_records = tf.data.TFRecordDataset(\n",
642
+ " exp_config.task.train_data.input_path).shuffle(\n",
643
+ " buffer_size=buffer_size).take(num_of_examples)\n",
644
+ "show_batch(raw_records, num_of_examples)"
645
+ ]
646
+ },
647
+ {
648
+ "cell_type": "markdown",
649
+ "metadata": {
650
+ "id": "IrWkJPyEbMKg"
651
+ },
652
+ "source": [
653
+ "## Train and evaluate.\n",
654
+ "\n",
655
+ "We follow the COCO challenge tradition to evaluate the accuracy of object detection based on mAP(mean Average Precision). Please check [here](https://cocodataset.org/#detection-eval) for detail explanation of how evaluation metrics for detection task is done.\n",
656
+ "\n",
657
+ "**IoU**: is defined as the area of the intersection divided by the area of the union of a predicted bounding box and ground truth bounding box."
658
+ ]
659
+ },
660
+ {
661
+ "cell_type": "code",
662
+ "execution_count": 18,
663
+ "metadata": {
664
+ "id": "SCjHHXvfZXX1"
665
+ },
666
+ "outputs": [],
667
+ "source": [
668
+ "model, eval_logs = tfm.core.train_lib.run_experiment(\n",
669
+ " distribution_strategy=distribution_strategy,\n",
670
+ " task=task,\n",
671
+ " mode='train_and_eval',\n",
672
+ " params=exp_config,\n",
673
+ " model_dir=model_dir,\n",
674
+ " run_post_eval=True)"
675
+ ]
676
+ },
677
+ {
678
+ "cell_type": "markdown",
679
+ "metadata": {
680
+ "id": "2Gd6uHLjbPKW"
681
+ },
682
+ "source": [
683
+ "## Load logs in tensorboard."
684
+ ]
685
+ },
686
+ {
687
+ "cell_type": "code",
688
+ "execution_count": null,
689
+ "metadata": {
690
+ "id": "Q6iDRUVqZY86"
691
+ },
692
+ "outputs": [],
693
+ "source": [
694
+ "%load_ext tensorboard\n",
695
+ "%tensorboard --logdir './trained_model/'"
696
+ ]
697
+ },
698
+ {
699
+ "cell_type": "markdown",
700
+ "metadata": {
701
+ "id": "AoL2MIJobReU"
702
+ },
703
+ "source": [
704
+ "## Saving and exporting the trained model.\n",
705
+ "\n",
706
+ "The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results."
707
+ ]
708
+ },
709
+ {
710
+ "cell_type": "code",
711
+ "execution_count": null,
712
+ "metadata": {
713
+ "id": "CmOBYXdXZah4"
714
+ },
715
+ "outputs": [],
716
+ "source": [
717
+ "export_saved_model_lib.export_inference_graph(\n",
718
+ " input_type='image_tensor',\n",
719
+ " batch_size=1,\n",
720
+ " input_image_size=[HEIGHT, WIDTH],\n",
721
+ " params=exp_config,\n",
722
+ " checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
723
+ " export_dir=export_dir)"
724
+ ]
725
+ },
726
+ {
727
+ "cell_type": "markdown",
728
+ "metadata": {
729
+ "id": "_JhXopm8bU1g"
730
+ },
731
+ "source": [
732
+ "## Inference from trained model"
733
+ ]
734
+ },
735
+ {
736
+ "cell_type": "code",
737
+ "execution_count": null,
738
+ "metadata": {
739
+ "id": "EbD4j1uCZcIV"
740
+ },
741
+ "outputs": [],
742
+ "source": [
743
+ "def load_image_into_numpy_array(path):\n",
744
+ " \"\"\"Load an image from file into a numpy array.\n",
745
+ "\n",
746
+ " Puts image into numpy array to feed into tensorflow graph.\n",
747
+ " Note that by convention we put it into a numpy array with shape\n",
748
+ " (height, width, channels), where channels=3 for RGB.\n",
749
+ "\n",
750
+ " Args:\n",
751
+ " path: the file path to the image\n",
752
+ "\n",
753
+ " Returns:\n",
754
+ " uint8 numpy array with shape (img_height, img_width, 3)\n",
755
+ " \"\"\"\n",
756
+ " image = None\n",
757
+ " if(path.startswith('http')):\n",
758
+ " response = urlopen(path)\n",
759
+ " image_data = response.read()\n",
760
+ " image_data = BytesIO(image_data)\n",
761
+ " image = Image.open(image_data)\n",
762
+ " else:\n",
763
+ " image_data = tf.io.gfile.GFile(path, 'rb').read()\n",
764
+ " image = Image.open(BytesIO(image_data))\n",
765
+ "\n",
766
+ " (im_width, im_height) = image.size\n",
767
+ " return np.array(image.getdata()).reshape(\n",
768
+ " (1, im_height, im_width, 3)).astype(np.uint8)\n",
769
+ "\n",
770
+ "\n",
771
+ "\n",
772
+ "def build_inputs_for_object_detection(image, input_image_size):\n",
773
+ " \"\"\"Builds Object Detection model inputs for serving.\"\"\"\n",
774
+ " image, _ = resize_and_crop_image(\n",
775
+ " image,\n",
776
+ " input_image_size,\n",
777
+ " padded_size=input_image_size,\n",
778
+ " aug_scale_min=1.0,\n",
779
+ " aug_scale_max=1.0)\n",
780
+ " return image"
781
+ ]
782
+ },
783
+ {
784
+ "cell_type": "markdown",
785
+ "metadata": {
786
+ "id": "o8bguhK_batq"
787
+ },
788
+ "source": [
789
+ "### Visualize test data."
790
+ ]
791
+ },
792
+ {
793
+ "cell_type": "code",
794
+ "execution_count": null,
795
+ "metadata": {
796
+ "id": "sOsDhYmyZd_m"
797
+ },
798
+ "outputs": [],
799
+ "source": [
800
+ "num_of_examples = 3\n",
801
+ "\n",
802
+ "test_ds = tf.data.TFRecordDataset(\n",
803
+ " './bccd_coco_tfrecords/test-00000-of-00001.tfrecord').take(\n",
804
+ " num_of_examples)\n",
805
+ "show_batch(test_ds, num_of_examples)"
806
+ ]
807
+ },
808
+ {
809
+ "cell_type": "markdown",
810
+ "metadata": {
811
+ "id": "kcYnb1Zfbba9"
812
+ },
813
+ "source": [
814
+ "### Importing SavedModel."
815
+ ]
816
+ },
817
+ {
818
+ "cell_type": "code",
819
+ "execution_count": null,
820
+ "metadata": {
821
+ "id": "nQ6waz9rZfhy"
822
+ },
823
+ "outputs": [],
824
+ "source": [
825
+ "imported = tf.saved_model.load(export_dir)\n",
826
+ "model_fn = imported.signatures['serving_default']"
827
+ ]
828
+ },
829
+ {
830
+ "cell_type": "markdown",
831
+ "metadata": {
832
+ "id": "CtB4gfZ3bfiC"
833
+ },
834
+ "source": [
835
+ "### Visualize predictions."
836
+ ]
837
+ },
838
+ {
839
+ "cell_type": "code",
840
+ "execution_count": null,
841
+ "metadata": {
842
+ "id": "UTSfNZ6yZhEV"
843
+ },
844
+ "outputs": [],
845
+ "source": [
846
+ "input_image_size = (HEIGHT, WIDTH)\n",
847
+ "plt.figure(figsize=(20, 20))\n",
848
+ "min_score_thresh = 0.30 # Change minimum score for threshold to see all bounding boxes confidences.\n",
849
+ "\n",
850
+ "for i, serialized_example in enumerate(test_ds):\n",
851
+ " plt.subplot(1, 3, i+1)\n",
852
+ " decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
853
+ " image = build_inputs_for_object_detection(decoded_tensors['image'], input_image_size)\n",
854
+ " image = tf.expand_dims(image, axis=0)\n",
855
+ " image = tf.cast(image, dtype = tf.uint8)\n",
856
+ " image_np = image[0].numpy()\n",
857
+ " result = model_fn(image)\n",
858
+ " visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
859
+ " image_np,\n",
860
+ " result['detection_boxes'][0].numpy(),\n",
861
+ " result['detection_classes'][0].numpy().astype(int),\n",
862
+ " result['detection_scores'][0].numpy(),\n",
863
+ " category_index=category_index,\n",
864
+ " use_normalized_coordinates=False,\n",
865
+ " max_boxes_to_draw=200,\n",
866
+ " min_score_thresh=min_score_thresh,\n",
867
+ " agnostic_mode=False,\n",
868
+ " instance_masks=None,\n",
869
+ " line_thickness=4)\n",
870
+ " plt.imshow(image_np)\n",
871
+ " plt.axis('off')\n",
872
+ "\n",
873
+ "plt.show()"
874
+ ]
875
+ }
876
+ ],
877
+ "metadata": {
878
+ "colab": {
879
+ "name": "object_detection.ipynb",
880
+ "provenance": [],
881
+ "toc_visible": true
882
+ },
883
+ "kernelspec": {
884
+ "display_name": "Python 3",
885
+ "name": "python3"
886
+ },
887
+ "language_info": {
888
+ "codemirror_mode": {
889
+ "name": "ipython",
890
+ "version": 3
891
+ },
892
+ "file_extension": ".py",
893
+ "mimetype": "text/x-python",
894
+ "name": "python",
895
+ "nbconvert_exporter": "python",
896
+ "pygments_lexer": "ipython3",
897
+ "version": "3.10.8"
898
+ }
899
+ },
900
+ "nbformat": 4,
901
+ "nbformat_minor": 0
902
+ }
models/docs/vision/semantic_segmentation.ipynb ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "uY4QMaQw9Yvi"
7
+ },
8
+ "source": [
9
+ "##### Copyright 2022 The TensorFlow Authors."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "cellView": "form",
17
+ "id": "NM0OBLSN9heW"
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
22
+ "# you may not use this file except in compliance with the License.\n",
23
+ "# You may obtain a copy of the License at\n",
24
+ "#\n",
25
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
26
+ "#\n",
27
+ "# Unless required by applicable law or agreed to in writing, software\n",
28
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
29
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
30
+ "# See the License for the specific language governing permissions and\n",
31
+ "# limitations under the License."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {
37
+ "id": "sg-GchQwFr_r"
38
+ },
39
+ "source": [
40
+ "# Semantic Segmentation with Model Garden\n",
41
+ "\n",
42
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
43
+ " \u003ctd\u003e\n",
44
+ " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/semantic_segmentation\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
45
+ " \u003c/td\u003e\n",
46
+ " \u003ctd\u003e\n",
47
+ " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/semantic_segmentation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
48
+ " \u003c/td\u003e\n",
49
+ " \u003ctd\u003e\n",
50
+ " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/semantic_segmentation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
51
+ " \u003c/td\u003e\n",
52
+ " \u003ctd\u003e\n",
53
+ " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/semantic_segmentation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
54
+ " \u003c/td\u003e\n",
55
+ "\u003c/table\u003e"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "markdown",
60
+ "metadata": {
61
+ "id": "c6J4IoNfN9jp"
62
+ },
63
+ "source": [
64
+ "This tutorial trains a [DeepLabV3](https://arxiv.org/pdf/1706.05587.pdf) with [Mobilenet V2](https://arxiv.org/abs/1801.04381) as backbone model from the [TensorFlow Model Garden](https://pypi.org/project/tf-models-official/) package (tensorflow-models).\n",
65
+ "\n",
66
+ "\n",
67
+ "[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
68
+ "\n",
69
+ "**Dataset**: [Oxford-IIIT Pets](https://www.tensorflow.org/datasets/catalog/oxford_iiit_pet)\n",
70
+ "\n",
71
+ "* The Oxford-IIIT pet dataset is a 37 category pet image dataset with roughly 200 images for each class. The images have large variations in scale, pose and lighting. All images have an associated ground truth annotation of breed.\n",
72
+ "\n",
73
+ "\n",
74
+ "**This tutorial demonstrates how to:**\n",
75
+ "\n",
76
+ "1. Use models from the TensorFlow Models package.\n",
77
+ "2. Train/Fine-tune a pre-built DeepLabV3 with mobilenet as backbone for Semantic Segmentation.\n",
78
+ "3. Export the trained/tuned DeepLabV3 model"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "metadata": {
84
+ "id": "AlxYhP0XFnDn"
85
+ },
86
+ "source": [
87
+ "## Install necessary dependencies"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {
94
+ "id": "pXWAySwgaWpN"
95
+ },
96
+ "outputs": [],
97
+ "source": [
98
+ "!pip install -U -q \"tensorflow\u003e=2.9.2\" \"tf-models-official\""
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "metadata": {
104
+ "id": "uExUsXlgaPD6"
105
+ },
106
+ "source": [
107
+ "## Import required libraries"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {
114
+ "id": "mOmKZ3Vky5t9"
115
+ },
116
+ "outputs": [],
117
+ "source": [
118
+ "import os\n",
119
+ "import pprint\n",
120
+ "import numpy as np\n",
121
+ "import matplotlib.pyplot as plt\n",
122
+ "\n",
123
+ "from IPython import display"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "metadata": {
130
+ "id": "nF8IHrXua_0b"
131
+ },
132
+ "outputs": [],
133
+ "source": [
134
+ "import tensorflow as tf\n",
135
+ "import tensorflow_datasets as tfds\n",
136
+ "\n",
137
+ "\n",
138
+ "import orbit\n",
139
+ "import tensorflow_models as tfm\n",
140
+ "from official.vision.data import tfrecord_lib\n",
141
+ "from official.vision.serving import export_saved_model_lib\n",
142
+ "from official.vision.utils.object_detection import visualization_utils\n",
143
+ "\n",
144
+ "pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
145
+ "print(tf.__version__) # Check the version of tensorflow used\n",
146
+ "\n",
147
+ "%matplotlib inline"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "markdown",
152
+ "metadata": {
153
+ "id": "gMs4l2dpaTd3"
154
+ },
155
+ "source": [
156
+ "## Custom dataset preparation for semantic segmentation\n",
157
+ "Models in Official repository (of model-garden) require models in a TFRecords dataformat.\n",
158
+ "\n",
159
+ "Please check [this resource](https://www.tensorflow.org/tutorials/load_data/tfrecord) to learn more about TFRecords data format.\n",
160
+ "\n",
161
+ "[Oxford_IIIT_pet:3](https://www.tensorflow.org/datasets/catalog/oxford_iiit_pet) dataset is taken from [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview)"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "metadata": {
168
+ "id": "JpWK1Z-N3fHh"
169
+ },
170
+ "outputs": [],
171
+ "source": [
172
+ "(train_ds, val_ds, test_ds), info = tfds.load(\n",
173
+ " 'oxford_iiit_pet:3.*.*',\n",
174
+ " split=['train+test[:50%]', 'test[50%:80%]', 'test[80%:100%]'],\n",
175
+ " with_info=True)\n",
176
+ "info"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "metadata": {
182
+ "id": "Sq6s11E1bMJB"
183
+ },
184
+ "source": [
185
+ "### Helper function to encode dataset as tfrecords"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {
192
+ "id": "NlEf_C-DjDHG"
193
+ },
194
+ "outputs": [],
195
+ "source": [
196
+ "def process_record(record):\n",
197
+ " keys_to_features = {\n",
198
+ " 'image/encoded': tfrecord_lib.convert_to_feature(\n",
199
+ " tf.io.encode_jpeg(record['image']).numpy()),\n",
200
+ " 'image/height': tfrecord_lib.convert_to_feature(record['image'].shape[0]),\n",
201
+ " 'image/width': tfrecord_lib.convert_to_feature(record['image'].shape[1]),\n",
202
+ " 'image/segmentation/class/encoded':tfrecord_lib.convert_to_feature(\n",
203
+ " tf.io.encode_png(record['segmentation_mask'] - 1).numpy())\n",
204
+ " }\n",
205
+ " example = tf.train.Example(\n",
206
+ " features=tf.train.Features(feature=keys_to_features))\n",
207
+ " return example"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "markdown",
212
+ "metadata": {
213
+ "id": "FoapGlIebP9r"
214
+ },
215
+ "source": [
216
+ "### Write TFRecords to a folder"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": null,
222
+ "metadata": {
223
+ "id": "dDbMn5q551LQ"
224
+ },
225
+ "outputs": [],
226
+ "source": [
227
+ "output_dir = './oxford_iiit_pet_tfrecords/'\n",
228
+ "LOG_EVERY = 100\n",
229
+ "if not os.path.exists(output_dir):\n",
230
+ " os.mkdir(output_dir)\n",
231
+ "\n",
232
+ "def write_tfrecords(dataset, output_path, num_shards=1):\n",
233
+ " writers = [\n",
234
+ " tf.io.TFRecordWriter(\n",
235
+ " output_path + '-%05d-of-%05d.tfrecord' % (i, num_shards))\n",
236
+ " for i in range(num_shards)\n",
237
+ " ]\n",
238
+ " for idx, record in enumerate(dataset):\n",
239
+ " if idx % LOG_EVERY == 0:\n",
240
+ " print('On image %d', idx)\n",
241
+ " tf_example = process_record(record)\n",
242
+ " writers[idx % num_shards].write(tf_example.SerializeToString())"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "markdown",
247
+ "metadata": {
248
+ "id": "QHDD-D7rbZj7"
249
+ },
250
+ "source": [
251
+ "### Write training data as TFRecords"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": null,
257
+ "metadata": {
258
+ "id": "qxJnVUfT0qBJ"
259
+ },
260
+ "outputs": [],
261
+ "source": [
262
+ "output_train_tfrecs = output_dir + 'train'\n",
263
+ "write_tfrecords(train_ds, output_train_tfrecs, num_shards=10)"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "markdown",
268
+ "metadata": {
269
+ "id": "ap55RwVFbhtu"
270
+ },
271
+ "source": [
272
+ "### Write validation data as TFRecords"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "execution_count": null,
278
+ "metadata": {
279
+ "id": "Fgq-VxF79ucR"
280
+ },
281
+ "outputs": [],
282
+ "source": [
283
+ "output_val_tfrecs = output_dir + 'val'\n",
284
+ "write_tfrecords(val_ds, output_val_tfrecs, num_shards=5)"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "markdown",
289
+ "metadata": {
290
+ "id": "0AZoIEzRbxZu"
291
+ },
292
+ "source": [
293
+ "### Write test data as TFRecords"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": null,
299
+ "metadata": {
300
+ "id": "QmwFmbP69t0U"
301
+ },
302
+ "outputs": [],
303
+ "source": [
304
+ "output_test_tfrecs = output_dir + 'test'\n",
305
+ "write_tfrecords(test_ds, output_test_tfrecs, num_shards=5)"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "markdown",
310
+ "metadata": {
311
+ "id": "uEFzV-6ZfBZW"
312
+ },
313
+ "source": [
314
+ "## Configure the DeepLabV3 Mobilenet model for custom dataset"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "metadata": {
321
+ "id": "_LPEIvLsqSaG"
322
+ },
323
+ "outputs": [],
324
+ "source": [
325
+ "train_data_tfrecords = './oxford_iiit_pet_tfrecords/train*'\n",
326
+ "val_data_tfrecords = './oxford_iiit_pet_tfrecords/val*'\n",
327
+ "test_data_tfrecords = './oxford_iiit_pet_tfrecords/test*'\n",
328
+ "trained_model = './trained_model/'\n",
329
+ "export_dir = './exported_model/'"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "markdown",
334
+ "metadata": {
335
+ "id": "1ZlSiSRyb1Q6"
336
+ },
337
+ "source": [
338
+ "In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
339
+ "\n",
340
+ "\n",
341
+ "Use the `mnv2_deeplabv3_pascal` experiment configuration, as defined by `tfm.vision.configs.semantic_segmentation.mnv2_deeplabv3_pascal`.\n",
342
+ "\n",
343
+ "Please find all the registered experiements [here](https://www.tensorflow.org/api_docs/python/tfm/core/exp_factory/get_exp_config)\n",
344
+ "\n",
345
+ "The configuration defines an experiment to train a [DeepLabV3](https://arxiv.org/pdf/1706.05587.pdf) model with MobilenetV2 as backbone and [ASPP](https://arxiv.org/pdf/1606.00915v2.pdf) as decoder.\n",
346
+ "\n",
347
+ "There are also other alternative experiments available such as\n",
348
+ "\n",
349
+ "* `seg_deeplabv3_pascal`\n",
350
+ "* `seg_deeplabv3plus_pascal`\n",
351
+ "* `seg_resnetfpn_pascal`\n",
352
+ "* `mnv2_deeplabv3plus_cityscapes`\n",
353
+ "\n",
354
+ "and more. One can switch to them by changing the experiment name argument to the `get_exp_config` function.\n"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "metadata": {
361
+ "id": "bj5UZ6BkfJCX"
362
+ },
363
+ "outputs": [],
364
+ "source": [
365
+ "exp_config = tfm.core.exp_factory.get_exp_config('mnv2_deeplabv3_pascal')"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": null,
371
+ "metadata": {
372
+ "id": "B8jyG-jGIdFs"
373
+ },
374
+ "outputs": [],
375
+ "source": [
376
+ "model_ckpt_path = './model_ckpt/'\n",
377
+ "if not os.path.exists(model_ckpt_path):\n",
378
+ " os.mkdir(model_ckpt_path)\n",
379
+ "\n",
380
+ "!gsutil cp gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63.data-00000-of-00001 './model_ckpt/'\n",
381
+ "!gsutil cp gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63.index './model_ckpt/'"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "markdown",
386
+ "metadata": {
387
+ "id": "QBYvVFZXhSGQ"
388
+ },
389
+ "source": [
390
+ "### Adjust the model and dataset configurations so that it works with custom dataset."
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "code",
395
+ "execution_count": null,
396
+ "metadata": {
397
+ "id": "o_Z_vWW9-5Sy"
398
+ },
399
+ "outputs": [],
400
+ "source": [
401
+ "num_classes = 3\n",
402
+ "WIDTH, HEIGHT = 128, 128\n",
403
+ "input_size = [HEIGHT, WIDTH, 3]\n",
404
+ "BATCH_SIZE = 16\n",
405
+ "\n",
406
+ "# Backbone Config\n",
407
+ "exp_config.task.init_checkpoint = model_ckpt_path + 'best_ckpt-63'\n",
408
+ "exp_config.task.freeze_backbone = True\n",
409
+ "\n",
410
+ "# Model Config\n",
411
+ "exp_config.task.model.num_classes = num_classes\n",
412
+ "exp_config.task.model.input_size = input_size\n",
413
+ "\n",
414
+ "# Training Data Config\n",
415
+ "exp_config.task.train_data.aug_scale_min = 1.0\n",
416
+ "exp_config.task.train_data.aug_scale_max = 1.0\n",
417
+ "exp_config.task.train_data.input_path = train_data_tfrecords\n",
418
+ "exp_config.task.train_data.global_batch_size = BATCH_SIZE\n",
419
+ "exp_config.task.train_data.dtype = 'float32'\n",
420
+ "exp_config.task.train_data.output_size = [HEIGHT, WIDTH]\n",
421
+ "exp_config.task.train_data.preserve_aspect_ratio = False\n",
422
+ "exp_config.task.train_data.seed = 21 # Reproducable Training Data\n",
423
+ "\n",
424
+ "# Validation Data Config\n",
425
+ "exp_config.task.validation_data.input_path = val_data_tfrecords\n",
426
+ "exp_config.task.validation_data.global_batch_size = BATCH_SIZE\n",
427
+ "exp_config.task.validation_data.dtype = 'float32'\n",
428
+ "exp_config.task.validation_data.output_size = [HEIGHT, WIDTH]\n",
429
+ "exp_config.task.validation_data.preserve_aspect_ratio = False\n",
430
+ "exp_config.task.validation_data.groundtruth_padded_size = [HEIGHT, WIDTH]\n",
431
+ "exp_config.task.validation_data.seed = 21 # Reproducable Validation Data"
432
+ ]
433
+ },
434
+ {
435
+ "cell_type": "markdown",
436
+ "metadata": {
437
+ "id": "0HDg5eKniMGJ"
438
+ },
439
+ "source": [
440
+ "### Adjust the trainer configuration."
441
+ ]
442
+ },
443
+ {
444
+ "cell_type": "code",
445
+ "execution_count": null,
446
+ "metadata": {
447
+ "id": "WASJZ3gUH8ni"
448
+ },
449
+ "outputs": [],
450
+ "source": [
451
+ "logical_device_names = [logical_device.name\n",
452
+ " for logical_device in tf.config.list_logical_devices()]\n",
453
+ "\n",
454
+ "if 'GPU' in ''.join(logical_device_names):\n",
455
+ " print('This may be broken in Colab.')\n",
456
+ " device = 'GPU'\n",
457
+ "elif 'TPU' in ''.join(logical_device_names):\n",
458
+ " print('This may be broken in Colab.')\n",
459
+ " device = 'TPU'\n",
460
+ "else:\n",
461
+ " print('Running on CPU is slow, so only train for a few steps.')\n",
462
+ " device = 'CPU'\n",
463
+ "\n",
464
+ "\n",
465
+ "train_steps = 2000\n",
466
+ "exp_config.trainer.steps_per_loop = int(train_ds.__len__().numpy() // BATCH_SIZE)\n",
467
+ "\n",
468
+ "exp_config.trainer.summary_interval = exp_config.trainer.steps_per_loop # steps_per_loop = num_of_validation_examples // eval_batch_size\n",
469
+ "exp_config.trainer.checkpoint_interval = exp_config.trainer.steps_per_loop\n",
470
+ "exp_config.trainer.validation_interval = exp_config.trainer.steps_per_loop\n",
471
+ "exp_config.trainer.validation_steps = int(train_ds.__len__().numpy() // BATCH_SIZE) # validation_steps = num_of_validation_examples // eval_batch_size\n",
472
+ "exp_config.trainer.train_steps = train_steps\n",
473
+ "exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = exp_config.trainer.steps_per_loop\n",
474
+ "exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
475
+ "exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
476
+ "exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
477
+ "exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "markdown",
482
+ "metadata": {
483
+ "id": "R66w5MwkiO8Z"
484
+ },
485
+ "source": [
486
+ "### Print the modified configuration."
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": null,
492
+ "metadata": {
493
+ "id": "ckpjzrqfhoSn"
494
+ },
495
+ "outputs": [],
496
+ "source": [
497
+ "pp.pprint(exp_config.as_dict())\n",
498
+ "display.Javascript('google.colab.output.setIframeHeight(\"500px\");')"
499
+ ]
500
+ },
501
+ {
502
+ "cell_type": "markdown",
503
+ "metadata": {
504
+ "id": "FYwzdGKAiSOV"
505
+ },
506
+ "source": [
507
+ "### Set up the distribution strategy."
508
+ ]
509
+ },
510
+ {
511
+ "cell_type": "code",
512
+ "execution_count": null,
513
+ "metadata": {
514
+ "id": "iwiOuYRRqdBi"
515
+ },
516
+ "outputs": [],
517
+ "source": [
518
+ "# Setting up the Strategy\n",
519
+ "if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
520
+ " tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
521
+ "\n",
522
+ "if 'GPU' in ''.join(logical_device_names):\n",
523
+ " distribution_strategy = tf.distribute.MirroredStrategy()\n",
524
+ "elif 'TPU' in ''.join(logical_device_names):\n",
525
+ " tf.tpu.experimental.initialize_tpu_system()\n",
526
+ " tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
527
+ " distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
528
+ "else:\n",
529
+ " print('Warning: this will be really slow.')\n",
530
+ " distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n",
531
+ "\n",
532
+ "print(\"Done\")"
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "markdown",
537
+ "metadata": {
538
+ "id": "ZLtk1GIIiVR2"
539
+ },
540
+ "source": [
541
+ "## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
542
+ "\n",
543
+ "The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
544
+ ]
545
+ },
546
+ {
547
+ "cell_type": "code",
548
+ "execution_count": null,
549
+ "metadata": {
550
+ "id": "ASTB5D2UISSr"
551
+ },
552
+ "outputs": [],
553
+ "source": [
554
+ "model_dir = './trained_model/'\n",
555
+ "\n",
556
+ "with distribution_strategy.scope():\n",
557
+ " task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)"
558
+ ]
559
+ },
560
+ {
561
+ "cell_type": "markdown",
562
+ "metadata": {
563
+ "id": "YIQ26TW-ihzA"
564
+ },
565
+ "source": [
566
+ "## Visualize a batch of the data."
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "code",
571
+ "execution_count": null,
572
+ "metadata": {
573
+ "id": "412WyIUAIdCr"
574
+ },
575
+ "outputs": [],
576
+ "source": [
577
+ "for images, masks in task.build_inputs(exp_config.task.train_data).take(1):\n",
578
+ " print()\n",
579
+ " print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
580
+ " print(f'masks.shape: {str(masks[\"masks\"].shape):16} images.dtype: {masks[\"masks\"].dtype!r}')"
581
+ ]
582
+ },
583
+ {
584
+ "cell_type": "markdown",
585
+ "metadata": {
586
+ "id": "3GgluDVJixMd"
587
+ },
588
+ "source": [
589
+ "### Helper function for visualizing the results from TFRecords"
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": null,
595
+ "metadata": {
596
+ "id": "1kueMMfERvLx"
597
+ },
598
+ "outputs": [],
599
+ "source": [
600
+ "def display(display_list):\n",
601
+ " plt.figure(figsize=(15, 15))\n",
602
+ "\n",
603
+ " title = ['Input Image', 'True Mask', 'Predicted Mask']\n",
604
+ "\n",
605
+ " for i in range(len(display_list)):\n",
606
+ " plt.subplot(1, len(display_list), i+1)\n",
607
+ " plt.title(title[i])\n",
608
+ " plt.imshow(tf.keras.utils.array_to_img(display_list[i]))\n",
609
+ "\n",
610
+ "\n",
611
+ " plt.axis('off')\n",
612
+ " plt.show()"
613
+ ]
614
+ },
615
+ {
616
+ "cell_type": "markdown",
617
+ "metadata": {
618
+ "id": "ZCtt09G7i3dq"
619
+ },
620
+ "source": [
621
+ "### Visualization of training data\n",
622
+ "\n",
623
+ "Image Title represents what is depicted from the image.\n",
624
+ "\n",
625
+ "Same helper function can be used while visualizing predicted mask"
626
+ ]
627
+ },
628
+ {
629
+ "cell_type": "code",
630
+ "execution_count": null,
631
+ "metadata": {
632
+ "id": "YwUPf9V2B6SR"
633
+ },
634
+ "outputs": [],
635
+ "source": [
636
+ "num_examples = 3\n",
637
+ "\n",
638
+ "for images, masks in task.build_inputs(exp_config.task.train_data).take(num_examples):\n",
639
+ " display([images[0], masks['masks'][0]])"
640
+ ]
641
+ },
642
+ {
643
+ "cell_type": "markdown",
644
+ "metadata": {
645
+ "id": "MeJ5w8KfjMmP"
646
+ },
647
+ "source": [
648
+ "## Train and evaluate\n",
649
+ "**IoU**: is defined as the area of the intersection divided by the area of the union of a predicted mask and ground truth mask."
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "code",
654
+ "execution_count": null,
655
+ "metadata": {
656
+ "id": "ru3aHTCySHoH"
657
+ },
658
+ "outputs": [],
659
+ "source": [
660
+ "model, eval_logs = tfm.core.train_lib.run_experiment(\n",
661
+ " distribution_strategy=distribution_strategy,\n",
662
+ " task=task,\n",
663
+ " mode='train_and_eval',\n",
664
+ " params=exp_config,\n",
665
+ " model_dir=model_dir,\n",
666
+ " run_post_eval=True)"
667
+ ]
668
+ },
669
+ {
670
+ "cell_type": "markdown",
671
+ "metadata": {
672
+ "id": "vt3WmtxhjfGe"
673
+ },
674
+ "source": [
675
+ "## Load logs in tensorboard"
676
+ ]
677
+ },
678
+ {
679
+ "cell_type": "code",
680
+ "execution_count": null,
681
+ "metadata": {
682
+ "id": "A9rct_7BoJFb"
683
+ },
684
+ "outputs": [],
685
+ "source": [
686
+ "%load_ext tensorboard\n",
687
+ "%tensorboard --logdir './trained_model'"
688
+ ]
689
+ },
690
+ {
691
+ "cell_type": "markdown",
692
+ "metadata": {
693
+ "id": "v6XaGoUuji7P"
694
+ },
695
+ "source": [
696
+ "## Saving and exporting the trained model\n",
697
+ "\n",
698
+ "The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results."
699
+ ]
700
+ },
701
+ {
702
+ "cell_type": "code",
703
+ "execution_count": null,
704
+ "metadata": {
705
+ "id": "GVsnyqzdnxHd"
706
+ },
707
+ "outputs": [],
708
+ "source": [
709
+ "export_saved_model_lib.export_inference_graph(\n",
710
+ " input_type='image_tensor',\n",
711
+ " batch_size=1,\n",
712
+ " input_image_size=[HEIGHT, WIDTH],\n",
713
+ " params=exp_config,\n",
714
+ " checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
715
+ " export_dir=export_dir)"
716
+ ]
717
+ },
718
+ {
719
+ "cell_type": "markdown",
720
+ "metadata": {
721
+ "id": "nM1S-tjIjvAr"
722
+ },
723
+ "source": [
724
+ "## Importing SavedModel"
725
+ ]
726
+ },
727
+ {
728
+ "cell_type": "code",
729
+ "execution_count": null,
730
+ "metadata": {
731
+ "id": "Nxi9pEwluUcT"
732
+ },
733
+ "outputs": [],
734
+ "source": [
735
+ "imported = tf.saved_model.load(export_dir)\n",
736
+ "model_fn = imported.signatures['serving_default']"
737
+ ]
738
+ },
739
+ {
740
+ "cell_type": "markdown",
741
+ "metadata": {
742
+ "id": "LbBfl6AUj_My"
743
+ },
744
+ "source": [
745
+ "## Visualize predictions"
746
+ ]
747
+ },
748
+ {
749
+ "cell_type": "code",
750
+ "execution_count": null,
751
+ "metadata": {
752
+ "id": "qifGt_ohpFhn"
753
+ },
754
+ "outputs": [],
755
+ "source": [
756
+ "def create_mask(pred_mask):\n",
757
+ " pred_mask = tf.math.argmax(pred_mask, axis=-1)\n",
758
+ " pred_mask = pred_mask[..., tf.newaxis]\n",
759
+ " return pred_mask[0]\n",
760
+ "\n",
761
+ "\n",
762
+ "for record in test_ds.take(15):\n",
763
+ " image = tf.image.resize(record['image'], size=[HEIGHT, WIDTH])\n",
764
+ " image = tf.cast(image, dtype=tf.uint8)\n",
765
+ " mask = tf.image.resize(record['segmentation_mask'], size=[HEIGHT, WIDTH])\n",
766
+ " predicted_mask = model_fn(tf.expand_dims(record['image'], axis=0))\n",
767
+ " display([image, mask, create_mask(predicted_mask['logits'])])"
768
+ ]
769
+ }
770
+ ],
771
+ "metadata": {
772
+ "accelerator": "GPU",
773
+ "colab": {
774
+ "name": "semantic_segmentation.ipynb",
775
+ "provenance": [],
776
+ "toc_visible": true
777
+ },
778
+ "kernelspec": {
779
+ "display_name": "Python 3",
780
+ "name": "python3"
781
+ }
782
+ },
783
+ "nbformat": 4,
784
+ "nbformat_minor": 0
785
+ }
models/official/README-TPU.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Offically Supported TensorFlow 2.1+ Models on Cloud TPU
2
+
3
+ ## Natural Language Processing
4
+
5
+ * [bert](nlp/bert): A powerful pre-trained language representation model:
6
+ BERT, which stands for Bidirectional Encoder Representations from
7
+ Transformers.
8
+ [BERT FineTuning with Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/bert-2.x) provides step by step instructions on Cloud TPU training. You can look [Bert MNLI Tensorboard.dev metrics](https://tensorboard.dev/experiment/LijZ1IrERxKALQfr76gndA) for MNLI fine tuning task.
9
+ * [transformer](nlp/transformer): A transformer model to translate the WMT
10
+ English to German dataset.
11
+ [Training transformer on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/transformer-2.x) for step by step instructions on Cloud TPU training.
12
+
13
+ ## Computer Vision
14
+
15
+ * [efficientnet](vision/image_classification): A family of convolutional
16
+ neural networks that scale by balancing network depth, width, and
17
+ resolution and can be used to classify ImageNet's dataset of 1000 classes.
18
+ See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/KnaWjrq5TXGfv0NW5m7rpg/#scalars).
19
+ * [mnist](vision/image_classification): A basic model to classify digits
20
+ from the MNIST dataset. See [Running MNIST on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/mnist-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/mIah5lppTASvrHqWrdr6NA).
21
+ * [mask-rcnn](vision/detection): An object detection and instance segmentation model. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/LH7k0fMsRwqUAcE09o9kPA).
22
+ * [resnet](vision/image_classification): A deep residual network that can
23
+ be used to classify ImageNet's dataset of 1000 classes.
24
+ See [Training ResNet on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/resnet-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/CxlDK8YMRrSpYEGtBRpOhg).
25
+ * [retinanet](vision/detection): A fast and powerful object detector. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/b8NRnWU3TqG6Rw0UxueU6Q).
26
+ * [shapemask](vision/detection): An object detection and instance segmentation model using shape priors. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/ZbXgVoc6Rf6mBRlPj0JpLA).
27
+
28
+ ## Recommendation
29
+ * [dlrm](recommendation/ranking): [Deep Learning Recommendation Model for
30
+ Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091).
31
+ * [dcn v2](recommendation/ranking): [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535).
32
+ * [ncf](recommendation): Neural Collaborative Filtering. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/0k3gKjZlR1ewkVTRyLB6IQ).
models/official/README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <img src="https://storage.googleapis.com/tf_model_garden/tf_model_garden_logo.png">
3
+ </div>
4
+
5
+ # TensorFlow Official Models
6
+
7
+ The TensorFlow official models are a collection of models
8
+ that use TensorFlow’s high-level APIs.
9
+ They are intended to be well-maintained, tested, and kept up to date
10
+ with the latest TensorFlow API.
11
+
12
+ They should also be reasonably optimized for fast performance while still
13
+ being easy to read.
14
+ These models are used as end-to-end tests, ensuring that the models run
15
+ with the same or improved speed and performance with each new TensorFlow build.
16
+
17
+ The API documentation of the latest stable release is published to
18
+ [tensorflow.org](https://www.tensorflow.org/api_docs/python/tfm).
19
+
20
+ ## More models to come!
21
+
22
+ The team is actively developing new models.
23
+ In the near future, we will add:
24
+
25
+ * State-of-the-art language understanding models.
26
+ * State-of-the-art image classification models.
27
+ * State-of-the-art object detection and instance segmentation models.
28
+ * State-of-the-art video classification models.
29
+
30
+ ## Table of Contents
31
+
32
+ - [Models and Implementations](#models-and-implementations)
33
+ * [Computer Vision](#computer-vision)
34
+ + [Image Classification](#image-classification)
35
+ + [Object Detection and Segmentation](#object-detection-and-segmentation)
36
+ + [Video Classification](#video-classification)
37
+ * [Natural Language Processing](#natural-language-processing)
38
+ * [Recommendation](#recommendation)
39
+ - [How to get started with the official models](#how-to-get-started-with-the-official-models)
40
+ - [Contributions](#contributions)
41
+
42
+ ## Models and Implementations
43
+
44
+ ### [Computer Vision](vision/README.md)
45
+
46
+ #### Image Classification
47
+
48
+ | Model | Reference (Paper) |
49
+ |-------|-------------------|
50
+ | [ResNet](vision/MODEL_GARDEN.md) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
51
+ | [ResNet-RS](vision/MODEL_GARDEN.md) | [Revisiting ResNets: Improved Training and Scaling Strategies](https://arxiv.org/abs/2103.07579) |
52
+ | [EfficientNet](vision/MODEL_GARDEN.md) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
53
+ | [Vision Transformer](vision/MODEL_GARDEN.md) | [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) |
54
+
55
+ #### Object Detection and Segmentation
56
+
57
+ | Model | Reference (Paper) |
58
+ |-------|-------------------|
59
+ | [RetinaNet](vision/MODEL_GARDEN.md) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
60
+ | [Mask R-CNN](vision/MODEL_GARDEN.md) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
61
+ | [YOLO](projects/yolo/README.md) | [YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors](https://arxiv.org/abs/2207.02696) |
62
+ | [SpineNet](vision/MODEL_GARDEN.md) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
63
+ | [Cascade RCNN-RS and RetinaNet-RS](vision/MODEL_GARDEN.md) | [Simple Training Strategies and Model Scaling for Object Detection](https://arxiv.org/abs/2107.00057)|
64
+
65
+ #### Video Classification
66
+
67
+ | Model | Reference (Paper) |
68
+ |-------|-------------------|
69
+ | [Mobile Video Networks (MoViNets)](projects/movinet) | [MoViNets: Mobile Video Networks for Efficient Video Recognition](https://arxiv.org/abs/2103.11511) |
70
+
71
+ ### [Natural Language Processing](nlp/README.md)
72
+
73
+ #### Pre-trained Language Model
74
+
75
+ | Model | Reference (Paper) |
76
+ |-------|-------------------|
77
+ | [ALBERT](nlp/MODEL_GARDEN.md#available-model-configs) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
78
+ | [BERT](nlp/MODEL_GARDEN.md#available-model-configs) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
79
+ | [ELECTRA](nlp/tasks/electra_task.py) | [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555) |
80
+
81
+
82
+ #### Neural Machine Translation
83
+
84
+ | Model | Reference (Paper) |
85
+ |-------|-------------------|
86
+ | [Transformer](nlp/MODEL_GARDEN.md#available-model-configs) | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
87
+
88
+ #### Natural Language Generation
89
+
90
+ | Model | Reference (Paper) |
91
+ |-------|-------------------|
92
+ | [NHNet (News Headline generation model)](projects/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
93
+
94
+
95
+ #### Knowledge Distillation
96
+
97
+ | Model | Reference (Paper) |
98
+ |-------|-------------------|
99
+ | [MobileBERT](projects/mobilebert) | [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) |
100
+
101
+ ### Recommendation
102
+
103
+ Model | Reference (Paper)
104
+ -------------------------------- | -----------------
105
+ [DLRM](recommendation/ranking) | [Deep Learning Recommendation Model for Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091)
106
+ [DCN v2](recommendation/ranking) | [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535)
107
+ [NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031)
108
+
109
+ ## How to get started with the official models
110
+
111
+ * The official models in the master branch are developed using
112
+ [master branch of TensorFlow 2](https://github.com/tensorflow/tensorflow/tree/master).
113
+ When you clone (the repository) or download (`pip` binary) master branch of
114
+ official models , master branch of TensorFlow gets downloaded as a
115
+ dependency. This is equivalent to the following.
116
+
117
+ ```shell
118
+ pip3 install tf-models-nightly
119
+ pip3 install tensorflow-text-nightly # when model uses `nlp` packages
120
+ ```
121
+
122
+ * Incase of stable versions, targeting a specific release, Tensorflow-models
123
+ repository version numbers match with the target TensorFlow release. For
124
+ example, [TensorFlow-models v2.8.x](https://github.com/tensorflow/models/releases/tag/v2.8.0)
125
+ is compatible with [TensorFlow v2.8.x](https://github.com/tensorflow/tensorflow/releases/tag/v2.8.0).
126
+ This is equivalent to the following:
127
+
128
+ ```shell
129
+ pip3 install tf-models-official==2.8.0
130
+ pip3 install tensorflow-text==2.8.0 # when models in uses `nlp` packages
131
+ ```
132
+
133
+ Starting from 2.9.x release, we release the modeling library as
134
+ `tensorflow_models` package and users can `import tensorflow_models` directly to
135
+ access to the exported symbols. If you are
136
+ using the latest nightly version or github code directly, please follow the
137
+ docstrings in the github.
138
+
139
+ Please follow the below steps before running models in this repository.
140
+
141
+ ### Requirements
142
+
143
+ * The latest TensorFlow Model Garden release and the latest TensorFlow 2
144
+ * If you are on a version of TensorFlow earlier than 2.2, please
145
+ upgrade your TensorFlow to [the latest TensorFlow 2](https://www.tensorflow.org/install/).
146
+ * Python 3.7+
147
+
148
+ Our integration tests run with Python 3.7. Although Python 3.6 should work, we
149
+ don't recommend earlier versions.
150
+
151
+ ### Installation
152
+
153
+ Please check [here](https://github.com/tensorflow/models#Installation) for the
154
+ instructions.
155
+
156
+ Available pypi packages:
157
+
158
+ * [tf-models-official](https://pypi.org/project/tf-models-official/)
159
+ * [tf-models-nightly](https://pypi.org/project/tf-models-nightly/): nightly
160
+ release with the latest changes.
161
+ * [tf-models-no-deps](https://pypi.org/project/tf-models-no-deps/): without
162
+ `tensorflow` and `tensorflow-text` in the `install_requires` list.
163
+
164
+ ## Contributions
165
+
166
+ If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
models/official/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
models/official/common/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
models/official/common/dataset_fn.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
16
+ #
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at
20
+ #
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+ #
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+ # ==============================================================================
29
+ """Utility library for picking an appropriate dataset function."""
30
+
31
+ import functools
32
+ from typing import Any, Callable, Type, Union
33
+
34
+ import tensorflow as tf, tf_keras
35
+
36
+ PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]]
37
+
38
+
39
+ def pick_dataset_fn(file_type: str) -> PossibleDatasetType:
40
+ if file_type == 'tfrecord':
41
+ return tf.data.TFRecordDataset
42
+ if file_type == 'tfrecord_compressed':
43
+ return functools.partial(tf.data.TFRecordDataset, compression_type='GZIP')
44
+ raise ValueError('Unrecognized file_type: {}'.format(file_type))
models/official/common/distribute_utils.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Helper functions for running models in a distributed setting."""
16
+
17
+ import json
18
+ import os
19
+ import tensorflow as tf, tf_keras
20
+
21
+
22
+ def _collective_communication(all_reduce_alg):
23
+ """Return a CollectiveCommunication based on all_reduce_alg.
24
+
25
+ Args:
26
+ all_reduce_alg: a string specifying which collective communication to pick,
27
+ or None.
28
+
29
+ Returns:
30
+ tf.distribute.experimental.CollectiveCommunication object
31
+
32
+ Raises:
33
+ ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
34
+ """
35
+ collective_communication_options = {
36
+ None: tf.distribute.experimental.CollectiveCommunication.AUTO,
37
+ "ring": tf.distribute.experimental.CollectiveCommunication.RING,
38
+ "nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
39
+ }
40
+ if all_reduce_alg not in collective_communication_options:
41
+ raise ValueError(
42
+ "When used with `multi_worker_mirrored`, valid values for "
43
+ "all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
44
+ all_reduce_alg))
45
+ return collective_communication_options[all_reduce_alg]
46
+
47
+
48
+ def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
49
+ """Return a CrossDeviceOps based on all_reduce_alg and num_packs.
50
+
51
+ Args:
52
+ all_reduce_alg: a string specifying which cross device op to pick, or None.
53
+ num_packs: an integer specifying number of packs for the cross device op.
54
+
55
+ Returns:
56
+ tf.distribute.CrossDeviceOps object or None.
57
+
58
+ Raises:
59
+ ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
60
+ """
61
+ if all_reduce_alg is None:
62
+ return None
63
+ mirrored_all_reduce_options = {
64
+ "nccl": tf.distribute.NcclAllReduce,
65
+ "hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
66
+ }
67
+ if all_reduce_alg not in mirrored_all_reduce_options:
68
+ raise ValueError(
69
+ "When used with `mirrored`, valid values for all_reduce_alg are "
70
+ "[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
71
+ all_reduce_alg))
72
+ cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
73
+ return cross_device_ops_class(num_packs=num_packs)
74
+
75
+
76
+ def tpu_initialize(tpu_address):
77
+ """Initializes TPU for TF 2.x training.
78
+
79
+ Args:
80
+ tpu_address: string, bns address of master TPU worker.
81
+
82
+ Returns:
83
+ A TPUClusterResolver.
84
+ """
85
+ cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
86
+ tpu=tpu_address)
87
+ if tpu_address not in ("", "local"):
88
+ tf.config.experimental_connect_to_cluster(cluster_resolver)
89
+ tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
90
+ return cluster_resolver
91
+
92
+
93
+ def get_distribution_strategy(distribution_strategy="mirrored",
94
+ num_gpus=0,
95
+ all_reduce_alg=None,
96
+ num_packs=1,
97
+ tpu_address=None,
98
+ **kwargs):
99
+ """Return a Strategy for running the model.
100
+
101
+ Args:
102
+ distribution_strategy: a string specifying which distribution strategy to
103
+ use. Accepted values are "off", "one_device", "mirrored",
104
+ "parameter_server", "multi_worker_mirrored", and "tpu" -- case
105
+ insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
106
+ "off" means to use the default strategy which is obtained from
107
+ tf.distribute.get_strategy (for details on the default strategy, see
108
+ https://www.tensorflow.org/guide/distributed_training#default_strategy).
109
+ num_gpus: Number of GPUs to run this model.
110
+ all_reduce_alg: Optional. Specifies which algorithm to use when performing
111
+ all-reduce. For `MirroredStrategy`, valid values are "nccl" and
112
+ "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
113
+ "ring" and "nccl". If None, DistributionStrategy will choose based on
114
+ device topology.
115
+ num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
116
+ or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
117
+ tpu_address: Optional. String that represents TPU to connect to. Must not be
118
+ None if `distribution_strategy` is set to `tpu`.
119
+ **kwargs: Additional kwargs for internal usages.
120
+
121
+ Returns:
122
+ tf.distribute.Strategy object.
123
+ Raises:
124
+ ValueError: if `distribution_strategy` is "off" or "one_device" and
125
+ `num_gpus` is larger than 1; or `num_gpus` is negative or if
126
+ `distribution_strategy` is `tpu` but `tpu_address` is not specified.
127
+ """
128
+ del kwargs
129
+ if num_gpus < 0:
130
+ raise ValueError("`num_gpus` can not be negative.")
131
+
132
+ if not isinstance(distribution_strategy, str):
133
+ msg = ("distribution_strategy must be a string but got: %s." %
134
+ (distribution_strategy,))
135
+ if distribution_strategy == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
136
+ msg += (" If you meant to pass the string 'off', make sure you add "
137
+ "quotes around 'off' so that yaml interprets it as a string "
138
+ "instead of a bool.")
139
+ raise ValueError(msg)
140
+
141
+ distribution_strategy = distribution_strategy.lower()
142
+ if distribution_strategy == "off":
143
+ if num_gpus > 1:
144
+ raise ValueError(f"When {num_gpus} GPUs are specified, "
145
+ "distribution_strategy flag cannot be set to `off`.")
146
+ # Return the default distribution strategy.
147
+ return tf.distribute.get_strategy()
148
+
149
+ if distribution_strategy == "tpu":
150
+ # When tpu_address is an empty string, we communicate with local TPUs.
151
+ cluster_resolver = tpu_initialize(tpu_address)
152
+ return tf.distribute.TPUStrategy(cluster_resolver)
153
+
154
+ if distribution_strategy == "multi_worker_mirrored":
155
+ return tf.distribute.experimental.MultiWorkerMirroredStrategy(
156
+ communication=_collective_communication(all_reduce_alg))
157
+
158
+ if distribution_strategy == "one_device":
159
+ if num_gpus == 0:
160
+ return tf.distribute.OneDeviceStrategy("device:CPU:0")
161
+ if num_gpus > 1:
162
+ raise ValueError("`OneDeviceStrategy` can not be used for more than "
163
+ "one device.")
164
+ return tf.distribute.OneDeviceStrategy("device:GPU:0")
165
+
166
+ if distribution_strategy == "mirrored":
167
+ if num_gpus == 0:
168
+ devices = ["device:CPU:0"]
169
+ else:
170
+ devices = ["device:GPU:%d" % i for i in range(num_gpus)]
171
+ return tf.distribute.MirroredStrategy(
172
+ devices=devices,
173
+ cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
174
+
175
+ if distribution_strategy == "parameter_server":
176
+ cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
177
+ return tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
178
+
179
+ raise ValueError("Unrecognized Distribution Strategy: %r" %
180
+ distribution_strategy)
181
+
182
+
183
+ def configure_cluster(worker_hosts=None, task_index=-1):
184
+ """Set multi-worker cluster spec in TF_CONFIG environment variable.
185
+
186
+ Args:
187
+ worker_hosts: comma-separated list of worker ip:port pairs.
188
+ task_index: index of the worker.
189
+
190
+ Returns:
191
+ Number of workers in the cluster.
192
+ """
193
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
194
+ if tf_config:
195
+ num_workers = (
196
+ len(tf_config["cluster"].get("chief", [])) +
197
+ len(tf_config["cluster"].get("worker", [])))
198
+ elif worker_hosts:
199
+ workers = worker_hosts.split(",")
200
+ num_workers = len(workers)
201
+ if num_workers > 1 and task_index < 0:
202
+ raise ValueError("Must specify task_index when number of workers > 1")
203
+ task_index = 0 if num_workers == 1 else task_index
204
+ os.environ["TF_CONFIG"] = json.dumps({
205
+ "cluster": {
206
+ "worker": workers
207
+ },
208
+ "task": {
209
+ "type": "worker",
210
+ "index": task_index
211
+ }
212
+ })
213
+ else:
214
+ num_workers = 1
215
+ return num_workers
216
+
217
+
218
+ def get_strategy_scope(strategy):
219
+ if strategy:
220
+ strategy_scope = strategy.scope()
221
+ else:
222
+ strategy_scope = DummyContextManager()
223
+
224
+ return strategy_scope
225
+
226
+
227
+ class DummyContextManager(object):
228
+
229
+ def __enter__(self):
230
+ pass
231
+
232
+ def __exit__(self, *args):
233
+ pass
models/official/common/distribute_utils_test.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for distribution util functions."""
16
+
17
+ import sys
18
+ import tensorflow as tf, tf_keras
19
+
20
+ from official.common import distribute_utils
21
+
22
+ TPU_TEST = 'test_tpu' in sys.argv[0]
23
+
24
+
25
+ class DistributeUtilsTest(tf.test.TestCase):
26
+ """Tests for distribute util functions."""
27
+
28
+ def test_invalid_args(self):
29
+ with self.assertRaisesRegex(ValueError, '`num_gpus` can not be negative.'):
30
+ _ = distribute_utils.get_distribution_strategy(num_gpus=-1)
31
+
32
+ with self.assertRaisesRegex(ValueError,
33
+ '.*If you meant to pass the string .*'):
34
+ _ = distribute_utils.get_distribution_strategy(
35
+ distribution_strategy=False, num_gpus=0)
36
+ with self.assertRaisesRegex(ValueError, 'When 2 GPUs are specified.*'):
37
+ _ = distribute_utils.get_distribution_strategy(
38
+ distribution_strategy='off', num_gpus=2)
39
+ with self.assertRaisesRegex(ValueError,
40
+ '`OneDeviceStrategy` can not be used.*'):
41
+ _ = distribute_utils.get_distribution_strategy(
42
+ distribution_strategy='one_device', num_gpus=2)
43
+
44
+ def test_one_device_strategy_cpu(self):
45
+ ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
46
+ self.assertEquals(ds.num_replicas_in_sync, 1)
47
+ self.assertEquals(len(ds.extended.worker_devices), 1)
48
+ self.assertIn('CPU', ds.extended.worker_devices[0])
49
+
50
+ def test_one_device_strategy_gpu(self):
51
+ ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=1)
52
+ self.assertEquals(ds.num_replicas_in_sync, 1)
53
+ self.assertEquals(len(ds.extended.worker_devices), 1)
54
+ self.assertIn('GPU', ds.extended.worker_devices[0])
55
+
56
+ def test_mirrored_strategy(self):
57
+ # CPU only.
58
+ _ = distribute_utils.get_distribution_strategy(num_gpus=0)
59
+ # 5 GPUs.
60
+ ds = distribute_utils.get_distribution_strategy(num_gpus=5)
61
+ self.assertEquals(ds.num_replicas_in_sync, 5)
62
+ self.assertEquals(len(ds.extended.worker_devices), 5)
63
+ for device in ds.extended.worker_devices:
64
+ self.assertIn('GPU', device)
65
+
66
+ _ = distribute_utils.get_distribution_strategy(
67
+ distribution_strategy='mirrored',
68
+ num_gpus=2,
69
+ all_reduce_alg='nccl',
70
+ num_packs=2)
71
+ with self.assertRaisesRegex(
72
+ ValueError,
73
+ 'When used with `mirrored`, valid values for all_reduce_alg are.*'):
74
+ _ = distribute_utils.get_distribution_strategy(
75
+ distribution_strategy='mirrored',
76
+ num_gpus=2,
77
+ all_reduce_alg='dummy',
78
+ num_packs=2)
79
+
80
+ def test_mwms(self):
81
+ distribute_utils.configure_cluster(worker_hosts=None, task_index=-1)
82
+ ds = distribute_utils.get_distribution_strategy(
83
+ 'multi_worker_mirrored', all_reduce_alg='nccl')
84
+ self.assertIsInstance(
85
+ ds, tf.distribute.experimental.MultiWorkerMirroredStrategy)
86
+
87
+ with self.assertRaisesRegex(
88
+ ValueError,
89
+ 'When used with `multi_worker_mirrored`, valid values.*'):
90
+ _ = distribute_utils.get_distribution_strategy(
91
+ 'multi_worker_mirrored', all_reduce_alg='dummy')
92
+
93
+ def test_no_strategy(self):
94
+ ds = distribute_utils.get_distribution_strategy('off')
95
+ self.assertIs(ds, tf.distribute.get_strategy())
96
+
97
+ def test_tpu_strategy(self):
98
+ if not TPU_TEST:
99
+ self.skipTest('Only Cloud TPU VM instances can have local TPUs.')
100
+ with self.assertRaises(ValueError):
101
+ _ = distribute_utils.get_distribution_strategy('tpu')
102
+
103
+ ds = distribute_utils.get_distribution_strategy('tpu', tpu_address='local')
104
+ self.assertIsInstance(
105
+ ds, tf.distribute.TPUStrategy)
106
+
107
+ def test_invalid_strategy(self):
108
+ with self.assertRaisesRegexp(
109
+ ValueError,
110
+ 'distribution_strategy must be a string but got: False. If'):
111
+ distribute_utils.get_distribution_strategy(False)
112
+ with self.assertRaisesRegexp(
113
+ ValueError, 'distribution_strategy must be a string but got: 1'):
114
+ distribute_utils.get_distribution_strategy(1)
115
+
116
+ def test_get_strategy_scope(self):
117
+ ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
118
+ with distribute_utils.get_strategy_scope(ds):
119
+ self.assertIs(tf.distribute.get_strategy(), ds)
120
+ with distribute_utils.get_strategy_scope(None):
121
+ self.assertIsNot(tf.distribute.get_strategy(), ds)
122
+
123
+ if __name__ == '__main__':
124
+ tf.test.main()
models/official/common/flags.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """The central place to define flags."""
16
+
17
+ from absl import flags
18
+
19
+
20
+ def define_flags():
21
+ """Defines flags.
22
+
23
+ All flags are defined as optional, but in practice most models use some of
24
+ these flags and so mark_flags_as_required() should be called after calling
25
+ this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
26
+ For example:
27
+
28
+ ```
29
+ from absl import flags
30
+ from official.common import flags as tfm_flags # pylint: disable=line-too-long
31
+ ...
32
+ tfm_flags.define_flags()
33
+ flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
34
+ ```
35
+
36
+ The reason all flags are optional is because unit tests often do not set or
37
+ use any of the flags.
38
+ """
39
+ flags.DEFINE_string(
40
+ 'experiment', default=None, help=
41
+ 'The experiment type registered, specifying an ExperimentConfig.')
42
+
43
+ flags.DEFINE_enum(
44
+ 'mode',
45
+ default=None,
46
+ enum_values=[
47
+ 'train', 'eval', 'train_and_eval', 'continuous_eval',
48
+ 'continuous_train_and_eval', 'train_and_validate',
49
+ 'train_and_post_eval'
50
+ ],
51
+ help='Mode to run: `train`, `eval`, `train_and_eval`, '
52
+ '`continuous_eval`, `continuous_train_and_eval` and '
53
+ '`train_and_validate` (which is not implemented in '
54
+ 'the open source version).')
55
+
56
+ flags.DEFINE_string(
57
+ 'model_dir',
58
+ default=None,
59
+ help='The directory where the model and training/evaluation summaries'
60
+ 'are stored.')
61
+
62
+ flags.DEFINE_multi_string(
63
+ 'config_file',
64
+ default=None,
65
+ help='YAML/JSON files which specifies overrides. The override order '
66
+ 'follows the order of args. Note that each file '
67
+ 'can be used as an override template to override the default parameters '
68
+ 'specified in Python. If the same parameter is specified in both '
69
+ '`--config_file` and `--params_override`, `config_file` will be used '
70
+ 'first, followed by params_override.')
71
+
72
+ flags.DEFINE_string(
73
+ 'params_override',
74
+ default=None,
75
+ help='a YAML/JSON string or a YAML file which specifies additional '
76
+ 'overrides over the default parameters and those specified in '
77
+ '`--config_file`. Note that this is supposed to be used only to override '
78
+ 'the model parameters, but not the parameters like TPU specific flags. '
79
+ 'One canonical use case of `--config_file` and `--params_override` is '
80
+ 'users first define a template config file using `--config_file`, then '
81
+ 'use `--params_override` to adjust the minimal set of tuning parameters, '
82
+ 'for example setting up different `train_batch_size`. The final override '
83
+ 'order of parameters: default_model_params --> params from config_file '
84
+ '--> params in params_override. See also the help message of '
85
+ '`--config_file`.')
86
+
87
+ # The libraries rely on gin often make mistakes that include flags inside
88
+ # the library files which causes conflicts.
89
+ try:
90
+ flags.DEFINE_multi_string(
91
+ 'gin_file', default=None, help='List of paths to the config files.')
92
+ except flags.DuplicateFlagError:
93
+ pass
94
+
95
+ try:
96
+ flags.DEFINE_multi_string(
97
+ 'gin_params',
98
+ default=None,
99
+ help='Newline separated list of Gin parameter bindings.')
100
+ except flags.DuplicateFlagError:
101
+ pass
102
+
103
+ flags.DEFINE_string(
104
+ 'tpu',
105
+ default=None,
106
+ help='The Cloud TPU to use for training. This should be either the name '
107
+ 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
108
+ 'url.')
109
+
110
+ flags.DEFINE_string(
111
+ 'tf_data_service', default=None, help='The tf.data service address')
112
+
113
+ flags.DEFINE_string(
114
+ 'tpu_platform', default=None, help='TPU platform type.')
models/official/common/registry_imports.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """All necessary imports for registration."""
16
+ # pylint: disable=unused-import
17
+ from official import vision
18
+ from official.nlp import tasks
19
+ from official.nlp.configs import experiment_configs
20
+ from official.utils.testing import mock_task
models/official/common/streamz_counters.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Global streamz counters."""
16
+
17
+ from tensorflow.python.eager import monitoring
18
+
19
+
20
+ progressive_policy_creation_counter = monitoring.Counter(
21
+ "/tensorflow/training/fast_training/progressive_policy_creation",
22
+ "Counter for the number of ProgressivePolicy creations.")
23
+
24
+
25
+ stack_vars_to_vars_call_counter = monitoring.Counter(
26
+ "/tensorflow/training/fast_training/tf_vars_to_vars",
27
+ "Counter for the number of low-level stacking API calls.")
models/official/core/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Core is shared by both `nlp` and `vision`."""
16
+
17
+ from official.core import actions
18
+ from official.core import base_task
19
+ from official.core import base_trainer
20
+ from official.core import config_definitions
21
+ from official.core import exp_factory
22
+ from official.core import export_base
23
+ from official.core import file_writers
24
+ from official.core import input_reader
25
+ from official.core import registry
26
+ from official.core import savedmodel_checkpoint_manager
27
+ from official.core import task_factory
28
+ from official.core import tf_example_builder
29
+ from official.core import tf_example_feature_key
30
+ from official.core import train_lib
31
+ from official.core import train_utils