NCTCMumbai
commited on
Upload 2583 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- CTH_CODE_MAP.csv +0 -0
- CTH_Description.csv +0 -0
- CTH_WISE_DUTY_RATE.csv +0 -0
- Checkpoint/assets/vocab.txt +0 -0
- Checkpoint/keras_metadata.pb +3 -0
- Checkpoint/saved_model.pb +3 -0
- Checkpoint/variables/variables.data-00000-of-00001 +3 -0
- Checkpoint/variables/variables.index +0 -0
- ETC/fun_advaitbert.py +339 -0
- app.py +91 -0
- fun_advaitbert.py +344 -0
- models/.github/ISSUE_TEMPLATE/00-official-bug-report-issue.md +59 -0
- models/.github/ISSUE_TEMPLATE/10-official-documentation-issue.md +20 -0
- models/.github/ISSUE_TEMPLATE/20-official-feature-request-issue.md +26 -0
- models/.github/ISSUE_TEMPLATE/30-research-bug-report-issue.md +58 -0
- models/.github/ISSUE_TEMPLATE/40-research-documentation-issue.md +20 -0
- models/.github/ISSUE_TEMPLATE/50-research-feature-request-issue.md +26 -0
- models/.github/ISSUE_TEMPLATE/60-questions-help-issue.md +14 -0
- models/.github/ISSUE_TEMPLATE/config.yml +1 -0
- models/.github/PULL_REQUEST_TEMPLATE.md +41 -0
- models/.github/README_TEMPLATE.md +122 -0
- models/.gitignore +98 -0
- models/AUTHORS +10 -0
- models/CODEOWNERS +61 -0
- models/CONTRIBUTING.md +10 -0
- models/ISSUES.md +24 -0
- models/LICENSE +203 -0
- models/README.md +39 -0
- models/official/LICENSE +203 -0
- models/official/README-TPU.md +25 -0
- models/official/README.md +142 -0
- models/official/__init__.py +0 -0
- models/official/__pycache__/__init__.cpython-310.pyc +0 -0
- models/official/__pycache__/__init__.cpython-38.pyc +0 -0
- models/official/__pycache__/__init__.cpython-39.pyc +0 -0
- models/official/benchmark/__init__.py +0 -0
- models/official/benchmark/benchmark_wrappers.py +97 -0
- models/official/benchmark/bert_benchmark.py +365 -0
- models/official/benchmark/bert_benchmark_utils.py +127 -0
- models/official/benchmark/bert_pretrain_benchmark.py +179 -0
- models/official/benchmark/bert_squad_benchmark.py +608 -0
- models/official/benchmark/datastore/schema/benchmark_metric.json +56 -0
- models/official/benchmark/datastore/schema/benchmark_run.json +368 -0
- models/official/benchmark/datastore/schema/benchmark_run_status.json +14 -0
- models/official/benchmark/keras_benchmark.py +98 -0
- models/official/benchmark/keras_cifar_benchmark.py +402 -0
- models/official/benchmark/keras_imagenet_benchmark.py +1724 -0
- models/official/benchmark/models/__init__.py +0 -0
- models/official/benchmark/models/cifar_preprocessing.py +159 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
Checkpoint/variables/variables.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
models/research/compression/image_encoder/example.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
models/research/deeplab/testing/pascal_voc_seg/val-00000-of-00001.tfrecord filter=lfs diff=lfs merge=lfs -text
|
39 |
+
models/research/lfads/synth_data/trained_itb/model-65000.meta filter=lfs diff=lfs merge=lfs -text
|
40 |
+
models/research/object_detection/g3doc/img/kites_with_segment_overlay.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
models/research/object_detection/test_images/image2.jpg filter=lfs diff=lfs merge=lfs -text
|
CTH_CODE_MAP.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
CTH_Description.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
CTH_WISE_DUTY_RATE.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Checkpoint/assets/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Checkpoint/keras_metadata.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:38bb0f1231a198848366566e176c8948dceab7b085b658d00550a83f784731f5
|
3 |
+
size 11535
|
Checkpoint/saved_model.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b0054848283b4fb79fefcebe71830bedb75e023ad04c5655adbc6a2ddd1e2c60
|
3 |
+
size 11477628
|
Checkpoint/variables/variables.data-00000-of-00001
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4ec6776ca3161577663eaa115fb9f965304670a1af8db7a37e9499a23082e67
|
3 |
+
size 1389095096
|
Checkpoint/variables/variables.index
ADDED
Binary file (46.6 kB). View file
|
|
ETC/fun_advaitbert.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
import tensorflow_hub as hub
|
5 |
+
import sys
|
6 |
+
import random
|
7 |
+
sys.path.append('models')
|
8 |
+
from official.nlp.data import classifier_data_lib
|
9 |
+
from official.nlp.bert import tokenization
|
10 |
+
from official.nlp import optimization
|
11 |
+
tf.get_logger().setLevel('ERROR')
|
12 |
+
from huggingface_hub import InferenceClient
|
13 |
+
import math
|
14 |
+
import gradio as gr
|
15 |
+
|
16 |
+
num_warmup_steps=1
|
17 |
+
num_train_steps=1
|
18 |
+
init_lr = 3e-5
|
19 |
+
optimizer = optimization.create_optimizer(init_lr=init_lr,num_train_steps=num_train_steps,num_warmup_steps=num_warmup_steps,optimizer_type='adamw')
|
20 |
+
|
21 |
+
### Load Model
|
22 |
+
checkpoint_filepath=r'./Checkpoint'
|
23 |
+
model = tf.keras.models.load_model(checkpoint_filepath, custom_objects={'KerasLayer':hub.KerasLayer , 'AdamWeightDecay': optimizer})
|
24 |
+
|
25 |
+
df_report = pd.read_csv('./CTH_Description.csv')
|
26 |
+
df_report['CTH Code'] = df_report['CTH Code'].astype(str).str.zfill(8)
|
27 |
+
|
28 |
+
df_report_DUTY = pd.read_csv('./CTH_WISE_DUTY_RATE.csv')
|
29 |
+
df_report_DUTY['CTH'] = df_report_DUTY['CTH'].astype(str).str.zfill(8)
|
30 |
+
|
31 |
+
df = pd.read_csv("./CTH_CODE_MAP.csv")
|
32 |
+
df['CTH'] = df['CTH'].astype(str).str.zfill(8)
|
33 |
+
df = df[['CTH', 'code']]
|
34 |
+
|
35 |
+
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class_names=df[['CTH','code']].drop_duplicates(subset='CTH').sort_values(by='code',ignore_index=True)['CTH'].values.tolist()
|
40 |
+
label_list=list(range(0,len(class_names)))
|
41 |
+
max_seq_length = 200 # maximum length of (token) input sequences . it can be any number
|
42 |
+
train_batch_size = 32 # batch size ( 16 choosen to avoid Out-Of-Memory errors)
|
43 |
+
|
44 |
+
# Get BERT layer and tokenizer:
|
45 |
+
# More details here: https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4
|
46 |
+
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4" , trainable = True)
|
47 |
+
vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
|
48 |
+
do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
|
49 |
+
tokenizer = tokenization.FullTokenizer(vocab_file , do_lower_case)
|
50 |
+
|
51 |
+
# This provides a function to convert each row to input features and label ( as required by BERT)
|
52 |
+
|
53 |
+
max_seq_length = 200 # maximum length of (token) input sequences . it can be any number
|
54 |
+
def to_feature(text, label, label_list=label_list, max_seq_length=max_seq_length, tokenizer=tokenizer):
|
55 |
+
example = classifier_data_lib.InputExample(guid = None,
|
56 |
+
text_a = text.numpy(),
|
57 |
+
text_b = None,
|
58 |
+
label = label.numpy())
|
59 |
+
feature = classifier_data_lib.convert_single_example(0 , example , label_list , max_seq_length , tokenizer)
|
60 |
+
|
61 |
+
return (feature.input_ids , feature.input_mask , feature.segment_ids , feature.label_id)
|
62 |
+
|
63 |
+
|
64 |
+
def to_feature_map(text, label):
|
65 |
+
input_ids , input_mask , segment_ids , label_id = tf.py_function(to_feature , inp = [text , label],
|
66 |
+
Tout = [tf.int32 , tf.int32 , tf.int32 , tf.int32])
|
67 |
+
|
68 |
+
input_ids.set_shape([max_seq_length])
|
69 |
+
input_mask.set_shape([max_seq_length])
|
70 |
+
segment_ids.set_shape([max_seq_length])
|
71 |
+
label_id.set_shape([])
|
72 |
+
|
73 |
+
x = {
|
74 |
+
"input_word_ids": input_ids,
|
75 |
+
"input_mask": input_mask,
|
76 |
+
"input_type_ids": segment_ids
|
77 |
+
}
|
78 |
+
|
79 |
+
return(x,label_id)
|
80 |
+
|
81 |
+
def print3largest(arr, arr_size):
|
82 |
+
third = first = second = -sys.maxsize
|
83 |
+
for i in range(0, arr_size):
|
84 |
+
|
85 |
+
if (arr[i] > first):
|
86 |
+
third = second
|
87 |
+
second = first
|
88 |
+
first = arr[i]
|
89 |
+
elif (arr[i] > second):
|
90 |
+
third = second
|
91 |
+
second = arr[i]
|
92 |
+
elif (arr[i] > third):
|
93 |
+
third = arr[i]
|
94 |
+
pred_value_max_three=[first, second, third]
|
95 |
+
return pred_value_max_three
|
96 |
+
|
97 |
+
def count_special_character(string):
|
98 |
+
special_char= 0
|
99 |
+
for i in range(len(string)):
|
100 |
+
ch = string[i]
|
101 |
+
if (string[i].isalpha()):
|
102 |
+
continue
|
103 |
+
else:
|
104 |
+
special_char += 1
|
105 |
+
|
106 |
+
if len(string)==special_char:
|
107 |
+
return False
|
108 |
+
else:
|
109 |
+
return True
|
110 |
+
|
111 |
+
def format_prompt(message, history):
|
112 |
+
prompt = "<s>"
|
113 |
+
for user_prompt, bot_response in history:
|
114 |
+
prompt += f"[INST] {user_prompt} [/INST]"
|
115 |
+
prompt += f" {bot_response}</s> "
|
116 |
+
prompt += f"[INST] {message} [/INST]"
|
117 |
+
return prompt
|
118 |
+
|
119 |
+
|
120 |
+
additional_inputs=[
|
121 |
+
gr.Textbox(
|
122 |
+
label="System Prompt",
|
123 |
+
max_lines=1,
|
124 |
+
interactive=True,
|
125 |
+
),
|
126 |
+
gr.Slider(
|
127 |
+
label="Temperature",
|
128 |
+
value=0.9,
|
129 |
+
minimum=0.0,
|
130 |
+
maximum=1.0,
|
131 |
+
step=0.05,
|
132 |
+
interactive=True,
|
133 |
+
info="Higher values produce more diverse outputs",
|
134 |
+
),
|
135 |
+
gr.Slider(
|
136 |
+
label="Max new tokens",
|
137 |
+
value=1024,
|
138 |
+
minimum=0,
|
139 |
+
maximum=4096,
|
140 |
+
step=64,
|
141 |
+
interactive=True,
|
142 |
+
info="The maximum numbers of new tokens",
|
143 |
+
),
|
144 |
+
gr.Slider(
|
145 |
+
label="Top-p (nucleus sampling)",
|
146 |
+
value=0.90,
|
147 |
+
minimum=0.0,
|
148 |
+
maximum=1,
|
149 |
+
step=0.05,
|
150 |
+
interactive=True,
|
151 |
+
info="Higher values sample more low-probability tokens",
|
152 |
+
),
|
153 |
+
gr.Slider(
|
154 |
+
label="Repetition penalty",
|
155 |
+
value=1.2,
|
156 |
+
minimum=1.0,
|
157 |
+
maximum=2.0,
|
158 |
+
step=0.05,
|
159 |
+
interactive=True,
|
160 |
+
info="Penalize repeated tokens",
|
161 |
+
)
|
162 |
+
]
|
163 |
+
|
164 |
+
def predict_CTH(txt):
|
165 |
+
print('Desc: ',txt)
|
166 |
+
if (txt!='') and len(txt)>=3 and (count_special_character(txt)):
|
167 |
+
valid_data = tf.data.Dataset.from_tensor_slices(([txt] , [1])) # 1 refers to 'entertainment' and 2 refers to 'sport'
|
168 |
+
valid_data = (valid_data.map(to_feature_map).batch(1))
|
169 |
+
preds = model.predict(valid_data)
|
170 |
+
predicted_values = tf.nn.softmax(preds)
|
171 |
+
arr = predicted_values.numpy().tolist()[0]
|
172 |
+
n = len(arr)
|
173 |
+
pred_value_max_three=print3largest(arr, n)
|
174 |
+
|
175 |
+
sum_all = pred_value_max_three[0] + pred_value_max_three[1] + pred_value_max_three[2]
|
176 |
+
|
177 |
+
val_1 = pred_value_max_three[0]/sum_all
|
178 |
+
val_2 = pred_value_max_three[1]/sum_all
|
179 |
+
val_3 = pred_value_max_three[2]/sum_all
|
180 |
+
|
181 |
+
if pred_value_max_three[0]<=0.000131:
|
182 |
+
Var_CTH=[]
|
183 |
+
Var_desc=[]
|
184 |
+
Var_duty=[]
|
185 |
+
pred_duty=''
|
186 |
+
pred_desc=''
|
187 |
+
pred_CTH=''
|
188 |
+
|
189 |
+
return{'Not a adequate description':float(1.0)}
|
190 |
+
else:
|
191 |
+
Var_CTH=[]
|
192 |
+
Var_desc=[]
|
193 |
+
Var_duty=[]
|
194 |
+
pred_duty=''
|
195 |
+
pred_desc=''
|
196 |
+
pred_CTH=''
|
197 |
+
|
198 |
+
for i in pred_value_max_three:
|
199 |
+
predicted_code=np.where(predicted_values.numpy()==i)[1][0]
|
200 |
+
pred_CTH=df[df['code'] == predicted_code]['CTH'].iloc[0]
|
201 |
+
|
202 |
+
try:
|
203 |
+
pred_duty=df_report_DUTY[df_report_DUTY['CTH']==str(pred_CTH)]['DUTY_RATE'].iloc[0]
|
204 |
+
pred_desc=df_report[df_report['CTH Code']==str(pred_CTH)]['Concat Description'].iloc[0]
|
205 |
+
except:
|
206 |
+
pass
|
207 |
+
|
208 |
+
Var_CTH.append(pred_CTH)
|
209 |
+
Var_desc.append(pred_desc)
|
210 |
+
Var_duty.append(pred_duty)
|
211 |
+
|
212 |
+
P1 ='CTH: '+str(Var_CTH[0])+' Duty Rate(%): '+ str(Var_duty[0])
|
213 |
+
P2 ='CTH: '+str(Var_CTH[1])+' Duty Rate(%): '+ str(Var_duty[1])
|
214 |
+
P3 ='CTH: '+str(Var_CTH[2])+' Duty Rate(%): '+ str(Var_duty[2])
|
215 |
+
|
216 |
+
Q1='Desc: '+str(Var_desc[0])
|
217 |
+
Q2='Desc: '+str(Var_desc[1])
|
218 |
+
Q3='Desc: '+str(Var_desc[2])
|
219 |
+
|
220 |
+
return {str(P1):float(val_1),str(Q1):float(val_1),
|
221 |
+
str(P2):float(val_2),str(Q2):float(val_2),
|
222 |
+
str(P3):float(val_3),str(Q3):float(val_3),}
|
223 |
+
else:
|
224 |
+
return{'Enter Correct Description':float(1.0)}
|
225 |
+
|
226 |
+
def llm_model_function(txt,history,chatbot=[], temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,):
|
227 |
+
system_prompt=[]
|
228 |
+
if (txt!='') and len(txt)>=3 and (count_special_character(txt)):
|
229 |
+
valid_data = tf.data.Dataset.from_tensor_slices(([txt] , [1])) # 1 refers to 'entertainment' and 2 refers to 'sport'
|
230 |
+
valid_data = (valid_data.map(to_feature_map).batch(1))
|
231 |
+
preds = model.predict(valid_data)
|
232 |
+
predicted_values = tf.nn.softmax(preds)
|
233 |
+
arr = predicted_values.numpy().tolist()[0]
|
234 |
+
n = len(arr)
|
235 |
+
pred_value_max_three=print3largest(arr, n)
|
236 |
+
|
237 |
+
sum_all = pred_value_max_three[0] + pred_value_max_three[1] + pred_value_max_three[2]
|
238 |
+
|
239 |
+
val_1 = pred_value_max_three[0]/sum_all
|
240 |
+
val_2 = pred_value_max_three[1]/sum_all
|
241 |
+
val_3 = pred_value_max_three[2]/sum_all
|
242 |
+
|
243 |
+
if pred_value_max_three[0]<=0.000131:
|
244 |
+
Var_CTH=[]
|
245 |
+
Var_desc=[]
|
246 |
+
Var_duty=[]
|
247 |
+
pred_duty=''
|
248 |
+
pred_desc=''
|
249 |
+
pred_CTH=''
|
250 |
+
|
251 |
+
return{'Not a adequate description':float(1.0)}
|
252 |
+
else:
|
253 |
+
Var_CTH=[]
|
254 |
+
Var_desc=[]
|
255 |
+
Var_duty=[]
|
256 |
+
pred_duty=''
|
257 |
+
pred_desc=''
|
258 |
+
pred_CTH=''
|
259 |
+
|
260 |
+
for i in pred_value_max_three:
|
261 |
+
predicted_code=np.where(predicted_values.numpy()==i)[1][0]
|
262 |
+
pred_CTH=df[df['code'] == predicted_code]['CTH'].iloc[0]
|
263 |
+
|
264 |
+
try:
|
265 |
+
pred_duty=df_report_DUTY[df_report_DUTY['CTH']==str(pred_CTH)]['DUTY_RATE'].iloc[0]
|
266 |
+
pred_desc=df_report[df_report['CTH Code']==str(pred_CTH)]['Concat Description'].iloc[0]
|
267 |
+
except:
|
268 |
+
pass
|
269 |
+
|
270 |
+
Var_CTH.append(pred_CTH)
|
271 |
+
Var_desc.append(pred_desc)
|
272 |
+
Var_duty.append(pred_duty)
|
273 |
+
|
274 |
+
P1 ='CTH: '+str(Var_CTH[0])+' Duty Rate(%): '+ str(Var_duty[0])
|
275 |
+
P2 ='CTH: '+str(Var_CTH[1])+' Duty Rate(%): '+ str(Var_duty[1])
|
276 |
+
P3 ='CTH: '+str(Var_CTH[2])+' Duty Rate(%): '+ str(Var_duty[2])
|
277 |
+
|
278 |
+
Q1='Desc: '+str(Var_desc[0])
|
279 |
+
Q2='Desc: '+str(Var_desc[1])
|
280 |
+
Q3='Desc: '+str(Var_desc[2])
|
281 |
+
|
282 |
+
output_str_msg='1. '+str(P1)+' '+str(Q1)+' '+'2. '+str(P2)+' '+str(Q2)+' '+'3. '+str(P3)+' '+str(Q3)
|
283 |
+
|
284 |
+
prompt=f'First Explain What is the product- {txt}. Which is the most appropriate 8 Digit classification code out of the three given below classes. Explain the reason step by step. if none of the three classification is applicable more precisely due to lack of any additional information, tell you need additional information and what is the that additional information. {output_str_msg} ?'
|
285 |
+
|
286 |
+
temperature = float(temperature)
|
287 |
+
if temperature < 1e-2:
|
288 |
+
temperature = 1e-2
|
289 |
+
top_p = float(top_p)
|
290 |
+
|
291 |
+
generate_kwargs = dict(
|
292 |
+
temperature=temperature,
|
293 |
+
max_new_tokens=max_new_tokens,
|
294 |
+
top_p=top_p,
|
295 |
+
repetition_penalty=repetition_penalty,
|
296 |
+
do_sample=True,
|
297 |
+
seed=42,
|
298 |
+
)
|
299 |
+
|
300 |
+
formatted_prompt = format_prompt(f", {prompt}", history)
|
301 |
+
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
302 |
+
output = ""
|
303 |
+
for response in stream:
|
304 |
+
output += response.token.text
|
305 |
+
|
306 |
+
chatbot.append((txt, output))
|
307 |
+
return "", chatbot
|
308 |
+
else:
|
309 |
+
warning_msg = f"Unexpected response"
|
310 |
+
raise gr.Error(warning_msg)
|
311 |
+
|
312 |
+
def product_explaination(txt,history,chatbot=[], temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,):
|
313 |
+
print('Input Descrption is:',txt)
|
314 |
+
prompt=f'What is the product- {txt}?'
|
315 |
+
print('prompt',prompt)
|
316 |
+
temperature = float(temperature)
|
317 |
+
if temperature < 1e-2:
|
318 |
+
temperature = 1e-2
|
319 |
+
top_p = float(top_p)
|
320 |
+
|
321 |
+
generate_kwargs = dict(
|
322 |
+
temperature=temperature,
|
323 |
+
max_new_tokens=max_new_tokens,
|
324 |
+
top_p=top_p,
|
325 |
+
repetition_penalty=repetition_penalty,
|
326 |
+
do_sample=True,
|
327 |
+
seed=42,
|
328 |
+
)
|
329 |
+
|
330 |
+
formatted_prompt = format_prompt(f", {prompt}", history)
|
331 |
+
|
332 |
+
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
333 |
+
output = ""
|
334 |
+
|
335 |
+
for response in stream:
|
336 |
+
output += response.token.text
|
337 |
+
|
338 |
+
chatbot.append((txt, output))
|
339 |
+
return "", chatbot
|
app.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from fun_advaitbert import predict_CTH
|
4 |
+
from fun_advaitbert import llm_model_function
|
5 |
+
from fun_advaitbert import product_explaination
|
6 |
+
|
7 |
+
title="<h1 style='color:green;text-align:center;font-size:2vw;'>AdvaitBERT:HS Code AI Explanability Through Mixtral 46.7B </a></h1>"
|
8 |
+
description = """
|
9 |
+
AdvaitBERT is modified version of BERT (Bidirectional Encoder Representation for Transformers), \
|
10 |
+
finetuned on the Text corpus of Indian Customs Declarations. It is trained for performing \
|
11 |
+
downstream tasks like automating the tariff classification and validation process of Customs \
|
12 |
+
declarations in realtime. This model may help Customs administration to efficiently use AI assisted \
|
13 |
+
NLP in realtime Customs process like Assessment, Post Clearance Audit, thereby highlighting classification \
|
14 |
+
inconsistencies and help in revenue augmentation.
|
15 |
+
"""
|
16 |
+
|
17 |
+
article="<p style='color:black;text-align:right;font-size:1vw;'>Powered by NCTC </a></p>"
|
18 |
+
|
19 |
+
|
20 |
+
css = """
|
21 |
+
.gradio-container {
|
22 |
+
width: 100vw !important;
|
23 |
+
min-height: 100vh !important;
|
24 |
+
padding:0 !important;
|
25 |
+
margin:0 !important;
|
26 |
+
max-width: none !important;
|
27 |
+
}
|
28 |
+
"""
|
29 |
+
|
30 |
+
footnote = """Note: All rights, including licensing and acceptable use policies, related to the AI models, can be found on their respective model pages on Hugging Face. Powered by NCTC
|
31 |
+
"""
|
32 |
+
|
33 |
+
#Powered by NCTC
|
34 |
+
|
35 |
+
# input_txt=gr.Textbox(label='Enter Your Product Descrption',lines=3,)
|
36 |
+
# textbox = gr.Textbox(container=False,placeholder='Enter text and click the Submit button or press Enter')
|
37 |
+
|
38 |
+
textbox = gr.Textbox(label='Enter Your Product Descrption',lines=3,)
|
39 |
+
textbox_2=textbox
|
40 |
+
|
41 |
+
print('textbox',textbox)
|
42 |
+
print('textbox_2',textbox_2)
|
43 |
+
|
44 |
+
chat_prod = gr.Chatbot(label="Product Explanation", layout='panel') #height=300
|
45 |
+
#chat_Advait = gr.Chatbot(label="Advaitbert Prediction", layout='panel')
|
46 |
+
chat_alpha = gr.Chatbot(label="AI Explanability", layout='panel')
|
47 |
+
chat_Advait=gr.Interface(predict_CTH,inputs=textbox,outputs="label",)
|
48 |
+
|
49 |
+
|
50 |
+
submit = gr.Button('Submit', variant='primary',)
|
51 |
+
submit_second = gr.Button('Submit', variant='secondary',)
|
52 |
+
#submit2 = gr.Button('Submit', variant='primary',)
|
53 |
+
retry = gr.Button('🔄Retry', variant='secondary')
|
54 |
+
undo = gr.Button('↩️Undo', variant='secondary')
|
55 |
+
|
56 |
+
with gr.Blocks(css=css) as demo:
|
57 |
+
gr.HTML(f'<h1><center> {title} </center></h1>')
|
58 |
+
gr.Markdown(description)
|
59 |
+
|
60 |
+
with gr.Row():
|
61 |
+
with gr.Column(scale=0,min_width=600):
|
62 |
+
chat_Advait.render()
|
63 |
+
|
64 |
+
with gr.Column(scale=1,min_width=600):
|
65 |
+
chat_alpha.render()
|
66 |
+
with gr.Row(equal_height=True):
|
67 |
+
with gr.Column(scale=1):
|
68 |
+
submit.render()
|
69 |
+
with gr.Column(scale=1):
|
70 |
+
undo.render()
|
71 |
+
with gr.Column(scale=1):
|
72 |
+
clear = gr.ClearButton(value='🗑️Clear',components=[chat_alpha,chat_prod,textbox])
|
73 |
+
chat_prod.render()
|
74 |
+
#submit_second.render()
|
75 |
+
|
76 |
+
gr.Markdown(footnote)
|
77 |
+
textbox.submit(llm_model_function, [textbox, chat_alpha], [textbox, chat_alpha])
|
78 |
+
textbox_2.submit(product_explaination, [textbox_2, chat_prod], [textbox_2, chat_prod])
|
79 |
+
|
80 |
+
submit.click(llm_model_function,[textbox, chat_alpha], [textbox, chat_alpha])
|
81 |
+
submit.click(product_explaination,[textbox_2, chat_prod], [textbox_2, chat_prod])
|
82 |
+
|
83 |
+
undo.click(lambda x:x[:-1], [chat_alpha], [chat_alpha])
|
84 |
+
undo.click(lambda x:x[:-1], [chat_prod], [chat_prod])
|
85 |
+
|
86 |
+
gr.Examples([
|
87 |
+
['200 SI/SI/SI LPO ALUMINIUM LIDS (QTY: 8820000 PCS/PRICE: 21.'],
|
88 |
+
],
|
89 |
+
textbox)
|
90 |
+
|
91 |
+
demo.launch(debug=True)
|
fun_advaitbert.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
import tensorflow_hub as hub
|
5 |
+
import sys
|
6 |
+
import random
|
7 |
+
sys.path.append('models')
|
8 |
+
from official.nlp.data import classifier_data_lib
|
9 |
+
from official.nlp.bert import tokenization
|
10 |
+
from official.nlp import optimization
|
11 |
+
tf.get_logger().setLevel('ERROR')
|
12 |
+
from huggingface_hub import InferenceClient
|
13 |
+
import math
|
14 |
+
import gradio as gr
|
15 |
+
|
16 |
+
num_warmup_steps=1
|
17 |
+
num_train_steps=1
|
18 |
+
init_lr = 3e-5
|
19 |
+
optimizer = optimization.create_optimizer(init_lr=init_lr,num_train_steps=num_train_steps,num_warmup_steps=num_warmup_steps,optimizer_type='adamw')
|
20 |
+
|
21 |
+
### Load Model
|
22 |
+
checkpoint_filepath=r'./Checkpoint'
|
23 |
+
model = tf.keras.models.load_model(checkpoint_filepath, custom_objects={'KerasLayer':hub.KerasLayer , 'AdamWeightDecay': optimizer})
|
24 |
+
|
25 |
+
df_report = pd.read_csv('./CTH_Description.csv')
|
26 |
+
df_report['CTH Code'] = df_report['CTH Code'].astype(str).str.zfill(8)
|
27 |
+
|
28 |
+
df_report_DUTY = pd.read_csv('./CTH_WISE_DUTY_RATE.csv')
|
29 |
+
df_report_DUTY['CTH'] = df_report_DUTY['CTH'].astype(str).str.zfill(8)
|
30 |
+
|
31 |
+
df = pd.read_csv("./CTH_CODE_MAP.csv")
|
32 |
+
df['CTH'] = df['CTH'].astype(str).str.zfill(8)
|
33 |
+
df = df[['CTH', 'code']]
|
34 |
+
|
35 |
+
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
36 |
+
|
37 |
+
|
38 |
+
class_names=df[['CTH','code']].drop_duplicates(subset='CTH').sort_values(by='code',ignore_index=True)['CTH'].values.tolist()
|
39 |
+
label_list=list(range(0,len(class_names)))
|
40 |
+
max_seq_length = 200 # maximum length of (token) input sequences . it can be any number
|
41 |
+
train_batch_size = 32 # batch size ( 16 choosen to avoid Out-Of-Memory errors)
|
42 |
+
|
43 |
+
# Get BERT layer and tokenizer:
|
44 |
+
# More details here: https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4
|
45 |
+
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4" , trainable = True)
|
46 |
+
vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
|
47 |
+
do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
|
48 |
+
tokenizer = tokenization.FullTokenizer(vocab_file , do_lower_case)
|
49 |
+
|
50 |
+
# This provides a function to convert each row to input features and label ( as required by BERT)
|
51 |
+
|
52 |
+
max_seq_length = 200 # maximum length of (token) input sequences . it can be any number
|
53 |
+
def to_feature(text, label, label_list=label_list, max_seq_length=max_seq_length, tokenizer=tokenizer):
|
54 |
+
example = classifier_data_lib.InputExample(guid = None,
|
55 |
+
text_a = text.numpy(),
|
56 |
+
text_b = None,
|
57 |
+
label = label.numpy())
|
58 |
+
feature = classifier_data_lib.convert_single_example(0 , example , label_list , max_seq_length , tokenizer)
|
59 |
+
|
60 |
+
return (feature.input_ids , feature.input_mask , feature.segment_ids , feature.label_id)
|
61 |
+
|
62 |
+
|
63 |
+
def to_feature_map(text, label):
|
64 |
+
input_ids , input_mask , segment_ids , label_id = tf.py_function(to_feature , inp = [text , label],
|
65 |
+
Tout = [tf.int32 , tf.int32 , tf.int32 , tf.int32])
|
66 |
+
|
67 |
+
input_ids.set_shape([max_seq_length])
|
68 |
+
input_mask.set_shape([max_seq_length])
|
69 |
+
segment_ids.set_shape([max_seq_length])
|
70 |
+
label_id.set_shape([])
|
71 |
+
|
72 |
+
x = {
|
73 |
+
"input_word_ids": input_ids,
|
74 |
+
"input_mask": input_mask,
|
75 |
+
"input_type_ids": segment_ids
|
76 |
+
}
|
77 |
+
|
78 |
+
return(x,label_id)
|
79 |
+
|
80 |
+
|
81 |
+
def find_max_10_with_position(arr, arr_size):
|
82 |
+
max_values_with_position = [(-sys.maxsize, -1)] * 10
|
83 |
+
|
84 |
+
for i in range(arr_size):
|
85 |
+
for j in range(5):
|
86 |
+
value, position = max_values_with_position[j]
|
87 |
+
if arr[i] > value:
|
88 |
+
max_values_with_position[j+1:] = max_values_with_position[j:9]
|
89 |
+
max_values_with_position[j] = (arr[i], i)
|
90 |
+
break
|
91 |
+
|
92 |
+
return max_values_with_position
|
93 |
+
|
94 |
+
def count_special_character(string):
|
95 |
+
special_char= 0
|
96 |
+
for i in range(len(string)):
|
97 |
+
ch = string[i]
|
98 |
+
if (string[i].isalpha()):
|
99 |
+
continue
|
100 |
+
else:
|
101 |
+
special_char += 1
|
102 |
+
|
103 |
+
if len(string)==special_char:
|
104 |
+
return False
|
105 |
+
else:
|
106 |
+
return True
|
107 |
+
|
108 |
+
def format_prompt(message, history):
|
109 |
+
prompt = "<s>"
|
110 |
+
for user_prompt, bot_response in history:
|
111 |
+
prompt += f"[INST] {user_prompt} [/INST]"
|
112 |
+
prompt += f" {bot_response}</s> "
|
113 |
+
prompt += f"[INST] {message} [/INST]"
|
114 |
+
return prompt
|
115 |
+
|
116 |
+
|
117 |
+
additional_inputs=[
|
118 |
+
gr.Textbox(
|
119 |
+
label="System Prompt",
|
120 |
+
max_lines=1,
|
121 |
+
interactive=True,
|
122 |
+
),
|
123 |
+
gr.Slider(
|
124 |
+
label="Temperature",
|
125 |
+
value=0.5,
|
126 |
+
minimum=0.0,
|
127 |
+
maximum=1.0,
|
128 |
+
step=0.05,
|
129 |
+
interactive=True,
|
130 |
+
info="Higher values produce more diverse outputs",
|
131 |
+
),
|
132 |
+
gr.Slider(
|
133 |
+
label="Max new tokens",
|
134 |
+
value=1024,
|
135 |
+
minimum=0,
|
136 |
+
maximum=4096,
|
137 |
+
step=64,
|
138 |
+
interactive=True,
|
139 |
+
info="The maximum numbers of new tokens",
|
140 |
+
),
|
141 |
+
gr.Slider(
|
142 |
+
label="Top-p (nucleus sampling)",
|
143 |
+
value=0.90,
|
144 |
+
minimum=0.0,
|
145 |
+
maximum=1,
|
146 |
+
step=0.05,
|
147 |
+
interactive=True,
|
148 |
+
info="Higher values sample more low-probability tokens",
|
149 |
+
),
|
150 |
+
gr.Slider(
|
151 |
+
label="Repetition penalty",
|
152 |
+
value=1.2,
|
153 |
+
minimum=1.0,
|
154 |
+
maximum=2.0,
|
155 |
+
step=0.05,
|
156 |
+
interactive=True,
|
157 |
+
info="Penalize repeated tokens",
|
158 |
+
)
|
159 |
+
]
|
160 |
+
|
161 |
+
def predict_CTH(txt):
|
162 |
+
print('Desc: ',txt)
|
163 |
+
global output_str_msg
|
164 |
+
if (txt!='') and len(txt)>=3 and (count_special_character(txt)):
|
165 |
+
valid_data = tf.data.Dataset.from_tensor_slices(([txt] , [1])) # 1 refers to 'entertainment' and 2 refers to 'sport'
|
166 |
+
valid_data = (valid_data.map(to_feature_map).batch(1))
|
167 |
+
preds = model.predict(valid_data)
|
168 |
+
predicted_values = tf.nn.softmax(preds)
|
169 |
+
arr = predicted_values.numpy().tolist()[0]
|
170 |
+
n = len(arr)
|
171 |
+
|
172 |
+
pred_value_max=find_max_10_with_position(arr, n)
|
173 |
+
|
174 |
+
sum_all = 0
|
175 |
+
for i in range(10):
|
176 |
+
sum_all += pred_value_max[i][0]
|
177 |
+
|
178 |
+
|
179 |
+
val_1 = pred_value_max[0][0]/sum_all
|
180 |
+
val_2 = pred_value_max[1][0]/sum_all
|
181 |
+
val_3 = pred_value_max[2][0]/sum_all
|
182 |
+
val_4 = pred_value_max[3][0]/sum_all
|
183 |
+
val_5 = pred_value_max[4][0]/sum_all
|
184 |
+
val_6 = pred_value_max[5][0]/sum_all
|
185 |
+
val_7 = pred_value_max[6][0]/sum_all
|
186 |
+
val_8 = pred_value_max[7][0]/sum_all
|
187 |
+
val_9 = pred_value_max[8][0]/sum_all
|
188 |
+
val_10 = pred_value_max[9][0]/sum_all
|
189 |
+
|
190 |
+
if pred_value_max[0][0]<=0.000131:
|
191 |
+
Var_CTH=[]
|
192 |
+
Var_desc=[]
|
193 |
+
Var_duty=[]
|
194 |
+
pred_duty=''
|
195 |
+
pred_desc=''
|
196 |
+
pred_CTH=''
|
197 |
+
|
198 |
+
output_str_msg='Not a adequate description'
|
199 |
+
|
200 |
+
return{'Not a adequate description':float(1.0)}
|
201 |
+
else:
|
202 |
+
Var_CTH=[]
|
203 |
+
Var_desc=[]
|
204 |
+
Var_duty=[]
|
205 |
+
pred_duty=''
|
206 |
+
pred_desc=''
|
207 |
+
pred_CTH=''
|
208 |
+
|
209 |
+
for i in range(len(pred_value_max)):
|
210 |
+
#predicted_code=np.where(predicted_values.numpy()==i)[1][0]
|
211 |
+
predicted_code=pred_value_max[i][1]
|
212 |
+
pred_CTH=df[df['code'] == predicted_code]['CTH'].iloc[0]
|
213 |
+
|
214 |
+
try:
|
215 |
+
pred_duty=df_report_DUTY[df_report_DUTY['CTH']==str(pred_CTH)]['DUTY_RATE'].iloc[0]
|
216 |
+
pred_desc=df_report[df_report['CTH Code']==str(pred_CTH)]['Concat Description'].iloc[0]
|
217 |
+
except:
|
218 |
+
pred_desc=''
|
219 |
+
pred_duty=''
|
220 |
+
pass
|
221 |
+
|
222 |
+
Var_CTH.append(pred_CTH)
|
223 |
+
Var_desc.append(pred_desc)
|
224 |
+
Var_duty.append(pred_duty)
|
225 |
+
|
226 |
+
P1 ='CTH: '+str(Var_CTH[0])+' Duty Rate(%): '+ str(Var_duty[0])
|
227 |
+
P2 ='CTH: '+str(Var_CTH[1])+' Duty Rate(%): '+ str(Var_duty[1])
|
228 |
+
P3 ='CTH: '+str(Var_CTH[2])+' Duty Rate(%): '+ str(Var_duty[2])
|
229 |
+
P4 ='CTH: '+str(Var_CTH[3])+' Duty Rate(%): '+ str(Var_duty[3])
|
230 |
+
P5 ='CTH: '+str(Var_CTH[4])+' Duty Rate(%): '+ str(Var_duty[4])
|
231 |
+
P6 ='CTH: '+str(Var_CTH[5])+' Duty Rate(%): '+ str(Var_duty[5])
|
232 |
+
P7 ='CTH: '+str(Var_CTH[6])+' Duty Rate(%): '+ str(Var_duty[6])
|
233 |
+
P8 ='CTH: '+str(Var_CTH[7])+' Duty Rate(%): '+ str(Var_duty[7])
|
234 |
+
P9 ='CTH: '+str(Var_CTH[8])+' Duty Rate(%): '+ str(Var_duty[8])
|
235 |
+
P10 ='CTH: '+str(Var_CTH[9])+' Duty Rate(%): '+ str(Var_duty[9])
|
236 |
+
|
237 |
+
Q1='Desc: '+str(Var_desc[0])
|
238 |
+
Q2='Desc: '+str(Var_desc[1])
|
239 |
+
Q3='Desc: '+str(Var_desc[2])
|
240 |
+
Q4='Desc: '+str(Var_desc[3])
|
241 |
+
Q5='Desc: '+str(Var_desc[4])
|
242 |
+
Q6='Desc: '+str(Var_desc[5])
|
243 |
+
Q7='Desc: '+str(Var_desc[6])
|
244 |
+
Q8='Desc: '+str(Var_desc[7])
|
245 |
+
Q9='Desc: '+str(Var_desc[8])
|
246 |
+
Q10='Desc: '+str(Var_desc[9])
|
247 |
+
|
248 |
+
output_str_msg = (
|
249 |
+
f'1. {P1} {Q1} '
|
250 |
+
f'2. {P2} {Q2} '
|
251 |
+
f'3. {P3} {Q3} '
|
252 |
+
f'4. {P4} {Q4} '
|
253 |
+
f'5. {P5} {Q5} '
|
254 |
+
f'6. {P6} {Q6} '
|
255 |
+
f'7. {P7} {Q7} '
|
256 |
+
f'8. {P8} {Q8} '
|
257 |
+
f'9. {P9} {Q9} '
|
258 |
+
f'10. {P10} {Q10}')
|
259 |
+
|
260 |
+
print('output_str_msg',output_str_msg)
|
261 |
+
|
262 |
+
return {str(P1):float(val_1),str(Q1):float(val_1),
|
263 |
+
str(P2):float(val_2),str(Q2):float(val_2),
|
264 |
+
str(P3):float(val_3),str(Q3):float(val_3),
|
265 |
+
str(P4):float(val_4),str(Q4):float(val_4),
|
266 |
+
str(P5):float(val_5),str(Q5):float(val_5),
|
267 |
+
str(P6):float(val_6),str(Q6):float(val_6),
|
268 |
+
str(P7):float(val_7),str(Q7):float(val_7),
|
269 |
+
str(P8):float(val_8),str(Q8):float(val_8),
|
270 |
+
str(P9):float(val_9),str(Q9):float(val_9),
|
271 |
+
str(P10):float(val_10),str(Q10):float(val_10),}
|
272 |
+
else:
|
273 |
+
output_str_msg='Not a adequate description'
|
274 |
+
return{'Enter Correct Description':float(1.0)}
|
275 |
+
|
276 |
+
def llm_model_function(txt,history,chatbot=[], temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,):
|
277 |
+
system_prompt=[]
|
278 |
+
chatbot=[]
|
279 |
+
|
280 |
+
global output_str_msg
|
281 |
+
|
282 |
+
print('output_str_msg',output_str_msg)
|
283 |
+
|
284 |
+
if output_str_msg!='Not a adequate description':
|
285 |
+
|
286 |
+
prompt=f'First Explain What is the product- {txt}. Which is the most appropriate 8 Digit classification code out of the three given below classes. Explain the reason step by step. if none of the three classification is applicable more precisely due to lack of any additional information, tell you need additional information and what is the that additional information. {output_str_msg} ?'
|
287 |
+
|
288 |
+
temperature = float(temperature)
|
289 |
+
if temperature < 1e-2:
|
290 |
+
temperature = 1e-2
|
291 |
+
top_p = float(top_p)
|
292 |
+
|
293 |
+
generate_kwargs = dict(
|
294 |
+
temperature=temperature,
|
295 |
+
max_new_tokens=max_new_tokens,
|
296 |
+
top_p=top_p,
|
297 |
+
repetition_penalty=repetition_penalty,
|
298 |
+
do_sample=True,
|
299 |
+
seed=42,
|
300 |
+
)
|
301 |
+
|
302 |
+
formatted_prompt = format_prompt(f", {prompt}", history)
|
303 |
+
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
304 |
+
output = ""
|
305 |
+
for response in stream:
|
306 |
+
output += response.token.text
|
307 |
+
|
308 |
+
chatbot.append((txt, output))
|
309 |
+
return "", chatbot
|
310 |
+
else:
|
311 |
+
# warning_msg = f"Unexpected response"
|
312 |
+
# raise gr.Error(warning_msg)
|
313 |
+
chatbot.append(('Not a adequate description', 'Not a adequate description'))
|
314 |
+
return "", chatbot
|
315 |
+
|
316 |
+
def product_explaination(txt,history,chatbot=[], temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,):
|
317 |
+
print('Input Descrption is:',txt)
|
318 |
+
chatbot=[]
|
319 |
+
prompt=f'What is the product- {txt}?'
|
320 |
+
#print('prompt',prompt)
|
321 |
+
temperature = float(temperature)
|
322 |
+
if temperature < 1e-2:
|
323 |
+
temperature = 1e-2
|
324 |
+
top_p = float(top_p)
|
325 |
+
|
326 |
+
generate_kwargs = dict(
|
327 |
+
temperature=temperature,
|
328 |
+
max_new_tokens=max_new_tokens,
|
329 |
+
top_p=top_p,
|
330 |
+
repetition_penalty=repetition_penalty,
|
331 |
+
do_sample=True,
|
332 |
+
seed=42,
|
333 |
+
)
|
334 |
+
|
335 |
+
formatted_prompt = format_prompt(f", {prompt}", history)
|
336 |
+
|
337 |
+
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
338 |
+
output = ""
|
339 |
+
|
340 |
+
for response in stream:
|
341 |
+
output += response.token.text
|
342 |
+
|
343 |
+
chatbot.append((txt, output))
|
344 |
+
return "", chatbot
|
models/.github/ISSUE_TEMPLATE/00-official-bug-report-issue.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "[Official Model] Bug Report"
|
3 |
+
about: Use this template for reporting a bug for the “official” directory
|
4 |
+
labels: type:bug,models:official
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
# Prerequisites
|
9 |
+
|
10 |
+
Please answer the following questions for yourself before submitting an issue.
|
11 |
+
|
12 |
+
- [ ] I am using the latest TensorFlow Model Garden release and TensorFlow 2.
|
13 |
+
- [ ] I am reporting the issue to the correct repository. (Model Garden official or research directory)
|
14 |
+
- [ ] I checked to make sure that this issue has not been filed already.
|
15 |
+
|
16 |
+
## 1. The entire URL of the file you are using
|
17 |
+
|
18 |
+
https://github.com/tensorflow/models/tree/master/official/...
|
19 |
+
|
20 |
+
## 2. Describe the bug
|
21 |
+
|
22 |
+
A clear and concise description of what the bug is.
|
23 |
+
|
24 |
+
## 3. Steps to reproduce
|
25 |
+
|
26 |
+
Steps to reproduce the behavior.
|
27 |
+
|
28 |
+
## 4. Expected behavior
|
29 |
+
|
30 |
+
A clear and concise description of what you expected to happen.
|
31 |
+
|
32 |
+
## 5. Additional context
|
33 |
+
|
34 |
+
Include any logs that would be helpful to diagnose the problem.
|
35 |
+
|
36 |
+
## 6. System information
|
37 |
+
|
38 |
+
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
39 |
+
- Mobile device name if the issue happens on a mobile device:
|
40 |
+
- TensorFlow installed from (source or binary):
|
41 |
+
- TensorFlow version (use command below):
|
42 |
+
- Python version:
|
43 |
+
- Bazel version (if compiling from source):
|
44 |
+
- GCC/Compiler version (if compiling from source):
|
45 |
+
- CUDA/cuDNN version:
|
46 |
+
- GPU model and memory:
|
47 |
+
|
48 |
+
<!--
|
49 |
+
Collect system information using our environment capture script.
|
50 |
+
https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
|
51 |
+
|
52 |
+
You can also obtain the TensorFlow version with:
|
53 |
+
|
54 |
+
1. TensorFlow 1.0
|
55 |
+
`python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
|
56 |
+
|
57 |
+
2. TensorFlow 2.0
|
58 |
+
`python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
59 |
+
-->
|
models/.github/ISSUE_TEMPLATE/10-official-documentation-issue.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "[Official Model] Documentation Issue"
|
3 |
+
about: Use this template for reporting a documentation issue for the “official” directory
|
4 |
+
labels: type:docs,models:official
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
# Prerequisites
|
9 |
+
|
10 |
+
Please answer the following question for yourself before submitting an issue.
|
11 |
+
|
12 |
+
- [ ] I checked to make sure that this issue has not been filed already.
|
13 |
+
|
14 |
+
## 1. The entire URL of the documentation with the issue
|
15 |
+
|
16 |
+
https://github.com/tensorflow/models/tree/master/official/...
|
17 |
+
|
18 |
+
## 2. Describe the issue
|
19 |
+
|
20 |
+
A clear and concise description of what needs to be changed.
|
models/.github/ISSUE_TEMPLATE/20-official-feature-request-issue.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "[Official Model] Feature request"
|
3 |
+
about: Use this template for raising a feature request for the “official” directory
|
4 |
+
labels: type:feature,models:official
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
# Prerequisites
|
9 |
+
|
10 |
+
Please answer the following question for yourself before submitting an issue.
|
11 |
+
|
12 |
+
- [ ] I checked to make sure that this feature has not been requested already.
|
13 |
+
|
14 |
+
## 1. The entire URL of the file you are using
|
15 |
+
|
16 |
+
https://github.com/tensorflow/models/tree/master/official/...
|
17 |
+
|
18 |
+
## 2. Describe the feature you request
|
19 |
+
|
20 |
+
A clear and concise description of what you want to happen.
|
21 |
+
|
22 |
+
## 3. Additional context
|
23 |
+
|
24 |
+
Add any other context about the feature request here.
|
25 |
+
|
26 |
+
## 4. Are you willing to contribute it? (Yes or No)
|
models/.github/ISSUE_TEMPLATE/30-research-bug-report-issue.md
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "[Research Model] Bug Report"
|
3 |
+
about: Use this template for reporting a bug for the “research” directory
|
4 |
+
labels: type:bug,models:research
|
5 |
+
|
6 |
+
---
|
7 |
+
# Prerequisites
|
8 |
+
|
9 |
+
Please answer the following questions for yourself before submitting an issue.
|
10 |
+
|
11 |
+
- [ ] I am using the latest TensorFlow Model Garden release and TensorFlow 2.
|
12 |
+
- [ ] I am reporting the issue to the correct repository. (Model Garden official or research directory)
|
13 |
+
- [ ] I checked to make sure that this issue has not already been filed.
|
14 |
+
|
15 |
+
## 1. The entire URL of the file you are using
|
16 |
+
|
17 |
+
https://github.com/tensorflow/models/tree/master/research/...
|
18 |
+
|
19 |
+
## 2. Describe the bug
|
20 |
+
|
21 |
+
A clear and concise description of what the bug is.
|
22 |
+
|
23 |
+
## 3. Steps to reproduce
|
24 |
+
|
25 |
+
Steps to reproduce the behavior.
|
26 |
+
|
27 |
+
## 4. Expected behavior
|
28 |
+
|
29 |
+
A clear and concise description of what you expected to happen.
|
30 |
+
|
31 |
+
## 5. Additional context
|
32 |
+
|
33 |
+
Include any logs that would be helpful to diagnose the problem.
|
34 |
+
|
35 |
+
## 6. System information
|
36 |
+
|
37 |
+
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
38 |
+
- Mobile device name if the issue happens on a mobile device:
|
39 |
+
- TensorFlow installed from (source or binary):
|
40 |
+
- TensorFlow version (use command below):
|
41 |
+
- Python version:
|
42 |
+
- Bazel version (if compiling from source):
|
43 |
+
- GCC/Compiler version (if compiling from source):
|
44 |
+
- CUDA/cuDNN version:
|
45 |
+
- GPU model and memory:
|
46 |
+
|
47 |
+
<!--
|
48 |
+
Collect system information using our environment capture script.
|
49 |
+
https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
|
50 |
+
|
51 |
+
You can also obtain the TensorFlow version with:
|
52 |
+
|
53 |
+
1. TensorFlow 1.0
|
54 |
+
`python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
|
55 |
+
|
56 |
+
2. TensorFlow 2.0
|
57 |
+
`python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
58 |
+
-->
|
models/.github/ISSUE_TEMPLATE/40-research-documentation-issue.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "[Research Model] Documentation Issue"
|
3 |
+
about: Use this template for reporting a documentation issue for the “research” directory
|
4 |
+
labels: type:docs,models:research
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
# Prerequisites
|
9 |
+
|
10 |
+
Please answer the following question for yourself before submitting an issue.
|
11 |
+
|
12 |
+
- [ ] I checked to make sure that this issue has not been filed already.
|
13 |
+
|
14 |
+
## 1. The entire URL of the documentation with the issue
|
15 |
+
|
16 |
+
https://github.com/tensorflow/models/tree/master/research/...
|
17 |
+
|
18 |
+
## 2. Describe the issue
|
19 |
+
|
20 |
+
A clear and concise description of what needs to be changed.
|
models/.github/ISSUE_TEMPLATE/50-research-feature-request-issue.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "[Research Model] Feature Request"
|
3 |
+
about: Use this template for raising a feature request for the “research” directory
|
4 |
+
labels: type:feature,models:research
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
# Prerequisites
|
9 |
+
|
10 |
+
Please answer the following question for yourself before submitting an issue.
|
11 |
+
|
12 |
+
- [ ] I checked to make sure that this feature has not been requested already.
|
13 |
+
|
14 |
+
## 1. The entire URL of the file you are using
|
15 |
+
|
16 |
+
https://github.com/tensorflow/models/tree/master/research/...
|
17 |
+
|
18 |
+
## 2. Describe the feature you request
|
19 |
+
|
20 |
+
A clear and concise description of what you want to happen.
|
21 |
+
|
22 |
+
## 3. Additional context
|
23 |
+
|
24 |
+
Add any other context about the feature request here.
|
25 |
+
|
26 |
+
## 4. Are you willing to contribute it? (Yes or No)
|
models/.github/ISSUE_TEMPLATE/60-questions-help-issue.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Questions and Help
|
3 |
+
about: Use this template for Questions and Help.
|
4 |
+
labels: type:support
|
5 |
+
|
6 |
+
---
|
7 |
+
<!--
|
8 |
+
As per our GitHub Policy (https://github.com/tensorflow/models/blob/master/ISSUES.md), we only address code bugs, documentation issues, and feature requests on GitHub.
|
9 |
+
|
10 |
+
We will automatically close questions and help related issues.
|
11 |
+
|
12 |
+
Please go to Stack Overflow (http://stackoverflow.com/questions/tagged/tensorflow-model-garden) for questions and help.
|
13 |
+
|
14 |
+
-->
|
models/.github/ISSUE_TEMPLATE/config.yml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
blank_issues_enabled: false
|
models/.github/PULL_REQUEST_TEMPLATE.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Description
|
2 |
+
|
3 |
+
> :memo: Please include a summary of the change.
|
4 |
+
>
|
5 |
+
> * Please also include relevant motivation and context.
|
6 |
+
> * List any dependencies that are required for this change.
|
7 |
+
|
8 |
+
## Type of change
|
9 |
+
|
10 |
+
For a new feature or function, please create an issue first to discuss it
|
11 |
+
with us before submitting a pull request.
|
12 |
+
|
13 |
+
Note: Please delete options that are not relevant.
|
14 |
+
|
15 |
+
- [ ] Bug fix (non-breaking change which fixes an issue)
|
16 |
+
- [ ] Documentation update
|
17 |
+
- [ ] TensorFlow 2 migration
|
18 |
+
- [ ] New feature (non-breaking change which adds functionality)
|
19 |
+
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
20 |
+
- [ ] A new research paper code implementation
|
21 |
+
- [ ] Other (Specify)
|
22 |
+
|
23 |
+
## Tests
|
24 |
+
|
25 |
+
> :memo: Please describe the tests that you ran to verify your changes.
|
26 |
+
>
|
27 |
+
> * Provide instructions so we can reproduce.
|
28 |
+
> * Please also list any relevant details for your test configuration.
|
29 |
+
|
30 |
+
**Test Configuration**:
|
31 |
+
|
32 |
+
## Checklist
|
33 |
+
|
34 |
+
- [ ] I have signed the [Contributor License Agreement](https://github.com/tensorflow/models/wiki/Contributor-License-Agreements).
|
35 |
+
- [ ] I have read [guidelines for pull request](https://github.com/tensorflow/models/wiki/Submitting-a-pull-request).
|
36 |
+
- [ ] My code follows the [coding guidelines](https://github.com/tensorflow/models/wiki/Coding-guidelines).
|
37 |
+
- [ ] I have performed a self [code review](https://github.com/tensorflow/models/wiki/Code-review) of my own code.
|
38 |
+
- [ ] I have commented my code, particularly in hard-to-understand areas.
|
39 |
+
- [ ] I have made corresponding changes to the documentation.
|
40 |
+
- [ ] My changes generate no new warnings.
|
41 |
+
- [ ] I have added tests that prove my fix is effective or that my feature works.
|
models/.github/README_TEMPLATE.md
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
> :memo: A README.md template for releasing a paper code implementation to a GitHub repository.
|
2 |
+
>
|
3 |
+
> * Template version: 1.0.2020.170
|
4 |
+
> * Please modify sections depending on needs.
|
5 |
+
|
6 |
+
# Model name, Paper title, or Project Name
|
7 |
+
|
8 |
+
> :memo: Add a badge for the ArXiv identifier of your paper (arXiv:YYMM.NNNNN)
|
9 |
+
|
10 |
+
[![Paper](http://img.shields.io/badge/Paper-arXiv.YYMM.NNNNN-B3181B?logo=arXiv)](https://arxiv.org/abs/...)
|
11 |
+
|
12 |
+
This repository is the official or unofficial implementation of the following paper.
|
13 |
+
|
14 |
+
* Paper title: [Paper Title](https://arxiv.org/abs/YYMM.NNNNN)
|
15 |
+
|
16 |
+
## Description
|
17 |
+
|
18 |
+
> :memo: Provide description of the model.
|
19 |
+
>
|
20 |
+
> * Provide brief information of the algorithms used.
|
21 |
+
> * Provide links for demos, blog posts, etc.
|
22 |
+
|
23 |
+
## History
|
24 |
+
|
25 |
+
> :memo: Provide a changelog.
|
26 |
+
|
27 |
+
## Authors or Maintainers
|
28 |
+
|
29 |
+
> :memo: Provide maintainer information.
|
30 |
+
|
31 |
+
* Full name ([@GitHub username](https://github.com/username))
|
32 |
+
* Full name ([@GitHub username](https://github.com/username))
|
33 |
+
|
34 |
+
## Table of Contents
|
35 |
+
|
36 |
+
> :memo: Provide a table of contents to help readers navigate a lengthy README document.
|
37 |
+
|
38 |
+
## Requirements
|
39 |
+
|
40 |
+
[![TensorFlow 2.1](https://img.shields.io/badge/TensorFlow-2.1-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0)
|
41 |
+
[![Python 3.6](https://img.shields.io/badge/Python-3.6-3776AB)](https://www.python.org/downloads/release/python-360/)
|
42 |
+
|
43 |
+
> :memo: Provide details of the software required.
|
44 |
+
>
|
45 |
+
> * Add a `requirements.txt` file to the root directory for installing the necessary dependencies.
|
46 |
+
> * Describe how to install requirements using pip.
|
47 |
+
> * Alternatively, create INSTALL.md.
|
48 |
+
|
49 |
+
To install requirements:
|
50 |
+
|
51 |
+
```setup
|
52 |
+
pip install -r requirements.txt
|
53 |
+
```
|
54 |
+
|
55 |
+
## Results
|
56 |
+
|
57 |
+
> :memo: Provide a table with results. (e.g., accuracy, latency)
|
58 |
+
>
|
59 |
+
> * Provide links to the pre-trained models (checkpoint, SavedModel files).
|
60 |
+
> * Publish TensorFlow SavedModel files on TensorFlow Hub (tfhub.dev) if possible.
|
61 |
+
> * Add links to [TensorBoard.dev](https://tensorboard.dev/) for visualizing metrics.
|
62 |
+
>
|
63 |
+
> An example table for image classification results
|
64 |
+
>
|
65 |
+
> ### Image Classification
|
66 |
+
>
|
67 |
+
> | Model name | Download | Top 1 Accuracy | Top 5 Accuracy |
|
68 |
+
> |------------|----------|----------------|----------------|
|
69 |
+
> | Model name | [Checkpoint](https://drive.google.com/...), [SavedModel](https://tfhub.dev/...) | xx% | xx% |
|
70 |
+
|
71 |
+
## Dataset
|
72 |
+
|
73 |
+
> :memo: Provide information of the dataset used.
|
74 |
+
|
75 |
+
## Training
|
76 |
+
|
77 |
+
> :memo: Provide training information.
|
78 |
+
>
|
79 |
+
> * Provide details for preprocessing, hyperparameters, random seeds, and environment.
|
80 |
+
> * Provide a command line example for training.
|
81 |
+
|
82 |
+
Please run this command line for training.
|
83 |
+
|
84 |
+
```shell
|
85 |
+
python3 ...
|
86 |
+
```
|
87 |
+
|
88 |
+
## Evaluation
|
89 |
+
|
90 |
+
> :memo: Provide an evaluation script with details of how to reproduce results.
|
91 |
+
>
|
92 |
+
> * Describe data preprocessing / postprocessing steps.
|
93 |
+
> * Provide a command line example for evaluation.
|
94 |
+
|
95 |
+
Please run this command line for evaluation.
|
96 |
+
|
97 |
+
```shell
|
98 |
+
python3 ...
|
99 |
+
```
|
100 |
+
|
101 |
+
## References
|
102 |
+
|
103 |
+
> :memo: Provide links to references.
|
104 |
+
|
105 |
+
## License
|
106 |
+
|
107 |
+
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
|
108 |
+
|
109 |
+
> :memo: Place your license text in a file named LICENSE in the root of the repository.
|
110 |
+
>
|
111 |
+
> * Include information about your license.
|
112 |
+
> * Reference: [Adding a license to a repository](https://help.github.com/en/github/building-a-strong-community/adding-a-license-to-a-repository)
|
113 |
+
|
114 |
+
This project is licensed under the terms of the **Apache License 2.0**.
|
115 |
+
|
116 |
+
## Citation
|
117 |
+
|
118 |
+
> :memo: Make your repository citable.
|
119 |
+
>
|
120 |
+
> * Reference: [Making Your Code Citable](https://guides.github.com/activities/citable-code/)
|
121 |
+
|
122 |
+
If you want to cite this repository in your research paper, please use the following information.
|
models/.gitignore
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
env/
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
|
27 |
+
# PyInstaller
|
28 |
+
# Usually these files are written by a python script from a template
|
29 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
30 |
+
*.manifest
|
31 |
+
*.spec
|
32 |
+
|
33 |
+
# Installer logs
|
34 |
+
pip-log.txt
|
35 |
+
pip-delete-this-directory.txt
|
36 |
+
|
37 |
+
# Unit test / coverage reports
|
38 |
+
htmlcov/
|
39 |
+
.tox/
|
40 |
+
.coverage
|
41 |
+
.coverage.*
|
42 |
+
.cache
|
43 |
+
nosetests.xml
|
44 |
+
coverage.xml
|
45 |
+
*,cover
|
46 |
+
.hypothesis/
|
47 |
+
|
48 |
+
# Translations
|
49 |
+
*.mo
|
50 |
+
*.pot
|
51 |
+
|
52 |
+
# Django stuff:
|
53 |
+
*.log
|
54 |
+
local_settings.py
|
55 |
+
|
56 |
+
# Flask stuff:
|
57 |
+
instance/
|
58 |
+
.webassets-cache
|
59 |
+
|
60 |
+
# Scrapy stuff:
|
61 |
+
.scrapy
|
62 |
+
|
63 |
+
# Sphinx documentation
|
64 |
+
docs/_build/
|
65 |
+
|
66 |
+
# PyBuilder
|
67 |
+
target/
|
68 |
+
|
69 |
+
# IPython Notebook
|
70 |
+
.ipynb_checkpoints
|
71 |
+
|
72 |
+
# pyenv
|
73 |
+
.python-version
|
74 |
+
|
75 |
+
# mypy
|
76 |
+
.mypy_cache
|
77 |
+
|
78 |
+
# celery beat schedule file
|
79 |
+
celerybeat-schedule
|
80 |
+
|
81 |
+
# dotenv
|
82 |
+
.env
|
83 |
+
|
84 |
+
# virtualenv
|
85 |
+
venv/
|
86 |
+
ENV/
|
87 |
+
|
88 |
+
# Spyder project settings
|
89 |
+
.spyderproject
|
90 |
+
|
91 |
+
# Rope project settings
|
92 |
+
.ropeproject
|
93 |
+
|
94 |
+
# PyCharm
|
95 |
+
.idea/
|
96 |
+
|
97 |
+
# For mac
|
98 |
+
.DS_Store
|
models/AUTHORS
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is the official list of authors for copyright purposes.
|
2 |
+
# This file is distinct from the CONTRIBUTORS files.
|
3 |
+
# See the latter for an explanation.
|
4 |
+
|
5 |
+
# Names should be added to this file as:
|
6 |
+
# Name or Organization <email address>
|
7 |
+
# The email address is not required for organizations.
|
8 |
+
|
9 |
+
Google Inc.
|
10 |
+
David Dao <daviddao@broad.mit.edu>
|
models/CODEOWNERS
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
* @tensorflow/tf-garden-team @tensorflow/tf-model-garden-team
|
2 |
+
/official/ @rachellj218 @saberkun @jaeyounkim
|
3 |
+
/official/nlp/ @saberkun @chenGitHuber @lehougoogle @rachellj218
|
4 |
+
/official/vision/ @pengchongjin @xianzhidu @yeqingli @arashwan @saberkun @rachellj218
|
5 |
+
/research/adv_imagenet_models/ @alexeykurakin
|
6 |
+
/research/adversarial_crypto/ @dave-andersen
|
7 |
+
/research/adversarial_logit_pairing/ @alexeykurakin
|
8 |
+
/research/adversarial_text/ @rsepassi @a-dai
|
9 |
+
/research/attention_ocr/ @xavigibert
|
10 |
+
/research/audioset/ @plakal @dpwe
|
11 |
+
/research/autoaugment/* @barretzoph
|
12 |
+
/research/autoencoders/ @snurkabill
|
13 |
+
/research/brain_coder/ @danabo
|
14 |
+
/research/cognitive_mapping_and_planning/ @s-gupta
|
15 |
+
/research/compression/ @nmjohn
|
16 |
+
/research/cvt_text/ @clarkkev @lmthang
|
17 |
+
/research/deep_contextual_bandits/ @rikel
|
18 |
+
/research/deep_speech/ @yhliang2018
|
19 |
+
/research/deeplab/ @aquariusjay @yknzhu @gpapan
|
20 |
+
/research/delf/ @andrefaraujo
|
21 |
+
/research/domain_adaptation/ @bousmalis @dmrd
|
22 |
+
/research/efficient-hrl/ @ofirnachum
|
23 |
+
/research/feelvos/ @pvoigtlaender @yuningchai @aquariusjay
|
24 |
+
/research/fivo/ @dieterichlawson
|
25 |
+
/research/global_objectives/ @mackeya-google
|
26 |
+
/research/im2txt/ @cshallue
|
27 |
+
/research/inception/ @shlens @vincentvanhoucke
|
28 |
+
/research/keypointnet/ @mnorouzi
|
29 |
+
/research/learned_optimizer/ @olganw @nirum
|
30 |
+
/research/learning_to_remember_rare_events/ @lukaszkaiser @ofirnachum
|
31 |
+
/research/learning_unsupervised_learning/ @lukemetz @nirum
|
32 |
+
/research/lexnet_nc/ @vered1986 @waterson
|
33 |
+
/research/lfads/ @jazcollins @sussillo
|
34 |
+
/research/lm_1b/ @oriolvinyals @panyx0718
|
35 |
+
/research/lm_commonsense/ @thtrieu
|
36 |
+
/research/lstm_object_detection/ @yinxiaoli @yongzhe2160
|
37 |
+
/research/marco/ @vincentvanhoucke
|
38 |
+
/research/maskgan/ @liamb315 @a-dai
|
39 |
+
/research/namignizer/ @knathanieltucker
|
40 |
+
/research/neural_gpu/ @lukaszkaiser
|
41 |
+
/research/neural_programmer/ @arvind2505
|
42 |
+
/research/next_frame_prediction/ @panyx0718
|
43 |
+
/research/object_detection/ @jch1 @tombstone @pkulzc
|
44 |
+
/research/pcl_rl/ @ofirnachum
|
45 |
+
/research/ptn/ @xcyan @arkanath @hellojas @honglaklee
|
46 |
+
/research/qa_kg/ @yuyuz
|
47 |
+
/research/real_nvp/ @laurent-dinh
|
48 |
+
/research/rebar/ @gjtucker
|
49 |
+
/research/sentiment_analysis/ @sculd
|
50 |
+
/research/seq2species/ @apbusia @depristo
|
51 |
+
/research/skip_thoughts/ @cshallue
|
52 |
+
/research/slim/ @sguada @marksandler2
|
53 |
+
/research/steve/ @buckman-google
|
54 |
+
/research/street/ @theraysmith
|
55 |
+
/research/struct2depth/ @aneliaangelova
|
56 |
+
/research/swivel/ @waterson
|
57 |
+
/research/tcn/ @coreylynch @sermanet
|
58 |
+
/research/textsum/ @panyx0718 @peterjliu
|
59 |
+
/research/transformer/ @daviddao
|
60 |
+
/research/vid2depth/ @rezama
|
61 |
+
/research/video_prediction/ @cbfinn
|
models/CONTRIBUTING.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How to contribute
|
2 |
+
|
3 |
+
![Contributors](https://img.shields.io/github/contributors/tensorflow/models)
|
4 |
+
|
5 |
+
We encourage you to contribute to the TensorFlow Model Garden.
|
6 |
+
|
7 |
+
Please read our [guidelines](../../wiki/How-to-contribute) for details.
|
8 |
+
|
9 |
+
**NOTE**: Only [code owners](./CODEOWNERS) are allowed to merge a pull request.
|
10 |
+
Please contact the code owners of each model to merge your pull request.
|
models/ISSUES.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# If you open a GitHub issue, here is our policy.
|
2 |
+
|
3 |
+
* It must be a **bug**, a **feature request**, or a significant problem
|
4 |
+
with **documentation**.
|
5 |
+
* Please send a pull request instead for small documentation fixes.
|
6 |
+
* The required form must be filled out.
|
7 |
+
* The issue should be related to the repository it is created in.
|
8 |
+
|
9 |
+
General help and support should be sought on [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow-model-garden) or other non-GitHub channels.
|
10 |
+
|
11 |
+
[![](https://img.shields.io/stackexchange/stackoverflow/t/tensorflow-model-garden)](https://stackoverflow.com/questions/tagged/tensorflow-model-garden)
|
12 |
+
|
13 |
+
TensorFlow developers respond to issues.
|
14 |
+
We want to focus on work that benefits the whole community such as fixing bugs
|
15 |
+
and adding new features.
|
16 |
+
It helps us to address bugs and feature requests in a timely manner.
|
17 |
+
|
18 |
+
---
|
19 |
+
|
20 |
+
Please understand that research models in the [research directory](https://github.com/tensorflow/models/tree/master/research)
|
21 |
+
included in this repository are experimental and research-style code.
|
22 |
+
They are not officially supported by the TensorFlow team.
|
23 |
+
|
24 |
+
|
models/LICENSE
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2016 The TensorFlow Authors. All rights reserved.
|
2 |
+
|
3 |
+
Apache License
|
4 |
+
Version 2.0, January 2004
|
5 |
+
http://www.apache.org/licenses/
|
6 |
+
|
7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
8 |
+
|
9 |
+
1. Definitions.
|
10 |
+
|
11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
13 |
+
|
14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
15 |
+
the copyright owner that is granting the License.
|
16 |
+
|
17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
18 |
+
other entities that control, are controlled by, or are under common
|
19 |
+
control with that entity. For the purposes of this definition,
|
20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
21 |
+
direction or management of such entity, whether by contract or
|
22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
26 |
+
exercising permissions granted by this License.
|
27 |
+
|
28 |
+
"Source" form shall mean the preferred form for making modifications,
|
29 |
+
including but not limited to software source code, documentation
|
30 |
+
source, and configuration files.
|
31 |
+
|
32 |
+
"Object" form shall mean any form resulting from mechanical
|
33 |
+
transformation or translation of a Source form, including but
|
34 |
+
not limited to compiled object code, generated documentation,
|
35 |
+
and conversions to other media types.
|
36 |
+
|
37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
38 |
+
Object form, made available under the License, as indicated by a
|
39 |
+
copyright notice that is included in or attached to the work
|
40 |
+
(an example is provided in the Appendix below).
|
41 |
+
|
42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
43 |
+
form, that is based on (or derived from) the Work and for which the
|
44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
46 |
+
of this License, Derivative Works shall not include works that remain
|
47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
48 |
+
the Work and Derivative Works thereof.
|
49 |
+
|
50 |
+
"Contribution" shall mean any work of authorship, including
|
51 |
+
the original version of the Work and any modifications or additions
|
52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
56 |
+
means any form of electronic, verbal, or written communication sent
|
57 |
+
to the Licensor or its representatives, including but not limited to
|
58 |
+
communication on electronic mailing lists, source code control systems,
|
59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
61 |
+
excluding communication that is conspicuously marked or otherwise
|
62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
63 |
+
|
64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
66 |
+
subsequently incorporated within the Work.
|
67 |
+
|
68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
73 |
+
Work and such Derivative Works in Source or Object form.
|
74 |
+
|
75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
78 |
+
(except as stated in this section) patent license to make, have made,
|
79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
80 |
+
where such license applies only to those patent claims licensable
|
81 |
+
by such Contributor that are necessarily infringed by their
|
82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
84 |
+
institute patent litigation against any entity (including a
|
85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
86 |
+
or a Contribution incorporated within the Work constitutes direct
|
87 |
+
or contributory patent infringement, then any patent licenses
|
88 |
+
granted to You under this License for that Work shall terminate
|
89 |
+
as of the date such litigation is filed.
|
90 |
+
|
91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
92 |
+
Work or Derivative Works thereof in any medium, with or without
|
93 |
+
modifications, and in Source or Object form, provided that You
|
94 |
+
meet the following conditions:
|
95 |
+
|
96 |
+
(a) You must give any other recipients of the Work or
|
97 |
+
Derivative Works a copy of this License; and
|
98 |
+
|
99 |
+
(b) You must cause any modified files to carry prominent notices
|
100 |
+
stating that You changed the files; and
|
101 |
+
|
102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
103 |
+
that You distribute, all copyright, patent, trademark, and
|
104 |
+
attribution notices from the Source form of the Work,
|
105 |
+
excluding those notices that do not pertain to any part of
|
106 |
+
the Derivative Works; and
|
107 |
+
|
108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
109 |
+
distribution, then any Derivative Works that You distribute must
|
110 |
+
include a readable copy of the attribution notices contained
|
111 |
+
within such NOTICE file, excluding those notices that do not
|
112 |
+
pertain to any part of the Derivative Works, in at least one
|
113 |
+
of the following places: within a NOTICE text file distributed
|
114 |
+
as part of the Derivative Works; within the Source form or
|
115 |
+
documentation, if provided along with the Derivative Works; or,
|
116 |
+
within a display generated by the Derivative Works, if and
|
117 |
+
wherever such third-party notices normally appear. The contents
|
118 |
+
of the NOTICE file are for informational purposes only and
|
119 |
+
do not modify the License. You may add Your own attribution
|
120 |
+
notices within Derivative Works that You distribute, alongside
|
121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
122 |
+
that such additional attribution notices cannot be construed
|
123 |
+
as modifying the License.
|
124 |
+
|
125 |
+
You may add Your own copyright statement to Your modifications and
|
126 |
+
may provide additional or different license terms and conditions
|
127 |
+
for use, reproduction, or distribution of Your modifications, or
|
128 |
+
for any such Derivative Works as a whole, provided Your use,
|
129 |
+
reproduction, and distribution of the Work otherwise complies with
|
130 |
+
the conditions stated in this License.
|
131 |
+
|
132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
134 |
+
by You to the Licensor shall be under the terms and conditions of
|
135 |
+
this License, without any additional terms or conditions.
|
136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
137 |
+
the terms of any separate license agreement you may have executed
|
138 |
+
with Licensor regarding such Contributions.
|
139 |
+
|
140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
142 |
+
except as required for reasonable and customary use in describing the
|
143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
144 |
+
|
145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
146 |
+
agreed to in writing, Licensor provides the Work (and each
|
147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
149 |
+
implied, including, without limitation, any warranties or conditions
|
150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
152 |
+
appropriateness of using or redistributing the Work and assume any
|
153 |
+
risks associated with Your exercise of permissions under this License.
|
154 |
+
|
155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
156 |
+
whether in tort (including negligence), contract, or otherwise,
|
157 |
+
unless required by applicable law (such as deliberate and grossly
|
158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
159 |
+
liable to You for damages, including any direct, indirect, special,
|
160 |
+
incidental, or consequential damages of any character arising as a
|
161 |
+
result of this License or out of the use or inability to use the
|
162 |
+
Work (including but not limited to damages for loss of goodwill,
|
163 |
+
work stoppage, computer failure or malfunction, or any and all
|
164 |
+
other commercial damages or losses), even if such Contributor
|
165 |
+
has been advised of the possibility of such damages.
|
166 |
+
|
167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
170 |
+
or other liability obligations and/or rights consistent with this
|
171 |
+
License. However, in accepting such obligations, You may act only
|
172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
173 |
+
of any other Contributor, and only if You agree to indemnify,
|
174 |
+
defend, and hold each Contributor harmless for any liability
|
175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
176 |
+
of your accepting any such warranty or additional liability.
|
177 |
+
|
178 |
+
END OF TERMS AND CONDITIONS
|
179 |
+
|
180 |
+
APPENDIX: How to apply the Apache License to your work.
|
181 |
+
|
182 |
+
To apply the Apache License to your work, attach the following
|
183 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
184 |
+
replaced with your own identifying information. (Don't include
|
185 |
+
the brackets!) The text should be enclosed in the appropriate
|
186 |
+
comment syntax for the file format. We also recommend that a
|
187 |
+
file or class name and description of purpose be included on the
|
188 |
+
same "printed page" as the copyright notice for easier
|
189 |
+
identification within third-party archives.
|
190 |
+
|
191 |
+
Copyright 2016, The Authors.
|
192 |
+
|
193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
194 |
+
you may not use this file except in compliance with the License.
|
195 |
+
You may obtain a copy of the License at
|
196 |
+
|
197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
198 |
+
|
199 |
+
Unless required by applicable law or agreed to in writing, software
|
200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
202 |
+
See the License for the specific language governing permissions and
|
203 |
+
limitations under the License.
|
models/README.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
![Logo](https://storage.googleapis.com/model_garden_artifacts/TF_Model_Garden.png)
|
2 |
+
|
3 |
+
# Welcome to the Model Garden for TensorFlow
|
4 |
+
|
5 |
+
The TensorFlow Model Garden is a repository with a number of different implementations of state-of-the-art (SOTA) models and modeling solutions for TensorFlow users. We aim to demonstrate the best practices for modeling so that TensorFlow users
|
6 |
+
can take full advantage of TensorFlow for their research and product development.
|
7 |
+
|
8 |
+
| Directory | Description |
|
9 |
+
|-----------|-------------|
|
10 |
+
| [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read |
|
11 |
+
| [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers |
|
12 |
+
| [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
|
13 |
+
|
14 |
+
## [Announcements](https://github.com/tensorflow/models/wiki/Announcements)
|
15 |
+
|
16 |
+
| Date | News |
|
17 |
+
|------|------|
|
18 |
+
| June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released
|
19 |
+
| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released
|
20 |
+
| May 19, 2020 | [MobileDets: Searching for Object Detection Architectures for Mobile Accelerators](https://github.com/tensorflow/models/tree/master/research/object_detection#may-19th-2020) released
|
21 |
+
| May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection
|
22 |
+
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1
|
23 |
+
| March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) |
|
24 |
+
|
25 |
+
## [Milestones](https://github.com/tensorflow/models/milestones)
|
26 |
+
|
27 |
+
| Date | Milestone |
|
28 |
+
|------|-----------|
|
29 |
+
| July 7, 2020 | [![GitHub milestone](https://img.shields.io/github/milestones/progress/tensorflow/models/1)](https://github.com/tensorflow/models/milestone/1) |
|
30 |
+
|
31 |
+
## Contributions
|
32 |
+
|
33 |
+
[![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
|
34 |
+
|
35 |
+
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
|
36 |
+
|
37 |
+
## License
|
38 |
+
|
39 |
+
[Apache License 2.0](LICENSE)
|
models/official/LICENSE
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2015 The TensorFlow Authors. All rights reserved.
|
2 |
+
|
3 |
+
Apache License
|
4 |
+
Version 2.0, January 2004
|
5 |
+
http://www.apache.org/licenses/
|
6 |
+
|
7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
8 |
+
|
9 |
+
1. Definitions.
|
10 |
+
|
11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
13 |
+
|
14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
15 |
+
the copyright owner that is granting the License.
|
16 |
+
|
17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
18 |
+
other entities that control, are controlled by, or are under common
|
19 |
+
control with that entity. For the purposes of this definition,
|
20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
21 |
+
direction or management of such entity, whether by contract or
|
22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
26 |
+
exercising permissions granted by this License.
|
27 |
+
|
28 |
+
"Source" form shall mean the preferred form for making modifications,
|
29 |
+
including but not limited to software source code, documentation
|
30 |
+
source, and configuration files.
|
31 |
+
|
32 |
+
"Object" form shall mean any form resulting from mechanical
|
33 |
+
transformation or translation of a Source form, including but
|
34 |
+
not limited to compiled object code, generated documentation,
|
35 |
+
and conversions to other media types.
|
36 |
+
|
37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
38 |
+
Object form, made available under the License, as indicated by a
|
39 |
+
copyright notice that is included in or attached to the work
|
40 |
+
(an example is provided in the Appendix below).
|
41 |
+
|
42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
43 |
+
form, that is based on (or derived from) the Work and for which the
|
44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
46 |
+
of this License, Derivative Works shall not include works that remain
|
47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
48 |
+
the Work and Derivative Works thereof.
|
49 |
+
|
50 |
+
"Contribution" shall mean any work of authorship, including
|
51 |
+
the original version of the Work and any modifications or additions
|
52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
56 |
+
means any form of electronic, verbal, or written communication sent
|
57 |
+
to the Licensor or its representatives, including but not limited to
|
58 |
+
communication on electronic mailing lists, source code control systems,
|
59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
61 |
+
excluding communication that is conspicuously marked or otherwise
|
62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
63 |
+
|
64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
66 |
+
subsequently incorporated within the Work.
|
67 |
+
|
68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
73 |
+
Work and such Derivative Works in Source or Object form.
|
74 |
+
|
75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
78 |
+
(except as stated in this section) patent license to make, have made,
|
79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
80 |
+
where such license applies only to those patent claims licensable
|
81 |
+
by such Contributor that are necessarily infringed by their
|
82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
84 |
+
institute patent litigation against any entity (including a
|
85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
86 |
+
or a Contribution incorporated within the Work constitutes direct
|
87 |
+
or contributory patent infringement, then any patent licenses
|
88 |
+
granted to You under this License for that Work shall terminate
|
89 |
+
as of the date such litigation is filed.
|
90 |
+
|
91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
92 |
+
Work or Derivative Works thereof in any medium, with or without
|
93 |
+
modifications, and in Source or Object form, provided that You
|
94 |
+
meet the following conditions:
|
95 |
+
|
96 |
+
(a) You must give any other recipients of the Work or
|
97 |
+
Derivative Works a copy of this License; and
|
98 |
+
|
99 |
+
(b) You must cause any modified files to carry prominent notices
|
100 |
+
stating that You changed the files; and
|
101 |
+
|
102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
103 |
+
that You distribute, all copyright, patent, trademark, and
|
104 |
+
attribution notices from the Source form of the Work,
|
105 |
+
excluding those notices that do not pertain to any part of
|
106 |
+
the Derivative Works; and
|
107 |
+
|
108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
109 |
+
distribution, then any Derivative Works that You distribute must
|
110 |
+
include a readable copy of the attribution notices contained
|
111 |
+
within such NOTICE file, excluding those notices that do not
|
112 |
+
pertain to any part of the Derivative Works, in at least one
|
113 |
+
of the following places: within a NOTICE text file distributed
|
114 |
+
as part of the Derivative Works; within the Source form or
|
115 |
+
documentation, if provided along with the Derivative Works; or,
|
116 |
+
within a display generated by the Derivative Works, if and
|
117 |
+
wherever such third-party notices normally appear. The contents
|
118 |
+
of the NOTICE file are for informational purposes only and
|
119 |
+
do not modify the License. You may add Your own attribution
|
120 |
+
notices within Derivative Works that You distribute, alongside
|
121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
122 |
+
that such additional attribution notices cannot be construed
|
123 |
+
as modifying the License.
|
124 |
+
|
125 |
+
You may add Your own copyright statement to Your modifications and
|
126 |
+
may provide additional or different license terms and conditions
|
127 |
+
for use, reproduction, or distribution of Your modifications, or
|
128 |
+
for any such Derivative Works as a whole, provided Your use,
|
129 |
+
reproduction, and distribution of the Work otherwise complies with
|
130 |
+
the conditions stated in this License.
|
131 |
+
|
132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
134 |
+
by You to the Licensor shall be under the terms and conditions of
|
135 |
+
this License, without any additional terms or conditions.
|
136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
137 |
+
the terms of any separate license agreement you may have executed
|
138 |
+
with Licensor regarding such Contributions.
|
139 |
+
|
140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
142 |
+
except as required for reasonable and customary use in describing the
|
143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
144 |
+
|
145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
146 |
+
agreed to in writing, Licensor provides the Work (and each
|
147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
149 |
+
implied, including, without limitation, any warranties or conditions
|
150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
152 |
+
appropriateness of using or redistributing the Work and assume any
|
153 |
+
risks associated with Your exercise of permissions under this License.
|
154 |
+
|
155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
156 |
+
whether in tort (including negligence), contract, or otherwise,
|
157 |
+
unless required by applicable law (such as deliberate and grossly
|
158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
159 |
+
liable to You for damages, including any direct, indirect, special,
|
160 |
+
incidental, or consequential damages of any character arising as a
|
161 |
+
result of this License or out of the use or inability to use the
|
162 |
+
Work (including but not limited to damages for loss of goodwill,
|
163 |
+
work stoppage, computer failure or malfunction, or any and all
|
164 |
+
other commercial damages or losses), even if such Contributor
|
165 |
+
has been advised of the possibility of such damages.
|
166 |
+
|
167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
170 |
+
or other liability obligations and/or rights consistent with this
|
171 |
+
License. However, in accepting such obligations, You may act only
|
172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
173 |
+
of any other Contributor, and only if You agree to indemnify,
|
174 |
+
defend, and hold each Contributor harmless for any liability
|
175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
176 |
+
of your accepting any such warranty or additional liability.
|
177 |
+
|
178 |
+
END OF TERMS AND CONDITIONS
|
179 |
+
|
180 |
+
APPENDIX: How to apply the Apache License to your work.
|
181 |
+
|
182 |
+
To apply the Apache License to your work, attach the following
|
183 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
184 |
+
replaced with your own identifying information. (Don't include
|
185 |
+
the brackets!) The text should be enclosed in the appropriate
|
186 |
+
comment syntax for the file format. We also recommend that a
|
187 |
+
file or class name and description of purpose be included on the
|
188 |
+
same "printed page" as the copyright notice for easier
|
189 |
+
identification within third-party archives.
|
190 |
+
|
191 |
+
Copyright 2015, The TensorFlow Authors.
|
192 |
+
|
193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
194 |
+
you may not use this file except in compliance with the License.
|
195 |
+
You may obtain a copy of the License at
|
196 |
+
|
197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
198 |
+
|
199 |
+
Unless required by applicable law or agreed to in writing, software
|
200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
202 |
+
See the License for the specific language governing permissions and
|
203 |
+
limitations under the License.
|
models/official/README-TPU.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Offically Supported TensorFlow 2.1+ Models on Cloud TPU
|
2 |
+
|
3 |
+
## Natural Language Processing
|
4 |
+
|
5 |
+
* [bert](nlp/bert): A powerful pre-trained language representation model:
|
6 |
+
BERT, which stands for Bidirectional Encoder Representations from
|
7 |
+
Transformers.
|
8 |
+
[BERT FineTuning with Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/bert-2.x) provides step by step instructions on Cloud TPU training. You can look [Bert MNLI Tensorboard.dev metrics](https://tensorboard.dev/experiment/LijZ1IrERxKALQfr76gndA) for MNLI fine tuning task.
|
9 |
+
* [transformer](nlp/transformer): A transformer model to translate the WMT
|
10 |
+
English to German dataset.
|
11 |
+
[Training transformer on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/transformer-2.x) for step by step instructions on Cloud TPU training.
|
12 |
+
|
13 |
+
## Computer Vision
|
14 |
+
|
15 |
+
* [efficientnet](vision/image_classification): A family of convolutional
|
16 |
+
neural networks that scale by balancing network depth, width, and
|
17 |
+
resolution and can be used to classify ImageNet's dataset of 1000 classes.
|
18 |
+
See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/KnaWjrq5TXGfv0NW5m7rpg/#scalars).
|
19 |
+
* [mnist](vision/image_classification): A basic model to classify digits
|
20 |
+
from the MNIST dataset. See [Running MNIST on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/mnist-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/mIah5lppTASvrHqWrdr6NA).
|
21 |
+
* [mask-rcnn](vision/detection): An object detection and instance segmentation model. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/LH7k0fMsRwqUAcE09o9kPA).
|
22 |
+
* [resnet](vision/image_classification): A deep residual network that can
|
23 |
+
be used to classify ImageNet's dataset of 1000 classes.
|
24 |
+
See [Training ResNet on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/resnet-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/CxlDK8YMRrSpYEGtBRpOhg).
|
25 |
+
* [retinanet](vision/detection): A fast and powerful object detector. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/b8NRnWU3TqG6Rw0UxueU6Q).
|
models/official/README.md
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
![Logo](https://storage.googleapis.com/model_garden_artifacts/TF_Model_Garden.png)
|
2 |
+
|
3 |
+
# TensorFlow Official Models
|
4 |
+
|
5 |
+
The TensorFlow official models are a collection of models
|
6 |
+
that use TensorFlow’s high-level APIs.
|
7 |
+
They are intended to be well-maintained, tested, and kept up to date
|
8 |
+
with the latest TensorFlow API.
|
9 |
+
|
10 |
+
They should also be reasonably optimized for fast performance while still
|
11 |
+
being easy to read.
|
12 |
+
These models are used as end-to-end tests, ensuring that the models run
|
13 |
+
with the same or improved speed and performance with each new TensorFlow build.
|
14 |
+
|
15 |
+
## More models to come!
|
16 |
+
|
17 |
+
The team is actively developing new models.
|
18 |
+
In the near future, we will add:
|
19 |
+
|
20 |
+
* State-of-the-art language understanding models:
|
21 |
+
More members in Transformer family
|
22 |
+
* Start-of-the-art image classification models:
|
23 |
+
EfficientNet, MnasNet, and variants
|
24 |
+
* A set of excellent objection detection models.
|
25 |
+
|
26 |
+
## Table of Contents
|
27 |
+
|
28 |
+
- [Models and Implementations](#models-and-implementations)
|
29 |
+
* [Computer Vision](#computer-vision)
|
30 |
+
+ [Image Classification](#image-classification)
|
31 |
+
+ [Object Detection and Segmentation](#object-detection-and-segmentation)
|
32 |
+
* [Natural Language Processing](#natural-language-processing)
|
33 |
+
* [Recommendation](#recommendation)
|
34 |
+
- [How to get started with the official models](#how-to-get-started-with-the-official-models)
|
35 |
+
|
36 |
+
## Models and Implementations
|
37 |
+
|
38 |
+
### Computer Vision
|
39 |
+
|
40 |
+
#### Image Classification
|
41 |
+
|
42 |
+
| Model | Reference (Paper) |
|
43 |
+
|-------|-------------------|
|
44 |
+
| [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) |
|
45 |
+
| [ResNet](vision/image_classification) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
|
46 |
+
| [EfficientNet](vision/image_classification) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
|
47 |
+
|
48 |
+
#### Object Detection and Segmentation
|
49 |
+
|
50 |
+
| Model | Reference (Paper) |
|
51 |
+
|-------|-------------------|
|
52 |
+
| [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
|
53 |
+
| [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
|
54 |
+
| [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) |
|
55 |
+
|
56 |
+
### Natural Language Processing
|
57 |
+
|
58 |
+
| Model | Reference (Paper) |
|
59 |
+
|-------|-------------------|
|
60 |
+
| [ALBERT (A Lite BERT)](nlp/albert) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
|
61 |
+
| [BERT (Bidirectional Encoder Representations from Transformers)](nlp/bert) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
|
62 |
+
| [NHNet (News Headline generation model)](nlp/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
|
63 |
+
| [Transformer](nlp/transformer) | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
|
64 |
+
| [XLNet](nlp/xlnet) | [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) |
|
65 |
+
|
66 |
+
### Recommendation
|
67 |
+
|
68 |
+
| Model | Reference (Paper) |
|
69 |
+
|-------|-------------------|
|
70 |
+
| [NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) |
|
71 |
+
|
72 |
+
## How to get started with the official models
|
73 |
+
|
74 |
+
* The models in the master branch are developed using TensorFlow 2,
|
75 |
+
and they target the TensorFlow [nightly binaries](https://github.com/tensorflow/tensorflow#installation)
|
76 |
+
built from the
|
77 |
+
[master branch of TensorFlow](https://github.com/tensorflow/tensorflow/tree/master).
|
78 |
+
* The stable versions targeting releases of TensorFlow are available
|
79 |
+
as tagged branches or [downloadable releases](https://github.com/tensorflow/models/releases).
|
80 |
+
* Model repository version numbers match the target TensorFlow release,
|
81 |
+
such that
|
82 |
+
[release v2.2.0](https://github.com/tensorflow/models/releases/tag/v2.2.0)
|
83 |
+
are compatible with
|
84 |
+
[TensorFlow v2.2.0](https://github.com/tensorflow/tensorflow/releases/tag/v2.2.0).
|
85 |
+
|
86 |
+
Please follow the below steps before running models in this repository.
|
87 |
+
|
88 |
+
### Requirements
|
89 |
+
|
90 |
+
* The latest TensorFlow Model Garden release and TensorFlow 2
|
91 |
+
* If you are on a version of TensorFlow earlier than 2.2, please
|
92 |
+
upgrade your TensorFlow to [the latest TensorFlow 2](https://www.tensorflow.org/install/).
|
93 |
+
|
94 |
+
```shell
|
95 |
+
pip3 install tf-nightly
|
96 |
+
```
|
97 |
+
|
98 |
+
### Installation
|
99 |
+
|
100 |
+
#### Method 1: Install the TensorFlow Model Garden pip package
|
101 |
+
|
102 |
+
**tf-models-nightly** is the nightly Model Garden package
|
103 |
+
created daily automatically. pip will install all models
|
104 |
+
and dependencies automatically.
|
105 |
+
|
106 |
+
```shell
|
107 |
+
pip install tf-models-nightly
|
108 |
+
```
|
109 |
+
|
110 |
+
Please check out our [example](colab/fine_tuning_bert.ipynb)
|
111 |
+
to learn how to use a PIP package.
|
112 |
+
|
113 |
+
#### Method 2: Clone the source
|
114 |
+
|
115 |
+
1. Clone the GitHub repository:
|
116 |
+
|
117 |
+
```shell
|
118 |
+
git clone https://github.com/tensorflow/models.git
|
119 |
+
```
|
120 |
+
|
121 |
+
2. Add the top-level ***/models*** folder to the Python path.
|
122 |
+
|
123 |
+
```shell
|
124 |
+
export PYTHONPATH=$PYTHONPATH:/path/to/models
|
125 |
+
```
|
126 |
+
|
127 |
+
If you are using a Colab notebook, please set the Python path with os.environ.
|
128 |
+
|
129 |
+
```python
|
130 |
+
import os
|
131 |
+
os.environ['PYTHONPATH'] += ":/path/to/models"
|
132 |
+
```
|
133 |
+
|
134 |
+
3. Install other dependencies
|
135 |
+
|
136 |
+
```shell
|
137 |
+
pip3 install --user -r official/requirements.txt
|
138 |
+
```
|
139 |
+
|
140 |
+
## Contributions
|
141 |
+
|
142 |
+
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
|
models/official/__init__.py
ADDED
File without changes
|
models/official/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (148 Bytes). View file
|
|
models/official/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (146 Bytes). View file
|
|
models/official/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (139 Bytes). View file
|
|
models/official/benchmark/__init__.py
ADDED
File without changes
|
models/official/benchmark/benchmark_wrappers.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Lint as: python3
|
2 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ==============================================================================
|
16 |
+
"""Utils to annotate and trace benchmarks."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
from absl import flags
|
23 |
+
from absl import logging
|
24 |
+
from absl.testing import flagsaver
|
25 |
+
|
26 |
+
FLAGS = flags.FLAGS
|
27 |
+
|
28 |
+
flags.DEFINE_multi_string(
|
29 |
+
'benchmark_method_flags', None,
|
30 |
+
'Optional list of runtime flags of the form key=value. Specify '
|
31 |
+
'multiple times to specify different flags. These will override the FLAGS '
|
32 |
+
'object directly after hardcoded settings in individual benchmark methods '
|
33 |
+
'before they call _run_and_report benchmark. Example if we set '
|
34 |
+
'--benchmark_method_flags=train_steps=10 and a benchmark method hardcodes '
|
35 |
+
'FLAGS.train_steps=10000 and later calls _run_and_report_benchmark, '
|
36 |
+
'it\'ll only run for 10 steps. This is useful for '
|
37 |
+
'debugging/profiling workflows.')
|
38 |
+
|
39 |
+
|
40 |
+
def enable_runtime_flags(decorated_func):
|
41 |
+
"""Sets attributes from --benchmark_method_flags for method execution.
|
42 |
+
|
43 |
+
@enable_runtime_flags decorator temporarily adds flags passed in via
|
44 |
+
--benchmark_method_flags and runs the decorated function in that context.
|
45 |
+
|
46 |
+
A user can set --benchmark_method_flags=train_steps=5 to run the benchmark
|
47 |
+
method in the snippet below with FLAGS.train_steps=5 for debugging (without
|
48 |
+
modifying the benchmark code).
|
49 |
+
|
50 |
+
class ModelBenchmark():
|
51 |
+
|
52 |
+
@benchmark_wrappers.enable_runtime_flags
|
53 |
+
def _run_and_report_benchmark(self):
|
54 |
+
# run benchmark ...
|
55 |
+
# report benchmark results ...
|
56 |
+
|
57 |
+
def benchmark_method(self):
|
58 |
+
FLAGS.train_steps = 1000
|
59 |
+
...
|
60 |
+
self._run_and_report_benchmark()
|
61 |
+
|
62 |
+
Args:
|
63 |
+
decorated_func: The method that runs the benchmark after previous setup
|
64 |
+
execution that set some flags.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
new_func: The same method which executes in a temporary context where flag
|
68 |
+
overrides from --benchmark_method_flags are active.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def runner(*args, **kwargs):
|
72 |
+
"""Creates a temporary context to activate --benchmark_method_flags."""
|
73 |
+
if FLAGS.benchmark_method_flags:
|
74 |
+
saved_flag_values = flagsaver.save_flag_values()
|
75 |
+
for key_value in FLAGS.benchmark_method_flags:
|
76 |
+
key, value = key_value.split('=', 1)
|
77 |
+
try:
|
78 |
+
numeric_float = float(value)
|
79 |
+
numeric_int = int(numeric_float)
|
80 |
+
if abs(numeric_int) == abs(numeric_float):
|
81 |
+
flag_value = numeric_int
|
82 |
+
else:
|
83 |
+
flag_value = numeric_float
|
84 |
+
except ValueError:
|
85 |
+
flag_value = value
|
86 |
+
logging.info('Setting --%s=%s', key, flag_value)
|
87 |
+
setattr(FLAGS, key, flag_value)
|
88 |
+
else:
|
89 |
+
saved_flag_values = None
|
90 |
+
try:
|
91 |
+
result = decorated_func(*args, **kwargs)
|
92 |
+
return result
|
93 |
+
finally:
|
94 |
+
if saved_flag_values:
|
95 |
+
flagsaver.restore_flag_values(saved_flag_values)
|
96 |
+
|
97 |
+
return runner
|
models/official/benchmark/bert_benchmark.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Executes BERT benchmarks and accuracy tests."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import functools
|
22 |
+
import json
|
23 |
+
import math
|
24 |
+
import os
|
25 |
+
import time
|
26 |
+
|
27 |
+
# pylint: disable=g-bad-import-order
|
28 |
+
from absl import flags
|
29 |
+
from absl.testing import flagsaver
|
30 |
+
import tensorflow as tf
|
31 |
+
# pylint: enable=g-bad-import-order
|
32 |
+
|
33 |
+
from official.benchmark import bert_benchmark_utils as benchmark_utils
|
34 |
+
from official.benchmark import owner_utils
|
35 |
+
from official.nlp.bert import configs
|
36 |
+
from official.nlp.bert import run_classifier
|
37 |
+
from official.utils.misc import distribution_utils
|
38 |
+
from official.benchmark import benchmark_wrappers
|
39 |
+
|
40 |
+
# pylint: disable=line-too-long
|
41 |
+
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
|
42 |
+
CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_train.tf_record'
|
43 |
+
CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_eval.tf_record'
|
44 |
+
CLASSIFIER_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_meta_data'
|
45 |
+
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
|
46 |
+
# pylint: enable=line-too-long
|
47 |
+
|
48 |
+
TMP_DIR = os.getenv('TMPDIR')
|
49 |
+
FLAGS = flags.FLAGS
|
50 |
+
|
51 |
+
|
52 |
+
class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
|
53 |
+
"""Base class to hold methods common to test classes in the module."""
|
54 |
+
|
55 |
+
def __init__(self, output_dir=None, tpu=None):
|
56 |
+
super(BertClassifyBenchmarkBase, self).__init__(output_dir, tpu=tpu)
|
57 |
+
self.num_epochs = None
|
58 |
+
self.num_steps_per_epoch = None
|
59 |
+
FLAGS.steps_per_loop = 1
|
60 |
+
|
61 |
+
@flagsaver.flagsaver
|
62 |
+
def _run_bert_classifier(self, callbacks=None, use_ds=True):
|
63 |
+
"""Starts BERT classification task."""
|
64 |
+
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
|
65 |
+
input_meta_data = json.loads(reader.read().decode('utf-8'))
|
66 |
+
|
67 |
+
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
68 |
+
epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs
|
69 |
+
if self.num_steps_per_epoch:
|
70 |
+
steps_per_epoch = self.num_steps_per_epoch
|
71 |
+
else:
|
72 |
+
train_data_size = input_meta_data['train_data_size']
|
73 |
+
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
|
74 |
+
warmup_steps = int(epochs * steps_per_epoch * 0.1)
|
75 |
+
eval_steps = int(
|
76 |
+
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
|
77 |
+
if self.tpu:
|
78 |
+
strategy = distribution_utils.get_distribution_strategy(
|
79 |
+
distribution_strategy='tpu', tpu_address=self.tpu)
|
80 |
+
else:
|
81 |
+
strategy = distribution_utils.get_distribution_strategy(
|
82 |
+
distribution_strategy='mirrored' if use_ds else 'off',
|
83 |
+
num_gpus=self.num_gpus)
|
84 |
+
|
85 |
+
max_seq_length = input_meta_data['max_seq_length']
|
86 |
+
train_input_fn = run_classifier.get_dataset_fn(
|
87 |
+
FLAGS.train_data_path,
|
88 |
+
max_seq_length,
|
89 |
+
FLAGS.train_batch_size,
|
90 |
+
is_training=True)
|
91 |
+
eval_input_fn = run_classifier.get_dataset_fn(
|
92 |
+
FLAGS.eval_data_path,
|
93 |
+
max_seq_length,
|
94 |
+
FLAGS.eval_batch_size,
|
95 |
+
is_training=False)
|
96 |
+
_, summary = run_classifier.run_bert_classifier(
|
97 |
+
strategy,
|
98 |
+
bert_config,
|
99 |
+
input_meta_data,
|
100 |
+
FLAGS.model_dir,
|
101 |
+
epochs,
|
102 |
+
steps_per_epoch,
|
103 |
+
FLAGS.steps_per_loop,
|
104 |
+
eval_steps,
|
105 |
+
warmup_steps,
|
106 |
+
FLAGS.learning_rate,
|
107 |
+
FLAGS.init_checkpoint,
|
108 |
+
train_input_fn,
|
109 |
+
eval_input_fn,
|
110 |
+
training_callbacks=False,
|
111 |
+
custom_callbacks=callbacks)
|
112 |
+
return summary
|
113 |
+
|
114 |
+
|
115 |
+
class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
|
116 |
+
"""Short benchmark performance tests for BERT model.
|
117 |
+
|
118 |
+
Tests BERT classification performance in different GPU, TPU configurations.
|
119 |
+
The naming convention of below test cases follow
|
120 |
+
`benchmark_(number of gpus)_gpu_(dataset type)` for GPUs and
|
121 |
+
`benchmark_(topology)_tpu_(dataset type)` for TPUs.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
|
125 |
+
super(BertClassifyBenchmarkReal, self).__init__(
|
126 |
+
output_dir=output_dir, tpu=tpu)
|
127 |
+
|
128 |
+
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
|
129 |
+
self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
|
130 |
+
self.bert_config_file = MODEL_CONFIG_FILE_PATH
|
131 |
+
self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
|
132 |
+
|
133 |
+
# Since we only care about performance metrics, we limit
|
134 |
+
# the number of training steps and epochs to prevent unnecessarily
|
135 |
+
# long tests.
|
136 |
+
self.num_steps_per_epoch = 100
|
137 |
+
self.num_epochs = 1
|
138 |
+
|
139 |
+
@benchmark_wrappers.enable_runtime_flags
|
140 |
+
def _run_and_report_benchmark(self,
|
141 |
+
training_summary_path,
|
142 |
+
min_accuracy=0,
|
143 |
+
max_accuracy=1,
|
144 |
+
use_ds=True):
|
145 |
+
"""Starts BERT performance benchmark test."""
|
146 |
+
start_time_sec = time.time()
|
147 |
+
summary = self._run_bert_classifier(
|
148 |
+
callbacks=[self.timer_callback], use_ds=use_ds)
|
149 |
+
wall_time_sec = time.time() - start_time_sec
|
150 |
+
|
151 |
+
# Since we do not load from any pretrained checkpoints, we ignore all
|
152 |
+
# accuracy metrics.
|
153 |
+
summary.pop('eval_metrics', None)
|
154 |
+
summary['start_time_sec'] = start_time_sec
|
155 |
+
|
156 |
+
super(BertClassifyBenchmarkReal, self)._report_benchmark(
|
157 |
+
stats=summary,
|
158 |
+
wall_time_sec=wall_time_sec,
|
159 |
+
min_accuracy=min_accuracy,
|
160 |
+
max_accuracy=max_accuracy)
|
161 |
+
|
162 |
+
def benchmark_1_gpu_mrpc(self):
|
163 |
+
"""Test BERT model performance with 1 GPU."""
|
164 |
+
|
165 |
+
self._setup()
|
166 |
+
self.num_gpus = 1
|
167 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc')
|
168 |
+
FLAGS.train_data_path = self.train_data_path
|
169 |
+
FLAGS.eval_data_path = self.eval_data_path
|
170 |
+
FLAGS.input_meta_data_path = self.input_meta_data_path
|
171 |
+
FLAGS.bert_config_file = self.bert_config_file
|
172 |
+
FLAGS.train_batch_size = 4
|
173 |
+
FLAGS.eval_batch_size = 4
|
174 |
+
|
175 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
176 |
+
'summaries/training_summary.txt')
|
177 |
+
self._run_and_report_benchmark(summary_path)
|
178 |
+
|
179 |
+
def benchmark_1_gpu_mrpc_xla(self):
|
180 |
+
"""Test BERT model performance with 1 GPU."""
|
181 |
+
|
182 |
+
self._setup()
|
183 |
+
self.num_gpus = 1
|
184 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc_xla')
|
185 |
+
FLAGS.train_data_path = self.train_data_path
|
186 |
+
FLAGS.eval_data_path = self.eval_data_path
|
187 |
+
FLAGS.input_meta_data_path = self.input_meta_data_path
|
188 |
+
FLAGS.bert_config_file = self.bert_config_file
|
189 |
+
FLAGS.train_batch_size = 4
|
190 |
+
FLAGS.eval_batch_size = 4
|
191 |
+
FLAGS.enable_xla = True
|
192 |
+
|
193 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
194 |
+
'summaries/training_summary.txt')
|
195 |
+
self._run_and_report_benchmark(summary_path)
|
196 |
+
|
197 |
+
def benchmark_1_gpu_mrpc_no_dist_strat(self):
|
198 |
+
"""Test BERT model performance with 1 GPU, no distribution strategy."""
|
199 |
+
|
200 |
+
self._setup()
|
201 |
+
self.num_gpus = 1
|
202 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc_no_dist_strat')
|
203 |
+
FLAGS.train_data_path = self.train_data_path
|
204 |
+
FLAGS.eval_data_path = self.eval_data_path
|
205 |
+
FLAGS.input_meta_data_path = self.input_meta_data_path
|
206 |
+
FLAGS.bert_config_file = self.bert_config_file
|
207 |
+
FLAGS.train_batch_size = 4
|
208 |
+
FLAGS.eval_batch_size = 4
|
209 |
+
|
210 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
211 |
+
'summaries/training_summary.txt')
|
212 |
+
self._run_and_report_benchmark(summary_path, use_ds=False)
|
213 |
+
|
214 |
+
@owner_utils.Owner('tf-model-garden')
|
215 |
+
def benchmark_8_gpu_mrpc(self):
|
216 |
+
"""Test BERT model performance with 8 GPUs."""
|
217 |
+
|
218 |
+
self._setup()
|
219 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
|
220 |
+
FLAGS.train_data_path = self.train_data_path
|
221 |
+
FLAGS.eval_data_path = self.eval_data_path
|
222 |
+
FLAGS.input_meta_data_path = self.input_meta_data_path
|
223 |
+
FLAGS.bert_config_file = self.bert_config_file
|
224 |
+
|
225 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
226 |
+
'summaries/training_summary.txt')
|
227 |
+
self._run_and_report_benchmark(summary_path)
|
228 |
+
|
229 |
+
def benchmark_1_gpu_amp_mrpc_no_dist_strat(self):
|
230 |
+
"""Performance for 1 GPU no DS with automatic mixed precision."""
|
231 |
+
self._setup()
|
232 |
+
self.num_gpus = 1
|
233 |
+
FLAGS.model_dir = self._get_model_dir(
|
234 |
+
'benchmark_1_gpu_amp_mrpc_no_dist_strat')
|
235 |
+
FLAGS.train_data_path = self.train_data_path
|
236 |
+
FLAGS.eval_data_path = self.eval_data_path
|
237 |
+
FLAGS.input_meta_data_path = self.input_meta_data_path
|
238 |
+
FLAGS.bert_config_file = self.bert_config_file
|
239 |
+
FLAGS.train_batch_size = 4
|
240 |
+
FLAGS.eval_batch_size = 4
|
241 |
+
FLAGS.dtype = 'fp16'
|
242 |
+
FLAGS.fp16_implementation = 'graph_rewrite'
|
243 |
+
|
244 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
245 |
+
'summaries/training_summary.txt')
|
246 |
+
self._run_and_report_benchmark(summary_path, use_ds=False)
|
247 |
+
|
248 |
+
def benchmark_8_gpu_amp_mrpc(self):
|
249 |
+
"""Test BERT model performance with 8 GPUs with automatic mixed precision."""
|
250 |
+
|
251 |
+
self._setup()
|
252 |
+
self.num_gpus = 8
|
253 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_mrpc')
|
254 |
+
FLAGS.train_data_path = self.train_data_path
|
255 |
+
FLAGS.eval_data_path = self.eval_data_path
|
256 |
+
FLAGS.input_meta_data_path = self.input_meta_data_path
|
257 |
+
FLAGS.bert_config_file = self.bert_config_file
|
258 |
+
FLAGS.train_batch_size = 32
|
259 |
+
FLAGS.eval_batch_size = 32
|
260 |
+
FLAGS.dtype = 'fp16'
|
261 |
+
FLAGS.fp16_implementation = 'graph_rewrite'
|
262 |
+
|
263 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
264 |
+
'summaries/training_summary.txt')
|
265 |
+
self._run_and_report_benchmark(summary_path, use_ds=False)
|
266 |
+
|
267 |
+
@owner_utils.Owner('tf-model-garden')
|
268 |
+
def benchmark_2x2_tpu_mrpc(self):
|
269 |
+
"""Test BERT model performance with 2x2 TPU."""
|
270 |
+
|
271 |
+
self._setup()
|
272 |
+
FLAGS.steps_per_loop = 50
|
273 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mrpc')
|
274 |
+
FLAGS.train_data_path = self.train_data_path
|
275 |
+
FLAGS.eval_data_path = self.eval_data_path
|
276 |
+
FLAGS.input_meta_data_path = self.input_meta_data_path
|
277 |
+
FLAGS.bert_config_file = self.bert_config_file
|
278 |
+
FLAGS.train_batch_size = 32
|
279 |
+
FLAGS.eval_batch_size = 32
|
280 |
+
|
281 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
282 |
+
'summaries/training_summary.txt')
|
283 |
+
self._run_and_report_benchmark(summary_path, use_ds=False)
|
284 |
+
|
285 |
+
|
286 |
+
class BertClassifyAccuracy(BertClassifyBenchmarkBase):
|
287 |
+
"""Short accuracy test for BERT model.
|
288 |
+
|
289 |
+
Tests BERT classification task model accuracy. The naming
|
290 |
+
convention of below test cases follow
|
291 |
+
`benchmark_(number of gpus)_gpu_(dataset type)` format.
|
292 |
+
"""
|
293 |
+
|
294 |
+
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
|
295 |
+
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
|
296 |
+
self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
|
297 |
+
self.bert_config_file = MODEL_CONFIG_FILE_PATH
|
298 |
+
self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
|
299 |
+
self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
|
300 |
+
|
301 |
+
super(BertClassifyAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
|
302 |
+
|
303 |
+
@benchmark_wrappers.enable_runtime_flags
|
304 |
+
def _run_and_report_benchmark(self,
|
305 |
+
training_summary_path,
|
306 |
+
min_accuracy=0.84,
|
307 |
+
max_accuracy=0.88):
|
308 |
+
"""Starts BERT accuracy benchmark test."""
|
309 |
+
|
310 |
+
start_time_sec = time.time()
|
311 |
+
summary = self._run_bert_classifier(callbacks=[self.timer_callback])
|
312 |
+
wall_time_sec = time.time() - start_time_sec
|
313 |
+
|
314 |
+
super(BertClassifyAccuracy, self)._report_benchmark(
|
315 |
+
stats=summary,
|
316 |
+
wall_time_sec=wall_time_sec,
|
317 |
+
min_accuracy=min_accuracy,
|
318 |
+
max_accuracy=max_accuracy)
|
319 |
+
|
320 |
+
def _setup(self):
|
321 |
+
super(BertClassifyAccuracy, self)._setup()
|
322 |
+
FLAGS.train_data_path = self.train_data_path
|
323 |
+
FLAGS.eval_data_path = self.eval_data_path
|
324 |
+
FLAGS.input_meta_data_path = self.input_meta_data_path
|
325 |
+
FLAGS.bert_config_file = self.bert_config_file
|
326 |
+
FLAGS.init_checkpoint = self.pretrained_checkpoint_path
|
327 |
+
|
328 |
+
@owner_utils.Owner('tf-model-garden')
|
329 |
+
def benchmark_8_gpu_mrpc(self):
|
330 |
+
"""Run BERT model accuracy test with 8 GPUs.
|
331 |
+
|
332 |
+
Due to comparatively small cardinality of MRPC dataset, training
|
333 |
+
accuracy metric has high variance between trainings. As so, we
|
334 |
+
set the wide range of allowed accuracy (84% to 88%).
|
335 |
+
"""
|
336 |
+
self._setup()
|
337 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
|
338 |
+
|
339 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
340 |
+
'summaries/training_summary.txt')
|
341 |
+
self._run_and_report_benchmark(summary_path)
|
342 |
+
|
343 |
+
def benchmark_8_gpu_mrpc_xla(self):
|
344 |
+
"""Run BERT model accuracy test with 8 GPUs with XLA."""
|
345 |
+
self._setup()
|
346 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc_xla')
|
347 |
+
FLAGS.enable_xla = True
|
348 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
349 |
+
'summaries/training_summary.txt')
|
350 |
+
self._run_and_report_benchmark(summary_path)
|
351 |
+
|
352 |
+
@owner_utils.Owner('tf-model-garden')
|
353 |
+
def benchmark_2x2_tpu_mrpc(self):
|
354 |
+
"""Run BERT model accuracy test on 2x2 TPU."""
|
355 |
+
self._setup()
|
356 |
+
FLAGS.steps_per_loop = 50
|
357 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mrpc')
|
358 |
+
|
359 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
360 |
+
'summaries/training_summary.txt')
|
361 |
+
self._run_and_report_benchmark(summary_path)
|
362 |
+
|
363 |
+
|
364 |
+
if __name__ == '__main__':
|
365 |
+
tf.test.main()
|
models/official/benchmark/bert_benchmark_utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Utility functions or classes shared between BERT benchmarks."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import time
|
22 |
+
|
23 |
+
# pylint: disable=g-bad-import-order
|
24 |
+
import numpy as np
|
25 |
+
from absl import flags
|
26 |
+
import tensorflow as tf
|
27 |
+
# pylint: enable=g-bad-import-order
|
28 |
+
|
29 |
+
from official.utils.flags import core as flags_core
|
30 |
+
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
|
31 |
+
|
32 |
+
FLAGS = flags.FLAGS
|
33 |
+
|
34 |
+
|
35 |
+
class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
|
36 |
+
"""Callback that records time it takes to run each batch."""
|
37 |
+
|
38 |
+
def __init__(self, num_batches_to_skip=10):
|
39 |
+
super(BenchmarkTimerCallback, self).__init__()
|
40 |
+
self.batch_start_times = {}
|
41 |
+
self.batch_stop_times = {}
|
42 |
+
|
43 |
+
def on_batch_begin(self, batch, logs=None):
|
44 |
+
self.batch_start_times[batch] = time.time()
|
45 |
+
|
46 |
+
def on_batch_end(self, batch, logs=None):
|
47 |
+
# If there are multiple steps_per_loop, the end batch index will not be the
|
48 |
+
# same as the starting index. Use the last starting index instead.
|
49 |
+
if batch not in self.batch_start_times:
|
50 |
+
batch = max(self.batch_start_times.keys())
|
51 |
+
|
52 |
+
self.batch_stop_times[batch] = time.time()
|
53 |
+
|
54 |
+
def get_examples_per_sec(self, batch_size, num_batches_to_skip=1):
|
55 |
+
batch_durations = []
|
56 |
+
for batch in self.batch_start_times:
|
57 |
+
if batch in self.batch_stop_times and batch >= num_batches_to_skip:
|
58 |
+
batch_durations.append(self.batch_stop_times[batch] -
|
59 |
+
self.batch_start_times[batch])
|
60 |
+
return batch_size / np.mean(batch_durations)
|
61 |
+
|
62 |
+
def get_startup_time(self, program_start_time):
|
63 |
+
return self.batch_start_times[0] - program_start_time
|
64 |
+
|
65 |
+
|
66 |
+
class BertBenchmarkBase(PerfZeroBenchmark):
|
67 |
+
"""Base class to hold methods common to test classes."""
|
68 |
+
local_flags = None
|
69 |
+
|
70 |
+
def __init__(self, output_dir=None, tpu=None, **kwargs):
|
71 |
+
super(BertBenchmarkBase, self).__init__(
|
72 |
+
output_dir=output_dir, tpu=tpu, **kwargs)
|
73 |
+
self.num_gpus = 8
|
74 |
+
self.timer_callback = None
|
75 |
+
|
76 |
+
def _setup(self):
|
77 |
+
"""Sets up and resets flags before each test."""
|
78 |
+
super(BertBenchmarkBase, self)._setup()
|
79 |
+
self.timer_callback = BenchmarkTimerCallback()
|
80 |
+
|
81 |
+
def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
|
82 |
+
"""Report benchmark results by writing to local protobuf file.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
stats: dict returned from BERT models with known entries.
|
86 |
+
wall_time_sec: the during of the benchmark execution in seconds
|
87 |
+
min_accuracy: Minimum classification accuracy constraint to verify
|
88 |
+
correctness of the model.
|
89 |
+
max_accuracy: Maximum classification accuracy constraint to verify
|
90 |
+
correctness of the model.
|
91 |
+
"""
|
92 |
+
metrics = [{
|
93 |
+
'name': 'training_loss',
|
94 |
+
'value': stats['train_loss'],
|
95 |
+
}]
|
96 |
+
if self.timer_callback:
|
97 |
+
metrics.append({
|
98 |
+
'name':
|
99 |
+
'exp_per_second',
|
100 |
+
'value':
|
101 |
+
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
|
102 |
+
FLAGS.steps_per_loop)
|
103 |
+
})
|
104 |
+
else:
|
105 |
+
metrics.append({
|
106 |
+
'name': 'exp_per_second',
|
107 |
+
'value': 0.0,
|
108 |
+
})
|
109 |
+
if self.timer_callback and 'start_time_sec' in stats:
|
110 |
+
metrics.append({
|
111 |
+
'name': 'startup_time',
|
112 |
+
'value': self.timer_callback.get_startup_time(stats['start_time_sec'])
|
113 |
+
})
|
114 |
+
|
115 |
+
if 'eval_metrics' in stats:
|
116 |
+
metrics.append({
|
117 |
+
'name': 'eval_accuracy',
|
118 |
+
'value': stats['eval_metrics'],
|
119 |
+
'min_value': min_accuracy,
|
120 |
+
'max_value': max_accuracy,
|
121 |
+
})
|
122 |
+
flags_str = flags_core.get_nondefault_flags_as_str()
|
123 |
+
self.report_benchmark(
|
124 |
+
iters=stats['total_training_steps'],
|
125 |
+
wall_time=wall_time_sec,
|
126 |
+
metrics=metrics,
|
127 |
+
extras={'flags': flags_str})
|
models/official/benchmark/bert_pretrain_benchmark.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Lint as: python3
|
2 |
+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ==============================================================================
|
16 |
+
"""Executes benchmark testing for bert pretraining."""
|
17 |
+
# pylint: disable=line-too-long
|
18 |
+
from __future__ import print_function
|
19 |
+
|
20 |
+
import json
|
21 |
+
import os
|
22 |
+
import time
|
23 |
+
from typing import Optional
|
24 |
+
|
25 |
+
from absl import flags
|
26 |
+
from absl import logging
|
27 |
+
import tensorflow as tf # pylint: disable=g-bad-import-order
|
28 |
+
|
29 |
+
from official.benchmark import benchmark_wrappers
|
30 |
+
from official.benchmark import bert_benchmark_utils
|
31 |
+
from official.benchmark import owner_utils
|
32 |
+
from official.nlp.bert import run_pretraining
|
33 |
+
from official.utils.flags import core as flags_core
|
34 |
+
from official.utils.misc import distribution_utils
|
35 |
+
|
36 |
+
# Pretrain masked lanauge modeling accuracy range:
|
37 |
+
MIN_MLM_ACCURACY = 0.635
|
38 |
+
MAX_MLM_ACCURACY = 0.645
|
39 |
+
|
40 |
+
# Pretrain next sentence prediction accuracy range:
|
41 |
+
MIN_NSP_ACCURACY = 0.94
|
42 |
+
MAX_NSP_ACCURACY = 0.96
|
43 |
+
|
44 |
+
BERT_PRETRAIN_FILES_SEQ128 = 'gs://mlcompass-data/bert/pretraining_data/seq_128/wikipedia.tfrecord*,gs://mlcompass-data/bert/pretraining_data/seq_128/books.tfrecord*'
|
45 |
+
BERT_BASE_CONFIG_FILE = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_config.json'
|
46 |
+
|
47 |
+
FLAGS = flags.FLAGS
|
48 |
+
|
49 |
+
|
50 |
+
class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
|
51 |
+
"""Benchmark accuracy tests for BERT Pretraining."""
|
52 |
+
|
53 |
+
def __init__(self,
|
54 |
+
output_dir: Optional[str] = None,
|
55 |
+
tpu: Optional[str] = None,
|
56 |
+
**kwargs):
|
57 |
+
"""Inits BertPretrainAccuracyBenchmark class.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
output_dir: Directory where to output e.g. log files
|
61 |
+
tpu: TPU name to use in a TPU benchmark.
|
62 |
+
**kwargs: Additional keyword arguments.
|
63 |
+
"""
|
64 |
+
super(BertPretrainAccuracyBenchmark, self).__init__(
|
65 |
+
output_dir=output_dir, tpu=tpu, **kwargs)
|
66 |
+
|
67 |
+
@benchmark_wrappers.enable_runtime_flags
|
68 |
+
def _run_and_report_benchmark(self, summary_path: str, report_accuracy: bool):
|
69 |
+
"""Runs and reports the benchmark given the provided configuration."""
|
70 |
+
distribution = distribution_utils.get_distribution_strategy(
|
71 |
+
distribution_strategy='tpu', tpu_address=self.tpu)
|
72 |
+
logging.info('Flags: %s', flags_core.get_nondefault_flags_as_str())
|
73 |
+
start_time_sec = time.time()
|
74 |
+
run_pretraining.run_bert_pretrain(
|
75 |
+
strategy=distribution, custom_callbacks=self.timer_callback)
|
76 |
+
wall_time_sec = time.time() - start_time_sec
|
77 |
+
|
78 |
+
with tf.io.gfile.GFile(summary_path, 'rb') as reader:
|
79 |
+
summary = json.loads(reader.read().decode('utf-8'))
|
80 |
+
self._report_benchmark(summary, start_time_sec, wall_time_sec,
|
81 |
+
report_accuracy)
|
82 |
+
|
83 |
+
def _report_benchmark(self, summary, start_time_sec, wall_time_sec,
|
84 |
+
report_accuracy):
|
85 |
+
metrics = [{
|
86 |
+
'name': 'train_loss',
|
87 |
+
'value': summary['train_loss'],
|
88 |
+
}, {
|
89 |
+
'name':
|
90 |
+
'exp_per_second',
|
91 |
+
'value':
|
92 |
+
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
|
93 |
+
FLAGS.steps_per_loop)
|
94 |
+
}, {
|
95 |
+
'name': 'startup_time',
|
96 |
+
'value': self.timer_callback.get_startup_time(start_time_sec)
|
97 |
+
}]
|
98 |
+
if report_accuracy:
|
99 |
+
metrics.extend([{
|
100 |
+
'name': 'masked_lm_accuracy',
|
101 |
+
'value': summary['masked_lm_accuracy'],
|
102 |
+
'min_value': MIN_MLM_ACCURACY,
|
103 |
+
'max_value': MAX_MLM_ACCURACY,
|
104 |
+
}, {
|
105 |
+
'name': 'next_sentence_accuracy',
|
106 |
+
'value': summary['next_sentence_accuracy'],
|
107 |
+
'min_value': MIN_NSP_ACCURACY,
|
108 |
+
'max_value': MAX_NSP_ACCURACY,
|
109 |
+
}])
|
110 |
+
self.report_benchmark(
|
111 |
+
iters=summary['total_training_steps'],
|
112 |
+
wall_time=wall_time_sec,
|
113 |
+
metrics=metrics,
|
114 |
+
extras={'flags': flags_core.get_nondefault_flags_as_str()})
|
115 |
+
|
116 |
+
def _specify_common_flags(self):
|
117 |
+
FLAGS.bert_config_file = BERT_BASE_CONFIG_FILE
|
118 |
+
FLAGS.train_batch_size = 512
|
119 |
+
FLAGS.learning_rate = 1e-4
|
120 |
+
FLAGS.warmup_steps = 10000
|
121 |
+
FLAGS.steps_per_loop = 10000
|
122 |
+
FLAGS.distribution_strategy = 'tpu'
|
123 |
+
FLAGS.input_files = BERT_PRETRAIN_FILES_SEQ128
|
124 |
+
FLAGS.max_seq_length = 128
|
125 |
+
FLAGS.max_predictions_per_seq = 20
|
126 |
+
FLAGS.dtype = 'bf16'
|
127 |
+
|
128 |
+
@owner_utils.Owner('tf-model-garden')
|
129 |
+
def benchmark_accuracy_8x8_tpu_bf16_seq128_500k_steps(self):
|
130 |
+
"""Test bert pretraining with 8x8 TPU for 500k steps."""
|
131 |
+
# This is used for accuracy test.
|
132 |
+
self._setup()
|
133 |
+
self._specify_common_flags()
|
134 |
+
FLAGS.num_steps_per_epoch = 500000
|
135 |
+
FLAGS.num_train_epochs = 1
|
136 |
+
FLAGS.model_dir = self._get_model_dir(
|
137 |
+
'benchmark_accuracy_8x8_tpu_bf16_seq128_500k_steps')
|
138 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
139 |
+
'summaries/training_summary.txt')
|
140 |
+
# Set train_summary_interval to -1 to disable training summary, because
|
141 |
+
# writing summary to gcs may fail and summaries are not needed for this
|
142 |
+
# accuracy benchmark test.
|
143 |
+
FLAGS.train_summary_interval = -1
|
144 |
+
self._run_and_report_benchmark(summary_path=summary_path,
|
145 |
+
report_accuracy=True)
|
146 |
+
|
147 |
+
@owner_utils.Owner('tf-model-garden')
|
148 |
+
def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps(self):
|
149 |
+
"""Test bert pretraining with 4x4 TPU for 10000 steps."""
|
150 |
+
self._setup()
|
151 |
+
self._specify_common_flags()
|
152 |
+
FLAGS.num_steps_per_epoch = 5000
|
153 |
+
FLAGS.num_train_epochs = 2
|
154 |
+
FLAGS.model_dir = self._get_model_dir(
|
155 |
+
'benchmark_perf_4x4_tpu_bf16_seq128_10k_steps')
|
156 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
157 |
+
'summaries/training_summary.txt')
|
158 |
+
# Disable accuracy check.
|
159 |
+
self._run_and_report_benchmark(
|
160 |
+
summary_path=summary_path, report_accuracy=False)
|
161 |
+
|
162 |
+
@owner_utils.Owner('tf-model-garden')
|
163 |
+
def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self):
|
164 |
+
"""Test bert pretraining with 8x8 TPU for 10000 steps."""
|
165 |
+
self._setup()
|
166 |
+
self._specify_common_flags()
|
167 |
+
FLAGS.num_steps_per_epoch = 5000
|
168 |
+
FLAGS.num_train_epochs = 2
|
169 |
+
FLAGS.model_dir = self._get_model_dir(
|
170 |
+
'benchmark_perf_8x8_tpu_bf16_seq128_10k_steps')
|
171 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
172 |
+
'summaries/training_summary.txt')
|
173 |
+
# Disable accuracy check.
|
174 |
+
self._run_and_report_benchmark(summary_path=summary_path,
|
175 |
+
report_accuracy=False)
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == '__main__':
|
179 |
+
tf.test.main()
|
models/official/benchmark/bert_squad_benchmark.py
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Executes BERT SQuAD benchmarks and accuracy tests."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import json
|
22 |
+
import os
|
23 |
+
import time
|
24 |
+
|
25 |
+
# pylint: disable=g-bad-import-order
|
26 |
+
from absl import flags
|
27 |
+
from absl import logging
|
28 |
+
from absl.testing import flagsaver
|
29 |
+
import tensorflow as tf
|
30 |
+
# pylint: enable=g-bad-import-order
|
31 |
+
|
32 |
+
from official.benchmark import bert_benchmark_utils as benchmark_utils
|
33 |
+
from official.benchmark import owner_utils
|
34 |
+
from official.nlp.bert import run_squad
|
35 |
+
from official.utils.misc import distribution_utils
|
36 |
+
from official.utils.misc import keras_utils
|
37 |
+
from official.benchmark import benchmark_wrappers
|
38 |
+
|
39 |
+
|
40 |
+
# pylint: disable=line-too-long
|
41 |
+
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
|
42 |
+
SQUAD_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_train.tf_record'
|
43 |
+
SQUAD_PREDICT_FILE = 'gs://tf-perfzero-data/bert/squad/dev-v1.1.json'
|
44 |
+
SQUAD_VOCAB_FILE = 'gs://tf-perfzero-data/bert/squad/vocab.txt'
|
45 |
+
SQUAD_MEDIUM_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_medium_meta_data'
|
46 |
+
SQUAD_LONG_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_long_meta_data'
|
47 |
+
SQUAD_FULL_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_full_meta_data'
|
48 |
+
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
|
49 |
+
# pylint: enable=line-too-long
|
50 |
+
|
51 |
+
TMP_DIR = os.getenv('TMPDIR')
|
52 |
+
FLAGS = flags.FLAGS
|
53 |
+
|
54 |
+
|
55 |
+
class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
|
56 |
+
"""Base class to hold methods common to test classes in the module."""
|
57 |
+
|
58 |
+
def __init__(self, output_dir=None, tpu=None):
|
59 |
+
super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir, tpu=tpu)
|
60 |
+
|
61 |
+
def _read_training_summary_from_file(self):
|
62 |
+
"""Reads the training summary from a file."""
|
63 |
+
summary_path = os.path.join(FLAGS.model_dir,
|
64 |
+
'summaries/training_summary.txt')
|
65 |
+
with tf.io.gfile.GFile(summary_path, 'rb') as reader:
|
66 |
+
return json.loads(reader.read().decode('utf-8'))
|
67 |
+
|
68 |
+
def _read_input_meta_data_from_file(self):
|
69 |
+
"""Reads the input metadata from a file."""
|
70 |
+
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
|
71 |
+
return json.loads(reader.read().decode('utf-8'))
|
72 |
+
|
73 |
+
def _get_distribution_strategy(self, ds_type='mirrored'):
|
74 |
+
"""Gets the distribution strategy.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
ds_type: String, the distribution strategy type to be used. Can be
|
78 |
+
'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
A `tf.distribute.DistibutionStrategy` object.
|
82 |
+
"""
|
83 |
+
if self.tpu or ds_type == 'tpu':
|
84 |
+
return distribution_utils.get_distribution_strategy(
|
85 |
+
distribution_strategy='tpu', tpu_address=self.tpu)
|
86 |
+
elif ds_type == 'multi_worker_mirrored':
|
87 |
+
# Configures cluster spec for multi-worker distribution strategy.
|
88 |
+
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
|
89 |
+
FLAGS.task_index)
|
90 |
+
return distribution_utils.get_distribution_strategy(
|
91 |
+
distribution_strategy=ds_type,
|
92 |
+
num_gpus=self.num_gpus,
|
93 |
+
all_reduce_alg=FLAGS.all_reduce_alg)
|
94 |
+
|
95 |
+
def _init_gpu_and_data_threads(self):
|
96 |
+
"""Set env variables before any TF calls."""
|
97 |
+
if FLAGS.tf_gpu_thread_mode:
|
98 |
+
keras_utils.set_gpu_thread_mode_and_count(
|
99 |
+
per_gpu_thread_count=FLAGS.per_gpu_thread_count,
|
100 |
+
gpu_thread_mode=FLAGS.tf_gpu_thread_mode,
|
101 |
+
num_gpus=self.num_gpus,
|
102 |
+
datasets_num_private_threads=FLAGS.datasets_num_private_threads)
|
103 |
+
|
104 |
+
@flagsaver.flagsaver
|
105 |
+
def _train_squad(self, run_eagerly=False, ds_type='mirrored'):
|
106 |
+
"""Runs BERT SQuAD training. Uses mirrored strategy by default."""
|
107 |
+
self._init_gpu_and_data_threads()
|
108 |
+
input_meta_data = self._read_input_meta_data_from_file()
|
109 |
+
strategy = self._get_distribution_strategy(ds_type)
|
110 |
+
|
111 |
+
run_squad.train_squad(
|
112 |
+
strategy=strategy,
|
113 |
+
input_meta_data=input_meta_data,
|
114 |
+
run_eagerly=run_eagerly,
|
115 |
+
custom_callbacks=[self.timer_callback])
|
116 |
+
|
117 |
+
@flagsaver.flagsaver
|
118 |
+
def _evaluate_squad(self, ds_type='mirrored'):
|
119 |
+
"""Runs BERT SQuAD evaluation. Uses mirrored strategy by default."""
|
120 |
+
self._init_gpu_and_data_threads()
|
121 |
+
input_meta_data = self._read_input_meta_data_from_file()
|
122 |
+
strategy = self._get_distribution_strategy(ds_type)
|
123 |
+
|
124 |
+
if input_meta_data.get('version_2_with_negative', False):
|
125 |
+
logging.error('In memory evaluation result for SQuAD v2 is not accurate')
|
126 |
+
eval_metrics = run_squad.eval_squad(strategy=strategy,
|
127 |
+
input_meta_data=input_meta_data)
|
128 |
+
# Use F1 score as reported evaluation metric.
|
129 |
+
self.eval_metrics = eval_metrics['final_f1']
|
130 |
+
|
131 |
+
|
132 |
+
class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
|
133 |
+
"""Short benchmark performance tests for BERT SQuAD model.
|
134 |
+
|
135 |
+
Tests BERT SQuAD performance in different GPU configurations.
|
136 |
+
The naming convention of below test cases follow
|
137 |
+
`benchmark_(number of gpus)_gpu` format for GPUs and
|
138 |
+
`benchmark_(topology)_tpu` format for TPUs.
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
|
142 |
+
super(BertSquadBenchmarkReal, self).__init__(output_dir=output_dir, tpu=tpu)
|
143 |
+
|
144 |
+
def _setup(self):
|
145 |
+
"""Sets up the benchmark and SQuAD flags."""
|
146 |
+
super(BertSquadBenchmarkReal, self)._setup()
|
147 |
+
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
|
148 |
+
FLAGS.predict_file = SQUAD_PREDICT_FILE
|
149 |
+
FLAGS.vocab_file = SQUAD_VOCAB_FILE
|
150 |
+
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
|
151 |
+
FLAGS.num_train_epochs = 1
|
152 |
+
FLAGS.steps_per_loop = 100
|
153 |
+
|
154 |
+
@benchmark_wrappers.enable_runtime_flags
|
155 |
+
def _run_and_report_benchmark(self,
|
156 |
+
run_eagerly=False,
|
157 |
+
ds_type='mirrored'):
|
158 |
+
"""Runs the benchmark and reports various metrics."""
|
159 |
+
if FLAGS.train_batch_size <= 4 or run_eagerly:
|
160 |
+
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
|
161 |
+
else:
|
162 |
+
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
|
163 |
+
start_time_sec = time.time()
|
164 |
+
self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
|
165 |
+
wall_time_sec = time.time() - start_time_sec
|
166 |
+
|
167 |
+
summary = self._read_training_summary_from_file()
|
168 |
+
summary['start_time_sec'] = start_time_sec
|
169 |
+
|
170 |
+
super(BertSquadBenchmarkReal, self)._report_benchmark(
|
171 |
+
stats=summary,
|
172 |
+
wall_time_sec=wall_time_sec,
|
173 |
+
min_accuracy=0,
|
174 |
+
max_accuracy=1)
|
175 |
+
|
176 |
+
def benchmark_1_gpu(self):
|
177 |
+
"""Tests BERT SQuAD model performance with 1 GPU."""
|
178 |
+
|
179 |
+
self._setup()
|
180 |
+
self.num_gpus = 1
|
181 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad')
|
182 |
+
FLAGS.train_batch_size = 4
|
183 |
+
|
184 |
+
self._run_and_report_benchmark()
|
185 |
+
|
186 |
+
def benchmark_1_gpu_eager(self):
|
187 |
+
"""Tests BERT SQuAD model performance with 1 GPU."""
|
188 |
+
|
189 |
+
self._setup()
|
190 |
+
self.num_gpus = 1
|
191 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
|
192 |
+
FLAGS.train_batch_size = 2
|
193 |
+
|
194 |
+
self._run_and_report_benchmark(run_eagerly=True)
|
195 |
+
|
196 |
+
def benchmark_1_gpu_xla(self):
|
197 |
+
"""Tests BERT SQuAD model performance with 1 GPU with XLA."""
|
198 |
+
|
199 |
+
self._setup()
|
200 |
+
self.num_gpus = 1
|
201 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad')
|
202 |
+
# XLA runs out of memory when running with batch size 4.
|
203 |
+
FLAGS.train_batch_size = 3
|
204 |
+
FLAGS.enable_xla = True
|
205 |
+
|
206 |
+
self._run_and_report_benchmark()
|
207 |
+
|
208 |
+
def benchmark_1_gpu_no_dist_strat(self):
|
209 |
+
"""Tests BERT SQuAD model performance with 1 GPU without DS."""
|
210 |
+
|
211 |
+
self._setup()
|
212 |
+
self.num_gpus = 1
|
213 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat_squad')
|
214 |
+
FLAGS.train_batch_size = 4
|
215 |
+
|
216 |
+
self._run_and_report_benchmark(ds_type='off')
|
217 |
+
|
218 |
+
def benchmark_1_gpu_eager_no_dist_strat(self):
|
219 |
+
"""Tests BERT SQuAD model performance with 1 GPU with eager execution."""
|
220 |
+
|
221 |
+
self._setup()
|
222 |
+
self.num_gpus = 1
|
223 |
+
FLAGS.model_dir = self._get_model_dir(
|
224 |
+
'benchmark_1_gpu_eager_no_dist_strat_squad')
|
225 |
+
FLAGS.train_batch_size = 4
|
226 |
+
|
227 |
+
self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
|
228 |
+
|
229 |
+
@owner_utils.Owner('tf-model-garden')
|
230 |
+
def benchmark_8_gpu(self):
|
231 |
+
"""Tests BERT SQuAD model performance with 8 GPUs."""
|
232 |
+
|
233 |
+
self._setup()
|
234 |
+
self.num_gpus = 8
|
235 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
|
236 |
+
FLAGS.train_batch_size = 24
|
237 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
238 |
+
|
239 |
+
self._run_and_report_benchmark()
|
240 |
+
|
241 |
+
def benchmark_1_gpu_fp16_eager(self):
|
242 |
+
"""Tests BERT SQuAD model performance with 1 GPU and FP16."""
|
243 |
+
|
244 |
+
self._setup()
|
245 |
+
self.num_gpus = 1
|
246 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_fp16_eager')
|
247 |
+
FLAGS.train_batch_size = 4
|
248 |
+
FLAGS.dtype = 'fp16'
|
249 |
+
FLAGS.loss_scale = 'dynamic'
|
250 |
+
|
251 |
+
self._run_and_report_benchmark(run_eagerly=True)
|
252 |
+
|
253 |
+
def benchmark_1_gpu_fp16(self):
|
254 |
+
"""Tests BERT SQuAD model performance with 1 GPU and FP16."""
|
255 |
+
|
256 |
+
self._setup()
|
257 |
+
self.num_gpus = 1
|
258 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_fp16')
|
259 |
+
FLAGS.train_batch_size = 4
|
260 |
+
FLAGS.dtype = 'fp16'
|
261 |
+
FLAGS.loss_scale = 'dynamic'
|
262 |
+
|
263 |
+
self._run_and_report_benchmark()
|
264 |
+
|
265 |
+
def benchmark_1_gpu_xla_fp16(self):
|
266 |
+
"""Tests BERT SQuAD model performance with 1 GPU with XLA and FP16."""
|
267 |
+
|
268 |
+
self._setup()
|
269 |
+
self.num_gpus = 1
|
270 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad_fp16')
|
271 |
+
FLAGS.train_batch_size = 4
|
272 |
+
FLAGS.enable_xla = True
|
273 |
+
FLAGS.dtype = 'fp16'
|
274 |
+
FLAGS.loss_scale = 'dynamic'
|
275 |
+
|
276 |
+
self._run_and_report_benchmark()
|
277 |
+
|
278 |
+
def benchmark_8_gpu_fp16(self):
|
279 |
+
"""Tests BERT SQuAD model performance with 8 GPUs."""
|
280 |
+
|
281 |
+
self._setup()
|
282 |
+
self.num_gpus = 8
|
283 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
|
284 |
+
FLAGS.train_batch_size = 32
|
285 |
+
FLAGS.dtype = 'fp16'
|
286 |
+
FLAGS.loss_scale = 'dynamic'
|
287 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
288 |
+
|
289 |
+
self._run_and_report_benchmark()
|
290 |
+
|
291 |
+
def benchmark_8_gpu_xla_fp16(self):
|
292 |
+
"""Tests BERT SQuAD model performance with 8 GPUs with XLA."""
|
293 |
+
|
294 |
+
self._setup()
|
295 |
+
self.num_gpus = 8
|
296 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
|
297 |
+
FLAGS.train_batch_size = 32
|
298 |
+
FLAGS.enable_xla = True
|
299 |
+
FLAGS.dtype = 'fp16'
|
300 |
+
FLAGS.loss_scale = 'dynamic'
|
301 |
+
|
302 |
+
self._run_and_report_benchmark()
|
303 |
+
|
304 |
+
def benchmark_1_gpu_amp(self):
|
305 |
+
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
|
306 |
+
|
307 |
+
self._setup()
|
308 |
+
self.num_gpus = 1
|
309 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp_squad')
|
310 |
+
FLAGS.train_batch_size = 4
|
311 |
+
FLAGS.dtype = 'fp16'
|
312 |
+
FLAGS.fp16_implementation = 'graph_rewrite'
|
313 |
+
|
314 |
+
self._run_and_report_benchmark()
|
315 |
+
|
316 |
+
def benchmark_8_gpu_amp(self):
|
317 |
+
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
|
318 |
+
|
319 |
+
self._setup()
|
320 |
+
self.num_gpus = 8
|
321 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_squad')
|
322 |
+
FLAGS.train_batch_size = 32
|
323 |
+
FLAGS.dtype = 'fp16'
|
324 |
+
FLAGS.fp16_implementation = 'graph_rewrite'
|
325 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
326 |
+
|
327 |
+
self._run_and_report_benchmark()
|
328 |
+
|
329 |
+
@owner_utils.Owner('tf-model-garden')
|
330 |
+
def benchmark_2x2_tpu(self):
|
331 |
+
"""Tests BERT SQuAD model performance with 2x2 TPU."""
|
332 |
+
|
333 |
+
self._setup()
|
334 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
|
335 |
+
FLAGS.train_batch_size = 48
|
336 |
+
FLAGS.predict_batch_size = 48
|
337 |
+
FLAGS.mode = 'train'
|
338 |
+
FLAGS.learning_rate = 8e-5
|
339 |
+
FLAGS.num_train_epochs = 1
|
340 |
+
FLAGS.steps_per_loop = 100
|
341 |
+
FLAGS.do_lower_case = True
|
342 |
+
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
|
343 |
+
self._run_and_report_benchmark()
|
344 |
+
|
345 |
+
|
346 |
+
class BertSquadAccuracy(BertSquadBenchmarkBase):
|
347 |
+
"""Short accuracy test for BERT SQuAD model.
|
348 |
+
|
349 |
+
Tests BERT SQuAD accuracy. The naming convention of below test cases follow
|
350 |
+
`benchmark_(number of gpus)_gpu` format for GPUs and
|
351 |
+
`benchmark_(topology)_tpu` format for TPUs.
|
352 |
+
"""
|
353 |
+
|
354 |
+
def __init__(self, output_dir=None, tpu=None, **kwargs):
|
355 |
+
super(BertSquadAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
|
356 |
+
|
357 |
+
def _setup(self):
|
358 |
+
"""Sets up the benchmark and SQuAD flags."""
|
359 |
+
super(BertSquadAccuracy, self)._setup()
|
360 |
+
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
|
361 |
+
FLAGS.predict_file = SQUAD_PREDICT_FILE
|
362 |
+
FLAGS.vocab_file = SQUAD_VOCAB_FILE
|
363 |
+
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
|
364 |
+
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
|
365 |
+
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
|
366 |
+
FLAGS.num_train_epochs = 2
|
367 |
+
FLAGS.steps_per_loop = 100
|
368 |
+
|
369 |
+
@benchmark_wrappers.enable_runtime_flags
|
370 |
+
def _run_and_report_benchmark(self,
|
371 |
+
run_eagerly=False,
|
372 |
+
ds_type='mirrored'):
|
373 |
+
"""Runs the benchmark and reports various metrics."""
|
374 |
+
start_time_sec = time.time()
|
375 |
+
self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
|
376 |
+
self._evaluate_squad(ds_type=ds_type)
|
377 |
+
wall_time_sec = time.time() - start_time_sec
|
378 |
+
|
379 |
+
summary = self._read_training_summary_from_file()
|
380 |
+
summary['eval_metrics'] = self.eval_metrics
|
381 |
+
summary['start_time_sec'] = start_time_sec
|
382 |
+
|
383 |
+
super(BertSquadAccuracy, self)._report_benchmark(
|
384 |
+
stats=summary,
|
385 |
+
wall_time_sec=wall_time_sec,
|
386 |
+
min_accuracy=0.900,
|
387 |
+
max_accuracy=0.920)
|
388 |
+
|
389 |
+
def benchmark_1_gpu_eager(self):
|
390 |
+
"""Tests BERT SQuAD model accuracy with 1 GPU with eager execution."""
|
391 |
+
|
392 |
+
self._setup()
|
393 |
+
self.num_gpus = 1
|
394 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
|
395 |
+
FLAGS.train_batch_size = 4
|
396 |
+
|
397 |
+
self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
|
398 |
+
|
399 |
+
@owner_utils.Owner('tf-model-garden')
|
400 |
+
def benchmark_8_gpu(self):
|
401 |
+
"""Tests BERT SQuAD model accuracy with 8 GPUs."""
|
402 |
+
|
403 |
+
self._setup()
|
404 |
+
self.num_gpus = 8
|
405 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
|
406 |
+
FLAGS.train_batch_size = 24
|
407 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
408 |
+
|
409 |
+
self._run_and_report_benchmark()
|
410 |
+
|
411 |
+
def benchmark_8_gpu_fp16(self):
|
412 |
+
"""Tests BERT SQuAD model accuracy with 8 GPUs and FP16."""
|
413 |
+
|
414 |
+
self._setup()
|
415 |
+
self.num_gpus = 8
|
416 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
|
417 |
+
FLAGS.train_batch_size = 32
|
418 |
+
FLAGS.dtype = 'fp16'
|
419 |
+
FLAGS.loss_scale = 'dynamic'
|
420 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
421 |
+
|
422 |
+
self._run_and_report_benchmark()
|
423 |
+
|
424 |
+
def benchmark_8_gpu_xla(self):
|
425 |
+
"""Tests BERT SQuAD model accuracy with 8 GPUs."""
|
426 |
+
|
427 |
+
self._setup()
|
428 |
+
self.num_gpus = 8
|
429 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_xla')
|
430 |
+
FLAGS.train_batch_size = 32
|
431 |
+
FLAGS.enable_xla = True
|
432 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
433 |
+
|
434 |
+
self._run_and_report_benchmark()
|
435 |
+
|
436 |
+
@owner_utils.Owner('tf-model-garden')
|
437 |
+
def benchmark_2x2_tpu(self):
|
438 |
+
"""Tests BERT SQuAD model accuracy with 2x2 TPU."""
|
439 |
+
|
440 |
+
self._setup()
|
441 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
|
442 |
+
FLAGS.train_batch_size = 48
|
443 |
+
|
444 |
+
self._run_and_report_benchmark()
|
445 |
+
|
446 |
+
|
447 |
+
class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
|
448 |
+
"""BERT SQuAD distributed accuracy tests with multiple workers."""
|
449 |
+
|
450 |
+
def __init__(self, output_dir=None, tpu=None, **kwargs):
|
451 |
+
super(BertSquadMultiWorkerAccuracy, self).__init__(
|
452 |
+
output_dir=output_dir, tpu=tpu)
|
453 |
+
|
454 |
+
def _setup(self):
|
455 |
+
"""Sets up the benchmark and SQuAD flags."""
|
456 |
+
super(BertSquadMultiWorkerAccuracy, self)._setup()
|
457 |
+
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
|
458 |
+
FLAGS.predict_file = SQUAD_PREDICT_FILE
|
459 |
+
FLAGS.vocab_file = SQUAD_VOCAB_FILE
|
460 |
+
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
|
461 |
+
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
|
462 |
+
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
|
463 |
+
FLAGS.num_train_epochs = 2
|
464 |
+
FLAGS.steps_per_loop = 100
|
465 |
+
|
466 |
+
@benchmark_wrappers.enable_runtime_flags
|
467 |
+
def _run_and_report_benchmark(self,
|
468 |
+
use_ds=True,
|
469 |
+
run_eagerly=False):
|
470 |
+
"""Runs the benchmark and reports various metrics."""
|
471 |
+
start_time_sec = time.time()
|
472 |
+
self._train_squad(run_eagerly=run_eagerly,
|
473 |
+
ds_type='multi_worker_mirrored')
|
474 |
+
self._evaluate_squad(ds_type='multi_worker_mirrored')
|
475 |
+
wall_time_sec = time.time() - start_time_sec
|
476 |
+
|
477 |
+
summary = self._read_training_summary_from_file()
|
478 |
+
summary['eval_metrics'] = self.eval_metrics
|
479 |
+
|
480 |
+
super(BertSquadMultiWorkerAccuracy, self)._report_benchmark(
|
481 |
+
stats=summary,
|
482 |
+
wall_time_sec=wall_time_sec,
|
483 |
+
min_accuracy=0.900,
|
484 |
+
max_accuracy=0.920)
|
485 |
+
|
486 |
+
def _benchmark_common(self, num_workers, all_reduce_alg):
|
487 |
+
"""Common to all benchmarks in this class."""
|
488 |
+
self._setup()
|
489 |
+
|
490 |
+
num_gpus = 8
|
491 |
+
FLAGS.num_gpus = num_gpus
|
492 |
+
FLAGS.dtype = 'fp16'
|
493 |
+
FLAGS.enable_xla = False
|
494 |
+
FLAGS.distribution_strategy = 'multi_worker_mirrored'
|
495 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
496 |
+
FLAGS.datasets_num_private_threads = 32
|
497 |
+
FLAGS.model_dir = self._get_model_dir(
|
498 |
+
'benchmark_8_gpu_{}_worker_fp16_{}_tweaked'.format(
|
499 |
+
num_workers, all_reduce_alg))
|
500 |
+
FLAGS.train_batch_size = 4 * num_gpus * num_workers
|
501 |
+
FLAGS.all_reduce_alg = all_reduce_alg
|
502 |
+
|
503 |
+
self._run_and_report_benchmark()
|
504 |
+
|
505 |
+
def benchmark_eager_8_gpu_2_workers_fp16_ring_tweaked(self):
|
506 |
+
"""8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
|
507 |
+
self._benchmark_common(num_workers=2, all_reduce_alg='ring')
|
508 |
+
|
509 |
+
def benchmark_eager_8_gpu_2_workers_fp16_nccl_tweaked(self):
|
510 |
+
"""8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
|
511 |
+
self._benchmark_common(num_workers=2, all_reduce_alg='nccl')
|
512 |
+
|
513 |
+
def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
|
514 |
+
"""8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
|
515 |
+
self._benchmark_common(num_workers=8, all_reduce_alg='ring')
|
516 |
+
|
517 |
+
def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
|
518 |
+
"""8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
|
519 |
+
self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
|
520 |
+
|
521 |
+
|
522 |
+
class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
|
523 |
+
"""BERT SQuAD distributed benchmark tests with multiple workers."""
|
524 |
+
|
525 |
+
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
|
526 |
+
super(BertSquadMultiWorkerBenchmark, self).__init__(
|
527 |
+
output_dir=output_dir, tpu=tpu)
|
528 |
+
|
529 |
+
def _setup(self):
|
530 |
+
"""Sets up the benchmark and SQuAD flags."""
|
531 |
+
super(BertSquadMultiWorkerBenchmark, self)._setup()
|
532 |
+
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
|
533 |
+
FLAGS.predict_file = SQUAD_PREDICT_FILE
|
534 |
+
FLAGS.vocab_file = SQUAD_VOCAB_FILE
|
535 |
+
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
|
536 |
+
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
|
537 |
+
FLAGS.num_train_epochs = 1
|
538 |
+
FLAGS.steps_per_loop = 100
|
539 |
+
|
540 |
+
@benchmark_wrappers.enable_runtime_flags
|
541 |
+
def _run_and_report_benchmark(self,
|
542 |
+
use_ds=True,
|
543 |
+
run_eagerly=False):
|
544 |
+
"""Runs the benchmark and reports various metrics."""
|
545 |
+
if FLAGS.train_batch_size <= 4 * 8:
|
546 |
+
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
|
547 |
+
else:
|
548 |
+
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
|
549 |
+
start_time_sec = time.time()
|
550 |
+
self._train_squad(run_eagerly=run_eagerly,
|
551 |
+
ds_type='multi_worker_mirrored')
|
552 |
+
wall_time_sec = time.time() - start_time_sec
|
553 |
+
|
554 |
+
summary = self._read_training_summary_from_file()
|
555 |
+
summary['start_time_sec'] = start_time_sec
|
556 |
+
|
557 |
+
super(BertSquadMultiWorkerBenchmark, self)._report_benchmark(
|
558 |
+
stats=summary,
|
559 |
+
wall_time_sec=wall_time_sec,
|
560 |
+
min_accuracy=0,
|
561 |
+
max_accuracy=1)
|
562 |
+
|
563 |
+
def _benchmark_common(self, num_workers, all_reduce_alg):
|
564 |
+
"""Common to all benchmarks in this class."""
|
565 |
+
self._setup()
|
566 |
+
|
567 |
+
num_gpus = 8
|
568 |
+
FLAGS.num_gpus = num_gpus
|
569 |
+
FLAGS.dtype = 'fp16'
|
570 |
+
FLAGS.enable_xla = False
|
571 |
+
FLAGS.distribution_strategy = 'multi_worker_mirrored'
|
572 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
573 |
+
FLAGS.datasets_num_private_threads = 32
|
574 |
+
FLAGS.model_dir = self._get_model_dir(
|
575 |
+
'benchmark_8_gpu_{}_worker_fp16_{}_tweaked'.format(
|
576 |
+
num_workers, all_reduce_alg))
|
577 |
+
FLAGS.train_batch_size = 4 * num_gpus * num_workers
|
578 |
+
FLAGS.all_reduce_alg = all_reduce_alg
|
579 |
+
|
580 |
+
self._run_and_report_benchmark()
|
581 |
+
|
582 |
+
def benchmark_8_gpu_1_worker_fp16_ring_tweaked(self):
|
583 |
+
"""8 GPUs per worker, 1 worker, fp16, ring all-reduce."""
|
584 |
+
self._benchmark_common(num_workers=1, all_reduce_alg='ring')
|
585 |
+
|
586 |
+
def benchmark_8_gpu_1_worker_fp16_nccl_tweaked(self):
|
587 |
+
"""8 GPUs per worker, 1 worker, fp16, nccl all-reduce."""
|
588 |
+
self._benchmark_common(num_workers=1, all_reduce_alg='nccl')
|
589 |
+
|
590 |
+
def benchmark_8_gpu_2_workers_fp16_ring_tweaked(self):
|
591 |
+
"""8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
|
592 |
+
self._benchmark_common(num_workers=2, all_reduce_alg='ring')
|
593 |
+
|
594 |
+
def benchmark_8_gpu_2_workers_fp16_nccl_tweaked(self):
|
595 |
+
"""8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
|
596 |
+
self._benchmark_common(num_workers=2, all_reduce_alg='nccl')
|
597 |
+
|
598 |
+
def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
|
599 |
+
"""8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
|
600 |
+
self._benchmark_common(num_workers=8, all_reduce_alg='ring')
|
601 |
+
|
602 |
+
def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
|
603 |
+
"""8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
|
604 |
+
self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
|
605 |
+
|
606 |
+
|
607 |
+
if __name__ == '__main__':
|
608 |
+
tf.test.main()
|
models/official/benchmark/datastore/schema/benchmark_metric.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"description": "The ID of the benchmark run, where this metric should tie to.",
|
4 |
+
"mode": "REQUIRED",
|
5 |
+
"name": "run_id",
|
6 |
+
"type": "STRING"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"description": "The name of the metric, which should be descriptive. E.g. training_loss, accuracy.",
|
10 |
+
"mode": "REQUIRED",
|
11 |
+
"name": "name",
|
12 |
+
"type": "STRING"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"description": "The unit of the metric. E.g. MB per sec.",
|
16 |
+
"mode": "NULLABLE",
|
17 |
+
"name": "unit",
|
18 |
+
"type": "STRING"
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"description": "The value of the metric.",
|
22 |
+
"mode": "NULLABLE",
|
23 |
+
"name": "value",
|
24 |
+
"type": "FLOAT"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"description": "The timestamp when the metric is recorded.",
|
28 |
+
"mode": "REQUIRED",
|
29 |
+
"name": "timestamp",
|
30 |
+
"type": "TIMESTAMP"
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"description": "The global step when this metric is recorded.",
|
34 |
+
"mode": "NULLABLE",
|
35 |
+
"name": "global_step",
|
36 |
+
"type": "INTEGER"
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"description": "Free format metadata for the extra information about the metric.",
|
40 |
+
"mode": "REPEATED",
|
41 |
+
"name": "extras",
|
42 |
+
"type": "RECORD",
|
43 |
+
"fields": [
|
44 |
+
{
|
45 |
+
"mode": "NULLABLE",
|
46 |
+
"name": "name",
|
47 |
+
"type": "STRING"
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"mode": "NULLABLE",
|
51 |
+
"name": "value",
|
52 |
+
"type": "STRING"
|
53 |
+
}
|
54 |
+
]
|
55 |
+
}
|
56 |
+
]
|
models/official/benchmark/datastore/schema/benchmark_run.json
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"description": "The UUID of the run for the benchmark.",
|
4 |
+
"mode": "REQUIRED",
|
5 |
+
"name": "model_id",
|
6 |
+
"type": "STRING"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"description": "The name of the model, E.g ResNet50, LeNet-5 etc.",
|
10 |
+
"mode": "REQUIRED",
|
11 |
+
"name": "model_name",
|
12 |
+
"type": "STRING"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"description": "The date when the test of the model is started",
|
16 |
+
"mode": "REQUIRED",
|
17 |
+
"name": "run_date",
|
18 |
+
"type": "TIMESTAMP"
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"description": "The unique name for a test by the combination of key parameters, eg batch size, num of GPU, etc. It is hardware independent.",
|
22 |
+
"mode": "NULLABLE",
|
23 |
+
"name": "test_id",
|
24 |
+
"type": "STRING"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"description": "The tensorflow version information.",
|
28 |
+
"fields": [
|
29 |
+
{
|
30 |
+
"description": "Version of the tensorflow. E.g. 1.7.0-rc0",
|
31 |
+
"mode": "REQUIRED",
|
32 |
+
"name": "version",
|
33 |
+
"type": "STRING"
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"description": "Git Hash of the tensorflow",
|
37 |
+
"mode": "NULLABLE",
|
38 |
+
"name": "git_hash",
|
39 |
+
"type": "STRING"
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"description": "The channel of the tensorflow binary, eg, nightly, RC, final, custom.",
|
43 |
+
"mode": "NULLABLE",
|
44 |
+
"name": "channel",
|
45 |
+
"type": "STRING"
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"description": "Identify anything special about the build, eg CUDA 10, NCCL, MKL, etc.",
|
49 |
+
"mode": "NULLABLE",
|
50 |
+
"name": "build_type",
|
51 |
+
"type": "STRING"
|
52 |
+
}
|
53 |
+
],
|
54 |
+
"mode": "REQUIRED",
|
55 |
+
"name": "tensorflow_version",
|
56 |
+
"type": "RECORD"
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"description": "The arbitrary attribute of the model.",
|
60 |
+
"fields": [
|
61 |
+
{
|
62 |
+
"description": "The name of the attribute.",
|
63 |
+
"mode": "REQUIRED",
|
64 |
+
"name": "name",
|
65 |
+
"type": "STRING"
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"description": "The value of the attribute.",
|
69 |
+
"mode": "NULLABLE",
|
70 |
+
"name": "value",
|
71 |
+
"type": "STRING"
|
72 |
+
}
|
73 |
+
],
|
74 |
+
"mode": "REPEATED",
|
75 |
+
"name": "attribute",
|
76 |
+
"type": "RECORD"
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"description": "Environment variables when the benchmark run is executed.",
|
80 |
+
"fields": [
|
81 |
+
{
|
82 |
+
"description": "The name of the variable.",
|
83 |
+
"mode": "REQUIRED",
|
84 |
+
"name": "name",
|
85 |
+
"type": "STRING"
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"description": "The value of the variable.",
|
89 |
+
"mode": "NULLABLE",
|
90 |
+
"name": "value",
|
91 |
+
"type": "STRING"
|
92 |
+
}
|
93 |
+
],
|
94 |
+
"mode": "REPEATED",
|
95 |
+
"name": "environment_variable",
|
96 |
+
"type": "RECORD"
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"description": "TF Environment variables when the benchmark run is executed.",
|
100 |
+
"fields": [
|
101 |
+
{
|
102 |
+
"description": "The name of the variable.",
|
103 |
+
"mode": "REQUIRED",
|
104 |
+
"name": "name",
|
105 |
+
"type": "STRING"
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"description": "The value of the variable.",
|
109 |
+
"mode": "NULLABLE",
|
110 |
+
"name": "value",
|
111 |
+
"type": "STRING"
|
112 |
+
}
|
113 |
+
],
|
114 |
+
"mode": "REPEATED",
|
115 |
+
"name": "tensorflow_environment_variables",
|
116 |
+
"type": "RECORD"
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"description": "The list of parameters run with the model. It could contain hyperparameters or others.",
|
120 |
+
"fields": [
|
121 |
+
{
|
122 |
+
"description": "The name of the parameter.",
|
123 |
+
"mode": "REQUIRED",
|
124 |
+
"name": "name",
|
125 |
+
"type": "STRING"
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"description": "The string value of the parameter.",
|
129 |
+
"mode": "NULLABLE",
|
130 |
+
"name": "string_value",
|
131 |
+
"type": "STRING"
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"description": "The bool value of the parameter.",
|
135 |
+
"mode": "NULLABLE",
|
136 |
+
"name": "bool_value",
|
137 |
+
"type": "STRING"
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"description": "The int/long value of the parameter.",
|
141 |
+
"mode": "NULLABLE",
|
142 |
+
"name": "long_value",
|
143 |
+
"type": "INTEGER"
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"description": "The double/float value of parameter.",
|
147 |
+
"mode": "NULLABLE",
|
148 |
+
"name": "float_value",
|
149 |
+
"type": "FLOAT"
|
150 |
+
}
|
151 |
+
],
|
152 |
+
"mode": "REPEATED",
|
153 |
+
"name": "run_parameters",
|
154 |
+
"type": "RECORD"
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"description": "The dataset that run with the benchmark.",
|
158 |
+
"mode": "NULLABLE",
|
159 |
+
"name": "dataset",
|
160 |
+
"type": "RECORD",
|
161 |
+
"fields": [
|
162 |
+
{
|
163 |
+
"description": "The name of the dataset that the model is trained/validated with. E.g ImageNet, mnist.",
|
164 |
+
"mode": "REQUIRED",
|
165 |
+
"name": "name",
|
166 |
+
"type": "STRING"
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"description": "The arbitrary attribute of the dataset.",
|
170 |
+
"fields": [
|
171 |
+
{
|
172 |
+
"description": "The name of the attribute.",
|
173 |
+
"mode": "REQUIRED",
|
174 |
+
"name": "name",
|
175 |
+
"type": "STRING"
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"description": "The value of the attribute.",
|
179 |
+
"mode": "NULLABLE",
|
180 |
+
"name": "value",
|
181 |
+
"type": "STRING"
|
182 |
+
}
|
183 |
+
],
|
184 |
+
"mode": "REPEATED",
|
185 |
+
"name": "attribute",
|
186 |
+
"type": "RECORD"
|
187 |
+
}
|
188 |
+
]
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"description": "Used to differentiate from AWS, GCE or DGX-1 at a high level",
|
192 |
+
"mode": "NULLABLE",
|
193 |
+
"name": "test_environment",
|
194 |
+
"type": "STRING"
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"description": "The machine configuration of the benchmark run.",
|
198 |
+
"mode": "NULLABLE",
|
199 |
+
"name": "machine_config",
|
200 |
+
"type": "RECORD",
|
201 |
+
"fields": [
|
202 |
+
{
|
203 |
+
"description": "The platform information of the benchmark run.",
|
204 |
+
"mode": "NULLABLE",
|
205 |
+
"name": "platform_info",
|
206 |
+
"type": "RECORD",
|
207 |
+
"fields": [
|
208 |
+
{
|
209 |
+
"description": "Eg: 64bit.",
|
210 |
+
"mode": "NULLABLE",
|
211 |
+
"name": "bits",
|
212 |
+
"type": "STRING"
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"description": "Eg: ELF.",
|
216 |
+
"mode": "NULLABLE",
|
217 |
+
"name": "linkage",
|
218 |
+
"type": "STRING"
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"description": "Eg: i386.",
|
222 |
+
"mode": "NULLABLE",
|
223 |
+
"name": "machine",
|
224 |
+
"type": "STRING"
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"description": "Eg: 3.13.0-76-generic.",
|
228 |
+
"mode": "NULLABLE",
|
229 |
+
"name": "release",
|
230 |
+
"type": "STRING"
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"description": "Eg: Linux.",
|
234 |
+
"mode": "NULLABLE",
|
235 |
+
"name": "system",
|
236 |
+
"type": "STRING"
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"description": "Eg: #120-Ubuntu SMP Mon Jan 18 15:59:10 UTC 2016.",
|
240 |
+
"mode": "NULLABLE",
|
241 |
+
"name": "version",
|
242 |
+
"type": "STRING"
|
243 |
+
}
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"description": "The CPU information of the benchmark run.",
|
248 |
+
"mode": "NULLABLE",
|
249 |
+
"name": "cpu_info",
|
250 |
+
"type": "RECORD",
|
251 |
+
"fields": [
|
252 |
+
{
|
253 |
+
"mode": "NULLABLE",
|
254 |
+
"name": "num_cores",
|
255 |
+
"type": "INTEGER"
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"mode": "NULLABLE",
|
259 |
+
"name": "num_cores_allowed",
|
260 |
+
"type": "INTEGER"
|
261 |
+
},
|
262 |
+
{
|
263 |
+
"description" : "How fast are those CPUs.",
|
264 |
+
"mode": "NULLABLE",
|
265 |
+
"name": "mhz_per_cpu",
|
266 |
+
"type": "FLOAT"
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"description" : "Additional CPU info, Eg: Intel Ivybridge with HyperThreading (24 cores).",
|
270 |
+
"mode": "NULLABLE",
|
271 |
+
"name": "cpu_info",
|
272 |
+
"type": "STRING"
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"description" : "What kind of cpu scaling is enabled on the host. Eg performance, ondemand, conservative, mixed.",
|
276 |
+
"mode": "NULLABLE",
|
277 |
+
"name": "cpu_governor",
|
278 |
+
"type": "STRING"
|
279 |
+
},
|
280 |
+
{
|
281 |
+
"description": "Cache size of the CPUs.",
|
282 |
+
"mode": "NULLABLE",
|
283 |
+
"name": "cache_size",
|
284 |
+
"type": "RECORD",
|
285 |
+
"fields": [
|
286 |
+
{
|
287 |
+
"mode": "NULLABLE",
|
288 |
+
"name": "level",
|
289 |
+
"type": "STRING"
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"mode": "NULLABLE",
|
293 |
+
"name": "size",
|
294 |
+
"type": "INTEGER"
|
295 |
+
}
|
296 |
+
]
|
297 |
+
}
|
298 |
+
]
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"mode": "NULLABLE",
|
302 |
+
"name": "gpu_info",
|
303 |
+
"type": "RECORD",
|
304 |
+
"fields": [
|
305 |
+
{
|
306 |
+
"mode": "NULLABLE",
|
307 |
+
"name": "count",
|
308 |
+
"type": "INTEGER"
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"mode": "NULLABLE",
|
312 |
+
"name": "model",
|
313 |
+
"type": "STRING"
|
314 |
+
},
|
315 |
+
{
|
316 |
+
"mode": "NULLABLE",
|
317 |
+
"name": "cuda_version",
|
318 |
+
"type": "STRING"
|
319 |
+
}
|
320 |
+
]
|
321 |
+
},
|
322 |
+
{
|
323 |
+
"description": "The cloud instance inforation if the benchmark run is executed on cloud",
|
324 |
+
"mode": "NULLABLE",
|
325 |
+
"name": "cloud_info",
|
326 |
+
"type": "RECORD",
|
327 |
+
"fields": [
|
328 |
+
{
|
329 |
+
"description": "The instance type, E.g. n1-standard-4.",
|
330 |
+
"mode": "NULLABLE",
|
331 |
+
"name": "instance_type",
|
332 |
+
"type": "STRING"
|
333 |
+
},
|
334 |
+
{
|
335 |
+
"description": "The arbitrary attribute of the cloud info.",
|
336 |
+
"fields": [
|
337 |
+
{
|
338 |
+
"description": "The name of the attribute.",
|
339 |
+
"mode": "REQUIRED",
|
340 |
+
"name": "name",
|
341 |
+
"type": "STRING"
|
342 |
+
},
|
343 |
+
{
|
344 |
+
"description": "The value of the attribute.",
|
345 |
+
"mode": "NULLABLE",
|
346 |
+
"name": "value",
|
347 |
+
"type": "STRING"
|
348 |
+
}
|
349 |
+
],
|
350 |
+
"mode": "REPEATED",
|
351 |
+
"name": "attribute",
|
352 |
+
"type": "RECORD"
|
353 |
+
}
|
354 |
+
]
|
355 |
+
},
|
356 |
+
{
|
357 |
+
"mode": "NULLABLE",
|
358 |
+
"name": "memory_total",
|
359 |
+
"type": "INTEGER"
|
360 |
+
},
|
361 |
+
{
|
362 |
+
"mode": "NULLABLE",
|
363 |
+
"name": "memory_available",
|
364 |
+
"type": "STRING"
|
365 |
+
}
|
366 |
+
]
|
367 |
+
}
|
368 |
+
]
|
models/official/benchmark/datastore/schema/benchmark_run_status.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"description": "The UUID of the run for the benchmark.",
|
4 |
+
"mode": "REQUIRED",
|
5 |
+
"name": "run_id",
|
6 |
+
"type": "STRING"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"description": "The status of the run for the benchmark. Eg, running, failed, success",
|
10 |
+
"mode": "REQUIRED",
|
11 |
+
"name": "status",
|
12 |
+
"type": "STRING"
|
13 |
+
}
|
14 |
+
]
|
models/official/benchmark/keras_benchmark.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Executes Keras benchmarks and accuracy tests."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import tensorflow as tf
|
22 |
+
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
|
23 |
+
from official.utils.flags import core as flags_core
|
24 |
+
|
25 |
+
|
26 |
+
class KerasBenchmark(PerfZeroBenchmark):
|
27 |
+
"""Base benchmark class with methods to simplify testing."""
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
output_dir=None,
|
31 |
+
default_flags=None,
|
32 |
+
flag_methods=None,
|
33 |
+
tpu=None):
|
34 |
+
super(KerasBenchmark, self).__init__(
|
35 |
+
output_dir=output_dir,
|
36 |
+
default_flags=default_flags,
|
37 |
+
flag_methods=flag_methods,
|
38 |
+
tpu=tpu)
|
39 |
+
|
40 |
+
def _report_benchmark(self,
|
41 |
+
stats,
|
42 |
+
wall_time_sec,
|
43 |
+
top_1_max=None,
|
44 |
+
top_1_min=None,
|
45 |
+
log_steps=None,
|
46 |
+
total_batch_size=None,
|
47 |
+
warmup=1,
|
48 |
+
start_time_sec=None):
|
49 |
+
"""Report benchmark results by writing to local protobuf file.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
stats: dict returned from keras models with known entries.
|
53 |
+
wall_time_sec: the during of the benchmark execution in seconds
|
54 |
+
top_1_max: highest passing level for top_1 accuracy.
|
55 |
+
top_1_min: lowest passing level for top_1 accuracy.
|
56 |
+
log_steps: How often the log was created for stats['step_timestamp_log'].
|
57 |
+
total_batch_size: Global batch-size.
|
58 |
+
warmup: number of entries in stats['step_timestamp_log'] to ignore.
|
59 |
+
start_time_sec: the start time of the program in seconds since epoch
|
60 |
+
"""
|
61 |
+
|
62 |
+
metrics = []
|
63 |
+
if 'accuracy_top_1' in stats:
|
64 |
+
metrics.append({'name': 'accuracy_top_1',
|
65 |
+
'value': stats['accuracy_top_1'],
|
66 |
+
'min_value': top_1_min,
|
67 |
+
'max_value': top_1_max})
|
68 |
+
metrics.append({'name': 'top_1_train_accuracy',
|
69 |
+
'value': stats['training_accuracy_top_1']})
|
70 |
+
|
71 |
+
if (warmup and 'step_timestamp_log' in stats and
|
72 |
+
len(stats['step_timestamp_log']) > warmup):
|
73 |
+
# first entry in the time_log is start of step 1. The rest of the
|
74 |
+
# entries are the end of each step recorded
|
75 |
+
time_log = stats['step_timestamp_log']
|
76 |
+
elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
|
77 |
+
num_examples = (
|
78 |
+
total_batch_size * log_steps * (len(time_log) - warmup - 1))
|
79 |
+
examples_per_sec = num_examples / elapsed
|
80 |
+
metrics.append({'name': 'exp_per_second',
|
81 |
+
'value': examples_per_sec})
|
82 |
+
|
83 |
+
if 'avg_exp_per_second' in stats:
|
84 |
+
metrics.append({'name': 'avg_exp_per_second',
|
85 |
+
'value': stats['avg_exp_per_second']})
|
86 |
+
|
87 |
+
if start_time_sec and 'step_timestamp_log' in stats:
|
88 |
+
time_log = stats['step_timestamp_log']
|
89 |
+
# time_log[0] is recorded at the beginning of the first step.
|
90 |
+
startup_time = time_log[0].timestamp - start_time_sec
|
91 |
+
metrics.append({'name': 'startup_time', 'value': startup_time})
|
92 |
+
|
93 |
+
flags_str = flags_core.get_nondefault_flags_as_str()
|
94 |
+
self.report_benchmark(
|
95 |
+
iters=-1,
|
96 |
+
wall_time=wall_time_sec,
|
97 |
+
metrics=metrics,
|
98 |
+
extras={'flags': flags_str})
|
models/official/benchmark/keras_cifar_benchmark.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Executes Keras benchmarks and accuracy tests."""
|
16 |
+
from __future__ import absolute_import
|
17 |
+
from __future__ import division
|
18 |
+
from __future__ import print_function
|
19 |
+
|
20 |
+
import os
|
21 |
+
import time
|
22 |
+
from absl import flags
|
23 |
+
import tensorflow as tf # pylint: disable=g-bad-import-order
|
24 |
+
|
25 |
+
from official.benchmark import keras_benchmark
|
26 |
+
from official.benchmark import benchmark_wrappers
|
27 |
+
from official.benchmark.models import resnet_cifar_main
|
28 |
+
|
29 |
+
MIN_TOP_1_ACCURACY = 0.929
|
30 |
+
MAX_TOP_1_ACCURACY = 0.938
|
31 |
+
|
32 |
+
FLAGS = flags.FLAGS
|
33 |
+
CIFAR_DATA_DIR_NAME = 'cifar-10-batches-bin'
|
34 |
+
|
35 |
+
|
36 |
+
class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
|
37 |
+
"""Accuracy tests for ResNet56 Keras CIFAR-10."""
|
38 |
+
|
39 |
+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
40 |
+
"""A benchmark class.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
output_dir: directory where to output e.g. log files
|
44 |
+
root_data_dir: directory under which to look for dataset
|
45 |
+
**kwargs: arbitrary named arguments. This is needed to make the
|
46 |
+
constructor forward compatible in case PerfZero provides more
|
47 |
+
named arguments before updating the constructor.
|
48 |
+
"""
|
49 |
+
|
50 |
+
self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
|
51 |
+
flag_methods = [resnet_cifar_main.define_cifar_flags]
|
52 |
+
|
53 |
+
super(Resnet56KerasAccuracy, self).__init__(
|
54 |
+
output_dir=output_dir, flag_methods=flag_methods)
|
55 |
+
|
56 |
+
def _setup(self):
|
57 |
+
super(Resnet56KerasAccuracy, self)._setup()
|
58 |
+
FLAGS.use_tensor_lr = False
|
59 |
+
|
60 |
+
def benchmark_graph_1_gpu(self):
|
61 |
+
"""Test keras based model with Keras fit and distribution strategies."""
|
62 |
+
self._setup()
|
63 |
+
FLAGS.num_gpus = 1
|
64 |
+
FLAGS.data_dir = self.data_dir
|
65 |
+
FLAGS.batch_size = 128
|
66 |
+
FLAGS.train_epochs = 182
|
67 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
|
68 |
+
FLAGS.dtype = 'fp32'
|
69 |
+
self._run_and_report_benchmark()
|
70 |
+
|
71 |
+
def benchmark_1_gpu(self):
|
72 |
+
"""Test keras based model with eager and distribution strategies."""
|
73 |
+
self._setup()
|
74 |
+
FLAGS.num_gpus = 1
|
75 |
+
FLAGS.data_dir = self.data_dir
|
76 |
+
FLAGS.batch_size = 128
|
77 |
+
FLAGS.train_epochs = 182
|
78 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
|
79 |
+
FLAGS.dtype = 'fp32'
|
80 |
+
FLAGS.enable_eager = True
|
81 |
+
self._run_and_report_benchmark()
|
82 |
+
|
83 |
+
def benchmark_cpu(self):
|
84 |
+
"""Test keras based model on CPU."""
|
85 |
+
self._setup()
|
86 |
+
FLAGS.num_gpus = 0
|
87 |
+
FLAGS.data_dir = self.data_dir
|
88 |
+
FLAGS.batch_size = 128
|
89 |
+
FLAGS.train_epochs = 182
|
90 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_cpu')
|
91 |
+
FLAGS.dtype = 'fp32'
|
92 |
+
FLAGS.enable_eager = True
|
93 |
+
FLAGS.data_format = 'channels_last'
|
94 |
+
self._run_and_report_benchmark()
|
95 |
+
|
96 |
+
def benchmark_cpu_no_dist_strat(self):
|
97 |
+
"""Test keras based model on CPU without distribution strategies."""
|
98 |
+
self._setup()
|
99 |
+
FLAGS.num_gpus = 0
|
100 |
+
FLAGS.data_dir = self.data_dir
|
101 |
+
FLAGS.batch_size = 128
|
102 |
+
FLAGS.train_epochs = 182
|
103 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_cpu_no_dist_strat')
|
104 |
+
FLAGS.dtype = 'fp32'
|
105 |
+
FLAGS.enable_eager = True
|
106 |
+
FLAGS.distribution_strategy = 'off'
|
107 |
+
FLAGS.data_format = 'channels_last'
|
108 |
+
self._run_and_report_benchmark()
|
109 |
+
|
110 |
+
def benchmark_cpu_no_dist_strat_run_eagerly(self):
|
111 |
+
"""Test keras based model on CPU w/forced eager and no dist_strat."""
|
112 |
+
self._setup()
|
113 |
+
FLAGS.num_gpus = 0
|
114 |
+
FLAGS.data_dir = self.data_dir
|
115 |
+
FLAGS.batch_size = 128
|
116 |
+
FLAGS.train_epochs = 182
|
117 |
+
FLAGS.model_dir = self._get_model_dir(
|
118 |
+
'benchmark_cpu_no_dist_strat_run_eagerly')
|
119 |
+
FLAGS.dtype = 'fp32'
|
120 |
+
FLAGS.enable_eager = True
|
121 |
+
FLAGS.run_eagerly = True
|
122 |
+
FLAGS.distribution_strategy = 'off'
|
123 |
+
FLAGS.data_format = 'channels_last'
|
124 |
+
self._run_and_report_benchmark()
|
125 |
+
|
126 |
+
def benchmark_1_gpu_no_dist_strat(self):
|
127 |
+
"""Test keras based model with eager and no dist strat."""
|
128 |
+
self._setup()
|
129 |
+
FLAGS.num_gpus = 1
|
130 |
+
FLAGS.data_dir = self.data_dir
|
131 |
+
FLAGS.batch_size = 128
|
132 |
+
FLAGS.train_epochs = 182
|
133 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
|
134 |
+
FLAGS.dtype = 'fp32'
|
135 |
+
FLAGS.enable_eager = True
|
136 |
+
FLAGS.distribution_strategy = 'off'
|
137 |
+
self._run_and_report_benchmark()
|
138 |
+
|
139 |
+
def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
|
140 |
+
"""Test keras based model w/forced eager and no dist_strat."""
|
141 |
+
self._setup()
|
142 |
+
FLAGS.num_gpus = 1
|
143 |
+
FLAGS.data_dir = self.data_dir
|
144 |
+
FLAGS.batch_size = 128
|
145 |
+
FLAGS.train_epochs = 182
|
146 |
+
FLAGS.model_dir = self._get_model_dir(
|
147 |
+
'benchmark_1_gpu_no_dist_strat_run_eagerly')
|
148 |
+
FLAGS.dtype = 'fp32'
|
149 |
+
FLAGS.enable_eager = True
|
150 |
+
FLAGS.run_eagerly = True
|
151 |
+
FLAGS.distribution_strategy = 'off'
|
152 |
+
self._run_and_report_benchmark()
|
153 |
+
|
154 |
+
def benchmark_graph_1_gpu_no_dist_strat(self):
|
155 |
+
"""Test keras based model with Keras fit but not distribution strategies."""
|
156 |
+
self._setup()
|
157 |
+
FLAGS.distribution_strategy = 'off'
|
158 |
+
FLAGS.num_gpus = 1
|
159 |
+
FLAGS.data_dir = self.data_dir
|
160 |
+
FLAGS.batch_size = 128
|
161 |
+
FLAGS.train_epochs = 182
|
162 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu_no_dist_strat')
|
163 |
+
FLAGS.dtype = 'fp32'
|
164 |
+
self._run_and_report_benchmark()
|
165 |
+
|
166 |
+
def benchmark_2_gpu(self):
|
167 |
+
"""Test keras based model with eager and distribution strategies."""
|
168 |
+
self._setup()
|
169 |
+
FLAGS.num_gpus = 2
|
170 |
+
FLAGS.data_dir = self.data_dir
|
171 |
+
FLAGS.batch_size = 128
|
172 |
+
FLAGS.train_epochs = 182
|
173 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu')
|
174 |
+
FLAGS.dtype = 'fp32'
|
175 |
+
FLAGS.enable_eager = True
|
176 |
+
self._run_and_report_benchmark()
|
177 |
+
|
178 |
+
def benchmark_graph_2_gpu(self):
|
179 |
+
"""Test keras based model with Keras fit and distribution strategies."""
|
180 |
+
self._setup()
|
181 |
+
FLAGS.num_gpus = 2
|
182 |
+
FLAGS.data_dir = self.data_dir
|
183 |
+
FLAGS.batch_size = 128
|
184 |
+
FLAGS.train_epochs = 182
|
185 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
|
186 |
+
FLAGS.dtype = 'fp32'
|
187 |
+
self._run_and_report_benchmark()
|
188 |
+
|
189 |
+
@benchmark_wrappers.enable_runtime_flags
|
190 |
+
def _run_and_report_benchmark(self):
|
191 |
+
start_time_sec = time.time()
|
192 |
+
stats = resnet_cifar_main.run(FLAGS)
|
193 |
+
wall_time_sec = time.time() - start_time_sec
|
194 |
+
|
195 |
+
super(Resnet56KerasAccuracy, self)._report_benchmark(
|
196 |
+
stats,
|
197 |
+
wall_time_sec,
|
198 |
+
top_1_min=MIN_TOP_1_ACCURACY,
|
199 |
+
top_1_max=MAX_TOP_1_ACCURACY,
|
200 |
+
total_batch_size=FLAGS.batch_size,
|
201 |
+
log_steps=100)
|
202 |
+
|
203 |
+
|
204 |
+
class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
|
205 |
+
"""Short performance tests for ResNet56 via Keras and CIFAR-10."""
|
206 |
+
|
207 |
+
def __init__(self, output_dir=None, default_flags=None):
|
208 |
+
flag_methods = [resnet_cifar_main.define_cifar_flags]
|
209 |
+
|
210 |
+
super(Resnet56KerasBenchmarkBase, self).__init__(
|
211 |
+
output_dir=output_dir,
|
212 |
+
flag_methods=flag_methods,
|
213 |
+
default_flags=default_flags)
|
214 |
+
|
215 |
+
@benchmark_wrappers.enable_runtime_flags
|
216 |
+
def _run_and_report_benchmark(self):
|
217 |
+
start_time_sec = time.time()
|
218 |
+
stats = resnet_cifar_main.run(FLAGS)
|
219 |
+
wall_time_sec = time.time() - start_time_sec
|
220 |
+
|
221 |
+
super(Resnet56KerasBenchmarkBase, self)._report_benchmark(
|
222 |
+
stats,
|
223 |
+
wall_time_sec,
|
224 |
+
total_batch_size=FLAGS.batch_size,
|
225 |
+
log_steps=FLAGS.log_steps)
|
226 |
+
|
227 |
+
def benchmark_1_gpu(self):
|
228 |
+
"""Test 1 gpu."""
|
229 |
+
self._setup()
|
230 |
+
FLAGS.num_gpus = 1
|
231 |
+
FLAGS.enable_eager = True
|
232 |
+
FLAGS.distribution_strategy = 'one_device'
|
233 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
|
234 |
+
FLAGS.batch_size = 128
|
235 |
+
self._run_and_report_benchmark()
|
236 |
+
|
237 |
+
def benchmark_1_gpu_xla(self):
|
238 |
+
"""Test 1 gpu with xla enabled."""
|
239 |
+
self._setup()
|
240 |
+
FLAGS.num_gpus = 1
|
241 |
+
FLAGS.enable_eager = True
|
242 |
+
FLAGS.run_eagerly = False
|
243 |
+
FLAGS.enable_xla = True
|
244 |
+
FLAGS.distribution_strategy = 'one_device'
|
245 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla')
|
246 |
+
FLAGS.batch_size = 128
|
247 |
+
self._run_and_report_benchmark()
|
248 |
+
|
249 |
+
def benchmark_graph_1_gpu(self):
|
250 |
+
"""Test 1 gpu graph."""
|
251 |
+
self._setup()
|
252 |
+
FLAGS.num_gpus = 1
|
253 |
+
FLAGS.enable_eager = False
|
254 |
+
FLAGS.run_eagerly = False
|
255 |
+
FLAGS.distribution_strategy = 'one_device'
|
256 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
|
257 |
+
FLAGS.batch_size = 128
|
258 |
+
self._run_and_report_benchmark()
|
259 |
+
|
260 |
+
def benchmark_1_gpu_no_dist_strat(self):
|
261 |
+
"""Test 1 gpu without distribution strategies."""
|
262 |
+
self._setup()
|
263 |
+
FLAGS.num_gpus = 1
|
264 |
+
FLAGS.enable_eager = True
|
265 |
+
FLAGS.distribution_strategy = 'off'
|
266 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
|
267 |
+
FLAGS.batch_size = 128
|
268 |
+
self._run_and_report_benchmark()
|
269 |
+
|
270 |
+
def benchmark_graph_1_gpu_no_dist_strat(self):
|
271 |
+
"""Test 1 gpu graph mode without distribution strategies."""
|
272 |
+
self._setup()
|
273 |
+
FLAGS.num_gpus = 1
|
274 |
+
FLAGS.enable_eager = False
|
275 |
+
FLAGS.distribution_strategy = 'off'
|
276 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu_no_dist_strat')
|
277 |
+
FLAGS.batch_size = 128
|
278 |
+
self._run_and_report_benchmark()
|
279 |
+
|
280 |
+
def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
|
281 |
+
"""Test 1 gpu without distribution strategy and forced eager."""
|
282 |
+
self._setup()
|
283 |
+
FLAGS.num_gpus = 1
|
284 |
+
FLAGS.batch_size = 128
|
285 |
+
FLAGS.model_dir = self._get_model_dir(
|
286 |
+
'benchmark_1_gpu_no_dist_strat_run_eagerly')
|
287 |
+
FLAGS.dtype = 'fp32'
|
288 |
+
FLAGS.enable_eager = True
|
289 |
+
FLAGS.run_eagerly = True
|
290 |
+
FLAGS.distribution_strategy = 'off'
|
291 |
+
self._run_and_report_benchmark()
|
292 |
+
|
293 |
+
def benchmark_2_gpu(self):
|
294 |
+
"""Test 2 gpu."""
|
295 |
+
self._setup()
|
296 |
+
FLAGS.num_gpus = 2
|
297 |
+
FLAGS.enable_eager = True
|
298 |
+
FLAGS.run_eagerly = False
|
299 |
+
FLAGS.distribution_strategy = 'mirrored'
|
300 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu')
|
301 |
+
FLAGS.batch_size = 128 * 2 # 2 GPUs
|
302 |
+
self._run_and_report_benchmark()
|
303 |
+
|
304 |
+
def benchmark_graph_2_gpu(self):
|
305 |
+
"""Test 2 gpu graph mode."""
|
306 |
+
self._setup()
|
307 |
+
FLAGS.num_gpus = 2
|
308 |
+
FLAGS.enable_eager = False
|
309 |
+
FLAGS.run_eagerly = False
|
310 |
+
FLAGS.distribution_strategy = 'mirrored'
|
311 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
|
312 |
+
FLAGS.batch_size = 128 * 2 # 2 GPUs
|
313 |
+
self._run_and_report_benchmark()
|
314 |
+
|
315 |
+
def benchmark_cpu(self):
|
316 |
+
"""Test cpu."""
|
317 |
+
self._setup()
|
318 |
+
FLAGS.num_gpus = 0
|
319 |
+
FLAGS.enable_eager = True
|
320 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_cpu')
|
321 |
+
FLAGS.batch_size = 128
|
322 |
+
FLAGS.data_format = 'channels_last'
|
323 |
+
self._run_and_report_benchmark()
|
324 |
+
|
325 |
+
def benchmark_graph_cpu(self):
|
326 |
+
"""Test cpu graph mode."""
|
327 |
+
self._setup()
|
328 |
+
FLAGS.num_gpus = 0
|
329 |
+
FLAGS.enable_eager = False
|
330 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_cpu')
|
331 |
+
FLAGS.batch_size = 128
|
332 |
+
FLAGS.data_format = 'channels_last'
|
333 |
+
self._run_and_report_benchmark()
|
334 |
+
|
335 |
+
def benchmark_cpu_no_dist_strat_run_eagerly(self):
|
336 |
+
"""Test cpu without distribution strategy and forced eager."""
|
337 |
+
self._setup()
|
338 |
+
FLAGS.num_gpus = 0
|
339 |
+
FLAGS.distribution_strategy = 'off'
|
340 |
+
FLAGS.enable_eager = True
|
341 |
+
FLAGS.run_eagerly = True
|
342 |
+
FLAGS.model_dir = self._get_model_dir(
|
343 |
+
'benchmark_cpu_no_dist_strat_run_eagerly')
|
344 |
+
FLAGS.batch_size = 128
|
345 |
+
FLAGS.data_format = 'channels_last'
|
346 |
+
self._run_and_report_benchmark()
|
347 |
+
|
348 |
+
def benchmark_cpu_no_dist_strat(self):
|
349 |
+
"""Test cpu without distribution strategies."""
|
350 |
+
self._setup()
|
351 |
+
FLAGS.num_gpus = 0
|
352 |
+
FLAGS.enable_eager = True
|
353 |
+
FLAGS.distribution_strategy = 'off'
|
354 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_cpu_no_dist_strat')
|
355 |
+
FLAGS.batch_size = 128
|
356 |
+
FLAGS.data_format = 'channels_last'
|
357 |
+
self._run_and_report_benchmark()
|
358 |
+
|
359 |
+
def benchmark_graph_cpu_no_dist_strat(self):
|
360 |
+
"""Test cpu graph mode without distribution strategies."""
|
361 |
+
self._setup()
|
362 |
+
FLAGS.num_gpus = 0
|
363 |
+
FLAGS.enable_eager = False
|
364 |
+
FLAGS.distribution_strategy = 'off'
|
365 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_cpu_no_dist_strat')
|
366 |
+
FLAGS.batch_size = 128
|
367 |
+
FLAGS.data_format = 'channels_last'
|
368 |
+
self._run_and_report_benchmark()
|
369 |
+
|
370 |
+
|
371 |
+
class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
|
372 |
+
"""Synthetic benchmarks for ResNet56 and Keras."""
|
373 |
+
|
374 |
+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
375 |
+
default_flags = {}
|
376 |
+
default_flags['skip_eval'] = True
|
377 |
+
default_flags['use_synthetic_data'] = True
|
378 |
+
default_flags['train_steps'] = 110
|
379 |
+
default_flags['log_steps'] = 10
|
380 |
+
default_flags['use_tensor_lr'] = False
|
381 |
+
|
382 |
+
super(Resnet56KerasBenchmarkSynth, self).__init__(
|
383 |
+
output_dir=output_dir, default_flags=default_flags)
|
384 |
+
|
385 |
+
|
386 |
+
class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase):
|
387 |
+
"""Real data benchmarks for ResNet56 and Keras."""
|
388 |
+
|
389 |
+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
390 |
+
default_flags = {}
|
391 |
+
default_flags['skip_eval'] = True
|
392 |
+
default_flags['data_dir'] = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
|
393 |
+
default_flags['train_steps'] = 110
|
394 |
+
default_flags['log_steps'] = 10
|
395 |
+
default_flags['use_tensor_lr'] = False
|
396 |
+
|
397 |
+
super(Resnet56KerasBenchmarkReal, self).__init__(
|
398 |
+
output_dir=output_dir, default_flags=default_flags)
|
399 |
+
|
400 |
+
|
401 |
+
if __name__ == '__main__':
|
402 |
+
tf.test.main()
|
models/official/benchmark/keras_imagenet_benchmark.py
ADDED
@@ -0,0 +1,1724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Lint as: python3
|
2 |
+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ==============================================================================
|
16 |
+
"""Executes Keras benchmarks and accuracy tests."""
|
17 |
+
# pylint: disable=line-too-long
|
18 |
+
from __future__ import print_function
|
19 |
+
|
20 |
+
import json
|
21 |
+
import os
|
22 |
+
import time
|
23 |
+
|
24 |
+
from typing import Any, MutableMapping, Optional
|
25 |
+
|
26 |
+
from absl import flags
|
27 |
+
import tensorflow as tf # pylint: disable=g-bad-import-order
|
28 |
+
|
29 |
+
from official.benchmark import benchmark_wrappers
|
30 |
+
from official.benchmark import keras_benchmark
|
31 |
+
from official.benchmark.models import resnet_imagenet_main
|
32 |
+
from official.vision.image_classification import classifier_trainer
|
33 |
+
|
34 |
+
MIN_TOP_1_ACCURACY = 0.76
|
35 |
+
MAX_TOP_1_ACCURACY = 0.77
|
36 |
+
|
37 |
+
MOBILENET_V1_MIN_TOP_1_ACCURACY = 0.65
|
38 |
+
MOBILENET_V1_MAX_TOP_1_ACCURACY = 0.68
|
39 |
+
|
40 |
+
# Range of top-1 accracies for model optimization techniques.
|
41 |
+
# Each item indicates (MIN_TOP_1_ACCURACY, MAX_TOP_1_ACCURACY).
|
42 |
+
MODEL_OPTIMIZATION_TOP_1_ACCURACY = {
|
43 |
+
'RESNET50_FINETUNE_PRUNING': (0.76, 0.77),
|
44 |
+
'MOBILENET_V1_FINETUNE_PRUNING': (0.67, 0.68),
|
45 |
+
}
|
46 |
+
|
47 |
+
FLAGS = flags.FLAGS
|
48 |
+
|
49 |
+
|
50 |
+
def _get_classifier_parameters(
|
51 |
+
num_gpus: int = 0,
|
52 |
+
builder: str = 'records',
|
53 |
+
skip_eval: bool = False,
|
54 |
+
distribution_strategy: str = 'mirrored',
|
55 |
+
per_replica_batch_size: int = 128,
|
56 |
+
epochs: int = 90,
|
57 |
+
steps: int = 0,
|
58 |
+
epochs_between_evals: int = 1,
|
59 |
+
dtype: str = 'float32',
|
60 |
+
enable_xla: bool = False,
|
61 |
+
run_eagerly: bool = False,
|
62 |
+
gpu_thread_mode: Optional[str] = None,
|
63 |
+
dataset_num_private_threads: Optional[int] = None,
|
64 |
+
loss_scale: Optional[str] = None,
|
65 |
+
report_metrics: bool = True,
|
66 |
+
batchnorm_spatial_persistent: bool = False) -> MutableMapping[str, Any]:
|
67 |
+
"""Gets classifier trainer's ResNet parameters."""
|
68 |
+
return {
|
69 |
+
'runtime': {
|
70 |
+
'num_gpus': num_gpus,
|
71 |
+
'distribution_strategy': distribution_strategy,
|
72 |
+
'run_eagerly': run_eagerly,
|
73 |
+
'enable_xla': enable_xla,
|
74 |
+
'dataset_num_private_threads': dataset_num_private_threads,
|
75 |
+
'gpu_thread_mode': gpu_thread_mode,
|
76 |
+
'loss_scale': loss_scale,
|
77 |
+
'batchnorm_spatial_persistent': batchnorm_spatial_persistent,
|
78 |
+
},
|
79 |
+
'train_dataset': {
|
80 |
+
'builder': builder,
|
81 |
+
'use_per_replica_batch_size': True,
|
82 |
+
'batch_size': per_replica_batch_size,
|
83 |
+
'image_size': 224,
|
84 |
+
'dtype': dtype,
|
85 |
+
},
|
86 |
+
'validation_dataset': {
|
87 |
+
'builder': builder,
|
88 |
+
'batch_size': per_replica_batch_size,
|
89 |
+
'use_per_replica_batch_size': True,
|
90 |
+
'image_size': 224,
|
91 |
+
'dtype': dtype,
|
92 |
+
},
|
93 |
+
'train': {
|
94 |
+
'epochs': epochs,
|
95 |
+
'steps': steps,
|
96 |
+
'callbacks': {
|
97 |
+
'enable_tensorboard': False,
|
98 |
+
'enable_checkpoint_and_export': False,
|
99 |
+
'enable_time_history': True,
|
100 |
+
},
|
101 |
+
'metrics': ['accuracy'] if report_metrics else [],
|
102 |
+
},
|
103 |
+
'model': {
|
104 |
+
'loss': {
|
105 |
+
'label_smoothing': 0.1,
|
106 |
+
},
|
107 |
+
},
|
108 |
+
'evaluation': {
|
109 |
+
'epochs_between_evals': epochs_between_evals,
|
110 |
+
'skip_eval': skip_eval,
|
111 |
+
},
|
112 |
+
}
|
113 |
+
|
114 |
+
|
115 |
+
class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
|
116 |
+
"""Benchmark accuracy tests for ResNet50 in Keras."""
|
117 |
+
|
118 |
+
def __init__(self,
|
119 |
+
output_dir: Optional[str] = None,
|
120 |
+
root_data_dir: Optional[str] = None,
|
121 |
+
**kwargs):
|
122 |
+
"""A benchmark class.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
output_dir: directory where to output e.g. log files
|
126 |
+
root_data_dir: directory under which to look for dataset
|
127 |
+
**kwargs: arbitrary named arguments. This is needed to make the
|
128 |
+
constructor forward compatible in case PerfZero provides more
|
129 |
+
named arguments before updating the constructor.
|
130 |
+
"""
|
131 |
+
|
132 |
+
flag_methods = [classifier_trainer.define_classifier_flags]
|
133 |
+
|
134 |
+
self.data_dir = os.path.join(root_data_dir, 'imagenet')
|
135 |
+
super(Resnet50KerasAccuracy, self).__init__(
|
136 |
+
output_dir=output_dir, flag_methods=flag_methods)
|
137 |
+
|
138 |
+
@benchmark_wrappers.enable_runtime_flags
|
139 |
+
def _run_and_report_benchmark(
|
140 |
+
self,
|
141 |
+
experiment_name: str,
|
142 |
+
top_1_min: float = MIN_TOP_1_ACCURACY,
|
143 |
+
top_1_max: float = MAX_TOP_1_ACCURACY,
|
144 |
+
num_gpus: int = 0,
|
145 |
+
distribution_strategy: str = 'mirrored',
|
146 |
+
per_replica_batch_size: int = 128,
|
147 |
+
epochs: int = 90,
|
148 |
+
steps: int = 0,
|
149 |
+
epochs_between_evals: int = 1,
|
150 |
+
dtype: str = 'float32',
|
151 |
+
enable_xla: bool = False,
|
152 |
+
run_eagerly: bool = False,
|
153 |
+
gpu_thread_mode: Optional[str] = None,
|
154 |
+
dataset_num_private_threads: Optional[int] = None,
|
155 |
+
loss_scale: Optional[str] = None):
|
156 |
+
"""Runs and reports the benchmark given the provided configuration."""
|
157 |
+
FLAGS.model_type = 'resnet'
|
158 |
+
FLAGS.dataset = 'imagenet'
|
159 |
+
FLAGS.mode = 'train_and_eval'
|
160 |
+
FLAGS.data_dir = self.data_dir
|
161 |
+
FLAGS.model_dir = self._get_model_dir(experiment_name)
|
162 |
+
parameters = _get_classifier_parameters(
|
163 |
+
num_gpus=num_gpus,
|
164 |
+
distribution_strategy=distribution_strategy,
|
165 |
+
per_replica_batch_size=per_replica_batch_size,
|
166 |
+
epochs=epochs,
|
167 |
+
steps=steps,
|
168 |
+
epochs_between_evals=epochs_between_evals,
|
169 |
+
dtype=dtype,
|
170 |
+
enable_xla=enable_xla,
|
171 |
+
run_eagerly=run_eagerly,
|
172 |
+
gpu_thread_mode=gpu_thread_mode,
|
173 |
+
dataset_num_private_threads=dataset_num_private_threads,
|
174 |
+
report_metrics=True,
|
175 |
+
loss_scale=loss_scale,
|
176 |
+
batchnorm_spatial_persistent=True)
|
177 |
+
FLAGS.params_override = json.dumps(parameters)
|
178 |
+
total_batch_size = num_gpus * per_replica_batch_size
|
179 |
+
|
180 |
+
start_time_sec = time.time()
|
181 |
+
stats = classifier_trainer.run(flags.FLAGS)
|
182 |
+
wall_time_sec = time.time() - start_time_sec
|
183 |
+
|
184 |
+
super(Resnet50KerasAccuracy, self)._report_benchmark(
|
185 |
+
stats,
|
186 |
+
wall_time_sec,
|
187 |
+
top_1_min=top_1_min,
|
188 |
+
top_1_max=top_1_max,
|
189 |
+
total_batch_size=total_batch_size,
|
190 |
+
log_steps=100)
|
191 |
+
|
192 |
+
def benchmark_8_gpu(self):
|
193 |
+
"""Tests Keras model with eager, dist_strat and 8 GPUs."""
|
194 |
+
self._setup()
|
195 |
+
self._run_and_report_benchmark(
|
196 |
+
experiment_name='benchmark_8_gpu',
|
197 |
+
num_gpus=8,
|
198 |
+
per_replica_batch_size=128,
|
199 |
+
epochs=90,
|
200 |
+
epochs_between_evals=10,
|
201 |
+
dtype='float32')
|
202 |
+
|
203 |
+
def benchmark_8_gpu_fp16(self):
|
204 |
+
"""Tests Keras model with eager, dist_strat, 8 GPUs, and fp16."""
|
205 |
+
self._setup()
|
206 |
+
self._run_and_report_benchmark(
|
207 |
+
experiment_name='benchmark_8_gpu_fp16',
|
208 |
+
num_gpus=8,
|
209 |
+
per_replica_batch_size=256,
|
210 |
+
epochs=90,
|
211 |
+
epochs_between_evals=10,
|
212 |
+
dtype='float16')
|
213 |
+
|
214 |
+
def benchmark_xla_8_gpu_fp16(self):
|
215 |
+
"""Tests Keras model with XLA, eager, dist_strat, 8 GPUs and fp16."""
|
216 |
+
self._setup()
|
217 |
+
self._run_and_report_benchmark(
|
218 |
+
experiment_name='benchmark_xla_8_gpu_fp16',
|
219 |
+
num_gpus=8,
|
220 |
+
per_replica_batch_size=256,
|
221 |
+
epochs=90,
|
222 |
+
epochs_between_evals=10,
|
223 |
+
dtype='float16',
|
224 |
+
enable_xla=True)
|
225 |
+
|
226 |
+
def benchmark_xla_8_gpu_fp16_dynamic(self):
|
227 |
+
"""Tests Keras model with XLA, eager, dist_strat, 8 GPUs, dynamic fp16."""
|
228 |
+
self._setup()
|
229 |
+
self._run_and_report_benchmark(
|
230 |
+
experiment_name='benchmark_xla_8_gpu_fp16_dynamic',
|
231 |
+
top_1_min=0.736,
|
232 |
+
num_gpus=8,
|
233 |
+
per_replica_batch_size=256,
|
234 |
+
epochs=90,
|
235 |
+
epochs_between_evals=10,
|
236 |
+
dtype='float16',
|
237 |
+
loss_scale='dynamic')
|
238 |
+
|
239 |
+
def _get_model_dir(self, folder_name):
|
240 |
+
return os.path.join(self.output_dir, folder_name)
|
241 |
+
|
242 |
+
|
243 |
+
class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark):
|
244 |
+
"""Benchmark accuracy tests for MobilenetV1 in Keras."""
|
245 |
+
|
246 |
+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
247 |
+
"""A benchmark class.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
output_dir: directory where to output e.g. log files
|
251 |
+
root_data_dir: directory under which to look for dataset
|
252 |
+
**kwargs: arbitrary named arguments. This is needed to make the
|
253 |
+
constructor forward compatible in case PerfZero provides more
|
254 |
+
named arguments before updating the constructor.
|
255 |
+
"""
|
256 |
+
|
257 |
+
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
|
258 |
+
|
259 |
+
self.data_dir = os.path.join(root_data_dir, 'imagenet')
|
260 |
+
super(MobilenetV1KerasAccuracy, self).__init__(
|
261 |
+
output_dir=output_dir,
|
262 |
+
flag_methods=flag_methods,
|
263 |
+
default_flags={
|
264 |
+
'model': 'mobilenet',
|
265 |
+
'optimizer': 'mobilenet_default',
|
266 |
+
'initial_learning_rate_per_sample': 0.00039,
|
267 |
+
})
|
268 |
+
|
269 |
+
def benchmark_8_gpu(self):
|
270 |
+
"""Test Keras model with eager, dist_strat and 8 GPUs."""
|
271 |
+
self._setup()
|
272 |
+
FLAGS.num_gpus = 8
|
273 |
+
FLAGS.data_dir = self.data_dir
|
274 |
+
FLAGS.batch_size = 128 * 8
|
275 |
+
FLAGS.train_epochs = 90
|
276 |
+
FLAGS.epochs_between_evals = 10
|
277 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
|
278 |
+
FLAGS.dtype = 'fp32'
|
279 |
+
FLAGS.enable_eager = True
|
280 |
+
self._run_and_report_benchmark()
|
281 |
+
|
282 |
+
@benchmark_wrappers.enable_runtime_flags
|
283 |
+
def _run_and_report_benchmark(self,
|
284 |
+
top_1_min=MOBILENET_V1_MIN_TOP_1_ACCURACY,
|
285 |
+
top_1_max=MOBILENET_V1_MAX_TOP_1_ACCURACY):
|
286 |
+
start_time_sec = time.time()
|
287 |
+
stats = resnet_imagenet_main.run(flags.FLAGS)
|
288 |
+
wall_time_sec = time.time() - start_time_sec
|
289 |
+
|
290 |
+
super(MobilenetV1KerasAccuracy, self)._report_benchmark(
|
291 |
+
stats,
|
292 |
+
wall_time_sec,
|
293 |
+
top_1_min=top_1_min,
|
294 |
+
top_1_max=top_1_max,
|
295 |
+
total_batch_size=FLAGS.batch_size,
|
296 |
+
log_steps=100)
|
297 |
+
|
298 |
+
def _get_model_dir(self, folder_name):
|
299 |
+
return os.path.join(self.output_dir, folder_name)
|
300 |
+
|
301 |
+
|
302 |
+
class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
|
303 |
+
"""Resnet50 (classifier_trainer) benchmarks."""
|
304 |
+
|
305 |
+
def __init__(self, output_dir=None, default_flags=None,
|
306 |
+
tpu=None, dataset_builder='records', train_epochs=1,
|
307 |
+
train_steps=110, data_dir=None):
|
308 |
+
flag_methods = [classifier_trainer.define_classifier_flags]
|
309 |
+
|
310 |
+
self.dataset_builder = dataset_builder
|
311 |
+
self.train_epochs = train_epochs
|
312 |
+
self.train_steps = train_steps
|
313 |
+
self.data_dir = data_dir
|
314 |
+
|
315 |
+
super(Resnet50KerasClassifierBenchmarkBase, self).__init__(
|
316 |
+
output_dir=output_dir,
|
317 |
+
flag_methods=flag_methods,
|
318 |
+
default_flags=default_flags,
|
319 |
+
tpu=tpu)
|
320 |
+
|
321 |
+
@benchmark_wrappers.enable_runtime_flags
|
322 |
+
def _run_and_report_benchmark(
|
323 |
+
self,
|
324 |
+
experiment_name: str,
|
325 |
+
skip_steps: Optional[int] = None,
|
326 |
+
top_1_min: float = MIN_TOP_1_ACCURACY,
|
327 |
+
top_1_max: float = MAX_TOP_1_ACCURACY,
|
328 |
+
num_gpus: int = 0,
|
329 |
+
num_tpus: int = 0,
|
330 |
+
distribution_strategy: str = 'mirrored',
|
331 |
+
per_replica_batch_size: int = 128,
|
332 |
+
epochs_between_evals: int = 1,
|
333 |
+
dtype: str = 'float32',
|
334 |
+
enable_xla: bool = False,
|
335 |
+
run_eagerly: bool = False,
|
336 |
+
gpu_thread_mode: Optional[str] = None,
|
337 |
+
dataset_num_private_threads: Optional[int] = None,
|
338 |
+
loss_scale: Optional[str] = None):
|
339 |
+
"""Runs and reports the benchmark given the provided configuration."""
|
340 |
+
FLAGS.model_type = 'resnet'
|
341 |
+
FLAGS.dataset = 'imagenet'
|
342 |
+
FLAGS.mode = 'train_and_eval'
|
343 |
+
FLAGS.data_dir = self.data_dir
|
344 |
+
FLAGS.model_dir = self._get_model_dir(experiment_name)
|
345 |
+
parameters = _get_classifier_parameters(
|
346 |
+
builder=self.dataset_builder,
|
347 |
+
skip_eval=True,
|
348 |
+
num_gpus=num_gpus,
|
349 |
+
distribution_strategy=distribution_strategy,
|
350 |
+
per_replica_batch_size=per_replica_batch_size,
|
351 |
+
epochs=self.train_epochs,
|
352 |
+
steps=self.train_steps,
|
353 |
+
epochs_between_evals=epochs_between_evals,
|
354 |
+
dtype=dtype,
|
355 |
+
enable_xla=enable_xla,
|
356 |
+
gpu_thread_mode=gpu_thread_mode,
|
357 |
+
dataset_num_private_threads=dataset_num_private_threads,
|
358 |
+
loss_scale=loss_scale,
|
359 |
+
report_metrics=False,
|
360 |
+
batchnorm_spatial_persistent=True)
|
361 |
+
FLAGS.params_override = json.dumps(parameters)
|
362 |
+
if distribution_strategy == 'tpu':
|
363 |
+
total_batch_size = num_tpus * per_replica_batch_size
|
364 |
+
else:
|
365 |
+
total_batch_size = num_gpus * per_replica_batch_size
|
366 |
+
|
367 |
+
start_time_sec = time.time()
|
368 |
+
stats = classifier_trainer.run(flags.FLAGS)
|
369 |
+
wall_time_sec = time.time() - start_time_sec
|
370 |
+
# Number of logged step time entries that are excluded in performance
|
371 |
+
# report. We keep results from last 100 batches, or skip the steps based on
|
372 |
+
# input skip_steps.
|
373 |
+
warmup = (skip_steps or (self.train_steps - 100)) // FLAGS.log_steps
|
374 |
+
|
375 |
+
super(Resnet50KerasClassifierBenchmarkBase, self)._report_benchmark(
|
376 |
+
stats,
|
377 |
+
wall_time_sec,
|
378 |
+
total_batch_size=total_batch_size,
|
379 |
+
log_steps=FLAGS.log_steps,
|
380 |
+
warmup=warmup,
|
381 |
+
start_time_sec=start_time_sec)
|
382 |
+
|
383 |
+
def benchmark_1_gpu_no_dist_strat(self):
|
384 |
+
"""Tests Keras model with 1 GPU, no distribution strategy."""
|
385 |
+
self._setup()
|
386 |
+
self._run_and_report_benchmark(
|
387 |
+
experiment_name='benchmark_1_gpu_no_dist_strat',
|
388 |
+
num_gpus=1,
|
389 |
+
distribution_strategy='off',
|
390 |
+
per_replica_batch_size=128)
|
391 |
+
|
392 |
+
def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
|
393 |
+
"""Tests Keras model with 1 GPU, no distribution strategy, run eagerly."""
|
394 |
+
self._setup()
|
395 |
+
self._run_and_report_benchmark(
|
396 |
+
experiment_name='benchmark_1_gpu_no_dist_strat_run_eagerly',
|
397 |
+
num_gpus=1,
|
398 |
+
run_eagerly=True,
|
399 |
+
distribution_strategy='off',
|
400 |
+
per_replica_batch_size=64)
|
401 |
+
|
402 |
+
def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16(self):
|
403 |
+
"""Tests with 1 GPU, no distribution strategy, fp16, run eagerly."""
|
404 |
+
self._setup()
|
405 |
+
self._run_and_report_benchmark(
|
406 |
+
experiment_name='benchmark_1_gpu_no_dist_strat_run_eagerly_fp16',
|
407 |
+
num_gpus=1,
|
408 |
+
run_eagerly=True,
|
409 |
+
distribution_strategy='off',
|
410 |
+
dtype='float16',
|
411 |
+
per_replica_batch_size=128)
|
412 |
+
|
413 |
+
def benchmark_1_gpu(self):
|
414 |
+
"""Tests Keras model with 1 GPU."""
|
415 |
+
self._setup()
|
416 |
+
self._run_and_report_benchmark(
|
417 |
+
experiment_name='benchmark_1_gpu',
|
418 |
+
num_gpus=1,
|
419 |
+
distribution_strategy='one_device',
|
420 |
+
per_replica_batch_size=128)
|
421 |
+
|
422 |
+
def benchmark_xla_1_gpu(self):
|
423 |
+
"""Tests Keras model with XLA and 1 GPU."""
|
424 |
+
self._setup()
|
425 |
+
self._run_and_report_benchmark(
|
426 |
+
experiment_name='benchmark_xla_1_gpu',
|
427 |
+
num_gpus=1,
|
428 |
+
enable_xla=True,
|
429 |
+
distribution_strategy='one_device',
|
430 |
+
per_replica_batch_size=128)
|
431 |
+
|
432 |
+
def benchmark_1_gpu_fp16(self):
|
433 |
+
"""Tests Keras model with 1 GPU and fp16."""
|
434 |
+
self._setup()
|
435 |
+
self._run_and_report_benchmark(
|
436 |
+
experiment_name='benchmark_1_gpu_fp16',
|
437 |
+
num_gpus=1,
|
438 |
+
distribution_strategy='one_device',
|
439 |
+
dtype='float16',
|
440 |
+
per_replica_batch_size=256)
|
441 |
+
|
442 |
+
def benchmark_1_gpu_fp16_dynamic(self):
|
443 |
+
"""Tests Keras model with 1 GPU, fp16, and dynamic loss scaling."""
|
444 |
+
self._setup()
|
445 |
+
self._run_and_report_benchmark(
|
446 |
+
experiment_name='benchmark_1_gpu_fp16_dynamic',
|
447 |
+
num_gpus=1,
|
448 |
+
distribution_strategy='one_device',
|
449 |
+
dtype='float16',
|
450 |
+
per_replica_batch_size=256,
|
451 |
+
loss_scale='dynamic')
|
452 |
+
|
453 |
+
def benchmark_xla_1_gpu_fp16(self):
|
454 |
+
"""Tests Keras model with XLA, 1 GPU and fp16."""
|
455 |
+
self._setup()
|
456 |
+
self._run_and_report_benchmark(
|
457 |
+
experiment_name='benchmark_xla_1_gpu_fp16',
|
458 |
+
num_gpus=1,
|
459 |
+
enable_xla=True,
|
460 |
+
distribution_strategy='one_device',
|
461 |
+
dtype='float16',
|
462 |
+
per_replica_batch_size=256)
|
463 |
+
|
464 |
+
def benchmark_xla_1_gpu_fp16_tweaked(self):
|
465 |
+
"""Tests Keras model with XLA, 1 GPU, fp16, and manual config tuning."""
|
466 |
+
self._setup()
|
467 |
+
self._run_and_report_benchmark(
|
468 |
+
experiment_name='benchmark_xla_1_gpu_fp16_tweaked',
|
469 |
+
num_gpus=1,
|
470 |
+
enable_xla=True,
|
471 |
+
distribution_strategy='one_device',
|
472 |
+
dtype='float16',
|
473 |
+
per_replica_batch_size=256,
|
474 |
+
gpu_thread_mode='gpu_private')
|
475 |
+
|
476 |
+
def benchmark_xla_1_gpu_fp16_dynamic(self):
|
477 |
+
"""Tests Keras model with XLA, 1 GPU, fp16, and dynamic loss scaling."""
|
478 |
+
self._setup()
|
479 |
+
self._run_and_report_benchmark(
|
480 |
+
experiment_name='benchmark_xla_1_gpu_fp16_dynamic',
|
481 |
+
num_gpus=1,
|
482 |
+
enable_xla=True,
|
483 |
+
distribution_strategy='one_device',
|
484 |
+
dtype='float16',
|
485 |
+
per_replica_batch_size=256,
|
486 |
+
loss_scale='dynamic')
|
487 |
+
|
488 |
+
def benchmark_8_gpu(self):
|
489 |
+
"""Tests Keras model with 8 GPUs."""
|
490 |
+
self._setup()
|
491 |
+
self._run_and_report_benchmark(
|
492 |
+
experiment_name='benchmark_8_gpu',
|
493 |
+
num_gpus=8,
|
494 |
+
distribution_strategy='mirrored',
|
495 |
+
per_replica_batch_size=128)
|
496 |
+
|
497 |
+
def benchmark_8_gpu_tweaked(self):
|
498 |
+
"""Tests Keras model with manual config tuning and 8 GPUs."""
|
499 |
+
self._setup()
|
500 |
+
self._run_and_report_benchmark(
|
501 |
+
experiment_name='benchmark_8_gpu_tweaked',
|
502 |
+
num_gpus=8,
|
503 |
+
distribution_strategy='mirrored',
|
504 |
+
per_replica_batch_size=128,
|
505 |
+
dataset_num_private_threads=14)
|
506 |
+
|
507 |
+
def benchmark_xla_8_gpu(self):
|
508 |
+
"""Tests Keras model with XLA and 8 GPUs."""
|
509 |
+
self._setup()
|
510 |
+
self._run_and_report_benchmark(
|
511 |
+
experiment_name='benchmark_xla_8_gpu',
|
512 |
+
num_gpus=8,
|
513 |
+
enable_xla=True,
|
514 |
+
distribution_strategy='mirrored',
|
515 |
+
per_replica_batch_size=128)
|
516 |
+
|
517 |
+
def benchmark_xla_8_gpu_tweaked(self):
|
518 |
+
"""Tests Keras model with manual config tuning, 8 GPUs, and XLA."""
|
519 |
+
self._setup()
|
520 |
+
self._run_and_report_benchmark(
|
521 |
+
experiment_name='benchmark_xla_8_gpu_tweaked',
|
522 |
+
num_gpus=8,
|
523 |
+
enable_xla=True,
|
524 |
+
distribution_strategy='mirrored',
|
525 |
+
per_replica_batch_size=128,
|
526 |
+
gpu_thread_mode='gpu_private',
|
527 |
+
dataset_num_private_threads=24)
|
528 |
+
|
529 |
+
def benchmark_8_gpu_fp16(self):
|
530 |
+
"""Tests Keras model with 8 GPUs and fp16."""
|
531 |
+
self._setup()
|
532 |
+
self._run_and_report_benchmark(
|
533 |
+
experiment_name='benchmark_8_gpu_fp16',
|
534 |
+
num_gpus=8,
|
535 |
+
dtype='float16',
|
536 |
+
distribution_strategy='mirrored',
|
537 |
+
per_replica_batch_size=256)
|
538 |
+
|
539 |
+
def benchmark_8_gpu_fp16_tweaked(self):
|
540 |
+
"""Tests Keras model with 8 GPUs, fp16, and manual config tuning."""
|
541 |
+
self._setup()
|
542 |
+
self._run_and_report_benchmark(
|
543 |
+
experiment_name='benchmark_8_gpu_fp16_tweaked',
|
544 |
+
num_gpus=8,
|
545 |
+
dtype='float16',
|
546 |
+
distribution_strategy='mirrored',
|
547 |
+
per_replica_batch_size=256,
|
548 |
+
gpu_thread_mode='gpu_private',
|
549 |
+
dataset_num_private_threads=40)
|
550 |
+
|
551 |
+
def benchmark_8_gpu_fp16_dynamic_tweaked(self):
|
552 |
+
"""Tests Keras model with 8 GPUs, fp16, dynamic loss scaling, and tuned."""
|
553 |
+
self._setup()
|
554 |
+
self._run_and_report_benchmark(
|
555 |
+
experiment_name='benchmark_8_gpu_fp16_dynamic_tweaked',
|
556 |
+
num_gpus=8,
|
557 |
+
dtype='float16',
|
558 |
+
distribution_strategy='mirrored',
|
559 |
+
per_replica_batch_size=256,
|
560 |
+
loss_scale='dynamic',
|
561 |
+
gpu_thread_mode='gpu_private',
|
562 |
+
dataset_num_private_threads=40)
|
563 |
+
|
564 |
+
def benchmark_xla_8_gpu_fp16(self):
|
565 |
+
"""Tests Keras model with XLA, 8 GPUs and fp16."""
|
566 |
+
self._setup()
|
567 |
+
self._run_and_report_benchmark(
|
568 |
+
experiment_name='benchmark_xla_8_gpu_fp16',
|
569 |
+
dtype='float16',
|
570 |
+
num_gpus=8,
|
571 |
+
enable_xla=True,
|
572 |
+
distribution_strategy='mirrored',
|
573 |
+
per_replica_batch_size=256)
|
574 |
+
|
575 |
+
def benchmark_xla_8_gpu_fp16_tweaked(self):
|
576 |
+
"""Test Keras model with manual config tuning, XLA, 8 GPUs and fp16."""
|
577 |
+
self._setup()
|
578 |
+
self._run_and_report_benchmark(
|
579 |
+
experiment_name='benchmark_xla_8_gpu_fp16_tweaked',
|
580 |
+
dtype='float16',
|
581 |
+
num_gpus=8,
|
582 |
+
enable_xla=True,
|
583 |
+
distribution_strategy='mirrored',
|
584 |
+
per_replica_batch_size=256,
|
585 |
+
gpu_thread_mode='gpu_private',
|
586 |
+
dataset_num_private_threads=48)
|
587 |
+
|
588 |
+
def benchmark_xla_8_gpu_fp16_tweaked_delay_measure(self):
|
589 |
+
"""Tests with manual config tuning, XLA, 8 GPUs and fp16.
|
590 |
+
|
591 |
+
Delay performance measurement for stable performance on 96 vCPU platforms.
|
592 |
+
"""
|
593 |
+
self._setup()
|
594 |
+
self._run_and_report_benchmark(
|
595 |
+
experiment_name='benchmark_xla_8_gpu_fp16_tweaked_delay_measure',
|
596 |
+
dtype='float16',
|
597 |
+
num_gpus=8,
|
598 |
+
enable_xla=True,
|
599 |
+
distribution_strategy='mirrored',
|
600 |
+
per_replica_batch_size=256,
|
601 |
+
gpu_thread_mode='gpu_private',
|
602 |
+
dataset_num_private_threads=48,
|
603 |
+
steps=310)
|
604 |
+
|
605 |
+
def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
|
606 |
+
"""Tests Keras model with config tuning, XLA, 8 GPUs and dynamic fp16."""
|
607 |
+
self._setup()
|
608 |
+
self._run_and_report_benchmark(
|
609 |
+
experiment_name='benchmark_xla_8_gpu_fp16_dynamic_tweaked',
|
610 |
+
dtype='float16',
|
611 |
+
num_gpus=8,
|
612 |
+
enable_xla=True,
|
613 |
+
distribution_strategy='mirrored',
|
614 |
+
per_replica_batch_size=256,
|
615 |
+
gpu_thread_mode='gpu_private',
|
616 |
+
loss_scale='dynamic',
|
617 |
+
dataset_num_private_threads=48)
|
618 |
+
|
619 |
+
def benchmark_2x2_tpu_bf16(self):
|
620 |
+
"""Test Keras model with 2x2 TPU, bf16."""
|
621 |
+
self._setup()
|
622 |
+
self._run_and_report_benchmark(
|
623 |
+
experiment_name='benchmark_2x2_tpu_bf16',
|
624 |
+
dtype='bfloat16',
|
625 |
+
num_tpus=8,
|
626 |
+
distribution_strategy='tpu',
|
627 |
+
per_replica_batch_size=128)
|
628 |
+
|
629 |
+
def benchmark_4x4_tpu_bf16(self):
|
630 |
+
"""Test Keras model with 4x4 TPU, bf16."""
|
631 |
+
self._setup()
|
632 |
+
self._run_and_report_benchmark(
|
633 |
+
experiment_name='benchmark_4x4_tpu_bf16',
|
634 |
+
dtype='bfloat16',
|
635 |
+
num_tpus=32,
|
636 |
+
distribution_strategy='tpu',
|
637 |
+
per_replica_batch_size=128)
|
638 |
+
|
639 |
+
def benchmark_8x8_tpu_bf16(self):
|
640 |
+
"""Test Keras model with 8x8 TPU, bf16."""
|
641 |
+
self._setup()
|
642 |
+
self._run_and_report_benchmark(
|
643 |
+
experiment_name='benchmark_8x8_tpu_bf16',
|
644 |
+
dtype='bfloat16',
|
645 |
+
num_tpus=128,
|
646 |
+
distribution_strategy='tpu',
|
647 |
+
per_replica_batch_size=64)
|
648 |
+
|
649 |
+
def fill_report_object(self, stats):
|
650 |
+
super(Resnet50KerasClassifierBenchmarkBase, self).fill_report_object(
|
651 |
+
stats,
|
652 |
+
total_batch_size=FLAGS.batch_size,
|
653 |
+
log_steps=FLAGS.log_steps)
|
654 |
+
|
655 |
+
|
656 |
+
class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
|
657 |
+
"""Resnet50 benchmarks."""
|
658 |
+
|
659 |
+
def __init__(self, output_dir=None, default_flags=None, tpu=None):
|
660 |
+
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
|
661 |
+
|
662 |
+
super(Resnet50KerasBenchmarkBase, self).__init__(
|
663 |
+
output_dir=output_dir,
|
664 |
+
flag_methods=flag_methods,
|
665 |
+
default_flags=default_flags,
|
666 |
+
tpu=tpu)
|
667 |
+
|
668 |
+
@benchmark_wrappers.enable_runtime_flags
|
669 |
+
def _run_and_report_benchmark(self, skip_steps=None):
|
670 |
+
start_time_sec = time.time()
|
671 |
+
stats = resnet_imagenet_main.run(FLAGS)
|
672 |
+
wall_time_sec = time.time() - start_time_sec
|
673 |
+
# Number of logged step time entries that are excluded in performance
|
674 |
+
# report. We keep results from last 100 batches, or skip the steps based on
|
675 |
+
# input skip_steps.
|
676 |
+
warmup = (skip_steps or (FLAGS.train_steps - 100)) // FLAGS.log_steps
|
677 |
+
|
678 |
+
super(Resnet50KerasBenchmarkBase, self)._report_benchmark(
|
679 |
+
stats,
|
680 |
+
wall_time_sec,
|
681 |
+
total_batch_size=FLAGS.batch_size,
|
682 |
+
log_steps=FLAGS.log_steps,
|
683 |
+
warmup=warmup,
|
684 |
+
start_time_sec=start_time_sec)
|
685 |
+
|
686 |
+
def benchmark_1_gpu_no_dist_strat(self):
|
687 |
+
"""Test Keras model with 1 GPU, no distribution strategy."""
|
688 |
+
self._setup()
|
689 |
+
|
690 |
+
FLAGS.num_gpus = 1
|
691 |
+
FLAGS.enable_eager = True
|
692 |
+
FLAGS.distribution_strategy = 'off'
|
693 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
|
694 |
+
FLAGS.batch_size = 128
|
695 |
+
self._run_and_report_benchmark()
|
696 |
+
|
697 |
+
def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
|
698 |
+
"""Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
|
699 |
+
self._setup()
|
700 |
+
|
701 |
+
FLAGS.num_gpus = 1
|
702 |
+
FLAGS.enable_eager = True
|
703 |
+
FLAGS.run_eagerly = True
|
704 |
+
FLAGS.distribution_strategy = 'off'
|
705 |
+
FLAGS.model_dir = self._get_model_dir(
|
706 |
+
'benchmark_1_gpu_no_dist_strat_run_eagerly')
|
707 |
+
FLAGS.batch_size = 64
|
708 |
+
self._run_and_report_benchmark()
|
709 |
+
|
710 |
+
def benchmark_1_gpu_no_dist_strat_run_eagerly_tweaked(self):
|
711 |
+
"""Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
|
712 |
+
self._setup()
|
713 |
+
|
714 |
+
FLAGS.num_gpus = 1
|
715 |
+
FLAGS.enable_eager = True
|
716 |
+
FLAGS.run_eagerly = True
|
717 |
+
FLAGS.explicit_gpu_placement = True
|
718 |
+
FLAGS.distribution_strategy = 'off'
|
719 |
+
FLAGS.model_dir = self._get_model_dir(
|
720 |
+
'benchmark_1_gpu_no_dist_strat_run_eagerly_tweaked')
|
721 |
+
FLAGS.batch_size = 64
|
722 |
+
self._run_and_report_benchmark()
|
723 |
+
|
724 |
+
def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16(self):
|
725 |
+
"""Test with 1 GPU, no distribution strategy, fp16, run eagerly."""
|
726 |
+
self._setup()
|
727 |
+
|
728 |
+
FLAGS.num_gpus = 1
|
729 |
+
FLAGS.enable_eager = True
|
730 |
+
FLAGS.run_eagerly = True
|
731 |
+
FLAGS.distribution_strategy = 'off'
|
732 |
+
FLAGS.model_dir = self._get_model_dir(
|
733 |
+
'benchmark_1_gpu_no_dist_strat_run_eagerly_fp16')
|
734 |
+
FLAGS.dtype = 'fp16'
|
735 |
+
FLAGS.batch_size = 128
|
736 |
+
self._run_and_report_benchmark()
|
737 |
+
|
738 |
+
def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16_tweaked(self):
|
739 |
+
"""Test with 1 GPU, no distribution strategy, fp16, run eagerly."""
|
740 |
+
self._setup()
|
741 |
+
|
742 |
+
FLAGS.num_gpus = 1
|
743 |
+
FLAGS.enable_eager = True
|
744 |
+
FLAGS.run_eagerly = True
|
745 |
+
FLAGS.explicit_gpu_placement = True
|
746 |
+
FLAGS.distribution_strategy = 'off'
|
747 |
+
FLAGS.model_dir = self._get_model_dir(
|
748 |
+
'benchmark_1_gpu_no_dist_strat_run_eagerly_fp16_tweaked')
|
749 |
+
FLAGS.dtype = 'fp16'
|
750 |
+
FLAGS.batch_size = 128
|
751 |
+
self._run_and_report_benchmark()
|
752 |
+
|
753 |
+
def benchmark_1_gpu(self):
|
754 |
+
"""Test Keras model with 1 GPU."""
|
755 |
+
self._setup()
|
756 |
+
|
757 |
+
FLAGS.num_gpus = 1
|
758 |
+
FLAGS.enable_eager = True
|
759 |
+
FLAGS.distribution_strategy = 'one_device'
|
760 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
|
761 |
+
FLAGS.batch_size = 128
|
762 |
+
self._run_and_report_benchmark()
|
763 |
+
|
764 |
+
def benchmark_1_gpu_amp(self):
|
765 |
+
"""Test Keras model with 1 GPU with automatic mixed precision."""
|
766 |
+
self._setup()
|
767 |
+
|
768 |
+
FLAGS.num_gpus = 1
|
769 |
+
FLAGS.enable_eager = True
|
770 |
+
FLAGS.dtype = 'fp16'
|
771 |
+
FLAGS.fp16_implementation = 'graph_rewrite'
|
772 |
+
FLAGS.distribution_strategy = 'one_device'
|
773 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
|
774 |
+
FLAGS.batch_size = 256
|
775 |
+
self._run_and_report_benchmark()
|
776 |
+
|
777 |
+
def benchmark_xla_1_gpu(self):
|
778 |
+
"""Test Keras model with XLA and 1 GPU."""
|
779 |
+
self._setup()
|
780 |
+
|
781 |
+
FLAGS.num_gpus = 1
|
782 |
+
FLAGS.enable_eager = True
|
783 |
+
FLAGS.enable_xla = True
|
784 |
+
FLAGS.distribution_strategy = 'one_device'
|
785 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu')
|
786 |
+
FLAGS.batch_size = 128
|
787 |
+
self._run_and_report_benchmark()
|
788 |
+
|
789 |
+
def benchmark_xla_1_gpu_amp(self):
|
790 |
+
"""Test Keras model with XLA and 1 GPU with automatic mixed precision."""
|
791 |
+
self._setup()
|
792 |
+
|
793 |
+
FLAGS.num_gpus = 1
|
794 |
+
FLAGS.enable_eager = True
|
795 |
+
FLAGS.dtype = 'fp16'
|
796 |
+
FLAGS.fp16_implementation = 'graph_rewrite'
|
797 |
+
FLAGS.enable_xla = True
|
798 |
+
FLAGS.distribution_strategy = 'one_device'
|
799 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
|
800 |
+
FLAGS.batch_size = 256
|
801 |
+
self._run_and_report_benchmark()
|
802 |
+
|
803 |
+
def benchmark_1_gpu_fp16(self):
|
804 |
+
"""Test Keras model with 1 GPU and fp16."""
|
805 |
+
self._setup()
|
806 |
+
|
807 |
+
FLAGS.num_gpus = 1
|
808 |
+
FLAGS.enable_eager = True
|
809 |
+
FLAGS.distribution_strategy = 'one_device'
|
810 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16')
|
811 |
+
FLAGS.dtype = 'fp16'
|
812 |
+
FLAGS.batch_size = 256
|
813 |
+
self._run_and_report_benchmark()
|
814 |
+
|
815 |
+
def benchmark_1_gpu_fp16_dynamic(self):
|
816 |
+
"""Test Keras model with 1 GPU, fp16, and dynamic loss scaling."""
|
817 |
+
self._setup()
|
818 |
+
|
819 |
+
FLAGS.num_gpus = 1
|
820 |
+
FLAGS.enable_eager = True
|
821 |
+
FLAGS.distribution_strategy = 'one_device'
|
822 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16_dynamic')
|
823 |
+
FLAGS.dtype = 'fp16'
|
824 |
+
FLAGS.batch_size = 256
|
825 |
+
FLAGS.loss_scale = 'dynamic'
|
826 |
+
self._run_and_report_benchmark()
|
827 |
+
|
828 |
+
def benchmark_xla_1_gpu_fp16(self):
|
829 |
+
"""Test Keras model with XLA, 1 GPU and fp16."""
|
830 |
+
self._setup()
|
831 |
+
|
832 |
+
FLAGS.num_gpus = 1
|
833 |
+
FLAGS.enable_eager = True
|
834 |
+
FLAGS.enable_xla = True
|
835 |
+
FLAGS.distribution_strategy = 'one_device'
|
836 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16')
|
837 |
+
FLAGS.dtype = 'fp16'
|
838 |
+
FLAGS.batch_size = 256
|
839 |
+
self._run_and_report_benchmark()
|
840 |
+
|
841 |
+
def benchmark_xla_1_gpu_fp16_tweaked(self):
|
842 |
+
"""Test Keras model with XLA, 1 GPU, fp16, and manual config tuning."""
|
843 |
+
self._setup()
|
844 |
+
|
845 |
+
FLAGS.num_gpus = 1
|
846 |
+
FLAGS.enable_eager = True
|
847 |
+
FLAGS.enable_xla = True
|
848 |
+
FLAGS.distribution_strategy = 'one_device'
|
849 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked')
|
850 |
+
FLAGS.dtype = 'fp16'
|
851 |
+
FLAGS.batch_size = 256
|
852 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
853 |
+
self._run_and_report_benchmark()
|
854 |
+
|
855 |
+
def benchmark_xla_1_gpu_fp16_dynamic(self):
|
856 |
+
"""Test Keras model with XLA, 1 GPU, fp16, and dynamic loss scaling."""
|
857 |
+
self._setup()
|
858 |
+
|
859 |
+
FLAGS.num_gpus = 1
|
860 |
+
FLAGS.enable_eager = True
|
861 |
+
FLAGS.enable_xla = True
|
862 |
+
FLAGS.distribution_strategy = 'one_device'
|
863 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_dynamic')
|
864 |
+
FLAGS.dtype = 'fp16'
|
865 |
+
FLAGS.batch_size = 256
|
866 |
+
FLAGS.loss_scale = 'dynamic'
|
867 |
+
self._run_and_report_benchmark()
|
868 |
+
|
869 |
+
def benchmark_8_gpu(self):
|
870 |
+
"""Test Keras model with 8 GPUs."""
|
871 |
+
self._setup()
|
872 |
+
|
873 |
+
FLAGS.num_gpus = 8
|
874 |
+
FLAGS.enable_eager = True
|
875 |
+
FLAGS.distribution_strategy = 'mirrored'
|
876 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
|
877 |
+
FLAGS.batch_size = 128 * 8 # 8 GPUs
|
878 |
+
self._run_and_report_benchmark()
|
879 |
+
|
880 |
+
def benchmark_8_gpu_amp(self):
|
881 |
+
"""Test Keras model with 8 GPUs with automatic mixed precision."""
|
882 |
+
self._setup()
|
883 |
+
|
884 |
+
FLAGS.num_gpus = 8
|
885 |
+
FLAGS.enable_eager = True
|
886 |
+
FLAGS.dtype = 'fp16'
|
887 |
+
FLAGS.fp16_implementation = 'graph_rewrite'
|
888 |
+
FLAGS.distribution_strategy = 'mirrored'
|
889 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
|
890 |
+
FLAGS.batch_size = 256 * 8 # 8 GPUs
|
891 |
+
self._run_and_report_benchmark()
|
892 |
+
|
893 |
+
def benchmark_8_gpu_tweaked(self):
|
894 |
+
"""Test Keras model with manual config tuning and 8 GPUs."""
|
895 |
+
self._setup()
|
896 |
+
|
897 |
+
FLAGS.num_gpus = 8
|
898 |
+
FLAGS.enable_eager = True
|
899 |
+
FLAGS.distribution_strategy = 'mirrored'
|
900 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_tweaked')
|
901 |
+
FLAGS.batch_size = 128 * 8 # 8 GPUs
|
902 |
+
FLAGS.datasets_num_private_threads = 14
|
903 |
+
self._run_and_report_benchmark()
|
904 |
+
|
905 |
+
def benchmark_xla_8_gpu(self):
|
906 |
+
"""Test Keras model with XLA and 8 GPUs."""
|
907 |
+
self._setup()
|
908 |
+
|
909 |
+
FLAGS.num_gpus = 8
|
910 |
+
FLAGS.enable_eager = True
|
911 |
+
FLAGS.enable_xla = True
|
912 |
+
FLAGS.distribution_strategy = 'mirrored'
|
913 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu')
|
914 |
+
FLAGS.batch_size = 128 * 8 # 8 GPUs
|
915 |
+
self._run_and_report_benchmark()
|
916 |
+
|
917 |
+
def benchmark_xla_8_gpu_amp(self):
|
918 |
+
"""Test Keras model with XLA and 8 GPUs with automatic mixed precision."""
|
919 |
+
self._setup()
|
920 |
+
|
921 |
+
FLAGS.num_gpus = 8
|
922 |
+
FLAGS.enable_eager = True
|
923 |
+
FLAGS.dtype = 'fp16'
|
924 |
+
FLAGS.fp16_implementation = 'graph_rewrite'
|
925 |
+
FLAGS.enable_xla = True
|
926 |
+
FLAGS.distribution_strategy = 'mirrored'
|
927 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
|
928 |
+
FLAGS.batch_size = 256 * 8 # 8 GPUs
|
929 |
+
self._run_and_report_benchmark()
|
930 |
+
|
931 |
+
def benchmark_xla_8_gpu_tweaked(self):
|
932 |
+
"""Test Keras model with manual config tuning, 8 GPUs, and XLA."""
|
933 |
+
self._setup()
|
934 |
+
|
935 |
+
FLAGS.num_gpus = 8
|
936 |
+
FLAGS.enable_eager = True
|
937 |
+
FLAGS.enable_xla = True
|
938 |
+
FLAGS.distribution_strategy = 'mirrored'
|
939 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_tweaked')
|
940 |
+
FLAGS.batch_size = 128 * 8
|
941 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
942 |
+
FLAGS.datasets_num_private_threads = 24
|
943 |
+
self._run_and_report_benchmark()
|
944 |
+
|
945 |
+
def benchmark_8_gpu_fp16(self):
|
946 |
+
"""Test Keras model with 8 GPUs and fp16."""
|
947 |
+
self._setup()
|
948 |
+
|
949 |
+
FLAGS.num_gpus = 8
|
950 |
+
FLAGS.dtype = 'fp16'
|
951 |
+
FLAGS.enable_eager = True
|
952 |
+
FLAGS.distribution_strategy = 'mirrored'
|
953 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
|
954 |
+
FLAGS.batch_size = 256 * 8 # 8 GPUs
|
955 |
+
self._run_and_report_benchmark()
|
956 |
+
|
957 |
+
def benchmark_8_gpu_fp16_tweaked(self):
|
958 |
+
"""Test Keras model with 8 GPUs, fp16, and manual config tuning."""
|
959 |
+
self._setup()
|
960 |
+
|
961 |
+
FLAGS.num_gpus = 8
|
962 |
+
FLAGS.dtype = 'fp16'
|
963 |
+
FLAGS.enable_eager = True
|
964 |
+
FLAGS.distribution_strategy = 'mirrored'
|
965 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked')
|
966 |
+
FLAGS.batch_size = 256 * 8 # 8 GPUs
|
967 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
968 |
+
FLAGS.dataset_num_private_threads = 40
|
969 |
+
self._run_and_report_benchmark()
|
970 |
+
|
971 |
+
def benchmark_8_gpu_fp16_dynamic_tweaked(self):
|
972 |
+
"""Test Keras model with 8 GPUs, fp16, dynamic loss scaling, and tuned."""
|
973 |
+
self._setup()
|
974 |
+
|
975 |
+
FLAGS.num_gpus = 8
|
976 |
+
FLAGS.dtype = 'fp16'
|
977 |
+
FLAGS.enable_eager = True
|
978 |
+
FLAGS.distribution_strategy = 'mirrored'
|
979 |
+
FLAGS.model_dir = self._get_model_dir(
|
980 |
+
'benchmark_8_gpu_fp16_dynamic_tweaked')
|
981 |
+
FLAGS.batch_size = 256 * 8 # 8 GPUs
|
982 |
+
FLAGS.loss_scale = 'dynamic'
|
983 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
984 |
+
FLAGS.dataset_num_private_threads = 40
|
985 |
+
self._run_and_report_benchmark()
|
986 |
+
|
987 |
+
def benchmark_xla_8_gpu_fp16(self):
|
988 |
+
"""Test Keras model with XLA, 8 GPUs and fp16."""
|
989 |
+
self._setup()
|
990 |
+
|
991 |
+
FLAGS.num_gpus = 8
|
992 |
+
FLAGS.dtype = 'fp16'
|
993 |
+
FLAGS.enable_eager = True
|
994 |
+
FLAGS.enable_xla = True
|
995 |
+
FLAGS.distribution_strategy = 'mirrored'
|
996 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16')
|
997 |
+
FLAGS.batch_size = 256 * 8 # 8 GPUs
|
998 |
+
self._run_and_report_benchmark()
|
999 |
+
|
1000 |
+
def benchmark_xla_8_gpu_fp16_tweaked(self):
|
1001 |
+
"""Test Keras model with manual config tuning, XLA, 8 GPUs and fp16."""
|
1002 |
+
self._setup()
|
1003 |
+
|
1004 |
+
FLAGS.num_gpus = 8
|
1005 |
+
FLAGS.dtype = 'fp16'
|
1006 |
+
FLAGS.enable_eager = True
|
1007 |
+
FLAGS.enable_xla = True
|
1008 |
+
FLAGS.distribution_strategy = 'mirrored'
|
1009 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked')
|
1010 |
+
FLAGS.batch_size = 256 * 8 # 8 GPUs
|
1011 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
1012 |
+
FLAGS.datasets_num_private_threads = 48
|
1013 |
+
self._run_and_report_benchmark()
|
1014 |
+
|
1015 |
+
def benchmark_xla_8_gpu_fp16_tweaked_delay_measure(self):
|
1016 |
+
"""Test with manual config tuning, XLA, 8 GPUs and fp16.
|
1017 |
+
|
1018 |
+
Delay performance measurement for stable performance on 96 vCPU platforms.
|
1019 |
+
"""
|
1020 |
+
self._setup()
|
1021 |
+
|
1022 |
+
FLAGS.num_gpus = 8
|
1023 |
+
FLAGS.dtype = 'fp16'
|
1024 |
+
FLAGS.enable_eager = True
|
1025 |
+
FLAGS.enable_xla = True
|
1026 |
+
FLAGS.distribution_strategy = 'mirrored'
|
1027 |
+
FLAGS.model_dir = self._get_model_dir(
|
1028 |
+
'benchmark_xla_8_gpu_fp16_tweaked_delay_measure')
|
1029 |
+
FLAGS.batch_size = 256 * 8
|
1030 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
1031 |
+
FLAGS.datasets_num_private_threads = 48
|
1032 |
+
FLAGS.train_steps = 310
|
1033 |
+
self._run_and_report_benchmark()
|
1034 |
+
|
1035 |
+
def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
|
1036 |
+
"""Test Keras model with config tuning, XLA, 8 GPUs and dynamic fp16."""
|
1037 |
+
self._setup()
|
1038 |
+
|
1039 |
+
FLAGS.num_gpus = 8
|
1040 |
+
FLAGS.dtype = 'fp16'
|
1041 |
+
FLAGS.enable_eager = True
|
1042 |
+
FLAGS.enable_xla = True
|
1043 |
+
FLAGS.distribution_strategy = 'mirrored'
|
1044 |
+
FLAGS.model_dir = self._get_model_dir(
|
1045 |
+
'benchmark_xla_8_gpu_fp16_dynamic_tweaked')
|
1046 |
+
FLAGS.batch_size = 256 * 8 # 8 GPUs
|
1047 |
+
FLAGS.loss_scale = 'dynamic'
|
1048 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
1049 |
+
FLAGS.datasets_num_private_threads = 48
|
1050 |
+
self._run_and_report_benchmark()
|
1051 |
+
|
1052 |
+
def benchmark_2x2_tpu_bf16(self):
|
1053 |
+
"""Test Keras model with 2x2 TPU, bf16."""
|
1054 |
+
self._setup()
|
1055 |
+
|
1056 |
+
FLAGS.dtype = 'bf16'
|
1057 |
+
FLAGS.distribution_strategy = 'tpu'
|
1058 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16')
|
1059 |
+
FLAGS.batch_size = 1024
|
1060 |
+
self._run_and_report_benchmark()
|
1061 |
+
|
1062 |
+
def benchmark_4x4_tpu_bf16(self):
|
1063 |
+
"""Test Keras model with 4x4 TPU, bf16."""
|
1064 |
+
self._setup()
|
1065 |
+
|
1066 |
+
FLAGS.dtype = 'bf16'
|
1067 |
+
FLAGS.distribution_strategy = 'tpu'
|
1068 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16')
|
1069 |
+
FLAGS.batch_size = 4096
|
1070 |
+
self._run_and_report_benchmark()
|
1071 |
+
|
1072 |
+
def benchmark_8x8_tpu_bf16(self):
|
1073 |
+
"""Test Keras model with 8x8 TPU, bf16."""
|
1074 |
+
self._setup()
|
1075 |
+
|
1076 |
+
FLAGS.dtype = 'bf16'
|
1077 |
+
FLAGS.distribution_strategy = 'tpu'
|
1078 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8x8_tpu_bf16')
|
1079 |
+
FLAGS.batch_size = 8192
|
1080 |
+
self._run_and_report_benchmark()
|
1081 |
+
|
1082 |
+
def fill_report_object(self, stats):
|
1083 |
+
super(Resnet50KerasBenchmarkBase, self).fill_report_object(
|
1084 |
+
stats,
|
1085 |
+
total_batch_size=FLAGS.batch_size,
|
1086 |
+
log_steps=FLAGS.log_steps)
|
1087 |
+
|
1088 |
+
|
1089 |
+
class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase):
|
1090 |
+
"""Resnet50 synthetic benchmark tests."""
|
1091 |
+
|
1092 |
+
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
|
1093 |
+
def_flags = {}
|
1094 |
+
def_flags['log_steps'] = 10
|
1095 |
+
|
1096 |
+
super(Resnet50KerasBenchmarkSynth, self).__init__(
|
1097 |
+
output_dir=output_dir, default_flags=def_flags, tpu=tpu,
|
1098 |
+
dataset_builder='synthetic', train_epochs=1, train_steps=110)
|
1099 |
+
|
1100 |
+
|
1101 |
+
class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase):
|
1102 |
+
"""Resnet50 real data benchmark tests."""
|
1103 |
+
|
1104 |
+
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
|
1105 |
+
data_dir = os.path.join(root_data_dir, 'imagenet')
|
1106 |
+
def_flags = {}
|
1107 |
+
def_flags['log_steps'] = 10
|
1108 |
+
|
1109 |
+
super(Resnet50KerasBenchmarkReal, self).__init__(
|
1110 |
+
output_dir=output_dir, default_flags=def_flags, tpu=tpu,
|
1111 |
+
dataset_builder='records', train_epochs=1, train_steps=110,
|
1112 |
+
data_dir=data_dir)
|
1113 |
+
|
1114 |
+
|
1115 |
+
class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase):
|
1116 |
+
"""Resnet50 real data (stored in remote storage) benchmark tests."""
|
1117 |
+
|
1118 |
+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
1119 |
+
def_flags = {}
|
1120 |
+
def_flags['skip_eval'] = True
|
1121 |
+
def_flags['report_accuracy_metrics'] = False
|
1122 |
+
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
|
1123 |
+
# Defining multiple epochs overrides the train_steps setting in benchmarks.
|
1124 |
+
def_flags['train_epochs'] = 2
|
1125 |
+
# Cache dataset so performance is stable after the first epoch.
|
1126 |
+
def_flags['training_dataset_cache'] = True
|
1127 |
+
def_flags['log_steps'] = 100
|
1128 |
+
# Note that for single GPU and pure eager tests which are less likely to be
|
1129 |
+
# input bound and more stable, these tests will run for shorter time by
|
1130 |
+
# overriding FLAGS.train_epochs, train_seteps, log_steps in benchmark
|
1131 |
+
# methods, and skip_steps in _run_and_report_benchmark().
|
1132 |
+
|
1133 |
+
super(Resnet50KerasBenchmarkRemoteData, self).__init__(
|
1134 |
+
output_dir=output_dir, default_flags=def_flags)
|
1135 |
+
|
1136 |
+
def _override_flags_to_run_test_shorter(self):
|
1137 |
+
FLAGS.train_epochs = 1
|
1138 |
+
FLAGS.train_steps = 300
|
1139 |
+
FLAGS.log_steps = 10
|
1140 |
+
|
1141 |
+
def benchmark_1_gpu_no_dist_strat(self):
|
1142 |
+
"""Test Keras model with 1 GPU, no distribution strategy."""
|
1143 |
+
self._setup()
|
1144 |
+
|
1145 |
+
FLAGS.num_gpus = 1
|
1146 |
+
FLAGS.enable_eager = True
|
1147 |
+
FLAGS.distribution_strategy = 'off'
|
1148 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
|
1149 |
+
FLAGS.batch_size = 128
|
1150 |
+
self._override_flags_to_run_test_shorter()
|
1151 |
+
self._run_and_report_benchmark()
|
1152 |
+
|
1153 |
+
def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
|
1154 |
+
"""Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
|
1155 |
+
self._setup()
|
1156 |
+
|
1157 |
+
FLAGS.num_gpus = 1
|
1158 |
+
FLAGS.enable_eager = True
|
1159 |
+
FLAGS.run_eagerly = True
|
1160 |
+
FLAGS.distribution_strategy = 'off'
|
1161 |
+
FLAGS.model_dir = self._get_model_dir(
|
1162 |
+
'benchmark_1_gpu_no_dist_strat_run_eagerly')
|
1163 |
+
FLAGS.batch_size = 64
|
1164 |
+
self._override_flags_to_run_test_shorter()
|
1165 |
+
self._run_and_report_benchmark()
|
1166 |
+
|
1167 |
+
def benchmark_1_gpu_no_dist_strat_run_eagerly_tweaked(self):
|
1168 |
+
"""Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
|
1169 |
+
self._setup()
|
1170 |
+
|
1171 |
+
FLAGS.num_gpus = 1
|
1172 |
+
FLAGS.enable_eager = True
|
1173 |
+
FLAGS.run_eagerly = True
|
1174 |
+
FLAGS.explicit_gpu_placement = True
|
1175 |
+
FLAGS.distribution_strategy = 'off'
|
1176 |
+
FLAGS.model_dir = self._get_model_dir(
|
1177 |
+
'benchmark_1_gpu_no_dist_strat_run_eagerly_tweaked')
|
1178 |
+
FLAGS.batch_size = 64
|
1179 |
+
self._override_flags_to_run_test_shorter()
|
1180 |
+
self._run_and_report_benchmark()
|
1181 |
+
|
1182 |
+
def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16(self):
|
1183 |
+
"""Test with 1 GPU, no distribution strategy, fp16, run eagerly."""
|
1184 |
+
self._setup()
|
1185 |
+
|
1186 |
+
FLAGS.num_gpus = 1
|
1187 |
+
FLAGS.enable_eager = True
|
1188 |
+
FLAGS.run_eagerly = True
|
1189 |
+
FLAGS.distribution_strategy = 'off'
|
1190 |
+
FLAGS.model_dir = self._get_model_dir(
|
1191 |
+
'benchmark_1_gpu_no_dist_strat_run_eagerly_fp16')
|
1192 |
+
FLAGS.dtype = 'fp16'
|
1193 |
+
FLAGS.batch_size = 128
|
1194 |
+
self._override_flags_to_run_test_shorter()
|
1195 |
+
self._run_and_report_benchmark()
|
1196 |
+
|
1197 |
+
def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16_tweaked(self):
|
1198 |
+
"""Test with 1 GPU, no distribution strategy, fp16, run eagerly."""
|
1199 |
+
self._setup()
|
1200 |
+
|
1201 |
+
FLAGS.num_gpus = 1
|
1202 |
+
FLAGS.enable_eager = True
|
1203 |
+
FLAGS.run_eagerly = True
|
1204 |
+
FLAGS.explicit_gpu_placement = True
|
1205 |
+
FLAGS.distribution_strategy = 'off'
|
1206 |
+
FLAGS.model_dir = self._get_model_dir(
|
1207 |
+
'benchmark_1_gpu_no_dist_strat_run_eagerly_fp16_tweaked')
|
1208 |
+
FLAGS.dtype = 'fp16'
|
1209 |
+
FLAGS.batch_size = 128
|
1210 |
+
self._override_flags_to_run_test_shorter()
|
1211 |
+
self._run_and_report_benchmark()
|
1212 |
+
|
1213 |
+
def benchmark_1_gpu(self):
|
1214 |
+
"""Test Keras model with 1 GPU."""
|
1215 |
+
self._setup()
|
1216 |
+
|
1217 |
+
FLAGS.num_gpus = 1
|
1218 |
+
FLAGS.enable_eager = True
|
1219 |
+
FLAGS.distribution_strategy = 'one_device'
|
1220 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
|
1221 |
+
FLAGS.batch_size = 128
|
1222 |
+
self._override_flags_to_run_test_shorter()
|
1223 |
+
self._run_and_report_benchmark()
|
1224 |
+
|
1225 |
+
def benchmark_1_gpu_amp(self):
|
1226 |
+
"""Test Keras model with 1 GPU with automatic mixed precision."""
|
1227 |
+
self._setup()
|
1228 |
+
|
1229 |
+
FLAGS.num_gpus = 1
|
1230 |
+
FLAGS.enable_eager = True
|
1231 |
+
FLAGS.dtype = 'fp16'
|
1232 |
+
FLAGS.fp16_implementation = 'graph_rewrite'
|
1233 |
+
FLAGS.distribution_strategy = 'one_device'
|
1234 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
|
1235 |
+
FLAGS.batch_size = 256
|
1236 |
+
self._override_flags_to_run_test_shorter()
|
1237 |
+
self._run_and_report_benchmark()
|
1238 |
+
|
1239 |
+
def benchmark_xla_1_gpu(self):
|
1240 |
+
"""Test Keras model with XLA and 1 GPU."""
|
1241 |
+
self._setup()
|
1242 |
+
|
1243 |
+
FLAGS.num_gpus = 1
|
1244 |
+
FLAGS.enable_eager = True
|
1245 |
+
FLAGS.enable_xla = True
|
1246 |
+
FLAGS.distribution_strategy = 'one_device'
|
1247 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu')
|
1248 |
+
FLAGS.batch_size = 128
|
1249 |
+
self._override_flags_to_run_test_shorter()
|
1250 |
+
self._run_and_report_benchmark()
|
1251 |
+
|
1252 |
+
def benchmark_xla_1_gpu_amp(self):
|
1253 |
+
"""Test Keras model with XLA and 1 GPU with automatic mixed precision."""
|
1254 |
+
self._setup()
|
1255 |
+
|
1256 |
+
FLAGS.num_gpus = 1
|
1257 |
+
FLAGS.enable_eager = True
|
1258 |
+
FLAGS.dtype = 'fp16'
|
1259 |
+
FLAGS.fp16_implementation = 'graph_rewrite'
|
1260 |
+
FLAGS.enable_xla = True
|
1261 |
+
FLAGS.distribution_strategy = 'one_device'
|
1262 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
|
1263 |
+
FLAGS.batch_size = 256
|
1264 |
+
self._override_flags_to_run_test_shorter()
|
1265 |
+
self._run_and_report_benchmark()
|
1266 |
+
|
1267 |
+
def benchmark_1_gpu_fp16(self):
|
1268 |
+
"""Test Keras model with 1 GPU and fp16."""
|
1269 |
+
self._setup()
|
1270 |
+
|
1271 |
+
FLAGS.num_gpus = 1
|
1272 |
+
FLAGS.enable_eager = True
|
1273 |
+
FLAGS.distribution_strategy = 'one_device'
|
1274 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16')
|
1275 |
+
FLAGS.dtype = 'fp16'
|
1276 |
+
FLAGS.batch_size = 256
|
1277 |
+
self._override_flags_to_run_test_shorter()
|
1278 |
+
self._run_and_report_benchmark()
|
1279 |
+
|
1280 |
+
def benchmark_1_gpu_fp16_dynamic(self):
|
1281 |
+
"""Test Keras model with 1 GPU, fp16, and dynamic loss scaling."""
|
1282 |
+
self._setup()
|
1283 |
+
|
1284 |
+
FLAGS.num_gpus = 1
|
1285 |
+
FLAGS.enable_eager = True
|
1286 |
+
FLAGS.distribution_strategy = 'one_device'
|
1287 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16_dynamic')
|
1288 |
+
FLAGS.dtype = 'fp16'
|
1289 |
+
FLAGS.batch_size = 256
|
1290 |
+
FLAGS.loss_scale = 'dynamic'
|
1291 |
+
self._override_flags_to_run_test_shorter()
|
1292 |
+
self._run_and_report_benchmark()
|
1293 |
+
|
1294 |
+
def benchmark_xla_1_gpu_fp16(self):
|
1295 |
+
"""Test Keras model with XLA, 1 GPU and fp16."""
|
1296 |
+
self._setup()
|
1297 |
+
|
1298 |
+
FLAGS.num_gpus = 1
|
1299 |
+
FLAGS.enable_eager = True
|
1300 |
+
FLAGS.enable_xla = True
|
1301 |
+
FLAGS.distribution_strategy = 'one_device'
|
1302 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16')
|
1303 |
+
FLAGS.dtype = 'fp16'
|
1304 |
+
FLAGS.batch_size = 256
|
1305 |
+
self._override_flags_to_run_test_shorter()
|
1306 |
+
self._run_and_report_benchmark()
|
1307 |
+
|
1308 |
+
def benchmark_xla_1_gpu_fp16_tweaked(self):
|
1309 |
+
"""Test Keras model with XLA, 1 GPU, fp16, and manual config tuning."""
|
1310 |
+
self._setup()
|
1311 |
+
|
1312 |
+
FLAGS.num_gpus = 1
|
1313 |
+
FLAGS.enable_eager = True
|
1314 |
+
FLAGS.enable_xla = True
|
1315 |
+
FLAGS.distribution_strategy = 'one_device'
|
1316 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked')
|
1317 |
+
FLAGS.dtype = 'fp16'
|
1318 |
+
FLAGS.batch_size = 256
|
1319 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
1320 |
+
self._override_flags_to_run_test_shorter()
|
1321 |
+
self._run_and_report_benchmark()
|
1322 |
+
|
1323 |
+
def benchmark_xla_1_gpu_fp16_dynamic(self):
|
1324 |
+
"""Test Keras model with XLA, 1 GPU, fp16, and dynamic loss scaling."""
|
1325 |
+
self._setup()
|
1326 |
+
|
1327 |
+
FLAGS.num_gpus = 1
|
1328 |
+
FLAGS.enable_eager = True
|
1329 |
+
FLAGS.enable_xla = True
|
1330 |
+
FLAGS.distribution_strategy = 'one_device'
|
1331 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_dynamic')
|
1332 |
+
FLAGS.dtype = 'fp16'
|
1333 |
+
FLAGS.batch_size = 256
|
1334 |
+
FLAGS.loss_scale = 'dynamic'
|
1335 |
+
self._override_flags_to_run_test_shorter()
|
1336 |
+
self._run_and_report_benchmark()
|
1337 |
+
|
1338 |
+
@benchmark_wrappers.enable_runtime_flags
|
1339 |
+
def _run_and_report_benchmark(self):
|
1340 |
+
if FLAGS.num_gpus == 1 or FLAGS.run_eagerly:
|
1341 |
+
# For single GPU and pure eager tests which are less likely to be input
|
1342 |
+
# bound and more stable, run for shorter time and use the default
|
1343 |
+
# skip_steps.
|
1344 |
+
skip_steps = None
|
1345 |
+
else:
|
1346 |
+
# skip the first epoch for performance measurement.
|
1347 |
+
skip_steps = 600
|
1348 |
+
super(Resnet50KerasBenchmarkRemoteData,
|
1349 |
+
self)._run_and_report_benchmark(skip_steps=skip_steps)
|
1350 |
+
|
1351 |
+
|
1352 |
+
class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
|
1353 |
+
"""Trivial model with real data benchmark tests."""
|
1354 |
+
|
1355 |
+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
1356 |
+
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
|
1357 |
+
|
1358 |
+
def_flags = {}
|
1359 |
+
def_flags['use_trivial_model'] = True
|
1360 |
+
def_flags['skip_eval'] = True
|
1361 |
+
def_flags['report_accuracy_metrics'] = False
|
1362 |
+
def_flags['dtype'] = 'fp16'
|
1363 |
+
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
|
1364 |
+
def_flags['train_steps'] = 600
|
1365 |
+
def_flags['log_steps'] = 100
|
1366 |
+
def_flags['distribution_strategy'] = 'mirrored'
|
1367 |
+
|
1368 |
+
super(TrivialKerasBenchmarkReal, self).__init__(
|
1369 |
+
output_dir=output_dir,
|
1370 |
+
flag_methods=flag_methods,
|
1371 |
+
default_flags=def_flags)
|
1372 |
+
|
1373 |
+
@benchmark_wrappers.enable_runtime_flags
|
1374 |
+
def _run_and_report_benchmark(self):
|
1375 |
+
start_time_sec = time.time()
|
1376 |
+
stats = resnet_imagenet_main.run(FLAGS)
|
1377 |
+
wall_time_sec = time.time() - start_time_sec
|
1378 |
+
|
1379 |
+
super(TrivialKerasBenchmarkReal, self)._report_benchmark(
|
1380 |
+
stats,
|
1381 |
+
wall_time_sec,
|
1382 |
+
total_batch_size=FLAGS.batch_size,
|
1383 |
+
log_steps=FLAGS.log_steps)
|
1384 |
+
|
1385 |
+
def benchmark_8_gpu_warmup(self):
|
1386 |
+
"""Dummy test that runs over an epoch to warmup the machine."""
|
1387 |
+
self._setup()
|
1388 |
+
|
1389 |
+
FLAGS.num_gpus = 8
|
1390 |
+
FLAGS.enable_eager = True
|
1391 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_warmup')
|
1392 |
+
FLAGS.batch_size = 256 * 8
|
1393 |
+
FLAGS.train_steps = 700
|
1394 |
+
self._run_and_report_benchmark()
|
1395 |
+
|
1396 |
+
def fill_report_object(self, stats):
|
1397 |
+
super(TrivialKerasBenchmarkReal, self).fill_report_object(
|
1398 |
+
stats,
|
1399 |
+
total_batch_size=FLAGS.batch_size,
|
1400 |
+
log_steps=FLAGS.log_steps)
|
1401 |
+
|
1402 |
+
|
1403 |
+
class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark):
|
1404 |
+
"""Resnet50 distributed accuracy tests with multiple workers."""
|
1405 |
+
|
1406 |
+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
1407 |
+
flag_methods = [classifier_trainer.define_imagenet_keras_flags]
|
1408 |
+
self.data_dir = os.path.join(root_data_dir, 'imagenet')
|
1409 |
+
super(Resnet50MultiWorkerKerasAccuracy, self).__init__(
|
1410 |
+
output_dir=output_dir, flag_methods=flag_methods)
|
1411 |
+
|
1412 |
+
def _benchmark_common(self, eager, num_workers, all_reduce_alg):
|
1413 |
+
"""Common to all benchmarks in this class."""
|
1414 |
+
self._setup()
|
1415 |
+
|
1416 |
+
num_gpus = 8
|
1417 |
+
FLAGS.num_gpus = num_gpus
|
1418 |
+
FLAGS.data_dir = self.data_dir
|
1419 |
+
FLAGS.train_epochs = 90
|
1420 |
+
FLAGS.epochs_between_evals = 10
|
1421 |
+
FLAGS.dtype = 'fp16'
|
1422 |
+
FLAGS.enable_eager = eager
|
1423 |
+
FLAGS.enable_xla = False
|
1424 |
+
FLAGS.distribution_strategy = 'multi_worker_mirrored'
|
1425 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
1426 |
+
FLAGS.datasets_num_private_threads = 32
|
1427 |
+
FLAGS.model_dir = self._get_model_dir(
|
1428 |
+
'benchmark_{}_8_gpu_{}_worker_fp16_{}_tweaked'.format(
|
1429 |
+
'eager' if eager else 'graph', num_workers, all_reduce_alg))
|
1430 |
+
FLAGS.batch_size = 256 * num_gpus * num_workers
|
1431 |
+
FLAGS.all_reduce_alg = all_reduce_alg
|
1432 |
+
|
1433 |
+
self._run_and_report_benchmark()
|
1434 |
+
|
1435 |
+
@benchmark_wrappers.enable_runtime_flags
|
1436 |
+
def _run_and_report_benchmark(self,
|
1437 |
+
top_1_min=MIN_TOP_1_ACCURACY,
|
1438 |
+
top_1_max=MAX_TOP_1_ACCURACY):
|
1439 |
+
start_time_sec = time.time()
|
1440 |
+
stats = classifier_trainer.run(flags.FLAGS)
|
1441 |
+
wall_time_sec = time.time() - start_time_sec
|
1442 |
+
|
1443 |
+
super(Resnet50MultiWorkerKerasAccuracy, self)._report_benchmark(
|
1444 |
+
stats,
|
1445 |
+
wall_time_sec,
|
1446 |
+
top_1_min=top_1_min,
|
1447 |
+
top_1_max=top_1_max,
|
1448 |
+
total_batch_size=FLAGS.batch_size,
|
1449 |
+
log_steps=100)
|
1450 |
+
|
1451 |
+
def _get_model_dir(self, folder_name):
|
1452 |
+
return os.path.join(self.output_dir, folder_name)
|
1453 |
+
|
1454 |
+
def benchmark_eager_8_gpu_2_workers_fp16_ring_tweaked(self):
|
1455 |
+
"""Eager, 8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
|
1456 |
+
self._benchmark_common(eager=True, num_workers=2, all_reduce_alg='ring')
|
1457 |
+
|
1458 |
+
def benchmark_eager_8_gpu_2_workers_fp16_nccl_tweaked(self):
|
1459 |
+
"""Eager, 8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
|
1460 |
+
self._benchmark_common(eager=True, num_workers=2, all_reduce_alg='nccl')
|
1461 |
+
|
1462 |
+
def benchmark_eager_8_gpu_8_workers_fp16_ring_tweaked(self):
|
1463 |
+
"""Eager, 8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
|
1464 |
+
self._benchmark_common(eager=True, num_workers=8, all_reduce_alg='ring')
|
1465 |
+
|
1466 |
+
def benchmark_eager_8_gpu_8_workers_fp16_nccl_tweaked(self):
|
1467 |
+
"""Eager, 8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
|
1468 |
+
self._benchmark_common(eager=True, num_workers=8, all_reduce_alg='nccl')
|
1469 |
+
|
1470 |
+
|
1471 |
+
class Resnet50MultiWorkerKerasBenchmark(Resnet50KerasBenchmarkBase):
|
1472 |
+
"""Resnet50 distributed benchmark tests with multiple workers."""
|
1473 |
+
|
1474 |
+
def __init__(self, output_dir=None, default_flags=None):
|
1475 |
+
super(Resnet50MultiWorkerKerasBenchmark, self).__init__(
|
1476 |
+
output_dir=output_dir, default_flags=default_flags)
|
1477 |
+
|
1478 |
+
def _benchmark_common(self, eager, num_workers, all_reduce_alg):
|
1479 |
+
"""Common to all benchmarks in this class."""
|
1480 |
+
self._setup()
|
1481 |
+
|
1482 |
+
num_gpus = 8
|
1483 |
+
FLAGS.num_gpus = num_gpus
|
1484 |
+
FLAGS.dtype = 'fp16'
|
1485 |
+
FLAGS.enable_eager = eager
|
1486 |
+
FLAGS.enable_xla = False
|
1487 |
+
FLAGS.distribution_strategy = 'multi_worker_mirrored'
|
1488 |
+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
|
1489 |
+
FLAGS.datasets_num_private_threads = 32
|
1490 |
+
FLAGS.model_dir = self._get_model_dir(
|
1491 |
+
'benchmark_{}_8_gpu_{}_worker_fp16_{}_tweaked'.format(
|
1492 |
+
'eager' if eager else 'graph', num_workers, all_reduce_alg))
|
1493 |
+
FLAGS.batch_size = 256 * num_gpus * num_workers
|
1494 |
+
FLAGS.all_reduce_alg = all_reduce_alg
|
1495 |
+
|
1496 |
+
self._run_and_report_benchmark()
|
1497 |
+
|
1498 |
+
def benchmark_eager_8_gpu_1_worker_fp16_ring_tweaked(self):
|
1499 |
+
"""Eager, 8 GPUs per worker, 1 worker, fp16, ring all-reduce."""
|
1500 |
+
self._benchmark_common(eager=True, num_workers=1, all_reduce_alg='ring')
|
1501 |
+
|
1502 |
+
def benchmark_eager_8_gpu_1_worker_fp16_nccl_tweaked(self):
|
1503 |
+
"""Eager, 8 GPUs per worker, 1 worker, fp16, nccl all-reduce."""
|
1504 |
+
self._benchmark_common(eager=True, num_workers=1, all_reduce_alg='nccl')
|
1505 |
+
|
1506 |
+
def benchmark_eager_8_gpu_2_workers_fp16_ring_tweaked(self):
|
1507 |
+
"""Eager, 8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
|
1508 |
+
self._benchmark_common(eager=True, num_workers=2, all_reduce_alg='ring')
|
1509 |
+
|
1510 |
+
def benchmark_eager_8_gpu_2_workers_fp16_nccl_tweaked(self):
|
1511 |
+
"""Eager, 8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
|
1512 |
+
self._benchmark_common(eager=True, num_workers=2, all_reduce_alg='nccl')
|
1513 |
+
|
1514 |
+
def benchmark_eager_8_gpu_8_workers_fp16_ring_tweaked(self):
|
1515 |
+
"""Eager, 8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
|
1516 |
+
self._benchmark_common(eager=True, num_workers=8, all_reduce_alg='ring')
|
1517 |
+
|
1518 |
+
def benchmark_eager_8_gpu_8_workers_fp16_nccl_tweaked(self):
|
1519 |
+
"""Eager, 8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
|
1520 |
+
self._benchmark_common(eager=True, num_workers=8, all_reduce_alg='nccl')
|
1521 |
+
|
1522 |
+
|
1523 |
+
class Resnet50MultiWorkerKerasBenchmarkSynth(Resnet50MultiWorkerKerasBenchmark):
|
1524 |
+
"""Resnet50 multi-worker synthetic data benchmark tests."""
|
1525 |
+
|
1526 |
+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
1527 |
+
def_flags = {}
|
1528 |
+
def_flags['skip_eval'] = True
|
1529 |
+
def_flags['report_accuracy_metrics'] = False
|
1530 |
+
def_flags['use_synthetic_data'] = True
|
1531 |
+
def_flags['train_steps'] = 110
|
1532 |
+
def_flags['log_steps'] = 10
|
1533 |
+
|
1534 |
+
super(Resnet50MultiWorkerKerasBenchmarkSynth, self).__init__(
|
1535 |
+
output_dir=output_dir, default_flags=def_flags)
|
1536 |
+
|
1537 |
+
|
1538 |
+
class Resnet50MultiWorkerKerasBenchmarkReal(Resnet50MultiWorkerKerasBenchmark):
|
1539 |
+
"""Resnet50 multi-worker real data benchmark tests."""
|
1540 |
+
|
1541 |
+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
|
1542 |
+
def_flags = {}
|
1543 |
+
def_flags['skip_eval'] = True
|
1544 |
+
def_flags['report_accuracy_metrics'] = False
|
1545 |
+
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
|
1546 |
+
def_flags['train_steps'] = 110
|
1547 |
+
def_flags['log_steps'] = 10
|
1548 |
+
|
1549 |
+
super(Resnet50MultiWorkerKerasBenchmarkReal, self).__init__(
|
1550 |
+
output_dir=output_dir, default_flags=def_flags)
|
1551 |
+
|
1552 |
+
|
1553 |
+
# TODO(kimjaehong): It also should be also cover other metheods of model
|
1554 |
+
# optimization techniques. In that time, this class will change to something
|
1555 |
+
# like 'KerasModelOptimizationAccuracyBase'.
|
1556 |
+
class KerasPruningAccuracyBase(keras_benchmark.KerasBenchmark):
|
1557 |
+
"""Benchmark accuracy tests for pruning method."""
|
1558 |
+
|
1559 |
+
def __init__(self,
|
1560 |
+
output_dir=None,
|
1561 |
+
root_data_dir=None,
|
1562 |
+
default_flags=None,
|
1563 |
+
**kwargs):
|
1564 |
+
"""A accuracy benchmark class for pruning method.
|
1565 |
+
|
1566 |
+
Args:
|
1567 |
+
output_dir: directory where to output e.g. log files
|
1568 |
+
root_data_dir: directory under which to look for dataset
|
1569 |
+
default_flags: default flags
|
1570 |
+
**kwargs: arbitrary named arguments. This is needed to make the
|
1571 |
+
constructor forward compatible in case PerfZero provides more
|
1572 |
+
named arguments before updating the constructor.
|
1573 |
+
"""
|
1574 |
+
if default_flags is None:
|
1575 |
+
default_flags = {}
|
1576 |
+
default_flags['pruning_method'] = 'polynomial_decay'
|
1577 |
+
default_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
|
1578 |
+
|
1579 |
+
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
|
1580 |
+
|
1581 |
+
super(KerasPruningAccuracyBase, self).__init__(
|
1582 |
+
output_dir=output_dir,
|
1583 |
+
flag_methods=flag_methods,
|
1584 |
+
default_flags=default_flags,
|
1585 |
+
**kwargs)
|
1586 |
+
|
1587 |
+
def benchmark_8_gpu(self):
|
1588 |
+
"""Test Keras model with eager, dist_strat and 8 GPUs."""
|
1589 |
+
self._setup()
|
1590 |
+
FLAGS.num_gpus = 8
|
1591 |
+
FLAGS.batch_size = 32 * 8
|
1592 |
+
FLAGS.train_epochs = 90
|
1593 |
+
FLAGS.epochs_between_evals = 10
|
1594 |
+
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
|
1595 |
+
FLAGS.dtype = 'fp32'
|
1596 |
+
FLAGS.enable_eager = True
|
1597 |
+
self._run_and_report_benchmark()
|
1598 |
+
|
1599 |
+
@benchmark_wrappers.enable_runtime_flags
|
1600 |
+
def _run_and_report_benchmark(self,
|
1601 |
+
top_1_min=MODEL_OPTIMIZATION_TOP_1_ACCURACY[
|
1602 |
+
'RESNET50_FINETUNE_PRUNING'][0],
|
1603 |
+
top_1_max=MODEL_OPTIMIZATION_TOP_1_ACCURACY[
|
1604 |
+
'RESNET50_FINETUNE_PRUNING'][1]):
|
1605 |
+
start_time_sec = time.time()
|
1606 |
+
stats = resnet_imagenet_main.run(flags.FLAGS)
|
1607 |
+
wall_time_sec = time.time() - start_time_sec
|
1608 |
+
|
1609 |
+
super(KerasPruningAccuracyBase, self)._report_benchmark(
|
1610 |
+
stats,
|
1611 |
+
wall_time_sec,
|
1612 |
+
top_1_min=top_1_min,
|
1613 |
+
top_1_max=top_1_max,
|
1614 |
+
total_batch_size=FLAGS.batch_size,
|
1615 |
+
log_steps=100)
|
1616 |
+
|
1617 |
+
|
1618 |
+
class MobilenetV1KerasPruningAccuracy(KerasPruningAccuracyBase):
|
1619 |
+
"""Benchmark accuracy tests for MobilenetV1 with pruning method."""
|
1620 |
+
|
1621 |
+
def __init__(self, root_data_dir=None, **kwargs):
|
1622 |
+
default_flags = {
|
1623 |
+
'model': 'mobilenet',
|
1624 |
+
'optimizer': 'mobilenet_default',
|
1625 |
+
'initial_learning_rate_per_sample': 0.00007,
|
1626 |
+
'pretrained_filepath': tf.train.latest_checkpoint(
|
1627 |
+
os.path.join(root_data_dir, 'mobilenet_v1')),
|
1628 |
+
'pruning_begin_step': 0,
|
1629 |
+
'pruning_end_step': 100000,
|
1630 |
+
'pruning_initial_sparsity': 0.0,
|
1631 |
+
'pruning_final_sparsity': 0.5,
|
1632 |
+
'pruning_frequency': 100,
|
1633 |
+
}
|
1634 |
+
super(MobilenetV1KerasPruningAccuracy, self).__init__(
|
1635 |
+
root_data_dir=root_data_dir,
|
1636 |
+
default_flags=default_flags,
|
1637 |
+
**kwargs)
|
1638 |
+
|
1639 |
+
def _run_and_report_benchmark(self):
|
1640 |
+
super(MobilenetV1KerasPruningAccuracy, self)._run_and_report_benchmark(
|
1641 |
+
top_1_min=\
|
1642 |
+
MODEL_OPTIMIZATION_TOP_1_ACCURACY['MOBILENET_V1_FINETUNE_PRUNING'][0],
|
1643 |
+
top_1_max=\
|
1644 |
+
MODEL_OPTIMIZATION_TOP_1_ACCURACY['MOBILENET_V1_FINETUNE_PRUNING'][1])
|
1645 |
+
|
1646 |
+
|
1647 |
+
class Resnet50KerasPruningAccuracy(KerasPruningAccuracyBase):
|
1648 |
+
"""Benchmark accuracy tests for resnet50 with pruning method."""
|
1649 |
+
|
1650 |
+
def __init__(self, root_data_dir=None, **kwargs):
|
1651 |
+
default_flags = {
|
1652 |
+
'model': 'resnet50_v1.5',
|
1653 |
+
'optimizer': 'mobilenet_default',
|
1654 |
+
'initial_learning_rate_per_sample': 0.0000039,
|
1655 |
+
'pretrained_filepath': tf.train.latest_checkpoint(
|
1656 |
+
os.path.join(root_data_dir, 'resnet50')),
|
1657 |
+
'pruning_begin_step': 0,
|
1658 |
+
'pruning_end_step': 50000,
|
1659 |
+
'pruning_initial_sparsity': 0.0,
|
1660 |
+
'pruning_final_sparsity': 0.5,
|
1661 |
+
'pruning_frequency': 100,
|
1662 |
+
}
|
1663 |
+
super(Resnet50KerasPruningAccuracy, self).__init__(
|
1664 |
+
root_data_dir=root_data_dir,
|
1665 |
+
default_flags=default_flags,
|
1666 |
+
**kwargs)
|
1667 |
+
|
1668 |
+
def _run_and_report_benchmark(self):
|
1669 |
+
super(Resnet50KerasPruningAccuracy, self)._run_and_report_benchmark(
|
1670 |
+
top_1_min=\
|
1671 |
+
MODEL_OPTIMIZATION_TOP_1_ACCURACY['RESNET50_FINETUNE_PRUNING'][0],
|
1672 |
+
top_1_max=\
|
1673 |
+
MODEL_OPTIMIZATION_TOP_1_ACCURACY['RESNET50_FINETUNE_PRUNING'][1])
|
1674 |
+
|
1675 |
+
|
1676 |
+
class KerasPruningBenchmarkRealBase(Resnet50KerasBenchmarkBase):
|
1677 |
+
"""Pruning method benchmarks."""
|
1678 |
+
|
1679 |
+
def __init__(self, root_data_dir=None, default_flags=None, **kwargs):
|
1680 |
+
if default_flags is None:
|
1681 |
+
default_flags = {}
|
1682 |
+
default_flags.update({
|
1683 |
+
'skip_eval': True,
|
1684 |
+
'report_accuracy_metrics': False,
|
1685 |
+
'data_dir': os.path.join(root_data_dir, 'imagenet'),
|
1686 |
+
'train_steps': 110,
|
1687 |
+
'log_steps': 10,
|
1688 |
+
'pruning_method': 'polynomial_decay',
|
1689 |
+
'pruning_begin_step': 0,
|
1690 |
+
'pruning_end_step': 50000,
|
1691 |
+
'pruning_initial_sparsity': 0,
|
1692 |
+
'pruning_final_sparsity': 0.5,
|
1693 |
+
'pruning_frequency': 100,
|
1694 |
+
})
|
1695 |
+
super(KerasPruningBenchmarkRealBase, self).__init__(
|
1696 |
+
default_flags=default_flags, **kwargs)
|
1697 |
+
|
1698 |
+
|
1699 |
+
class MobilenetV1KerasPruningBenchmarkReal(KerasPruningBenchmarkRealBase):
|
1700 |
+
"""Pruning method benchmarks for MobilenetV1."""
|
1701 |
+
|
1702 |
+
def __init__(self, **kwargs):
|
1703 |
+
default_flags = {
|
1704 |
+
'model': 'mobilenet',
|
1705 |
+
'optimizer': 'mobilenet_default',
|
1706 |
+
}
|
1707 |
+
super(MobilenetV1KerasPruningBenchmarkReal, self).__init__(
|
1708 |
+
default_flags=default_flags, **kwargs)
|
1709 |
+
|
1710 |
+
|
1711 |
+
class Resnet50KerasPruningBenchmarkReal(KerasPruningBenchmarkRealBase):
|
1712 |
+
"""Pruning method benchmarks for resnet50."""
|
1713 |
+
|
1714 |
+
def __init__(self, **kwargs):
|
1715 |
+
default_flags = {
|
1716 |
+
'model': 'resnet50_v1.5',
|
1717 |
+
'optimizer': 'mobilenet_default',
|
1718 |
+
}
|
1719 |
+
super(Resnet50KerasPruningBenchmarkReal, self).__init__(
|
1720 |
+
default_flags=default_flags, **kwargs)
|
1721 |
+
|
1722 |
+
|
1723 |
+
if __name__ == '__main__':
|
1724 |
+
tf.test.main()
|
models/official/benchmark/models/__init__.py
ADDED
File without changes
|
models/official/benchmark/models/cifar_preprocessing.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Provides utilities to Cifar-10 dataset."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import os
|
22 |
+
from absl import logging
|
23 |
+
import tensorflow as tf
|
24 |
+
|
25 |
+
from official.vision.image_classification.resnet import imagenet_preprocessing
|
26 |
+
|
27 |
+
HEIGHT = 32
|
28 |
+
WIDTH = 32
|
29 |
+
NUM_CHANNELS = 3
|
30 |
+
_DEFAULT_IMAGE_BYTES = HEIGHT * WIDTH * NUM_CHANNELS
|
31 |
+
# The record is the image plus a one-byte label
|
32 |
+
_RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
|
33 |
+
|
34 |
+
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
|
35 |
+
NUM_IMAGES = {
|
36 |
+
'train': 50000,
|
37 |
+
'validation': 10000,
|
38 |
+
}
|
39 |
+
_NUM_DATA_FILES = 5
|
40 |
+
NUM_CLASSES = 10
|
41 |
+
|
42 |
+
|
43 |
+
def parse_record(raw_record, is_training, dtype):
|
44 |
+
"""Parses a record containing a training example of an image.
|
45 |
+
|
46 |
+
The input record is parsed into a label and image, and the image is passed
|
47 |
+
through preprocessing steps (cropping, flipping, and so on).
|
48 |
+
|
49 |
+
This method converts the label to one hot to fit the loss function.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
raw_record: scalar Tensor tf.string containing a serialized
|
53 |
+
Example protocol buffer.
|
54 |
+
is_training: A boolean denoting whether the input is for training.
|
55 |
+
dtype: Data type to use for input images.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Tuple with processed image tensor and one-hot-encoded label tensor.
|
59 |
+
"""
|
60 |
+
# Convert bytes to a vector of uint8 that is record_bytes long.
|
61 |
+
record_vector = tf.io.decode_raw(raw_record, tf.uint8)
|
62 |
+
|
63 |
+
# The first byte represents the label, which we convert from uint8 to int32
|
64 |
+
# and then to one-hot.
|
65 |
+
label = tf.cast(record_vector[0], tf.int32)
|
66 |
+
|
67 |
+
# The remaining bytes after the label represent the image, which we reshape
|
68 |
+
# from [depth * height * width] to [depth, height, width].
|
69 |
+
depth_major = tf.reshape(record_vector[1:_RECORD_BYTES],
|
70 |
+
[NUM_CHANNELS, HEIGHT, WIDTH])
|
71 |
+
|
72 |
+
# Convert from [depth, height, width] to [height, width, depth], and cast as
|
73 |
+
# float32.
|
74 |
+
image = tf.cast(tf.transpose(a=depth_major, perm=[1, 2, 0]), tf.float32)
|
75 |
+
|
76 |
+
image = preprocess_image(image, is_training)
|
77 |
+
image = tf.cast(image, dtype)
|
78 |
+
|
79 |
+
return image, label
|
80 |
+
|
81 |
+
|
82 |
+
def preprocess_image(image, is_training):
|
83 |
+
"""Preprocess a single image of layout [height, width, depth]."""
|
84 |
+
if is_training:
|
85 |
+
# Resize the image to add four extra pixels on each side.
|
86 |
+
image = tf.image.resize_with_crop_or_pad(
|
87 |
+
image, HEIGHT + 8, WIDTH + 8)
|
88 |
+
|
89 |
+
# Randomly crop a [HEIGHT, WIDTH] section of the image.
|
90 |
+
image = tf.image.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS])
|
91 |
+
|
92 |
+
# Randomly flip the image horizontally.
|
93 |
+
image = tf.image.random_flip_left_right(image)
|
94 |
+
|
95 |
+
# Subtract off the mean and divide by the variance of the pixels.
|
96 |
+
image = tf.image.per_image_standardization(image)
|
97 |
+
return image
|
98 |
+
|
99 |
+
|
100 |
+
def get_filenames(is_training, data_dir):
|
101 |
+
"""Returns a list of filenames."""
|
102 |
+
assert tf.io.gfile.exists(data_dir), (
|
103 |
+
'Run cifar10_download_and_extract.py first to download and extract the '
|
104 |
+
'CIFAR-10 data.')
|
105 |
+
|
106 |
+
if is_training:
|
107 |
+
return [
|
108 |
+
os.path.join(data_dir, 'data_batch_%d.bin' % i)
|
109 |
+
for i in range(1, _NUM_DATA_FILES + 1)
|
110 |
+
]
|
111 |
+
else:
|
112 |
+
return [os.path.join(data_dir, 'test_batch.bin')]
|
113 |
+
|
114 |
+
|
115 |
+
def input_fn(is_training,
|
116 |
+
data_dir,
|
117 |
+
batch_size,
|
118 |
+
dtype=tf.float32,
|
119 |
+
datasets_num_private_threads=None,
|
120 |
+
parse_record_fn=parse_record,
|
121 |
+
input_context=None,
|
122 |
+
drop_remainder=False):
|
123 |
+
"""Input function which provides batches for train or eval.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
is_training: A boolean denoting whether the input is for training.
|
127 |
+
data_dir: The directory containing the input data.
|
128 |
+
batch_size: The number of samples per batch.
|
129 |
+
dtype: Data type to use for images/features
|
130 |
+
datasets_num_private_threads: Number of private threads for tf.data.
|
131 |
+
parse_record_fn: Function to use for parsing the records.
|
132 |
+
input_context: A `tf.distribute.InputContext` object passed in by
|
133 |
+
`tf.distribute.Strategy`.
|
134 |
+
drop_remainder: A boolean indicates whether to drop the remainder of the
|
135 |
+
batches. If True, the batch dimension will be static.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
A dataset that can be used for iteration.
|
139 |
+
"""
|
140 |
+
filenames = get_filenames(is_training, data_dir)
|
141 |
+
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
|
142 |
+
|
143 |
+
if input_context:
|
144 |
+
logging.info(
|
145 |
+
'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
|
146 |
+
input_context.input_pipeline_id, input_context.num_input_pipelines)
|
147 |
+
dataset = dataset.shard(input_context.num_input_pipelines,
|
148 |
+
input_context.input_pipeline_id)
|
149 |
+
|
150 |
+
return imagenet_preprocessing.process_record_dataset(
|
151 |
+
dataset=dataset,
|
152 |
+
is_training=is_training,
|
153 |
+
batch_size=batch_size,
|
154 |
+
shuffle_buffer=NUM_IMAGES['train'],
|
155 |
+
parse_record_fn=parse_record_fn,
|
156 |
+
dtype=dtype,
|
157 |
+
datasets_num_private_threads=datasets_num_private_threads,
|
158 |
+
drop_remainder=drop_remainder
|
159 |
+
)
|