NCTCMumbai commited on
Commit
18ddfe2
·
verified ·
1 Parent(s): ba6b919

Upload 2583 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. CTH_CODE_MAP.csv +0 -0
  3. CTH_Description.csv +0 -0
  4. CTH_WISE_DUTY_RATE.csv +0 -0
  5. Checkpoint/assets/vocab.txt +0 -0
  6. Checkpoint/keras_metadata.pb +3 -0
  7. Checkpoint/saved_model.pb +3 -0
  8. Checkpoint/variables/variables.data-00000-of-00001 +3 -0
  9. Checkpoint/variables/variables.index +0 -0
  10. ETC/fun_advaitbert.py +339 -0
  11. app.py +91 -0
  12. fun_advaitbert.py +344 -0
  13. models/.github/ISSUE_TEMPLATE/00-official-bug-report-issue.md +59 -0
  14. models/.github/ISSUE_TEMPLATE/10-official-documentation-issue.md +20 -0
  15. models/.github/ISSUE_TEMPLATE/20-official-feature-request-issue.md +26 -0
  16. models/.github/ISSUE_TEMPLATE/30-research-bug-report-issue.md +58 -0
  17. models/.github/ISSUE_TEMPLATE/40-research-documentation-issue.md +20 -0
  18. models/.github/ISSUE_TEMPLATE/50-research-feature-request-issue.md +26 -0
  19. models/.github/ISSUE_TEMPLATE/60-questions-help-issue.md +14 -0
  20. models/.github/ISSUE_TEMPLATE/config.yml +1 -0
  21. models/.github/PULL_REQUEST_TEMPLATE.md +41 -0
  22. models/.github/README_TEMPLATE.md +122 -0
  23. models/.gitignore +98 -0
  24. models/AUTHORS +10 -0
  25. models/CODEOWNERS +61 -0
  26. models/CONTRIBUTING.md +10 -0
  27. models/ISSUES.md +24 -0
  28. models/LICENSE +203 -0
  29. models/README.md +39 -0
  30. models/official/LICENSE +203 -0
  31. models/official/README-TPU.md +25 -0
  32. models/official/README.md +142 -0
  33. models/official/__init__.py +0 -0
  34. models/official/__pycache__/__init__.cpython-310.pyc +0 -0
  35. models/official/__pycache__/__init__.cpython-38.pyc +0 -0
  36. models/official/__pycache__/__init__.cpython-39.pyc +0 -0
  37. models/official/benchmark/__init__.py +0 -0
  38. models/official/benchmark/benchmark_wrappers.py +97 -0
  39. models/official/benchmark/bert_benchmark.py +365 -0
  40. models/official/benchmark/bert_benchmark_utils.py +127 -0
  41. models/official/benchmark/bert_pretrain_benchmark.py +179 -0
  42. models/official/benchmark/bert_squad_benchmark.py +608 -0
  43. models/official/benchmark/datastore/schema/benchmark_metric.json +56 -0
  44. models/official/benchmark/datastore/schema/benchmark_run.json +368 -0
  45. models/official/benchmark/datastore/schema/benchmark_run_status.json +14 -0
  46. models/official/benchmark/keras_benchmark.py +98 -0
  47. models/official/benchmark/keras_cifar_benchmark.py +402 -0
  48. models/official/benchmark/keras_imagenet_benchmark.py +1724 -0
  49. models/official/benchmark/models/__init__.py +0 -0
  50. models/official/benchmark/models/cifar_preprocessing.py +159 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Checkpoint/variables/variables.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
37
+ models/research/compression/image_encoder/example.png filter=lfs diff=lfs merge=lfs -text
38
+ models/research/deeplab/testing/pascal_voc_seg/val-00000-of-00001.tfrecord filter=lfs diff=lfs merge=lfs -text
39
+ models/research/lfads/synth_data/trained_itb/model-65000.meta filter=lfs diff=lfs merge=lfs -text
40
+ models/research/object_detection/g3doc/img/kites_with_segment_overlay.png filter=lfs diff=lfs merge=lfs -text
41
+ models/research/object_detection/test_images/image2.jpg filter=lfs diff=lfs merge=lfs -text
CTH_CODE_MAP.csv ADDED
The diff for this file is too large to render. See raw diff
 
CTH_Description.csv ADDED
The diff for this file is too large to render. See raw diff
 
CTH_WISE_DUTY_RATE.csv ADDED
The diff for this file is too large to render. See raw diff
 
Checkpoint/assets/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
Checkpoint/keras_metadata.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38bb0f1231a198848366566e176c8948dceab7b085b658d00550a83f784731f5
3
+ size 11535
Checkpoint/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0054848283b4fb79fefcebe71830bedb75e023ad04c5655adbc6a2ddd1e2c60
3
+ size 11477628
Checkpoint/variables/variables.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4ec6776ca3161577663eaa115fb9f965304670a1af8db7a37e9499a23082e67
3
+ size 1389095096
Checkpoint/variables/variables.index ADDED
Binary file (46.6 kB). View file
 
ETC/fun_advaitbert.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import tensorflow_hub as hub
5
+ import sys
6
+ import random
7
+ sys.path.append('models')
8
+ from official.nlp.data import classifier_data_lib
9
+ from official.nlp.bert import tokenization
10
+ from official.nlp import optimization
11
+ tf.get_logger().setLevel('ERROR')
12
+ from huggingface_hub import InferenceClient
13
+ import math
14
+ import gradio as gr
15
+
16
+ num_warmup_steps=1
17
+ num_train_steps=1
18
+ init_lr = 3e-5
19
+ optimizer = optimization.create_optimizer(init_lr=init_lr,num_train_steps=num_train_steps,num_warmup_steps=num_warmup_steps,optimizer_type='adamw')
20
+
21
+ ### Load Model
22
+ checkpoint_filepath=r'./Checkpoint'
23
+ model = tf.keras.models.load_model(checkpoint_filepath, custom_objects={'KerasLayer':hub.KerasLayer , 'AdamWeightDecay': optimizer})
24
+
25
+ df_report = pd.read_csv('./CTH_Description.csv')
26
+ df_report['CTH Code'] = df_report['CTH Code'].astype(str).str.zfill(8)
27
+
28
+ df_report_DUTY = pd.read_csv('./CTH_WISE_DUTY_RATE.csv')
29
+ df_report_DUTY['CTH'] = df_report_DUTY['CTH'].astype(str).str.zfill(8)
30
+
31
+ df = pd.read_csv("./CTH_CODE_MAP.csv")
32
+ df['CTH'] = df['CTH'].astype(str).str.zfill(8)
33
+ df = df[['CTH', 'code']]
34
+
35
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
36
+
37
+
38
+
39
+ class_names=df[['CTH','code']].drop_duplicates(subset='CTH').sort_values(by='code',ignore_index=True)['CTH'].values.tolist()
40
+ label_list=list(range(0,len(class_names)))
41
+ max_seq_length = 200 # maximum length of (token) input sequences . it can be any number
42
+ train_batch_size = 32 # batch size ( 16 choosen to avoid Out-Of-Memory errors)
43
+
44
+ # Get BERT layer and tokenizer:
45
+ # More details here: https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4
46
+ bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4" , trainable = True)
47
+ vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
48
+ do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
49
+ tokenizer = tokenization.FullTokenizer(vocab_file , do_lower_case)
50
+
51
+ # This provides a function to convert each row to input features and label ( as required by BERT)
52
+
53
+ max_seq_length = 200 # maximum length of (token) input sequences . it can be any number
54
+ def to_feature(text, label, label_list=label_list, max_seq_length=max_seq_length, tokenizer=tokenizer):
55
+ example = classifier_data_lib.InputExample(guid = None,
56
+ text_a = text.numpy(),
57
+ text_b = None,
58
+ label = label.numpy())
59
+ feature = classifier_data_lib.convert_single_example(0 , example , label_list , max_seq_length , tokenizer)
60
+
61
+ return (feature.input_ids , feature.input_mask , feature.segment_ids , feature.label_id)
62
+
63
+
64
+ def to_feature_map(text, label):
65
+ input_ids , input_mask , segment_ids , label_id = tf.py_function(to_feature , inp = [text , label],
66
+ Tout = [tf.int32 , tf.int32 , tf.int32 , tf.int32])
67
+
68
+ input_ids.set_shape([max_seq_length])
69
+ input_mask.set_shape([max_seq_length])
70
+ segment_ids.set_shape([max_seq_length])
71
+ label_id.set_shape([])
72
+
73
+ x = {
74
+ "input_word_ids": input_ids,
75
+ "input_mask": input_mask,
76
+ "input_type_ids": segment_ids
77
+ }
78
+
79
+ return(x,label_id)
80
+
81
+ def print3largest(arr, arr_size):
82
+ third = first = second = -sys.maxsize
83
+ for i in range(0, arr_size):
84
+
85
+ if (arr[i] > first):
86
+ third = second
87
+ second = first
88
+ first = arr[i]
89
+ elif (arr[i] > second):
90
+ third = second
91
+ second = arr[i]
92
+ elif (arr[i] > third):
93
+ third = arr[i]
94
+ pred_value_max_three=[first, second, third]
95
+ return pred_value_max_three
96
+
97
+ def count_special_character(string):
98
+ special_char= 0
99
+ for i in range(len(string)):
100
+ ch = string[i]
101
+ if (string[i].isalpha()):
102
+ continue
103
+ else:
104
+ special_char += 1
105
+
106
+ if len(string)==special_char:
107
+ return False
108
+ else:
109
+ return True
110
+
111
+ def format_prompt(message, history):
112
+ prompt = "<s>"
113
+ for user_prompt, bot_response in history:
114
+ prompt += f"[INST] {user_prompt} [/INST]"
115
+ prompt += f" {bot_response}</s> "
116
+ prompt += f"[INST] {message} [/INST]"
117
+ return prompt
118
+
119
+
120
+ additional_inputs=[
121
+ gr.Textbox(
122
+ label="System Prompt",
123
+ max_lines=1,
124
+ interactive=True,
125
+ ),
126
+ gr.Slider(
127
+ label="Temperature",
128
+ value=0.9,
129
+ minimum=0.0,
130
+ maximum=1.0,
131
+ step=0.05,
132
+ interactive=True,
133
+ info="Higher values produce more diverse outputs",
134
+ ),
135
+ gr.Slider(
136
+ label="Max new tokens",
137
+ value=1024,
138
+ minimum=0,
139
+ maximum=4096,
140
+ step=64,
141
+ interactive=True,
142
+ info="The maximum numbers of new tokens",
143
+ ),
144
+ gr.Slider(
145
+ label="Top-p (nucleus sampling)",
146
+ value=0.90,
147
+ minimum=0.0,
148
+ maximum=1,
149
+ step=0.05,
150
+ interactive=True,
151
+ info="Higher values sample more low-probability tokens",
152
+ ),
153
+ gr.Slider(
154
+ label="Repetition penalty",
155
+ value=1.2,
156
+ minimum=1.0,
157
+ maximum=2.0,
158
+ step=0.05,
159
+ interactive=True,
160
+ info="Penalize repeated tokens",
161
+ )
162
+ ]
163
+
164
+ def predict_CTH(txt):
165
+ print('Desc: ',txt)
166
+ if (txt!='') and len(txt)>=3 and (count_special_character(txt)):
167
+ valid_data = tf.data.Dataset.from_tensor_slices(([txt] , [1])) # 1 refers to 'entertainment' and 2 refers to 'sport'
168
+ valid_data = (valid_data.map(to_feature_map).batch(1))
169
+ preds = model.predict(valid_data)
170
+ predicted_values = tf.nn.softmax(preds)
171
+ arr = predicted_values.numpy().tolist()[0]
172
+ n = len(arr)
173
+ pred_value_max_three=print3largest(arr, n)
174
+
175
+ sum_all = pred_value_max_three[0] + pred_value_max_three[1] + pred_value_max_three[2]
176
+
177
+ val_1 = pred_value_max_three[0]/sum_all
178
+ val_2 = pred_value_max_three[1]/sum_all
179
+ val_3 = pred_value_max_three[2]/sum_all
180
+
181
+ if pred_value_max_three[0]<=0.000131:
182
+ Var_CTH=[]
183
+ Var_desc=[]
184
+ Var_duty=[]
185
+ pred_duty=''
186
+ pred_desc=''
187
+ pred_CTH=''
188
+
189
+ return{'Not a adequate description':float(1.0)}
190
+ else:
191
+ Var_CTH=[]
192
+ Var_desc=[]
193
+ Var_duty=[]
194
+ pred_duty=''
195
+ pred_desc=''
196
+ pred_CTH=''
197
+
198
+ for i in pred_value_max_three:
199
+ predicted_code=np.where(predicted_values.numpy()==i)[1][0]
200
+ pred_CTH=df[df['code'] == predicted_code]['CTH'].iloc[0]
201
+
202
+ try:
203
+ pred_duty=df_report_DUTY[df_report_DUTY['CTH']==str(pred_CTH)]['DUTY_RATE'].iloc[0]
204
+ pred_desc=df_report[df_report['CTH Code']==str(pred_CTH)]['Concat Description'].iloc[0]
205
+ except:
206
+ pass
207
+
208
+ Var_CTH.append(pred_CTH)
209
+ Var_desc.append(pred_desc)
210
+ Var_duty.append(pred_duty)
211
+
212
+ P1 ='CTH: '+str(Var_CTH[0])+' Duty Rate(%): '+ str(Var_duty[0])
213
+ P2 ='CTH: '+str(Var_CTH[1])+' Duty Rate(%): '+ str(Var_duty[1])
214
+ P3 ='CTH: '+str(Var_CTH[2])+' Duty Rate(%): '+ str(Var_duty[2])
215
+
216
+ Q1='Desc: '+str(Var_desc[0])
217
+ Q2='Desc: '+str(Var_desc[1])
218
+ Q3='Desc: '+str(Var_desc[2])
219
+
220
+ return {str(P1):float(val_1),str(Q1):float(val_1),
221
+ str(P2):float(val_2),str(Q2):float(val_2),
222
+ str(P3):float(val_3),str(Q3):float(val_3),}
223
+ else:
224
+ return{'Enter Correct Description':float(1.0)}
225
+
226
+ def llm_model_function(txt,history,chatbot=[], temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,):
227
+ system_prompt=[]
228
+ if (txt!='') and len(txt)>=3 and (count_special_character(txt)):
229
+ valid_data = tf.data.Dataset.from_tensor_slices(([txt] , [1])) # 1 refers to 'entertainment' and 2 refers to 'sport'
230
+ valid_data = (valid_data.map(to_feature_map).batch(1))
231
+ preds = model.predict(valid_data)
232
+ predicted_values = tf.nn.softmax(preds)
233
+ arr = predicted_values.numpy().tolist()[0]
234
+ n = len(arr)
235
+ pred_value_max_three=print3largest(arr, n)
236
+
237
+ sum_all = pred_value_max_three[0] + pred_value_max_three[1] + pred_value_max_three[2]
238
+
239
+ val_1 = pred_value_max_three[0]/sum_all
240
+ val_2 = pred_value_max_three[1]/sum_all
241
+ val_3 = pred_value_max_three[2]/sum_all
242
+
243
+ if pred_value_max_three[0]<=0.000131:
244
+ Var_CTH=[]
245
+ Var_desc=[]
246
+ Var_duty=[]
247
+ pred_duty=''
248
+ pred_desc=''
249
+ pred_CTH=''
250
+
251
+ return{'Not a adequate description':float(1.0)}
252
+ else:
253
+ Var_CTH=[]
254
+ Var_desc=[]
255
+ Var_duty=[]
256
+ pred_duty=''
257
+ pred_desc=''
258
+ pred_CTH=''
259
+
260
+ for i in pred_value_max_three:
261
+ predicted_code=np.where(predicted_values.numpy()==i)[1][0]
262
+ pred_CTH=df[df['code'] == predicted_code]['CTH'].iloc[0]
263
+
264
+ try:
265
+ pred_duty=df_report_DUTY[df_report_DUTY['CTH']==str(pred_CTH)]['DUTY_RATE'].iloc[0]
266
+ pred_desc=df_report[df_report['CTH Code']==str(pred_CTH)]['Concat Description'].iloc[0]
267
+ except:
268
+ pass
269
+
270
+ Var_CTH.append(pred_CTH)
271
+ Var_desc.append(pred_desc)
272
+ Var_duty.append(pred_duty)
273
+
274
+ P1 ='CTH: '+str(Var_CTH[0])+' Duty Rate(%): '+ str(Var_duty[0])
275
+ P2 ='CTH: '+str(Var_CTH[1])+' Duty Rate(%): '+ str(Var_duty[1])
276
+ P3 ='CTH: '+str(Var_CTH[2])+' Duty Rate(%): '+ str(Var_duty[2])
277
+
278
+ Q1='Desc: '+str(Var_desc[0])
279
+ Q2='Desc: '+str(Var_desc[1])
280
+ Q3='Desc: '+str(Var_desc[2])
281
+
282
+ output_str_msg='1. '+str(P1)+' '+str(Q1)+' '+'2. '+str(P2)+' '+str(Q2)+' '+'3. '+str(P3)+' '+str(Q3)
283
+
284
+ prompt=f'First Explain What is the product- {txt}. Which is the most appropriate 8 Digit classification code out of the three given below classes. Explain the reason step by step. if none of the three classification is applicable more precisely due to lack of any additional information, tell you need additional information and what is the that additional information. {output_str_msg} ?'
285
+
286
+ temperature = float(temperature)
287
+ if temperature < 1e-2:
288
+ temperature = 1e-2
289
+ top_p = float(top_p)
290
+
291
+ generate_kwargs = dict(
292
+ temperature=temperature,
293
+ max_new_tokens=max_new_tokens,
294
+ top_p=top_p,
295
+ repetition_penalty=repetition_penalty,
296
+ do_sample=True,
297
+ seed=42,
298
+ )
299
+
300
+ formatted_prompt = format_prompt(f", {prompt}", history)
301
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
302
+ output = ""
303
+ for response in stream:
304
+ output += response.token.text
305
+
306
+ chatbot.append((txt, output))
307
+ return "", chatbot
308
+ else:
309
+ warning_msg = f"Unexpected response"
310
+ raise gr.Error(warning_msg)
311
+
312
+ def product_explaination(txt,history,chatbot=[], temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,):
313
+ print('Input Descrption is:',txt)
314
+ prompt=f'What is the product- {txt}?'
315
+ print('prompt',prompt)
316
+ temperature = float(temperature)
317
+ if temperature < 1e-2:
318
+ temperature = 1e-2
319
+ top_p = float(top_p)
320
+
321
+ generate_kwargs = dict(
322
+ temperature=temperature,
323
+ max_new_tokens=max_new_tokens,
324
+ top_p=top_p,
325
+ repetition_penalty=repetition_penalty,
326
+ do_sample=True,
327
+ seed=42,
328
+ )
329
+
330
+ formatted_prompt = format_prompt(f", {prompt}", history)
331
+
332
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
333
+ output = ""
334
+
335
+ for response in stream:
336
+ output += response.token.text
337
+
338
+ chatbot.append((txt, output))
339
+ return "", chatbot
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from fun_advaitbert import predict_CTH
4
+ from fun_advaitbert import llm_model_function
5
+ from fun_advaitbert import product_explaination
6
+
7
+ title="<h1 style='color:green;text-align:center;font-size:2vw;'>AdvaitBERT:HS Code AI Explanability Through Mixtral 46.7B </a></h1>"
8
+ description = """
9
+ AdvaitBERT is modified version of BERT (Bidirectional Encoder Representation for Transformers), \
10
+ finetuned on the Text corpus of Indian Customs Declarations. It is trained for performing \
11
+ downstream tasks like automating the tariff classification and validation process of Customs \
12
+ declarations in realtime. This model may help Customs administration to efficiently use AI assisted \
13
+ NLP in realtime Customs process like Assessment, Post Clearance Audit, thereby highlighting classification \
14
+ inconsistencies and help in revenue augmentation.
15
+ """
16
+
17
+ article="<p style='color:black;text-align:right;font-size:1vw;'>Powered by NCTC </a></p>"
18
+
19
+
20
+ css = """
21
+ .gradio-container {
22
+ width: 100vw !important;
23
+ min-height: 100vh !important;
24
+ padding:0 !important;
25
+ margin:0 !important;
26
+ max-width: none !important;
27
+ }
28
+ """
29
+
30
+ footnote = """Note: All rights, including licensing and acceptable use policies, related to the AI models, can be found on their respective model pages on Hugging Face. Powered by NCTC
31
+ """
32
+
33
+ #Powered by NCTC
34
+
35
+ # input_txt=gr.Textbox(label='Enter Your Product Descrption',lines=3,)
36
+ # textbox = gr.Textbox(container=False,placeholder='Enter text and click the Submit button or press Enter')
37
+
38
+ textbox = gr.Textbox(label='Enter Your Product Descrption',lines=3,)
39
+ textbox_2=textbox
40
+
41
+ print('textbox',textbox)
42
+ print('textbox_2',textbox_2)
43
+
44
+ chat_prod = gr.Chatbot(label="Product Explanation", layout='panel') #height=300
45
+ #chat_Advait = gr.Chatbot(label="Advaitbert Prediction", layout='panel')
46
+ chat_alpha = gr.Chatbot(label="AI Explanability", layout='panel')
47
+ chat_Advait=gr.Interface(predict_CTH,inputs=textbox,outputs="label",)
48
+
49
+
50
+ submit = gr.Button('Submit', variant='primary',)
51
+ submit_second = gr.Button('Submit', variant='secondary',)
52
+ #submit2 = gr.Button('Submit', variant='primary',)
53
+ retry = gr.Button('🔄Retry', variant='secondary')
54
+ undo = gr.Button('↩️Undo', variant='secondary')
55
+
56
+ with gr.Blocks(css=css) as demo:
57
+ gr.HTML(f'<h1><center> {title} </center></h1>')
58
+ gr.Markdown(description)
59
+
60
+ with gr.Row():
61
+ with gr.Column(scale=0,min_width=600):
62
+ chat_Advait.render()
63
+
64
+ with gr.Column(scale=1,min_width=600):
65
+ chat_alpha.render()
66
+ with gr.Row(equal_height=True):
67
+ with gr.Column(scale=1):
68
+ submit.render()
69
+ with gr.Column(scale=1):
70
+ undo.render()
71
+ with gr.Column(scale=1):
72
+ clear = gr.ClearButton(value='🗑️Clear',components=[chat_alpha,chat_prod,textbox])
73
+ chat_prod.render()
74
+ #submit_second.render()
75
+
76
+ gr.Markdown(footnote)
77
+ textbox.submit(llm_model_function, [textbox, chat_alpha], [textbox, chat_alpha])
78
+ textbox_2.submit(product_explaination, [textbox_2, chat_prod], [textbox_2, chat_prod])
79
+
80
+ submit.click(llm_model_function,[textbox, chat_alpha], [textbox, chat_alpha])
81
+ submit.click(product_explaination,[textbox_2, chat_prod], [textbox_2, chat_prod])
82
+
83
+ undo.click(lambda x:x[:-1], [chat_alpha], [chat_alpha])
84
+ undo.click(lambda x:x[:-1], [chat_prod], [chat_prod])
85
+
86
+ gr.Examples([
87
+ ['200 SI/SI/SI LPO ALUMINIUM LIDS (QTY: 8820000 PCS/PRICE: 21.'],
88
+ ],
89
+ textbox)
90
+
91
+ demo.launch(debug=True)
fun_advaitbert.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import tensorflow_hub as hub
5
+ import sys
6
+ import random
7
+ sys.path.append('models')
8
+ from official.nlp.data import classifier_data_lib
9
+ from official.nlp.bert import tokenization
10
+ from official.nlp import optimization
11
+ tf.get_logger().setLevel('ERROR')
12
+ from huggingface_hub import InferenceClient
13
+ import math
14
+ import gradio as gr
15
+
16
+ num_warmup_steps=1
17
+ num_train_steps=1
18
+ init_lr = 3e-5
19
+ optimizer = optimization.create_optimizer(init_lr=init_lr,num_train_steps=num_train_steps,num_warmup_steps=num_warmup_steps,optimizer_type='adamw')
20
+
21
+ ### Load Model
22
+ checkpoint_filepath=r'./Checkpoint'
23
+ model = tf.keras.models.load_model(checkpoint_filepath, custom_objects={'KerasLayer':hub.KerasLayer , 'AdamWeightDecay': optimizer})
24
+
25
+ df_report = pd.read_csv('./CTH_Description.csv')
26
+ df_report['CTH Code'] = df_report['CTH Code'].astype(str).str.zfill(8)
27
+
28
+ df_report_DUTY = pd.read_csv('./CTH_WISE_DUTY_RATE.csv')
29
+ df_report_DUTY['CTH'] = df_report_DUTY['CTH'].astype(str).str.zfill(8)
30
+
31
+ df = pd.read_csv("./CTH_CODE_MAP.csv")
32
+ df['CTH'] = df['CTH'].astype(str).str.zfill(8)
33
+ df = df[['CTH', 'code']]
34
+
35
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
36
+
37
+
38
+ class_names=df[['CTH','code']].drop_duplicates(subset='CTH').sort_values(by='code',ignore_index=True)['CTH'].values.tolist()
39
+ label_list=list(range(0,len(class_names)))
40
+ max_seq_length = 200 # maximum length of (token) input sequences . it can be any number
41
+ train_batch_size = 32 # batch size ( 16 choosen to avoid Out-Of-Memory errors)
42
+
43
+ # Get BERT layer and tokenizer:
44
+ # More details here: https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4
45
+ bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4" , trainable = True)
46
+ vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
47
+ do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
48
+ tokenizer = tokenization.FullTokenizer(vocab_file , do_lower_case)
49
+
50
+ # This provides a function to convert each row to input features and label ( as required by BERT)
51
+
52
+ max_seq_length = 200 # maximum length of (token) input sequences . it can be any number
53
+ def to_feature(text, label, label_list=label_list, max_seq_length=max_seq_length, tokenizer=tokenizer):
54
+ example = classifier_data_lib.InputExample(guid = None,
55
+ text_a = text.numpy(),
56
+ text_b = None,
57
+ label = label.numpy())
58
+ feature = classifier_data_lib.convert_single_example(0 , example , label_list , max_seq_length , tokenizer)
59
+
60
+ return (feature.input_ids , feature.input_mask , feature.segment_ids , feature.label_id)
61
+
62
+
63
+ def to_feature_map(text, label):
64
+ input_ids , input_mask , segment_ids , label_id = tf.py_function(to_feature , inp = [text , label],
65
+ Tout = [tf.int32 , tf.int32 , tf.int32 , tf.int32])
66
+
67
+ input_ids.set_shape([max_seq_length])
68
+ input_mask.set_shape([max_seq_length])
69
+ segment_ids.set_shape([max_seq_length])
70
+ label_id.set_shape([])
71
+
72
+ x = {
73
+ "input_word_ids": input_ids,
74
+ "input_mask": input_mask,
75
+ "input_type_ids": segment_ids
76
+ }
77
+
78
+ return(x,label_id)
79
+
80
+
81
+ def find_max_10_with_position(arr, arr_size):
82
+ max_values_with_position = [(-sys.maxsize, -1)] * 10
83
+
84
+ for i in range(arr_size):
85
+ for j in range(5):
86
+ value, position = max_values_with_position[j]
87
+ if arr[i] > value:
88
+ max_values_with_position[j+1:] = max_values_with_position[j:9]
89
+ max_values_with_position[j] = (arr[i], i)
90
+ break
91
+
92
+ return max_values_with_position
93
+
94
+ def count_special_character(string):
95
+ special_char= 0
96
+ for i in range(len(string)):
97
+ ch = string[i]
98
+ if (string[i].isalpha()):
99
+ continue
100
+ else:
101
+ special_char += 1
102
+
103
+ if len(string)==special_char:
104
+ return False
105
+ else:
106
+ return True
107
+
108
+ def format_prompt(message, history):
109
+ prompt = "<s>"
110
+ for user_prompt, bot_response in history:
111
+ prompt += f"[INST] {user_prompt} [/INST]"
112
+ prompt += f" {bot_response}</s> "
113
+ prompt += f"[INST] {message} [/INST]"
114
+ return prompt
115
+
116
+
117
+ additional_inputs=[
118
+ gr.Textbox(
119
+ label="System Prompt",
120
+ max_lines=1,
121
+ interactive=True,
122
+ ),
123
+ gr.Slider(
124
+ label="Temperature",
125
+ value=0.5,
126
+ minimum=0.0,
127
+ maximum=1.0,
128
+ step=0.05,
129
+ interactive=True,
130
+ info="Higher values produce more diverse outputs",
131
+ ),
132
+ gr.Slider(
133
+ label="Max new tokens",
134
+ value=1024,
135
+ minimum=0,
136
+ maximum=4096,
137
+ step=64,
138
+ interactive=True,
139
+ info="The maximum numbers of new tokens",
140
+ ),
141
+ gr.Slider(
142
+ label="Top-p (nucleus sampling)",
143
+ value=0.90,
144
+ minimum=0.0,
145
+ maximum=1,
146
+ step=0.05,
147
+ interactive=True,
148
+ info="Higher values sample more low-probability tokens",
149
+ ),
150
+ gr.Slider(
151
+ label="Repetition penalty",
152
+ value=1.2,
153
+ minimum=1.0,
154
+ maximum=2.0,
155
+ step=0.05,
156
+ interactive=True,
157
+ info="Penalize repeated tokens",
158
+ )
159
+ ]
160
+
161
+ def predict_CTH(txt):
162
+ print('Desc: ',txt)
163
+ global output_str_msg
164
+ if (txt!='') and len(txt)>=3 and (count_special_character(txt)):
165
+ valid_data = tf.data.Dataset.from_tensor_slices(([txt] , [1])) # 1 refers to 'entertainment' and 2 refers to 'sport'
166
+ valid_data = (valid_data.map(to_feature_map).batch(1))
167
+ preds = model.predict(valid_data)
168
+ predicted_values = tf.nn.softmax(preds)
169
+ arr = predicted_values.numpy().tolist()[0]
170
+ n = len(arr)
171
+
172
+ pred_value_max=find_max_10_with_position(arr, n)
173
+
174
+ sum_all = 0
175
+ for i in range(10):
176
+ sum_all += pred_value_max[i][0]
177
+
178
+
179
+ val_1 = pred_value_max[0][0]/sum_all
180
+ val_2 = pred_value_max[1][0]/sum_all
181
+ val_3 = pred_value_max[2][0]/sum_all
182
+ val_4 = pred_value_max[3][0]/sum_all
183
+ val_5 = pred_value_max[4][0]/sum_all
184
+ val_6 = pred_value_max[5][0]/sum_all
185
+ val_7 = pred_value_max[6][0]/sum_all
186
+ val_8 = pred_value_max[7][0]/sum_all
187
+ val_9 = pred_value_max[8][0]/sum_all
188
+ val_10 = pred_value_max[9][0]/sum_all
189
+
190
+ if pred_value_max[0][0]<=0.000131:
191
+ Var_CTH=[]
192
+ Var_desc=[]
193
+ Var_duty=[]
194
+ pred_duty=''
195
+ pred_desc=''
196
+ pred_CTH=''
197
+
198
+ output_str_msg='Not a adequate description'
199
+
200
+ return{'Not a adequate description':float(1.0)}
201
+ else:
202
+ Var_CTH=[]
203
+ Var_desc=[]
204
+ Var_duty=[]
205
+ pred_duty=''
206
+ pred_desc=''
207
+ pred_CTH=''
208
+
209
+ for i in range(len(pred_value_max)):
210
+ #predicted_code=np.where(predicted_values.numpy()==i)[1][0]
211
+ predicted_code=pred_value_max[i][1]
212
+ pred_CTH=df[df['code'] == predicted_code]['CTH'].iloc[0]
213
+
214
+ try:
215
+ pred_duty=df_report_DUTY[df_report_DUTY['CTH']==str(pred_CTH)]['DUTY_RATE'].iloc[0]
216
+ pred_desc=df_report[df_report['CTH Code']==str(pred_CTH)]['Concat Description'].iloc[0]
217
+ except:
218
+ pred_desc=''
219
+ pred_duty=''
220
+ pass
221
+
222
+ Var_CTH.append(pred_CTH)
223
+ Var_desc.append(pred_desc)
224
+ Var_duty.append(pred_duty)
225
+
226
+ P1 ='CTH: '+str(Var_CTH[0])+' Duty Rate(%): '+ str(Var_duty[0])
227
+ P2 ='CTH: '+str(Var_CTH[1])+' Duty Rate(%): '+ str(Var_duty[1])
228
+ P3 ='CTH: '+str(Var_CTH[2])+' Duty Rate(%): '+ str(Var_duty[2])
229
+ P4 ='CTH: '+str(Var_CTH[3])+' Duty Rate(%): '+ str(Var_duty[3])
230
+ P5 ='CTH: '+str(Var_CTH[4])+' Duty Rate(%): '+ str(Var_duty[4])
231
+ P6 ='CTH: '+str(Var_CTH[5])+' Duty Rate(%): '+ str(Var_duty[5])
232
+ P7 ='CTH: '+str(Var_CTH[6])+' Duty Rate(%): '+ str(Var_duty[6])
233
+ P8 ='CTH: '+str(Var_CTH[7])+' Duty Rate(%): '+ str(Var_duty[7])
234
+ P9 ='CTH: '+str(Var_CTH[8])+' Duty Rate(%): '+ str(Var_duty[8])
235
+ P10 ='CTH: '+str(Var_CTH[9])+' Duty Rate(%): '+ str(Var_duty[9])
236
+
237
+ Q1='Desc: '+str(Var_desc[0])
238
+ Q2='Desc: '+str(Var_desc[1])
239
+ Q3='Desc: '+str(Var_desc[2])
240
+ Q4='Desc: '+str(Var_desc[3])
241
+ Q5='Desc: '+str(Var_desc[4])
242
+ Q6='Desc: '+str(Var_desc[5])
243
+ Q7='Desc: '+str(Var_desc[6])
244
+ Q8='Desc: '+str(Var_desc[7])
245
+ Q9='Desc: '+str(Var_desc[8])
246
+ Q10='Desc: '+str(Var_desc[9])
247
+
248
+ output_str_msg = (
249
+ f'1. {P1} {Q1} '
250
+ f'2. {P2} {Q2} '
251
+ f'3. {P3} {Q3} '
252
+ f'4. {P4} {Q4} '
253
+ f'5. {P5} {Q5} '
254
+ f'6. {P6} {Q6} '
255
+ f'7. {P7} {Q7} '
256
+ f'8. {P8} {Q8} '
257
+ f'9. {P9} {Q9} '
258
+ f'10. {P10} {Q10}')
259
+
260
+ print('output_str_msg',output_str_msg)
261
+
262
+ return {str(P1):float(val_1),str(Q1):float(val_1),
263
+ str(P2):float(val_2),str(Q2):float(val_2),
264
+ str(P3):float(val_3),str(Q3):float(val_3),
265
+ str(P4):float(val_4),str(Q4):float(val_4),
266
+ str(P5):float(val_5),str(Q5):float(val_5),
267
+ str(P6):float(val_6),str(Q6):float(val_6),
268
+ str(P7):float(val_7),str(Q7):float(val_7),
269
+ str(P8):float(val_8),str(Q8):float(val_8),
270
+ str(P9):float(val_9),str(Q9):float(val_9),
271
+ str(P10):float(val_10),str(Q10):float(val_10),}
272
+ else:
273
+ output_str_msg='Not a adequate description'
274
+ return{'Enter Correct Description':float(1.0)}
275
+
276
+ def llm_model_function(txt,history,chatbot=[], temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,):
277
+ system_prompt=[]
278
+ chatbot=[]
279
+
280
+ global output_str_msg
281
+
282
+ print('output_str_msg',output_str_msg)
283
+
284
+ if output_str_msg!='Not a adequate description':
285
+
286
+ prompt=f'First Explain What is the product- {txt}. Which is the most appropriate 8 Digit classification code out of the three given below classes. Explain the reason step by step. if none of the three classification is applicable more precisely due to lack of any additional information, tell you need additional information and what is the that additional information. {output_str_msg} ?'
287
+
288
+ temperature = float(temperature)
289
+ if temperature < 1e-2:
290
+ temperature = 1e-2
291
+ top_p = float(top_p)
292
+
293
+ generate_kwargs = dict(
294
+ temperature=temperature,
295
+ max_new_tokens=max_new_tokens,
296
+ top_p=top_p,
297
+ repetition_penalty=repetition_penalty,
298
+ do_sample=True,
299
+ seed=42,
300
+ )
301
+
302
+ formatted_prompt = format_prompt(f", {prompt}", history)
303
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
304
+ output = ""
305
+ for response in stream:
306
+ output += response.token.text
307
+
308
+ chatbot.append((txt, output))
309
+ return "", chatbot
310
+ else:
311
+ # warning_msg = f"Unexpected response"
312
+ # raise gr.Error(warning_msg)
313
+ chatbot.append(('Not a adequate description', 'Not a adequate description'))
314
+ return "", chatbot
315
+
316
+ def product_explaination(txt,history,chatbot=[], temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,):
317
+ print('Input Descrption is:',txt)
318
+ chatbot=[]
319
+ prompt=f'What is the product- {txt}?'
320
+ #print('prompt',prompt)
321
+ temperature = float(temperature)
322
+ if temperature < 1e-2:
323
+ temperature = 1e-2
324
+ top_p = float(top_p)
325
+
326
+ generate_kwargs = dict(
327
+ temperature=temperature,
328
+ max_new_tokens=max_new_tokens,
329
+ top_p=top_p,
330
+ repetition_penalty=repetition_penalty,
331
+ do_sample=True,
332
+ seed=42,
333
+ )
334
+
335
+ formatted_prompt = format_prompt(f", {prompt}", history)
336
+
337
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
338
+ output = ""
339
+
340
+ for response in stream:
341
+ output += response.token.text
342
+
343
+ chatbot.append((txt, output))
344
+ return "", chatbot
models/.github/ISSUE_TEMPLATE/00-official-bug-report-issue.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "[Official Model] Bug Report"
3
+ about: Use this template for reporting a bug for the “official” directory
4
+ labels: type:bug,models:official
5
+
6
+ ---
7
+
8
+ # Prerequisites
9
+
10
+ Please answer the following questions for yourself before submitting an issue.
11
+
12
+ - [ ] I am using the latest TensorFlow Model Garden release and TensorFlow 2.
13
+ - [ ] I am reporting the issue to the correct repository. (Model Garden official or research directory)
14
+ - [ ] I checked to make sure that this issue has not been filed already.
15
+
16
+ ## 1. The entire URL of the file you are using
17
+
18
+ https://github.com/tensorflow/models/tree/master/official/...
19
+
20
+ ## 2. Describe the bug
21
+
22
+ A clear and concise description of what the bug is.
23
+
24
+ ## 3. Steps to reproduce
25
+
26
+ Steps to reproduce the behavior.
27
+
28
+ ## 4. Expected behavior
29
+
30
+ A clear and concise description of what you expected to happen.
31
+
32
+ ## 5. Additional context
33
+
34
+ Include any logs that would be helpful to diagnose the problem.
35
+
36
+ ## 6. System information
37
+
38
+ - OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
39
+ - Mobile device name if the issue happens on a mobile device:
40
+ - TensorFlow installed from (source or binary):
41
+ - TensorFlow version (use command below):
42
+ - Python version:
43
+ - Bazel version (if compiling from source):
44
+ - GCC/Compiler version (if compiling from source):
45
+ - CUDA/cuDNN version:
46
+ - GPU model and memory:
47
+
48
+ <!--
49
+ Collect system information using our environment capture script.
50
+ https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
51
+
52
+ You can also obtain the TensorFlow version with:
53
+
54
+ 1. TensorFlow 1.0
55
+ `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
56
+
57
+ 2. TensorFlow 2.0
58
+ `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
59
+ -->
models/.github/ISSUE_TEMPLATE/10-official-documentation-issue.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "[Official Model] Documentation Issue"
3
+ about: Use this template for reporting a documentation issue for the “official” directory
4
+ labels: type:docs,models:official
5
+
6
+ ---
7
+
8
+ # Prerequisites
9
+
10
+ Please answer the following question for yourself before submitting an issue.
11
+
12
+ - [ ] I checked to make sure that this issue has not been filed already.
13
+
14
+ ## 1. The entire URL of the documentation with the issue
15
+
16
+ https://github.com/tensorflow/models/tree/master/official/...
17
+
18
+ ## 2. Describe the issue
19
+
20
+ A clear and concise description of what needs to be changed.
models/.github/ISSUE_TEMPLATE/20-official-feature-request-issue.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "[Official Model] Feature request"
3
+ about: Use this template for raising a feature request for the “official” directory
4
+ labels: type:feature,models:official
5
+
6
+ ---
7
+
8
+ # Prerequisites
9
+
10
+ Please answer the following question for yourself before submitting an issue.
11
+
12
+ - [ ] I checked to make sure that this feature has not been requested already.
13
+
14
+ ## 1. The entire URL of the file you are using
15
+
16
+ https://github.com/tensorflow/models/tree/master/official/...
17
+
18
+ ## 2. Describe the feature you request
19
+
20
+ A clear and concise description of what you want to happen.
21
+
22
+ ## 3. Additional context
23
+
24
+ Add any other context about the feature request here.
25
+
26
+ ## 4. Are you willing to contribute it? (Yes or No)
models/.github/ISSUE_TEMPLATE/30-research-bug-report-issue.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "[Research Model] Bug Report"
3
+ about: Use this template for reporting a bug for the “research” directory
4
+ labels: type:bug,models:research
5
+
6
+ ---
7
+ # Prerequisites
8
+
9
+ Please answer the following questions for yourself before submitting an issue.
10
+
11
+ - [ ] I am using the latest TensorFlow Model Garden release and TensorFlow 2.
12
+ - [ ] I am reporting the issue to the correct repository. (Model Garden official or research directory)
13
+ - [ ] I checked to make sure that this issue has not already been filed.
14
+
15
+ ## 1. The entire URL of the file you are using
16
+
17
+ https://github.com/tensorflow/models/tree/master/research/...
18
+
19
+ ## 2. Describe the bug
20
+
21
+ A clear and concise description of what the bug is.
22
+
23
+ ## 3. Steps to reproduce
24
+
25
+ Steps to reproduce the behavior.
26
+
27
+ ## 4. Expected behavior
28
+
29
+ A clear and concise description of what you expected to happen.
30
+
31
+ ## 5. Additional context
32
+
33
+ Include any logs that would be helpful to diagnose the problem.
34
+
35
+ ## 6. System information
36
+
37
+ - OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
38
+ - Mobile device name if the issue happens on a mobile device:
39
+ - TensorFlow installed from (source or binary):
40
+ - TensorFlow version (use command below):
41
+ - Python version:
42
+ - Bazel version (if compiling from source):
43
+ - GCC/Compiler version (if compiling from source):
44
+ - CUDA/cuDNN version:
45
+ - GPU model and memory:
46
+
47
+ <!--
48
+ Collect system information using our environment capture script.
49
+ https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
50
+
51
+ You can also obtain the TensorFlow version with:
52
+
53
+ 1. TensorFlow 1.0
54
+ `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
55
+
56
+ 2. TensorFlow 2.0
57
+ `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
58
+ -->
models/.github/ISSUE_TEMPLATE/40-research-documentation-issue.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "[Research Model] Documentation Issue"
3
+ about: Use this template for reporting a documentation issue for the “research” directory
4
+ labels: type:docs,models:research
5
+
6
+ ---
7
+
8
+ # Prerequisites
9
+
10
+ Please answer the following question for yourself before submitting an issue.
11
+
12
+ - [ ] I checked to make sure that this issue has not been filed already.
13
+
14
+ ## 1. The entire URL of the documentation with the issue
15
+
16
+ https://github.com/tensorflow/models/tree/master/research/...
17
+
18
+ ## 2. Describe the issue
19
+
20
+ A clear and concise description of what needs to be changed.
models/.github/ISSUE_TEMPLATE/50-research-feature-request-issue.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "[Research Model] Feature Request"
3
+ about: Use this template for raising a feature request for the “research” directory
4
+ labels: type:feature,models:research
5
+
6
+ ---
7
+
8
+ # Prerequisites
9
+
10
+ Please answer the following question for yourself before submitting an issue.
11
+
12
+ - [ ] I checked to make sure that this feature has not been requested already.
13
+
14
+ ## 1. The entire URL of the file you are using
15
+
16
+ https://github.com/tensorflow/models/tree/master/research/...
17
+
18
+ ## 2. Describe the feature you request
19
+
20
+ A clear and concise description of what you want to happen.
21
+
22
+ ## 3. Additional context
23
+
24
+ Add any other context about the feature request here.
25
+
26
+ ## 4. Are you willing to contribute it? (Yes or No)
models/.github/ISSUE_TEMPLATE/60-questions-help-issue.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Questions and Help
3
+ about: Use this template for Questions and Help.
4
+ labels: type:support
5
+
6
+ ---
7
+ <!--
8
+ As per our GitHub Policy (https://github.com/tensorflow/models/blob/master/ISSUES.md), we only address code bugs, documentation issues, and feature requests on GitHub.
9
+
10
+ We will automatically close questions and help related issues.
11
+
12
+ Please go to Stack Overflow (http://stackoverflow.com/questions/tagged/tensorflow-model-garden) for questions and help.
13
+
14
+ -->
models/.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1 @@
 
 
1
+ blank_issues_enabled: false
models/.github/PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description
2
+
3
+ > :memo: Please include a summary of the change.
4
+ >
5
+ > * Please also include relevant motivation and context.
6
+ > * List any dependencies that are required for this change.
7
+
8
+ ## Type of change
9
+
10
+ For a new feature or function, please create an issue first to discuss it
11
+ with us before submitting a pull request.
12
+
13
+ Note: Please delete options that are not relevant.
14
+
15
+ - [ ] Bug fix (non-breaking change which fixes an issue)
16
+ - [ ] Documentation update
17
+ - [ ] TensorFlow 2 migration
18
+ - [ ] New feature (non-breaking change which adds functionality)
19
+ - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
20
+ - [ ] A new research paper code implementation
21
+ - [ ] Other (Specify)
22
+
23
+ ## Tests
24
+
25
+ > :memo: Please describe the tests that you ran to verify your changes.
26
+ >
27
+ > * Provide instructions so we can reproduce.
28
+ > * Please also list any relevant details for your test configuration.
29
+
30
+ **Test Configuration**:
31
+
32
+ ## Checklist
33
+
34
+ - [ ] I have signed the [Contributor License Agreement](https://github.com/tensorflow/models/wiki/Contributor-License-Agreements).
35
+ - [ ] I have read [guidelines for pull request](https://github.com/tensorflow/models/wiki/Submitting-a-pull-request).
36
+ - [ ] My code follows the [coding guidelines](https://github.com/tensorflow/models/wiki/Coding-guidelines).
37
+ - [ ] I have performed a self [code review](https://github.com/tensorflow/models/wiki/Code-review) of my own code.
38
+ - [ ] I have commented my code, particularly in hard-to-understand areas.
39
+ - [ ] I have made corresponding changes to the documentation.
40
+ - [ ] My changes generate no new warnings.
41
+ - [ ] I have added tests that prove my fix is effective or that my feature works.
models/.github/README_TEMPLATE.md ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ > :memo: A README.md template for releasing a paper code implementation to a GitHub repository.
2
+ >
3
+ > * Template version: 1.0.2020.170
4
+ > * Please modify sections depending on needs.
5
+
6
+ # Model name, Paper title, or Project Name
7
+
8
+ > :memo: Add a badge for the ArXiv identifier of your paper (arXiv:YYMM.NNNNN)
9
+
10
+ [![Paper](http://img.shields.io/badge/Paper-arXiv.YYMM.NNNNN-B3181B?logo=arXiv)](https://arxiv.org/abs/...)
11
+
12
+ This repository is the official or unofficial implementation of the following paper.
13
+
14
+ * Paper title: [Paper Title](https://arxiv.org/abs/YYMM.NNNNN)
15
+
16
+ ## Description
17
+
18
+ > :memo: Provide description of the model.
19
+ >
20
+ > * Provide brief information of the algorithms used.
21
+ > * Provide links for demos, blog posts, etc.
22
+
23
+ ## History
24
+
25
+ > :memo: Provide a changelog.
26
+
27
+ ## Authors or Maintainers
28
+
29
+ > :memo: Provide maintainer information.
30
+
31
+ * Full name ([@GitHub username](https://github.com/username))
32
+ * Full name ([@GitHub username](https://github.com/username))
33
+
34
+ ## Table of Contents
35
+
36
+ > :memo: Provide a table of contents to help readers navigate a lengthy README document.
37
+
38
+ ## Requirements
39
+
40
+ [![TensorFlow 2.1](https://img.shields.io/badge/TensorFlow-2.1-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0)
41
+ [![Python 3.6](https://img.shields.io/badge/Python-3.6-3776AB)](https://www.python.org/downloads/release/python-360/)
42
+
43
+ > :memo: Provide details of the software required.
44
+ >
45
+ > * Add a `requirements.txt` file to the root directory for installing the necessary dependencies.
46
+ > * Describe how to install requirements using pip.
47
+ > * Alternatively, create INSTALL.md.
48
+
49
+ To install requirements:
50
+
51
+ ```setup
52
+ pip install -r requirements.txt
53
+ ```
54
+
55
+ ## Results
56
+
57
+ > :memo: Provide a table with results. (e.g., accuracy, latency)
58
+ >
59
+ > * Provide links to the pre-trained models (checkpoint, SavedModel files).
60
+ > * Publish TensorFlow SavedModel files on TensorFlow Hub (tfhub.dev) if possible.
61
+ > * Add links to [TensorBoard.dev](https://tensorboard.dev/) for visualizing metrics.
62
+ >
63
+ > An example table for image classification results
64
+ >
65
+ > ### Image Classification
66
+ >
67
+ > | Model name | Download | Top 1 Accuracy | Top 5 Accuracy |
68
+ > |------------|----------|----------------|----------------|
69
+ > | Model name | [Checkpoint](https://drive.google.com/...), [SavedModel](https://tfhub.dev/...) | xx% | xx% |
70
+
71
+ ## Dataset
72
+
73
+ > :memo: Provide information of the dataset used.
74
+
75
+ ## Training
76
+
77
+ > :memo: Provide training information.
78
+ >
79
+ > * Provide details for preprocessing, hyperparameters, random seeds, and environment.
80
+ > * Provide a command line example for training.
81
+
82
+ Please run this command line for training.
83
+
84
+ ```shell
85
+ python3 ...
86
+ ```
87
+
88
+ ## Evaluation
89
+
90
+ > :memo: Provide an evaluation script with details of how to reproduce results.
91
+ >
92
+ > * Describe data preprocessing / postprocessing steps.
93
+ > * Provide a command line example for evaluation.
94
+
95
+ Please run this command line for evaluation.
96
+
97
+ ```shell
98
+ python3 ...
99
+ ```
100
+
101
+ ## References
102
+
103
+ > :memo: Provide links to references.
104
+
105
+ ## License
106
+
107
+ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
108
+
109
+ > :memo: Place your license text in a file named LICENSE in the root of the repository.
110
+ >
111
+ > * Include information about your license.
112
+ > * Reference: [Adding a license to a repository](https://help.github.com/en/github/building-a-strong-community/adding-a-license-to-a-repository)
113
+
114
+ This project is licensed under the terms of the **Apache License 2.0**.
115
+
116
+ ## Citation
117
+
118
+ > :memo: Make your repository citable.
119
+ >
120
+ > * Reference: [Making Your Code Citable](https://guides.github.com/activities/citable-code/)
121
+
122
+ If you want to cite this repository in your research paper, please use the following information.
models/.gitignore ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ env/
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+
27
+ # PyInstaller
28
+ # Usually these files are written by a python script from a template
29
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
30
+ *.manifest
31
+ *.spec
32
+
33
+ # Installer logs
34
+ pip-log.txt
35
+ pip-delete-this-directory.txt
36
+
37
+ # Unit test / coverage reports
38
+ htmlcov/
39
+ .tox/
40
+ .coverage
41
+ .coverage.*
42
+ .cache
43
+ nosetests.xml
44
+ coverage.xml
45
+ *,cover
46
+ .hypothesis/
47
+
48
+ # Translations
49
+ *.mo
50
+ *.pot
51
+
52
+ # Django stuff:
53
+ *.log
54
+ local_settings.py
55
+
56
+ # Flask stuff:
57
+ instance/
58
+ .webassets-cache
59
+
60
+ # Scrapy stuff:
61
+ .scrapy
62
+
63
+ # Sphinx documentation
64
+ docs/_build/
65
+
66
+ # PyBuilder
67
+ target/
68
+
69
+ # IPython Notebook
70
+ .ipynb_checkpoints
71
+
72
+ # pyenv
73
+ .python-version
74
+
75
+ # mypy
76
+ .mypy_cache
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # dotenv
82
+ .env
83
+
84
+ # virtualenv
85
+ venv/
86
+ ENV/
87
+
88
+ # Spyder project settings
89
+ .spyderproject
90
+
91
+ # Rope project settings
92
+ .ropeproject
93
+
94
+ # PyCharm
95
+ .idea/
96
+
97
+ # For mac
98
+ .DS_Store
models/AUTHORS ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is the official list of authors for copyright purposes.
2
+ # This file is distinct from the CONTRIBUTORS files.
3
+ # See the latter for an explanation.
4
+
5
+ # Names should be added to this file as:
6
+ # Name or Organization <email address>
7
+ # The email address is not required for organizations.
8
+
9
+ Google Inc.
10
+ David Dao <daviddao@broad.mit.edu>
models/CODEOWNERS ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * @tensorflow/tf-garden-team @tensorflow/tf-model-garden-team
2
+ /official/ @rachellj218 @saberkun @jaeyounkim
3
+ /official/nlp/ @saberkun @chenGitHuber @lehougoogle @rachellj218
4
+ /official/vision/ @pengchongjin @xianzhidu @yeqingli @arashwan @saberkun @rachellj218
5
+ /research/adv_imagenet_models/ @alexeykurakin
6
+ /research/adversarial_crypto/ @dave-andersen
7
+ /research/adversarial_logit_pairing/ @alexeykurakin
8
+ /research/adversarial_text/ @rsepassi @a-dai
9
+ /research/attention_ocr/ @xavigibert
10
+ /research/audioset/ @plakal @dpwe
11
+ /research/autoaugment/* @barretzoph
12
+ /research/autoencoders/ @snurkabill
13
+ /research/brain_coder/ @danabo
14
+ /research/cognitive_mapping_and_planning/ @s-gupta
15
+ /research/compression/ @nmjohn
16
+ /research/cvt_text/ @clarkkev @lmthang
17
+ /research/deep_contextual_bandits/ @rikel
18
+ /research/deep_speech/ @yhliang2018
19
+ /research/deeplab/ @aquariusjay @yknzhu @gpapan
20
+ /research/delf/ @andrefaraujo
21
+ /research/domain_adaptation/ @bousmalis @dmrd
22
+ /research/efficient-hrl/ @ofirnachum
23
+ /research/feelvos/ @pvoigtlaender @yuningchai @aquariusjay
24
+ /research/fivo/ @dieterichlawson
25
+ /research/global_objectives/ @mackeya-google
26
+ /research/im2txt/ @cshallue
27
+ /research/inception/ @shlens @vincentvanhoucke
28
+ /research/keypointnet/ @mnorouzi
29
+ /research/learned_optimizer/ @olganw @nirum
30
+ /research/learning_to_remember_rare_events/ @lukaszkaiser @ofirnachum
31
+ /research/learning_unsupervised_learning/ @lukemetz @nirum
32
+ /research/lexnet_nc/ @vered1986 @waterson
33
+ /research/lfads/ @jazcollins @sussillo
34
+ /research/lm_1b/ @oriolvinyals @panyx0718
35
+ /research/lm_commonsense/ @thtrieu
36
+ /research/lstm_object_detection/ @yinxiaoli @yongzhe2160
37
+ /research/marco/ @vincentvanhoucke
38
+ /research/maskgan/ @liamb315 @a-dai
39
+ /research/namignizer/ @knathanieltucker
40
+ /research/neural_gpu/ @lukaszkaiser
41
+ /research/neural_programmer/ @arvind2505
42
+ /research/next_frame_prediction/ @panyx0718
43
+ /research/object_detection/ @jch1 @tombstone @pkulzc
44
+ /research/pcl_rl/ @ofirnachum
45
+ /research/ptn/ @xcyan @arkanath @hellojas @honglaklee
46
+ /research/qa_kg/ @yuyuz
47
+ /research/real_nvp/ @laurent-dinh
48
+ /research/rebar/ @gjtucker
49
+ /research/sentiment_analysis/ @sculd
50
+ /research/seq2species/ @apbusia @depristo
51
+ /research/skip_thoughts/ @cshallue
52
+ /research/slim/ @sguada @marksandler2
53
+ /research/steve/ @buckman-google
54
+ /research/street/ @theraysmith
55
+ /research/struct2depth/ @aneliaangelova
56
+ /research/swivel/ @waterson
57
+ /research/tcn/ @coreylynch @sermanet
58
+ /research/textsum/ @panyx0718 @peterjliu
59
+ /research/transformer/ @daviddao
60
+ /research/vid2depth/ @rezama
61
+ /research/video_prediction/ @cbfinn
models/CONTRIBUTING.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to contribute
2
+
3
+ ![Contributors](https://img.shields.io/github/contributors/tensorflow/models)
4
+
5
+ We encourage you to contribute to the TensorFlow Model Garden.
6
+
7
+ Please read our [guidelines](../../wiki/How-to-contribute) for details.
8
+
9
+ **NOTE**: Only [code owners](./CODEOWNERS) are allowed to merge a pull request.
10
+ Please contact the code owners of each model to merge your pull request.
models/ISSUES.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # If you open a GitHub issue, here is our policy.
2
+
3
+ * It must be a **bug**, a **feature request**, or a significant problem
4
+ with **documentation**.
5
+ * Please send a pull request instead for small documentation fixes.
6
+ * The required form must be filled out.
7
+ * The issue should be related to the repository it is created in.
8
+
9
+ General help and support should be sought on [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow-model-garden) or other non-GitHub channels.
10
+
11
+ [![](https://img.shields.io/stackexchange/stackoverflow/t/tensorflow-model-garden)](https://stackoverflow.com/questions/tagged/tensorflow-model-garden)
12
+
13
+ TensorFlow developers respond to issues.
14
+ We want to focus on work that benefits the whole community such as fixing bugs
15
+ and adding new features.
16
+ It helps us to address bugs and feature requests in a timely manner.
17
+
18
+ ---
19
+
20
+ Please understand that research models in the [research directory](https://github.com/tensorflow/models/tree/master/research)
21
+ included in this repository are experimental and research-style code.
22
+ They are not officially supported by the TensorFlow team.
23
+
24
+
models/LICENSE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2016 The TensorFlow Authors. All rights reserved.
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright 2016, The Authors.
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
models/README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![Logo](https://storage.googleapis.com/model_garden_artifacts/TF_Model_Garden.png)
2
+
3
+ # Welcome to the Model Garden for TensorFlow
4
+
5
+ The TensorFlow Model Garden is a repository with a number of different implementations of state-of-the-art (SOTA) models and modeling solutions for TensorFlow users. We aim to demonstrate the best practices for modeling so that TensorFlow users
6
+ can take full advantage of TensorFlow for their research and product development.
7
+
8
+ | Directory | Description |
9
+ |-----------|-------------|
10
+ | [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read |
11
+ | [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers |
12
+ | [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
13
+
14
+ ## [Announcements](https://github.com/tensorflow/models/wiki/Announcements)
15
+
16
+ | Date | News |
17
+ |------|------|
18
+ | June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released
19
+ | May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released
20
+ | May 19, 2020 | [MobileDets: Searching for Object Detection Architectures for Mobile Accelerators](https://github.com/tensorflow/models/tree/master/research/object_detection#may-19th-2020) released
21
+ | May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection
22
+ | May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1
23
+ | March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) |
24
+
25
+ ## [Milestones](https://github.com/tensorflow/models/milestones)
26
+
27
+ | Date | Milestone |
28
+ |------|-----------|
29
+ | July 7, 2020 | [![GitHub milestone](https://img.shields.io/github/milestones/progress/tensorflow/models/1)](https://github.com/tensorflow/models/milestone/1) |
30
+
31
+ ## Contributions
32
+
33
+ [![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
34
+
35
+ If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
36
+
37
+ ## License
38
+
39
+ [Apache License 2.0](LICENSE)
models/official/LICENSE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2015 The TensorFlow Authors. All rights reserved.
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright 2015, The TensorFlow Authors.
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
models/official/README-TPU.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Offically Supported TensorFlow 2.1+ Models on Cloud TPU
2
+
3
+ ## Natural Language Processing
4
+
5
+ * [bert](nlp/bert): A powerful pre-trained language representation model:
6
+ BERT, which stands for Bidirectional Encoder Representations from
7
+ Transformers.
8
+ [BERT FineTuning with Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/bert-2.x) provides step by step instructions on Cloud TPU training. You can look [Bert MNLI Tensorboard.dev metrics](https://tensorboard.dev/experiment/LijZ1IrERxKALQfr76gndA) for MNLI fine tuning task.
9
+ * [transformer](nlp/transformer): A transformer model to translate the WMT
10
+ English to German dataset.
11
+ [Training transformer on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/transformer-2.x) for step by step instructions on Cloud TPU training.
12
+
13
+ ## Computer Vision
14
+
15
+ * [efficientnet](vision/image_classification): A family of convolutional
16
+ neural networks that scale by balancing network depth, width, and
17
+ resolution and can be used to classify ImageNet's dataset of 1000 classes.
18
+ See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/KnaWjrq5TXGfv0NW5m7rpg/#scalars).
19
+ * [mnist](vision/image_classification): A basic model to classify digits
20
+ from the MNIST dataset. See [Running MNIST on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/mnist-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/mIah5lppTASvrHqWrdr6NA).
21
+ * [mask-rcnn](vision/detection): An object detection and instance segmentation model. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/LH7k0fMsRwqUAcE09o9kPA).
22
+ * [resnet](vision/image_classification): A deep residual network that can
23
+ be used to classify ImageNet's dataset of 1000 classes.
24
+ See [Training ResNet on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/resnet-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/CxlDK8YMRrSpYEGtBRpOhg).
25
+ * [retinanet](vision/detection): A fast and powerful object detector. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/b8NRnWU3TqG6Rw0UxueU6Q).
models/official/README.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![Logo](https://storage.googleapis.com/model_garden_artifacts/TF_Model_Garden.png)
2
+
3
+ # TensorFlow Official Models
4
+
5
+ The TensorFlow official models are a collection of models
6
+ that use TensorFlow’s high-level APIs.
7
+ They are intended to be well-maintained, tested, and kept up to date
8
+ with the latest TensorFlow API.
9
+
10
+ They should also be reasonably optimized for fast performance while still
11
+ being easy to read.
12
+ These models are used as end-to-end tests, ensuring that the models run
13
+ with the same or improved speed and performance with each new TensorFlow build.
14
+
15
+ ## More models to come!
16
+
17
+ The team is actively developing new models.
18
+ In the near future, we will add:
19
+
20
+ * State-of-the-art language understanding models:
21
+ More members in Transformer family
22
+ * Start-of-the-art image classification models:
23
+ EfficientNet, MnasNet, and variants
24
+ * A set of excellent objection detection models.
25
+
26
+ ## Table of Contents
27
+
28
+ - [Models and Implementations](#models-and-implementations)
29
+ * [Computer Vision](#computer-vision)
30
+ + [Image Classification](#image-classification)
31
+ + [Object Detection and Segmentation](#object-detection-and-segmentation)
32
+ * [Natural Language Processing](#natural-language-processing)
33
+ * [Recommendation](#recommendation)
34
+ - [How to get started with the official models](#how-to-get-started-with-the-official-models)
35
+
36
+ ## Models and Implementations
37
+
38
+ ### Computer Vision
39
+
40
+ #### Image Classification
41
+
42
+ | Model | Reference (Paper) |
43
+ |-------|-------------------|
44
+ | [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) |
45
+ | [ResNet](vision/image_classification) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
46
+ | [EfficientNet](vision/image_classification) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
47
+
48
+ #### Object Detection and Segmentation
49
+
50
+ | Model | Reference (Paper) |
51
+ |-------|-------------------|
52
+ | [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
53
+ | [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
54
+ | [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) |
55
+
56
+ ### Natural Language Processing
57
+
58
+ | Model | Reference (Paper) |
59
+ |-------|-------------------|
60
+ | [ALBERT (A Lite BERT)](nlp/albert) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
61
+ | [BERT (Bidirectional Encoder Representations from Transformers)](nlp/bert) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
62
+ | [NHNet (News Headline generation model)](nlp/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
63
+ | [Transformer](nlp/transformer) | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
64
+ | [XLNet](nlp/xlnet) | [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) |
65
+
66
+ ### Recommendation
67
+
68
+ | Model | Reference (Paper) |
69
+ |-------|-------------------|
70
+ | [NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) |
71
+
72
+ ## How to get started with the official models
73
+
74
+ * The models in the master branch are developed using TensorFlow 2,
75
+ and they target the TensorFlow [nightly binaries](https://github.com/tensorflow/tensorflow#installation)
76
+ built from the
77
+ [master branch of TensorFlow](https://github.com/tensorflow/tensorflow/tree/master).
78
+ * The stable versions targeting releases of TensorFlow are available
79
+ as tagged branches or [downloadable releases](https://github.com/tensorflow/models/releases).
80
+ * Model repository version numbers match the target TensorFlow release,
81
+ such that
82
+ [release v2.2.0](https://github.com/tensorflow/models/releases/tag/v2.2.0)
83
+ are compatible with
84
+ [TensorFlow v2.2.0](https://github.com/tensorflow/tensorflow/releases/tag/v2.2.0).
85
+
86
+ Please follow the below steps before running models in this repository.
87
+
88
+ ### Requirements
89
+
90
+ * The latest TensorFlow Model Garden release and TensorFlow 2
91
+ * If you are on a version of TensorFlow earlier than 2.2, please
92
+ upgrade your TensorFlow to [the latest TensorFlow 2](https://www.tensorflow.org/install/).
93
+
94
+ ```shell
95
+ pip3 install tf-nightly
96
+ ```
97
+
98
+ ### Installation
99
+
100
+ #### Method 1: Install the TensorFlow Model Garden pip package
101
+
102
+ **tf-models-nightly** is the nightly Model Garden package
103
+ created daily automatically. pip will install all models
104
+ and dependencies automatically.
105
+
106
+ ```shell
107
+ pip install tf-models-nightly
108
+ ```
109
+
110
+ Please check out our [example](colab/fine_tuning_bert.ipynb)
111
+ to learn how to use a PIP package.
112
+
113
+ #### Method 2: Clone the source
114
+
115
+ 1. Clone the GitHub repository:
116
+
117
+ ```shell
118
+ git clone https://github.com/tensorflow/models.git
119
+ ```
120
+
121
+ 2. Add the top-level ***/models*** folder to the Python path.
122
+
123
+ ```shell
124
+ export PYTHONPATH=$PYTHONPATH:/path/to/models
125
+ ```
126
+
127
+ If you are using a Colab notebook, please set the Python path with os.environ.
128
+
129
+ ```python
130
+ import os
131
+ os.environ['PYTHONPATH'] += ":/path/to/models"
132
+ ```
133
+
134
+ 3. Install other dependencies
135
+
136
+ ```shell
137
+ pip3 install --user -r official/requirements.txt
138
+ ```
139
+
140
+ ## Contributions
141
+
142
+ If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
models/official/__init__.py ADDED
File without changes
models/official/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (148 Bytes). View file
 
models/official/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (146 Bytes). View file
 
models/official/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (139 Bytes). View file
 
models/official/benchmark/__init__.py ADDED
File without changes
models/official/benchmark/benchmark_wrappers.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Lint as: python3
2
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """Utils to annotate and trace benchmarks."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ from absl import flags
23
+ from absl import logging
24
+ from absl.testing import flagsaver
25
+
26
+ FLAGS = flags.FLAGS
27
+
28
+ flags.DEFINE_multi_string(
29
+ 'benchmark_method_flags', None,
30
+ 'Optional list of runtime flags of the form key=value. Specify '
31
+ 'multiple times to specify different flags. These will override the FLAGS '
32
+ 'object directly after hardcoded settings in individual benchmark methods '
33
+ 'before they call _run_and_report benchmark. Example if we set '
34
+ '--benchmark_method_flags=train_steps=10 and a benchmark method hardcodes '
35
+ 'FLAGS.train_steps=10000 and later calls _run_and_report_benchmark, '
36
+ 'it\'ll only run for 10 steps. This is useful for '
37
+ 'debugging/profiling workflows.')
38
+
39
+
40
+ def enable_runtime_flags(decorated_func):
41
+ """Sets attributes from --benchmark_method_flags for method execution.
42
+
43
+ @enable_runtime_flags decorator temporarily adds flags passed in via
44
+ --benchmark_method_flags and runs the decorated function in that context.
45
+
46
+ A user can set --benchmark_method_flags=train_steps=5 to run the benchmark
47
+ method in the snippet below with FLAGS.train_steps=5 for debugging (without
48
+ modifying the benchmark code).
49
+
50
+ class ModelBenchmark():
51
+
52
+ @benchmark_wrappers.enable_runtime_flags
53
+ def _run_and_report_benchmark(self):
54
+ # run benchmark ...
55
+ # report benchmark results ...
56
+
57
+ def benchmark_method(self):
58
+ FLAGS.train_steps = 1000
59
+ ...
60
+ self._run_and_report_benchmark()
61
+
62
+ Args:
63
+ decorated_func: The method that runs the benchmark after previous setup
64
+ execution that set some flags.
65
+
66
+ Returns:
67
+ new_func: The same method which executes in a temporary context where flag
68
+ overrides from --benchmark_method_flags are active.
69
+ """
70
+
71
+ def runner(*args, **kwargs):
72
+ """Creates a temporary context to activate --benchmark_method_flags."""
73
+ if FLAGS.benchmark_method_flags:
74
+ saved_flag_values = flagsaver.save_flag_values()
75
+ for key_value in FLAGS.benchmark_method_flags:
76
+ key, value = key_value.split('=', 1)
77
+ try:
78
+ numeric_float = float(value)
79
+ numeric_int = int(numeric_float)
80
+ if abs(numeric_int) == abs(numeric_float):
81
+ flag_value = numeric_int
82
+ else:
83
+ flag_value = numeric_float
84
+ except ValueError:
85
+ flag_value = value
86
+ logging.info('Setting --%s=%s', key, flag_value)
87
+ setattr(FLAGS, key, flag_value)
88
+ else:
89
+ saved_flag_values = None
90
+ try:
91
+ result = decorated_func(*args, **kwargs)
92
+ return result
93
+ finally:
94
+ if saved_flag_values:
95
+ flagsaver.restore_flag_values(saved_flag_values)
96
+
97
+ return runner
models/official/benchmark/bert_benchmark.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Executes BERT benchmarks and accuracy tests."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import functools
22
+ import json
23
+ import math
24
+ import os
25
+ import time
26
+
27
+ # pylint: disable=g-bad-import-order
28
+ from absl import flags
29
+ from absl.testing import flagsaver
30
+ import tensorflow as tf
31
+ # pylint: enable=g-bad-import-order
32
+
33
+ from official.benchmark import bert_benchmark_utils as benchmark_utils
34
+ from official.benchmark import owner_utils
35
+ from official.nlp.bert import configs
36
+ from official.nlp.bert import run_classifier
37
+ from official.utils.misc import distribution_utils
38
+ from official.benchmark import benchmark_wrappers
39
+
40
+ # pylint: disable=line-too-long
41
+ PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
42
+ CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_train.tf_record'
43
+ CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_eval.tf_record'
44
+ CLASSIFIER_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_meta_data'
45
+ MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
46
+ # pylint: enable=line-too-long
47
+
48
+ TMP_DIR = os.getenv('TMPDIR')
49
+ FLAGS = flags.FLAGS
50
+
51
+
52
+ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
53
+ """Base class to hold methods common to test classes in the module."""
54
+
55
+ def __init__(self, output_dir=None, tpu=None):
56
+ super(BertClassifyBenchmarkBase, self).__init__(output_dir, tpu=tpu)
57
+ self.num_epochs = None
58
+ self.num_steps_per_epoch = None
59
+ FLAGS.steps_per_loop = 1
60
+
61
+ @flagsaver.flagsaver
62
+ def _run_bert_classifier(self, callbacks=None, use_ds=True):
63
+ """Starts BERT classification task."""
64
+ with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
65
+ input_meta_data = json.loads(reader.read().decode('utf-8'))
66
+
67
+ bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
68
+ epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs
69
+ if self.num_steps_per_epoch:
70
+ steps_per_epoch = self.num_steps_per_epoch
71
+ else:
72
+ train_data_size = input_meta_data['train_data_size']
73
+ steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
74
+ warmup_steps = int(epochs * steps_per_epoch * 0.1)
75
+ eval_steps = int(
76
+ math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
77
+ if self.tpu:
78
+ strategy = distribution_utils.get_distribution_strategy(
79
+ distribution_strategy='tpu', tpu_address=self.tpu)
80
+ else:
81
+ strategy = distribution_utils.get_distribution_strategy(
82
+ distribution_strategy='mirrored' if use_ds else 'off',
83
+ num_gpus=self.num_gpus)
84
+
85
+ max_seq_length = input_meta_data['max_seq_length']
86
+ train_input_fn = run_classifier.get_dataset_fn(
87
+ FLAGS.train_data_path,
88
+ max_seq_length,
89
+ FLAGS.train_batch_size,
90
+ is_training=True)
91
+ eval_input_fn = run_classifier.get_dataset_fn(
92
+ FLAGS.eval_data_path,
93
+ max_seq_length,
94
+ FLAGS.eval_batch_size,
95
+ is_training=False)
96
+ _, summary = run_classifier.run_bert_classifier(
97
+ strategy,
98
+ bert_config,
99
+ input_meta_data,
100
+ FLAGS.model_dir,
101
+ epochs,
102
+ steps_per_epoch,
103
+ FLAGS.steps_per_loop,
104
+ eval_steps,
105
+ warmup_steps,
106
+ FLAGS.learning_rate,
107
+ FLAGS.init_checkpoint,
108
+ train_input_fn,
109
+ eval_input_fn,
110
+ training_callbacks=False,
111
+ custom_callbacks=callbacks)
112
+ return summary
113
+
114
+
115
+ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
116
+ """Short benchmark performance tests for BERT model.
117
+
118
+ Tests BERT classification performance in different GPU, TPU configurations.
119
+ The naming convention of below test cases follow
120
+ `benchmark_(number of gpus)_gpu_(dataset type)` for GPUs and
121
+ `benchmark_(topology)_tpu_(dataset type)` for TPUs.
122
+ """
123
+
124
+ def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
125
+ super(BertClassifyBenchmarkReal, self).__init__(
126
+ output_dir=output_dir, tpu=tpu)
127
+
128
+ self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
129
+ self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
130
+ self.bert_config_file = MODEL_CONFIG_FILE_PATH
131
+ self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
132
+
133
+ # Since we only care about performance metrics, we limit
134
+ # the number of training steps and epochs to prevent unnecessarily
135
+ # long tests.
136
+ self.num_steps_per_epoch = 100
137
+ self.num_epochs = 1
138
+
139
+ @benchmark_wrappers.enable_runtime_flags
140
+ def _run_and_report_benchmark(self,
141
+ training_summary_path,
142
+ min_accuracy=0,
143
+ max_accuracy=1,
144
+ use_ds=True):
145
+ """Starts BERT performance benchmark test."""
146
+ start_time_sec = time.time()
147
+ summary = self._run_bert_classifier(
148
+ callbacks=[self.timer_callback], use_ds=use_ds)
149
+ wall_time_sec = time.time() - start_time_sec
150
+
151
+ # Since we do not load from any pretrained checkpoints, we ignore all
152
+ # accuracy metrics.
153
+ summary.pop('eval_metrics', None)
154
+ summary['start_time_sec'] = start_time_sec
155
+
156
+ super(BertClassifyBenchmarkReal, self)._report_benchmark(
157
+ stats=summary,
158
+ wall_time_sec=wall_time_sec,
159
+ min_accuracy=min_accuracy,
160
+ max_accuracy=max_accuracy)
161
+
162
+ def benchmark_1_gpu_mrpc(self):
163
+ """Test BERT model performance with 1 GPU."""
164
+
165
+ self._setup()
166
+ self.num_gpus = 1
167
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc')
168
+ FLAGS.train_data_path = self.train_data_path
169
+ FLAGS.eval_data_path = self.eval_data_path
170
+ FLAGS.input_meta_data_path = self.input_meta_data_path
171
+ FLAGS.bert_config_file = self.bert_config_file
172
+ FLAGS.train_batch_size = 4
173
+ FLAGS.eval_batch_size = 4
174
+
175
+ summary_path = os.path.join(FLAGS.model_dir,
176
+ 'summaries/training_summary.txt')
177
+ self._run_and_report_benchmark(summary_path)
178
+
179
+ def benchmark_1_gpu_mrpc_xla(self):
180
+ """Test BERT model performance with 1 GPU."""
181
+
182
+ self._setup()
183
+ self.num_gpus = 1
184
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc_xla')
185
+ FLAGS.train_data_path = self.train_data_path
186
+ FLAGS.eval_data_path = self.eval_data_path
187
+ FLAGS.input_meta_data_path = self.input_meta_data_path
188
+ FLAGS.bert_config_file = self.bert_config_file
189
+ FLAGS.train_batch_size = 4
190
+ FLAGS.eval_batch_size = 4
191
+ FLAGS.enable_xla = True
192
+
193
+ summary_path = os.path.join(FLAGS.model_dir,
194
+ 'summaries/training_summary.txt')
195
+ self._run_and_report_benchmark(summary_path)
196
+
197
+ def benchmark_1_gpu_mrpc_no_dist_strat(self):
198
+ """Test BERT model performance with 1 GPU, no distribution strategy."""
199
+
200
+ self._setup()
201
+ self.num_gpus = 1
202
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc_no_dist_strat')
203
+ FLAGS.train_data_path = self.train_data_path
204
+ FLAGS.eval_data_path = self.eval_data_path
205
+ FLAGS.input_meta_data_path = self.input_meta_data_path
206
+ FLAGS.bert_config_file = self.bert_config_file
207
+ FLAGS.train_batch_size = 4
208
+ FLAGS.eval_batch_size = 4
209
+
210
+ summary_path = os.path.join(FLAGS.model_dir,
211
+ 'summaries/training_summary.txt')
212
+ self._run_and_report_benchmark(summary_path, use_ds=False)
213
+
214
+ @owner_utils.Owner('tf-model-garden')
215
+ def benchmark_8_gpu_mrpc(self):
216
+ """Test BERT model performance with 8 GPUs."""
217
+
218
+ self._setup()
219
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
220
+ FLAGS.train_data_path = self.train_data_path
221
+ FLAGS.eval_data_path = self.eval_data_path
222
+ FLAGS.input_meta_data_path = self.input_meta_data_path
223
+ FLAGS.bert_config_file = self.bert_config_file
224
+
225
+ summary_path = os.path.join(FLAGS.model_dir,
226
+ 'summaries/training_summary.txt')
227
+ self._run_and_report_benchmark(summary_path)
228
+
229
+ def benchmark_1_gpu_amp_mrpc_no_dist_strat(self):
230
+ """Performance for 1 GPU no DS with automatic mixed precision."""
231
+ self._setup()
232
+ self.num_gpus = 1
233
+ FLAGS.model_dir = self._get_model_dir(
234
+ 'benchmark_1_gpu_amp_mrpc_no_dist_strat')
235
+ FLAGS.train_data_path = self.train_data_path
236
+ FLAGS.eval_data_path = self.eval_data_path
237
+ FLAGS.input_meta_data_path = self.input_meta_data_path
238
+ FLAGS.bert_config_file = self.bert_config_file
239
+ FLAGS.train_batch_size = 4
240
+ FLAGS.eval_batch_size = 4
241
+ FLAGS.dtype = 'fp16'
242
+ FLAGS.fp16_implementation = 'graph_rewrite'
243
+
244
+ summary_path = os.path.join(FLAGS.model_dir,
245
+ 'summaries/training_summary.txt')
246
+ self._run_and_report_benchmark(summary_path, use_ds=False)
247
+
248
+ def benchmark_8_gpu_amp_mrpc(self):
249
+ """Test BERT model performance with 8 GPUs with automatic mixed precision."""
250
+
251
+ self._setup()
252
+ self.num_gpus = 8
253
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_mrpc')
254
+ FLAGS.train_data_path = self.train_data_path
255
+ FLAGS.eval_data_path = self.eval_data_path
256
+ FLAGS.input_meta_data_path = self.input_meta_data_path
257
+ FLAGS.bert_config_file = self.bert_config_file
258
+ FLAGS.train_batch_size = 32
259
+ FLAGS.eval_batch_size = 32
260
+ FLAGS.dtype = 'fp16'
261
+ FLAGS.fp16_implementation = 'graph_rewrite'
262
+
263
+ summary_path = os.path.join(FLAGS.model_dir,
264
+ 'summaries/training_summary.txt')
265
+ self._run_and_report_benchmark(summary_path, use_ds=False)
266
+
267
+ @owner_utils.Owner('tf-model-garden')
268
+ def benchmark_2x2_tpu_mrpc(self):
269
+ """Test BERT model performance with 2x2 TPU."""
270
+
271
+ self._setup()
272
+ FLAGS.steps_per_loop = 50
273
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mrpc')
274
+ FLAGS.train_data_path = self.train_data_path
275
+ FLAGS.eval_data_path = self.eval_data_path
276
+ FLAGS.input_meta_data_path = self.input_meta_data_path
277
+ FLAGS.bert_config_file = self.bert_config_file
278
+ FLAGS.train_batch_size = 32
279
+ FLAGS.eval_batch_size = 32
280
+
281
+ summary_path = os.path.join(FLAGS.model_dir,
282
+ 'summaries/training_summary.txt')
283
+ self._run_and_report_benchmark(summary_path, use_ds=False)
284
+
285
+
286
+ class BertClassifyAccuracy(BertClassifyBenchmarkBase):
287
+ """Short accuracy test for BERT model.
288
+
289
+ Tests BERT classification task model accuracy. The naming
290
+ convention of below test cases follow
291
+ `benchmark_(number of gpus)_gpu_(dataset type)` format.
292
+ """
293
+
294
+ def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
295
+ self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
296
+ self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
297
+ self.bert_config_file = MODEL_CONFIG_FILE_PATH
298
+ self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
299
+ self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
300
+
301
+ super(BertClassifyAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
302
+
303
+ @benchmark_wrappers.enable_runtime_flags
304
+ def _run_and_report_benchmark(self,
305
+ training_summary_path,
306
+ min_accuracy=0.84,
307
+ max_accuracy=0.88):
308
+ """Starts BERT accuracy benchmark test."""
309
+
310
+ start_time_sec = time.time()
311
+ summary = self._run_bert_classifier(callbacks=[self.timer_callback])
312
+ wall_time_sec = time.time() - start_time_sec
313
+
314
+ super(BertClassifyAccuracy, self)._report_benchmark(
315
+ stats=summary,
316
+ wall_time_sec=wall_time_sec,
317
+ min_accuracy=min_accuracy,
318
+ max_accuracy=max_accuracy)
319
+
320
+ def _setup(self):
321
+ super(BertClassifyAccuracy, self)._setup()
322
+ FLAGS.train_data_path = self.train_data_path
323
+ FLAGS.eval_data_path = self.eval_data_path
324
+ FLAGS.input_meta_data_path = self.input_meta_data_path
325
+ FLAGS.bert_config_file = self.bert_config_file
326
+ FLAGS.init_checkpoint = self.pretrained_checkpoint_path
327
+
328
+ @owner_utils.Owner('tf-model-garden')
329
+ def benchmark_8_gpu_mrpc(self):
330
+ """Run BERT model accuracy test with 8 GPUs.
331
+
332
+ Due to comparatively small cardinality of MRPC dataset, training
333
+ accuracy metric has high variance between trainings. As so, we
334
+ set the wide range of allowed accuracy (84% to 88%).
335
+ """
336
+ self._setup()
337
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
338
+
339
+ summary_path = os.path.join(FLAGS.model_dir,
340
+ 'summaries/training_summary.txt')
341
+ self._run_and_report_benchmark(summary_path)
342
+
343
+ def benchmark_8_gpu_mrpc_xla(self):
344
+ """Run BERT model accuracy test with 8 GPUs with XLA."""
345
+ self._setup()
346
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc_xla')
347
+ FLAGS.enable_xla = True
348
+ summary_path = os.path.join(FLAGS.model_dir,
349
+ 'summaries/training_summary.txt')
350
+ self._run_and_report_benchmark(summary_path)
351
+
352
+ @owner_utils.Owner('tf-model-garden')
353
+ def benchmark_2x2_tpu_mrpc(self):
354
+ """Run BERT model accuracy test on 2x2 TPU."""
355
+ self._setup()
356
+ FLAGS.steps_per_loop = 50
357
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mrpc')
358
+
359
+ summary_path = os.path.join(FLAGS.model_dir,
360
+ 'summaries/training_summary.txt')
361
+ self._run_and_report_benchmark(summary_path)
362
+
363
+
364
+ if __name__ == '__main__':
365
+ tf.test.main()
models/official/benchmark/bert_benchmark_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utility functions or classes shared between BERT benchmarks."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import time
22
+
23
+ # pylint: disable=g-bad-import-order
24
+ import numpy as np
25
+ from absl import flags
26
+ import tensorflow as tf
27
+ # pylint: enable=g-bad-import-order
28
+
29
+ from official.utils.flags import core as flags_core
30
+ from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
31
+
32
+ FLAGS = flags.FLAGS
33
+
34
+
35
+ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
36
+ """Callback that records time it takes to run each batch."""
37
+
38
+ def __init__(self, num_batches_to_skip=10):
39
+ super(BenchmarkTimerCallback, self).__init__()
40
+ self.batch_start_times = {}
41
+ self.batch_stop_times = {}
42
+
43
+ def on_batch_begin(self, batch, logs=None):
44
+ self.batch_start_times[batch] = time.time()
45
+
46
+ def on_batch_end(self, batch, logs=None):
47
+ # If there are multiple steps_per_loop, the end batch index will not be the
48
+ # same as the starting index. Use the last starting index instead.
49
+ if batch not in self.batch_start_times:
50
+ batch = max(self.batch_start_times.keys())
51
+
52
+ self.batch_stop_times[batch] = time.time()
53
+
54
+ def get_examples_per_sec(self, batch_size, num_batches_to_skip=1):
55
+ batch_durations = []
56
+ for batch in self.batch_start_times:
57
+ if batch in self.batch_stop_times and batch >= num_batches_to_skip:
58
+ batch_durations.append(self.batch_stop_times[batch] -
59
+ self.batch_start_times[batch])
60
+ return batch_size / np.mean(batch_durations)
61
+
62
+ def get_startup_time(self, program_start_time):
63
+ return self.batch_start_times[0] - program_start_time
64
+
65
+
66
+ class BertBenchmarkBase(PerfZeroBenchmark):
67
+ """Base class to hold methods common to test classes."""
68
+ local_flags = None
69
+
70
+ def __init__(self, output_dir=None, tpu=None, **kwargs):
71
+ super(BertBenchmarkBase, self).__init__(
72
+ output_dir=output_dir, tpu=tpu, **kwargs)
73
+ self.num_gpus = 8
74
+ self.timer_callback = None
75
+
76
+ def _setup(self):
77
+ """Sets up and resets flags before each test."""
78
+ super(BertBenchmarkBase, self)._setup()
79
+ self.timer_callback = BenchmarkTimerCallback()
80
+
81
+ def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
82
+ """Report benchmark results by writing to local protobuf file.
83
+
84
+ Args:
85
+ stats: dict returned from BERT models with known entries.
86
+ wall_time_sec: the during of the benchmark execution in seconds
87
+ min_accuracy: Minimum classification accuracy constraint to verify
88
+ correctness of the model.
89
+ max_accuracy: Maximum classification accuracy constraint to verify
90
+ correctness of the model.
91
+ """
92
+ metrics = [{
93
+ 'name': 'training_loss',
94
+ 'value': stats['train_loss'],
95
+ }]
96
+ if self.timer_callback:
97
+ metrics.append({
98
+ 'name':
99
+ 'exp_per_second',
100
+ 'value':
101
+ self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
102
+ FLAGS.steps_per_loop)
103
+ })
104
+ else:
105
+ metrics.append({
106
+ 'name': 'exp_per_second',
107
+ 'value': 0.0,
108
+ })
109
+ if self.timer_callback and 'start_time_sec' in stats:
110
+ metrics.append({
111
+ 'name': 'startup_time',
112
+ 'value': self.timer_callback.get_startup_time(stats['start_time_sec'])
113
+ })
114
+
115
+ if 'eval_metrics' in stats:
116
+ metrics.append({
117
+ 'name': 'eval_accuracy',
118
+ 'value': stats['eval_metrics'],
119
+ 'min_value': min_accuracy,
120
+ 'max_value': max_accuracy,
121
+ })
122
+ flags_str = flags_core.get_nondefault_flags_as_str()
123
+ self.report_benchmark(
124
+ iters=stats['total_training_steps'],
125
+ wall_time=wall_time_sec,
126
+ metrics=metrics,
127
+ extras={'flags': flags_str})
models/official/benchmark/bert_pretrain_benchmark.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Lint as: python3
2
+ # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """Executes benchmark testing for bert pretraining."""
17
+ # pylint: disable=line-too-long
18
+ from __future__ import print_function
19
+
20
+ import json
21
+ import os
22
+ import time
23
+ from typing import Optional
24
+
25
+ from absl import flags
26
+ from absl import logging
27
+ import tensorflow as tf # pylint: disable=g-bad-import-order
28
+
29
+ from official.benchmark import benchmark_wrappers
30
+ from official.benchmark import bert_benchmark_utils
31
+ from official.benchmark import owner_utils
32
+ from official.nlp.bert import run_pretraining
33
+ from official.utils.flags import core as flags_core
34
+ from official.utils.misc import distribution_utils
35
+
36
+ # Pretrain masked lanauge modeling accuracy range:
37
+ MIN_MLM_ACCURACY = 0.635
38
+ MAX_MLM_ACCURACY = 0.645
39
+
40
+ # Pretrain next sentence prediction accuracy range:
41
+ MIN_NSP_ACCURACY = 0.94
42
+ MAX_NSP_ACCURACY = 0.96
43
+
44
+ BERT_PRETRAIN_FILES_SEQ128 = 'gs://mlcompass-data/bert/pretraining_data/seq_128/wikipedia.tfrecord*,gs://mlcompass-data/bert/pretraining_data/seq_128/books.tfrecord*'
45
+ BERT_BASE_CONFIG_FILE = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_config.json'
46
+
47
+ FLAGS = flags.FLAGS
48
+
49
+
50
+ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
51
+ """Benchmark accuracy tests for BERT Pretraining."""
52
+
53
+ def __init__(self,
54
+ output_dir: Optional[str] = None,
55
+ tpu: Optional[str] = None,
56
+ **kwargs):
57
+ """Inits BertPretrainAccuracyBenchmark class.
58
+
59
+ Args:
60
+ output_dir: Directory where to output e.g. log files
61
+ tpu: TPU name to use in a TPU benchmark.
62
+ **kwargs: Additional keyword arguments.
63
+ """
64
+ super(BertPretrainAccuracyBenchmark, self).__init__(
65
+ output_dir=output_dir, tpu=tpu, **kwargs)
66
+
67
+ @benchmark_wrappers.enable_runtime_flags
68
+ def _run_and_report_benchmark(self, summary_path: str, report_accuracy: bool):
69
+ """Runs and reports the benchmark given the provided configuration."""
70
+ distribution = distribution_utils.get_distribution_strategy(
71
+ distribution_strategy='tpu', tpu_address=self.tpu)
72
+ logging.info('Flags: %s', flags_core.get_nondefault_flags_as_str())
73
+ start_time_sec = time.time()
74
+ run_pretraining.run_bert_pretrain(
75
+ strategy=distribution, custom_callbacks=self.timer_callback)
76
+ wall_time_sec = time.time() - start_time_sec
77
+
78
+ with tf.io.gfile.GFile(summary_path, 'rb') as reader:
79
+ summary = json.loads(reader.read().decode('utf-8'))
80
+ self._report_benchmark(summary, start_time_sec, wall_time_sec,
81
+ report_accuracy)
82
+
83
+ def _report_benchmark(self, summary, start_time_sec, wall_time_sec,
84
+ report_accuracy):
85
+ metrics = [{
86
+ 'name': 'train_loss',
87
+ 'value': summary['train_loss'],
88
+ }, {
89
+ 'name':
90
+ 'exp_per_second',
91
+ 'value':
92
+ self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
93
+ FLAGS.steps_per_loop)
94
+ }, {
95
+ 'name': 'startup_time',
96
+ 'value': self.timer_callback.get_startup_time(start_time_sec)
97
+ }]
98
+ if report_accuracy:
99
+ metrics.extend([{
100
+ 'name': 'masked_lm_accuracy',
101
+ 'value': summary['masked_lm_accuracy'],
102
+ 'min_value': MIN_MLM_ACCURACY,
103
+ 'max_value': MAX_MLM_ACCURACY,
104
+ }, {
105
+ 'name': 'next_sentence_accuracy',
106
+ 'value': summary['next_sentence_accuracy'],
107
+ 'min_value': MIN_NSP_ACCURACY,
108
+ 'max_value': MAX_NSP_ACCURACY,
109
+ }])
110
+ self.report_benchmark(
111
+ iters=summary['total_training_steps'],
112
+ wall_time=wall_time_sec,
113
+ metrics=metrics,
114
+ extras={'flags': flags_core.get_nondefault_flags_as_str()})
115
+
116
+ def _specify_common_flags(self):
117
+ FLAGS.bert_config_file = BERT_BASE_CONFIG_FILE
118
+ FLAGS.train_batch_size = 512
119
+ FLAGS.learning_rate = 1e-4
120
+ FLAGS.warmup_steps = 10000
121
+ FLAGS.steps_per_loop = 10000
122
+ FLAGS.distribution_strategy = 'tpu'
123
+ FLAGS.input_files = BERT_PRETRAIN_FILES_SEQ128
124
+ FLAGS.max_seq_length = 128
125
+ FLAGS.max_predictions_per_seq = 20
126
+ FLAGS.dtype = 'bf16'
127
+
128
+ @owner_utils.Owner('tf-model-garden')
129
+ def benchmark_accuracy_8x8_tpu_bf16_seq128_500k_steps(self):
130
+ """Test bert pretraining with 8x8 TPU for 500k steps."""
131
+ # This is used for accuracy test.
132
+ self._setup()
133
+ self._specify_common_flags()
134
+ FLAGS.num_steps_per_epoch = 500000
135
+ FLAGS.num_train_epochs = 1
136
+ FLAGS.model_dir = self._get_model_dir(
137
+ 'benchmark_accuracy_8x8_tpu_bf16_seq128_500k_steps')
138
+ summary_path = os.path.join(FLAGS.model_dir,
139
+ 'summaries/training_summary.txt')
140
+ # Set train_summary_interval to -1 to disable training summary, because
141
+ # writing summary to gcs may fail and summaries are not needed for this
142
+ # accuracy benchmark test.
143
+ FLAGS.train_summary_interval = -1
144
+ self._run_and_report_benchmark(summary_path=summary_path,
145
+ report_accuracy=True)
146
+
147
+ @owner_utils.Owner('tf-model-garden')
148
+ def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps(self):
149
+ """Test bert pretraining with 4x4 TPU for 10000 steps."""
150
+ self._setup()
151
+ self._specify_common_flags()
152
+ FLAGS.num_steps_per_epoch = 5000
153
+ FLAGS.num_train_epochs = 2
154
+ FLAGS.model_dir = self._get_model_dir(
155
+ 'benchmark_perf_4x4_tpu_bf16_seq128_10k_steps')
156
+ summary_path = os.path.join(FLAGS.model_dir,
157
+ 'summaries/training_summary.txt')
158
+ # Disable accuracy check.
159
+ self._run_and_report_benchmark(
160
+ summary_path=summary_path, report_accuracy=False)
161
+
162
+ @owner_utils.Owner('tf-model-garden')
163
+ def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self):
164
+ """Test bert pretraining with 8x8 TPU for 10000 steps."""
165
+ self._setup()
166
+ self._specify_common_flags()
167
+ FLAGS.num_steps_per_epoch = 5000
168
+ FLAGS.num_train_epochs = 2
169
+ FLAGS.model_dir = self._get_model_dir(
170
+ 'benchmark_perf_8x8_tpu_bf16_seq128_10k_steps')
171
+ summary_path = os.path.join(FLAGS.model_dir,
172
+ 'summaries/training_summary.txt')
173
+ # Disable accuracy check.
174
+ self._run_and_report_benchmark(summary_path=summary_path,
175
+ report_accuracy=False)
176
+
177
+
178
+ if __name__ == '__main__':
179
+ tf.test.main()
models/official/benchmark/bert_squad_benchmark.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Executes BERT SQuAD benchmarks and accuracy tests."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import json
22
+ import os
23
+ import time
24
+
25
+ # pylint: disable=g-bad-import-order
26
+ from absl import flags
27
+ from absl import logging
28
+ from absl.testing import flagsaver
29
+ import tensorflow as tf
30
+ # pylint: enable=g-bad-import-order
31
+
32
+ from official.benchmark import bert_benchmark_utils as benchmark_utils
33
+ from official.benchmark import owner_utils
34
+ from official.nlp.bert import run_squad
35
+ from official.utils.misc import distribution_utils
36
+ from official.utils.misc import keras_utils
37
+ from official.benchmark import benchmark_wrappers
38
+
39
+
40
+ # pylint: disable=line-too-long
41
+ PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
42
+ SQUAD_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_train.tf_record'
43
+ SQUAD_PREDICT_FILE = 'gs://tf-perfzero-data/bert/squad/dev-v1.1.json'
44
+ SQUAD_VOCAB_FILE = 'gs://tf-perfzero-data/bert/squad/vocab.txt'
45
+ SQUAD_MEDIUM_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_medium_meta_data'
46
+ SQUAD_LONG_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_long_meta_data'
47
+ SQUAD_FULL_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_full_meta_data'
48
+ MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
49
+ # pylint: enable=line-too-long
50
+
51
+ TMP_DIR = os.getenv('TMPDIR')
52
+ FLAGS = flags.FLAGS
53
+
54
+
55
+ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
56
+ """Base class to hold methods common to test classes in the module."""
57
+
58
+ def __init__(self, output_dir=None, tpu=None):
59
+ super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir, tpu=tpu)
60
+
61
+ def _read_training_summary_from_file(self):
62
+ """Reads the training summary from a file."""
63
+ summary_path = os.path.join(FLAGS.model_dir,
64
+ 'summaries/training_summary.txt')
65
+ with tf.io.gfile.GFile(summary_path, 'rb') as reader:
66
+ return json.loads(reader.read().decode('utf-8'))
67
+
68
+ def _read_input_meta_data_from_file(self):
69
+ """Reads the input metadata from a file."""
70
+ with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
71
+ return json.loads(reader.read().decode('utf-8'))
72
+
73
+ def _get_distribution_strategy(self, ds_type='mirrored'):
74
+ """Gets the distribution strategy.
75
+
76
+ Args:
77
+ ds_type: String, the distribution strategy type to be used. Can be
78
+ 'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'.
79
+
80
+ Returns:
81
+ A `tf.distribute.DistibutionStrategy` object.
82
+ """
83
+ if self.tpu or ds_type == 'tpu':
84
+ return distribution_utils.get_distribution_strategy(
85
+ distribution_strategy='tpu', tpu_address=self.tpu)
86
+ elif ds_type == 'multi_worker_mirrored':
87
+ # Configures cluster spec for multi-worker distribution strategy.
88
+ _ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
89
+ FLAGS.task_index)
90
+ return distribution_utils.get_distribution_strategy(
91
+ distribution_strategy=ds_type,
92
+ num_gpus=self.num_gpus,
93
+ all_reduce_alg=FLAGS.all_reduce_alg)
94
+
95
+ def _init_gpu_and_data_threads(self):
96
+ """Set env variables before any TF calls."""
97
+ if FLAGS.tf_gpu_thread_mode:
98
+ keras_utils.set_gpu_thread_mode_and_count(
99
+ per_gpu_thread_count=FLAGS.per_gpu_thread_count,
100
+ gpu_thread_mode=FLAGS.tf_gpu_thread_mode,
101
+ num_gpus=self.num_gpus,
102
+ datasets_num_private_threads=FLAGS.datasets_num_private_threads)
103
+
104
+ @flagsaver.flagsaver
105
+ def _train_squad(self, run_eagerly=False, ds_type='mirrored'):
106
+ """Runs BERT SQuAD training. Uses mirrored strategy by default."""
107
+ self._init_gpu_and_data_threads()
108
+ input_meta_data = self._read_input_meta_data_from_file()
109
+ strategy = self._get_distribution_strategy(ds_type)
110
+
111
+ run_squad.train_squad(
112
+ strategy=strategy,
113
+ input_meta_data=input_meta_data,
114
+ run_eagerly=run_eagerly,
115
+ custom_callbacks=[self.timer_callback])
116
+
117
+ @flagsaver.flagsaver
118
+ def _evaluate_squad(self, ds_type='mirrored'):
119
+ """Runs BERT SQuAD evaluation. Uses mirrored strategy by default."""
120
+ self._init_gpu_and_data_threads()
121
+ input_meta_data = self._read_input_meta_data_from_file()
122
+ strategy = self._get_distribution_strategy(ds_type)
123
+
124
+ if input_meta_data.get('version_2_with_negative', False):
125
+ logging.error('In memory evaluation result for SQuAD v2 is not accurate')
126
+ eval_metrics = run_squad.eval_squad(strategy=strategy,
127
+ input_meta_data=input_meta_data)
128
+ # Use F1 score as reported evaluation metric.
129
+ self.eval_metrics = eval_metrics['final_f1']
130
+
131
+
132
+ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
133
+ """Short benchmark performance tests for BERT SQuAD model.
134
+
135
+ Tests BERT SQuAD performance in different GPU configurations.
136
+ The naming convention of below test cases follow
137
+ `benchmark_(number of gpus)_gpu` format for GPUs and
138
+ `benchmark_(topology)_tpu` format for TPUs.
139
+ """
140
+
141
+ def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
142
+ super(BertSquadBenchmarkReal, self).__init__(output_dir=output_dir, tpu=tpu)
143
+
144
+ def _setup(self):
145
+ """Sets up the benchmark and SQuAD flags."""
146
+ super(BertSquadBenchmarkReal, self)._setup()
147
+ FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
148
+ FLAGS.predict_file = SQUAD_PREDICT_FILE
149
+ FLAGS.vocab_file = SQUAD_VOCAB_FILE
150
+ FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
151
+ FLAGS.num_train_epochs = 1
152
+ FLAGS.steps_per_loop = 100
153
+
154
+ @benchmark_wrappers.enable_runtime_flags
155
+ def _run_and_report_benchmark(self,
156
+ run_eagerly=False,
157
+ ds_type='mirrored'):
158
+ """Runs the benchmark and reports various metrics."""
159
+ if FLAGS.train_batch_size <= 4 or run_eagerly:
160
+ FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
161
+ else:
162
+ FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
163
+ start_time_sec = time.time()
164
+ self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
165
+ wall_time_sec = time.time() - start_time_sec
166
+
167
+ summary = self._read_training_summary_from_file()
168
+ summary['start_time_sec'] = start_time_sec
169
+
170
+ super(BertSquadBenchmarkReal, self)._report_benchmark(
171
+ stats=summary,
172
+ wall_time_sec=wall_time_sec,
173
+ min_accuracy=0,
174
+ max_accuracy=1)
175
+
176
+ def benchmark_1_gpu(self):
177
+ """Tests BERT SQuAD model performance with 1 GPU."""
178
+
179
+ self._setup()
180
+ self.num_gpus = 1
181
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad')
182
+ FLAGS.train_batch_size = 4
183
+
184
+ self._run_and_report_benchmark()
185
+
186
+ def benchmark_1_gpu_eager(self):
187
+ """Tests BERT SQuAD model performance with 1 GPU."""
188
+
189
+ self._setup()
190
+ self.num_gpus = 1
191
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
192
+ FLAGS.train_batch_size = 2
193
+
194
+ self._run_and_report_benchmark(run_eagerly=True)
195
+
196
+ def benchmark_1_gpu_xla(self):
197
+ """Tests BERT SQuAD model performance with 1 GPU with XLA."""
198
+
199
+ self._setup()
200
+ self.num_gpus = 1
201
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad')
202
+ # XLA runs out of memory when running with batch size 4.
203
+ FLAGS.train_batch_size = 3
204
+ FLAGS.enable_xla = True
205
+
206
+ self._run_and_report_benchmark()
207
+
208
+ def benchmark_1_gpu_no_dist_strat(self):
209
+ """Tests BERT SQuAD model performance with 1 GPU without DS."""
210
+
211
+ self._setup()
212
+ self.num_gpus = 1
213
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat_squad')
214
+ FLAGS.train_batch_size = 4
215
+
216
+ self._run_and_report_benchmark(ds_type='off')
217
+
218
+ def benchmark_1_gpu_eager_no_dist_strat(self):
219
+ """Tests BERT SQuAD model performance with 1 GPU with eager execution."""
220
+
221
+ self._setup()
222
+ self.num_gpus = 1
223
+ FLAGS.model_dir = self._get_model_dir(
224
+ 'benchmark_1_gpu_eager_no_dist_strat_squad')
225
+ FLAGS.train_batch_size = 4
226
+
227
+ self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
228
+
229
+ @owner_utils.Owner('tf-model-garden')
230
+ def benchmark_8_gpu(self):
231
+ """Tests BERT SQuAD model performance with 8 GPUs."""
232
+
233
+ self._setup()
234
+ self.num_gpus = 8
235
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
236
+ FLAGS.train_batch_size = 24
237
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
238
+
239
+ self._run_and_report_benchmark()
240
+
241
+ def benchmark_1_gpu_fp16_eager(self):
242
+ """Tests BERT SQuAD model performance with 1 GPU and FP16."""
243
+
244
+ self._setup()
245
+ self.num_gpus = 1
246
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_fp16_eager')
247
+ FLAGS.train_batch_size = 4
248
+ FLAGS.dtype = 'fp16'
249
+ FLAGS.loss_scale = 'dynamic'
250
+
251
+ self._run_and_report_benchmark(run_eagerly=True)
252
+
253
+ def benchmark_1_gpu_fp16(self):
254
+ """Tests BERT SQuAD model performance with 1 GPU and FP16."""
255
+
256
+ self._setup()
257
+ self.num_gpus = 1
258
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_fp16')
259
+ FLAGS.train_batch_size = 4
260
+ FLAGS.dtype = 'fp16'
261
+ FLAGS.loss_scale = 'dynamic'
262
+
263
+ self._run_and_report_benchmark()
264
+
265
+ def benchmark_1_gpu_xla_fp16(self):
266
+ """Tests BERT SQuAD model performance with 1 GPU with XLA and FP16."""
267
+
268
+ self._setup()
269
+ self.num_gpus = 1
270
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad_fp16')
271
+ FLAGS.train_batch_size = 4
272
+ FLAGS.enable_xla = True
273
+ FLAGS.dtype = 'fp16'
274
+ FLAGS.loss_scale = 'dynamic'
275
+
276
+ self._run_and_report_benchmark()
277
+
278
+ def benchmark_8_gpu_fp16(self):
279
+ """Tests BERT SQuAD model performance with 8 GPUs."""
280
+
281
+ self._setup()
282
+ self.num_gpus = 8
283
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
284
+ FLAGS.train_batch_size = 32
285
+ FLAGS.dtype = 'fp16'
286
+ FLAGS.loss_scale = 'dynamic'
287
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
288
+
289
+ self._run_and_report_benchmark()
290
+
291
+ def benchmark_8_gpu_xla_fp16(self):
292
+ """Tests BERT SQuAD model performance with 8 GPUs with XLA."""
293
+
294
+ self._setup()
295
+ self.num_gpus = 8
296
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
297
+ FLAGS.train_batch_size = 32
298
+ FLAGS.enable_xla = True
299
+ FLAGS.dtype = 'fp16'
300
+ FLAGS.loss_scale = 'dynamic'
301
+
302
+ self._run_and_report_benchmark()
303
+
304
+ def benchmark_1_gpu_amp(self):
305
+ """Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
306
+
307
+ self._setup()
308
+ self.num_gpus = 1
309
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp_squad')
310
+ FLAGS.train_batch_size = 4
311
+ FLAGS.dtype = 'fp16'
312
+ FLAGS.fp16_implementation = 'graph_rewrite'
313
+
314
+ self._run_and_report_benchmark()
315
+
316
+ def benchmark_8_gpu_amp(self):
317
+ """Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
318
+
319
+ self._setup()
320
+ self.num_gpus = 8
321
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_squad')
322
+ FLAGS.train_batch_size = 32
323
+ FLAGS.dtype = 'fp16'
324
+ FLAGS.fp16_implementation = 'graph_rewrite'
325
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
326
+
327
+ self._run_and_report_benchmark()
328
+
329
+ @owner_utils.Owner('tf-model-garden')
330
+ def benchmark_2x2_tpu(self):
331
+ """Tests BERT SQuAD model performance with 2x2 TPU."""
332
+
333
+ self._setup()
334
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
335
+ FLAGS.train_batch_size = 48
336
+ FLAGS.predict_batch_size = 48
337
+ FLAGS.mode = 'train'
338
+ FLAGS.learning_rate = 8e-5
339
+ FLAGS.num_train_epochs = 1
340
+ FLAGS.steps_per_loop = 100
341
+ FLAGS.do_lower_case = True
342
+ FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
343
+ self._run_and_report_benchmark()
344
+
345
+
346
+ class BertSquadAccuracy(BertSquadBenchmarkBase):
347
+ """Short accuracy test for BERT SQuAD model.
348
+
349
+ Tests BERT SQuAD accuracy. The naming convention of below test cases follow
350
+ `benchmark_(number of gpus)_gpu` format for GPUs and
351
+ `benchmark_(topology)_tpu` format for TPUs.
352
+ """
353
+
354
+ def __init__(self, output_dir=None, tpu=None, **kwargs):
355
+ super(BertSquadAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
356
+
357
+ def _setup(self):
358
+ """Sets up the benchmark and SQuAD flags."""
359
+ super(BertSquadAccuracy, self)._setup()
360
+ FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
361
+ FLAGS.predict_file = SQUAD_PREDICT_FILE
362
+ FLAGS.vocab_file = SQUAD_VOCAB_FILE
363
+ FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
364
+ FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
365
+ FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
366
+ FLAGS.num_train_epochs = 2
367
+ FLAGS.steps_per_loop = 100
368
+
369
+ @benchmark_wrappers.enable_runtime_flags
370
+ def _run_and_report_benchmark(self,
371
+ run_eagerly=False,
372
+ ds_type='mirrored'):
373
+ """Runs the benchmark and reports various metrics."""
374
+ start_time_sec = time.time()
375
+ self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
376
+ self._evaluate_squad(ds_type=ds_type)
377
+ wall_time_sec = time.time() - start_time_sec
378
+
379
+ summary = self._read_training_summary_from_file()
380
+ summary['eval_metrics'] = self.eval_metrics
381
+ summary['start_time_sec'] = start_time_sec
382
+
383
+ super(BertSquadAccuracy, self)._report_benchmark(
384
+ stats=summary,
385
+ wall_time_sec=wall_time_sec,
386
+ min_accuracy=0.900,
387
+ max_accuracy=0.920)
388
+
389
+ def benchmark_1_gpu_eager(self):
390
+ """Tests BERT SQuAD model accuracy with 1 GPU with eager execution."""
391
+
392
+ self._setup()
393
+ self.num_gpus = 1
394
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
395
+ FLAGS.train_batch_size = 4
396
+
397
+ self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
398
+
399
+ @owner_utils.Owner('tf-model-garden')
400
+ def benchmark_8_gpu(self):
401
+ """Tests BERT SQuAD model accuracy with 8 GPUs."""
402
+
403
+ self._setup()
404
+ self.num_gpus = 8
405
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
406
+ FLAGS.train_batch_size = 24
407
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
408
+
409
+ self._run_and_report_benchmark()
410
+
411
+ def benchmark_8_gpu_fp16(self):
412
+ """Tests BERT SQuAD model accuracy with 8 GPUs and FP16."""
413
+
414
+ self._setup()
415
+ self.num_gpus = 8
416
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_fp16')
417
+ FLAGS.train_batch_size = 32
418
+ FLAGS.dtype = 'fp16'
419
+ FLAGS.loss_scale = 'dynamic'
420
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
421
+
422
+ self._run_and_report_benchmark()
423
+
424
+ def benchmark_8_gpu_xla(self):
425
+ """Tests BERT SQuAD model accuracy with 8 GPUs."""
426
+
427
+ self._setup()
428
+ self.num_gpus = 8
429
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_xla')
430
+ FLAGS.train_batch_size = 32
431
+ FLAGS.enable_xla = True
432
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
433
+
434
+ self._run_and_report_benchmark()
435
+
436
+ @owner_utils.Owner('tf-model-garden')
437
+ def benchmark_2x2_tpu(self):
438
+ """Tests BERT SQuAD model accuracy with 2x2 TPU."""
439
+
440
+ self._setup()
441
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
442
+ FLAGS.train_batch_size = 48
443
+
444
+ self._run_and_report_benchmark()
445
+
446
+
447
+ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
448
+ """BERT SQuAD distributed accuracy tests with multiple workers."""
449
+
450
+ def __init__(self, output_dir=None, tpu=None, **kwargs):
451
+ super(BertSquadMultiWorkerAccuracy, self).__init__(
452
+ output_dir=output_dir, tpu=tpu)
453
+
454
+ def _setup(self):
455
+ """Sets up the benchmark and SQuAD flags."""
456
+ super(BertSquadMultiWorkerAccuracy, self)._setup()
457
+ FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
458
+ FLAGS.predict_file = SQUAD_PREDICT_FILE
459
+ FLAGS.vocab_file = SQUAD_VOCAB_FILE
460
+ FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
461
+ FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
462
+ FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
463
+ FLAGS.num_train_epochs = 2
464
+ FLAGS.steps_per_loop = 100
465
+
466
+ @benchmark_wrappers.enable_runtime_flags
467
+ def _run_and_report_benchmark(self,
468
+ use_ds=True,
469
+ run_eagerly=False):
470
+ """Runs the benchmark and reports various metrics."""
471
+ start_time_sec = time.time()
472
+ self._train_squad(run_eagerly=run_eagerly,
473
+ ds_type='multi_worker_mirrored')
474
+ self._evaluate_squad(ds_type='multi_worker_mirrored')
475
+ wall_time_sec = time.time() - start_time_sec
476
+
477
+ summary = self._read_training_summary_from_file()
478
+ summary['eval_metrics'] = self.eval_metrics
479
+
480
+ super(BertSquadMultiWorkerAccuracy, self)._report_benchmark(
481
+ stats=summary,
482
+ wall_time_sec=wall_time_sec,
483
+ min_accuracy=0.900,
484
+ max_accuracy=0.920)
485
+
486
+ def _benchmark_common(self, num_workers, all_reduce_alg):
487
+ """Common to all benchmarks in this class."""
488
+ self._setup()
489
+
490
+ num_gpus = 8
491
+ FLAGS.num_gpus = num_gpus
492
+ FLAGS.dtype = 'fp16'
493
+ FLAGS.enable_xla = False
494
+ FLAGS.distribution_strategy = 'multi_worker_mirrored'
495
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
496
+ FLAGS.datasets_num_private_threads = 32
497
+ FLAGS.model_dir = self._get_model_dir(
498
+ 'benchmark_8_gpu_{}_worker_fp16_{}_tweaked'.format(
499
+ num_workers, all_reduce_alg))
500
+ FLAGS.train_batch_size = 4 * num_gpus * num_workers
501
+ FLAGS.all_reduce_alg = all_reduce_alg
502
+
503
+ self._run_and_report_benchmark()
504
+
505
+ def benchmark_eager_8_gpu_2_workers_fp16_ring_tweaked(self):
506
+ """8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
507
+ self._benchmark_common(num_workers=2, all_reduce_alg='ring')
508
+
509
+ def benchmark_eager_8_gpu_2_workers_fp16_nccl_tweaked(self):
510
+ """8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
511
+ self._benchmark_common(num_workers=2, all_reduce_alg='nccl')
512
+
513
+ def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
514
+ """8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
515
+ self._benchmark_common(num_workers=8, all_reduce_alg='ring')
516
+
517
+ def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
518
+ """8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
519
+ self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
520
+
521
+
522
+ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
523
+ """BERT SQuAD distributed benchmark tests with multiple workers."""
524
+
525
+ def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
526
+ super(BertSquadMultiWorkerBenchmark, self).__init__(
527
+ output_dir=output_dir, tpu=tpu)
528
+
529
+ def _setup(self):
530
+ """Sets up the benchmark and SQuAD flags."""
531
+ super(BertSquadMultiWorkerBenchmark, self)._setup()
532
+ FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
533
+ FLAGS.predict_file = SQUAD_PREDICT_FILE
534
+ FLAGS.vocab_file = SQUAD_VOCAB_FILE
535
+ FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
536
+ FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
537
+ FLAGS.num_train_epochs = 1
538
+ FLAGS.steps_per_loop = 100
539
+
540
+ @benchmark_wrappers.enable_runtime_flags
541
+ def _run_and_report_benchmark(self,
542
+ use_ds=True,
543
+ run_eagerly=False):
544
+ """Runs the benchmark and reports various metrics."""
545
+ if FLAGS.train_batch_size <= 4 * 8:
546
+ FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
547
+ else:
548
+ FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
549
+ start_time_sec = time.time()
550
+ self._train_squad(run_eagerly=run_eagerly,
551
+ ds_type='multi_worker_mirrored')
552
+ wall_time_sec = time.time() - start_time_sec
553
+
554
+ summary = self._read_training_summary_from_file()
555
+ summary['start_time_sec'] = start_time_sec
556
+
557
+ super(BertSquadMultiWorkerBenchmark, self)._report_benchmark(
558
+ stats=summary,
559
+ wall_time_sec=wall_time_sec,
560
+ min_accuracy=0,
561
+ max_accuracy=1)
562
+
563
+ def _benchmark_common(self, num_workers, all_reduce_alg):
564
+ """Common to all benchmarks in this class."""
565
+ self._setup()
566
+
567
+ num_gpus = 8
568
+ FLAGS.num_gpus = num_gpus
569
+ FLAGS.dtype = 'fp16'
570
+ FLAGS.enable_xla = False
571
+ FLAGS.distribution_strategy = 'multi_worker_mirrored'
572
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
573
+ FLAGS.datasets_num_private_threads = 32
574
+ FLAGS.model_dir = self._get_model_dir(
575
+ 'benchmark_8_gpu_{}_worker_fp16_{}_tweaked'.format(
576
+ num_workers, all_reduce_alg))
577
+ FLAGS.train_batch_size = 4 * num_gpus * num_workers
578
+ FLAGS.all_reduce_alg = all_reduce_alg
579
+
580
+ self._run_and_report_benchmark()
581
+
582
+ def benchmark_8_gpu_1_worker_fp16_ring_tweaked(self):
583
+ """8 GPUs per worker, 1 worker, fp16, ring all-reduce."""
584
+ self._benchmark_common(num_workers=1, all_reduce_alg='ring')
585
+
586
+ def benchmark_8_gpu_1_worker_fp16_nccl_tweaked(self):
587
+ """8 GPUs per worker, 1 worker, fp16, nccl all-reduce."""
588
+ self._benchmark_common(num_workers=1, all_reduce_alg='nccl')
589
+
590
+ def benchmark_8_gpu_2_workers_fp16_ring_tweaked(self):
591
+ """8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
592
+ self._benchmark_common(num_workers=2, all_reduce_alg='ring')
593
+
594
+ def benchmark_8_gpu_2_workers_fp16_nccl_tweaked(self):
595
+ """8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
596
+ self._benchmark_common(num_workers=2, all_reduce_alg='nccl')
597
+
598
+ def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
599
+ """8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
600
+ self._benchmark_common(num_workers=8, all_reduce_alg='ring')
601
+
602
+ def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
603
+ """8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
604
+ self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
605
+
606
+
607
+ if __name__ == '__main__':
608
+ tf.test.main()
models/official/benchmark/datastore/schema/benchmark_metric.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "description": "The ID of the benchmark run, where this metric should tie to.",
4
+ "mode": "REQUIRED",
5
+ "name": "run_id",
6
+ "type": "STRING"
7
+ },
8
+ {
9
+ "description": "The name of the metric, which should be descriptive. E.g. training_loss, accuracy.",
10
+ "mode": "REQUIRED",
11
+ "name": "name",
12
+ "type": "STRING"
13
+ },
14
+ {
15
+ "description": "The unit of the metric. E.g. MB per sec.",
16
+ "mode": "NULLABLE",
17
+ "name": "unit",
18
+ "type": "STRING"
19
+ },
20
+ {
21
+ "description": "The value of the metric.",
22
+ "mode": "NULLABLE",
23
+ "name": "value",
24
+ "type": "FLOAT"
25
+ },
26
+ {
27
+ "description": "The timestamp when the metric is recorded.",
28
+ "mode": "REQUIRED",
29
+ "name": "timestamp",
30
+ "type": "TIMESTAMP"
31
+ },
32
+ {
33
+ "description": "The global step when this metric is recorded.",
34
+ "mode": "NULLABLE",
35
+ "name": "global_step",
36
+ "type": "INTEGER"
37
+ },
38
+ {
39
+ "description": "Free format metadata for the extra information about the metric.",
40
+ "mode": "REPEATED",
41
+ "name": "extras",
42
+ "type": "RECORD",
43
+ "fields": [
44
+ {
45
+ "mode": "NULLABLE",
46
+ "name": "name",
47
+ "type": "STRING"
48
+ },
49
+ {
50
+ "mode": "NULLABLE",
51
+ "name": "value",
52
+ "type": "STRING"
53
+ }
54
+ ]
55
+ }
56
+ ]
models/official/benchmark/datastore/schema/benchmark_run.json ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "description": "The UUID of the run for the benchmark.",
4
+ "mode": "REQUIRED",
5
+ "name": "model_id",
6
+ "type": "STRING"
7
+ },
8
+ {
9
+ "description": "The name of the model, E.g ResNet50, LeNet-5 etc.",
10
+ "mode": "REQUIRED",
11
+ "name": "model_name",
12
+ "type": "STRING"
13
+ },
14
+ {
15
+ "description": "The date when the test of the model is started",
16
+ "mode": "REQUIRED",
17
+ "name": "run_date",
18
+ "type": "TIMESTAMP"
19
+ },
20
+ {
21
+ "description": "The unique name for a test by the combination of key parameters, eg batch size, num of GPU, etc. It is hardware independent.",
22
+ "mode": "NULLABLE",
23
+ "name": "test_id",
24
+ "type": "STRING"
25
+ },
26
+ {
27
+ "description": "The tensorflow version information.",
28
+ "fields": [
29
+ {
30
+ "description": "Version of the tensorflow. E.g. 1.7.0-rc0",
31
+ "mode": "REQUIRED",
32
+ "name": "version",
33
+ "type": "STRING"
34
+ },
35
+ {
36
+ "description": "Git Hash of the tensorflow",
37
+ "mode": "NULLABLE",
38
+ "name": "git_hash",
39
+ "type": "STRING"
40
+ },
41
+ {
42
+ "description": "The channel of the tensorflow binary, eg, nightly, RC, final, custom.",
43
+ "mode": "NULLABLE",
44
+ "name": "channel",
45
+ "type": "STRING"
46
+ },
47
+ {
48
+ "description": "Identify anything special about the build, eg CUDA 10, NCCL, MKL, etc.",
49
+ "mode": "NULLABLE",
50
+ "name": "build_type",
51
+ "type": "STRING"
52
+ }
53
+ ],
54
+ "mode": "REQUIRED",
55
+ "name": "tensorflow_version",
56
+ "type": "RECORD"
57
+ },
58
+ {
59
+ "description": "The arbitrary attribute of the model.",
60
+ "fields": [
61
+ {
62
+ "description": "The name of the attribute.",
63
+ "mode": "REQUIRED",
64
+ "name": "name",
65
+ "type": "STRING"
66
+ },
67
+ {
68
+ "description": "The value of the attribute.",
69
+ "mode": "NULLABLE",
70
+ "name": "value",
71
+ "type": "STRING"
72
+ }
73
+ ],
74
+ "mode": "REPEATED",
75
+ "name": "attribute",
76
+ "type": "RECORD"
77
+ },
78
+ {
79
+ "description": "Environment variables when the benchmark run is executed.",
80
+ "fields": [
81
+ {
82
+ "description": "The name of the variable.",
83
+ "mode": "REQUIRED",
84
+ "name": "name",
85
+ "type": "STRING"
86
+ },
87
+ {
88
+ "description": "The value of the variable.",
89
+ "mode": "NULLABLE",
90
+ "name": "value",
91
+ "type": "STRING"
92
+ }
93
+ ],
94
+ "mode": "REPEATED",
95
+ "name": "environment_variable",
96
+ "type": "RECORD"
97
+ },
98
+ {
99
+ "description": "TF Environment variables when the benchmark run is executed.",
100
+ "fields": [
101
+ {
102
+ "description": "The name of the variable.",
103
+ "mode": "REQUIRED",
104
+ "name": "name",
105
+ "type": "STRING"
106
+ },
107
+ {
108
+ "description": "The value of the variable.",
109
+ "mode": "NULLABLE",
110
+ "name": "value",
111
+ "type": "STRING"
112
+ }
113
+ ],
114
+ "mode": "REPEATED",
115
+ "name": "tensorflow_environment_variables",
116
+ "type": "RECORD"
117
+ },
118
+ {
119
+ "description": "The list of parameters run with the model. It could contain hyperparameters or others.",
120
+ "fields": [
121
+ {
122
+ "description": "The name of the parameter.",
123
+ "mode": "REQUIRED",
124
+ "name": "name",
125
+ "type": "STRING"
126
+ },
127
+ {
128
+ "description": "The string value of the parameter.",
129
+ "mode": "NULLABLE",
130
+ "name": "string_value",
131
+ "type": "STRING"
132
+ },
133
+ {
134
+ "description": "The bool value of the parameter.",
135
+ "mode": "NULLABLE",
136
+ "name": "bool_value",
137
+ "type": "STRING"
138
+ },
139
+ {
140
+ "description": "The int/long value of the parameter.",
141
+ "mode": "NULLABLE",
142
+ "name": "long_value",
143
+ "type": "INTEGER"
144
+ },
145
+ {
146
+ "description": "The double/float value of parameter.",
147
+ "mode": "NULLABLE",
148
+ "name": "float_value",
149
+ "type": "FLOAT"
150
+ }
151
+ ],
152
+ "mode": "REPEATED",
153
+ "name": "run_parameters",
154
+ "type": "RECORD"
155
+ },
156
+ {
157
+ "description": "The dataset that run with the benchmark.",
158
+ "mode": "NULLABLE",
159
+ "name": "dataset",
160
+ "type": "RECORD",
161
+ "fields": [
162
+ {
163
+ "description": "The name of the dataset that the model is trained/validated with. E.g ImageNet, mnist.",
164
+ "mode": "REQUIRED",
165
+ "name": "name",
166
+ "type": "STRING"
167
+ },
168
+ {
169
+ "description": "The arbitrary attribute of the dataset.",
170
+ "fields": [
171
+ {
172
+ "description": "The name of the attribute.",
173
+ "mode": "REQUIRED",
174
+ "name": "name",
175
+ "type": "STRING"
176
+ },
177
+ {
178
+ "description": "The value of the attribute.",
179
+ "mode": "NULLABLE",
180
+ "name": "value",
181
+ "type": "STRING"
182
+ }
183
+ ],
184
+ "mode": "REPEATED",
185
+ "name": "attribute",
186
+ "type": "RECORD"
187
+ }
188
+ ]
189
+ },
190
+ {
191
+ "description": "Used to differentiate from AWS, GCE or DGX-1 at a high level",
192
+ "mode": "NULLABLE",
193
+ "name": "test_environment",
194
+ "type": "STRING"
195
+ },
196
+ {
197
+ "description": "The machine configuration of the benchmark run.",
198
+ "mode": "NULLABLE",
199
+ "name": "machine_config",
200
+ "type": "RECORD",
201
+ "fields": [
202
+ {
203
+ "description": "The platform information of the benchmark run.",
204
+ "mode": "NULLABLE",
205
+ "name": "platform_info",
206
+ "type": "RECORD",
207
+ "fields": [
208
+ {
209
+ "description": "Eg: 64bit.",
210
+ "mode": "NULLABLE",
211
+ "name": "bits",
212
+ "type": "STRING"
213
+ },
214
+ {
215
+ "description": "Eg: ELF.",
216
+ "mode": "NULLABLE",
217
+ "name": "linkage",
218
+ "type": "STRING"
219
+ },
220
+ {
221
+ "description": "Eg: i386.",
222
+ "mode": "NULLABLE",
223
+ "name": "machine",
224
+ "type": "STRING"
225
+ },
226
+ {
227
+ "description": "Eg: 3.13.0-76-generic.",
228
+ "mode": "NULLABLE",
229
+ "name": "release",
230
+ "type": "STRING"
231
+ },
232
+ {
233
+ "description": "Eg: Linux.",
234
+ "mode": "NULLABLE",
235
+ "name": "system",
236
+ "type": "STRING"
237
+ },
238
+ {
239
+ "description": "Eg: #120-Ubuntu SMP Mon Jan 18 15:59:10 UTC 2016.",
240
+ "mode": "NULLABLE",
241
+ "name": "version",
242
+ "type": "STRING"
243
+ }
244
+ ]
245
+ },
246
+ {
247
+ "description": "The CPU information of the benchmark run.",
248
+ "mode": "NULLABLE",
249
+ "name": "cpu_info",
250
+ "type": "RECORD",
251
+ "fields": [
252
+ {
253
+ "mode": "NULLABLE",
254
+ "name": "num_cores",
255
+ "type": "INTEGER"
256
+ },
257
+ {
258
+ "mode": "NULLABLE",
259
+ "name": "num_cores_allowed",
260
+ "type": "INTEGER"
261
+ },
262
+ {
263
+ "description" : "How fast are those CPUs.",
264
+ "mode": "NULLABLE",
265
+ "name": "mhz_per_cpu",
266
+ "type": "FLOAT"
267
+ },
268
+ {
269
+ "description" : "Additional CPU info, Eg: Intel Ivybridge with HyperThreading (24 cores).",
270
+ "mode": "NULLABLE",
271
+ "name": "cpu_info",
272
+ "type": "STRING"
273
+ },
274
+ {
275
+ "description" : "What kind of cpu scaling is enabled on the host. Eg performance, ondemand, conservative, mixed.",
276
+ "mode": "NULLABLE",
277
+ "name": "cpu_governor",
278
+ "type": "STRING"
279
+ },
280
+ {
281
+ "description": "Cache size of the CPUs.",
282
+ "mode": "NULLABLE",
283
+ "name": "cache_size",
284
+ "type": "RECORD",
285
+ "fields": [
286
+ {
287
+ "mode": "NULLABLE",
288
+ "name": "level",
289
+ "type": "STRING"
290
+ },
291
+ {
292
+ "mode": "NULLABLE",
293
+ "name": "size",
294
+ "type": "INTEGER"
295
+ }
296
+ ]
297
+ }
298
+ ]
299
+ },
300
+ {
301
+ "mode": "NULLABLE",
302
+ "name": "gpu_info",
303
+ "type": "RECORD",
304
+ "fields": [
305
+ {
306
+ "mode": "NULLABLE",
307
+ "name": "count",
308
+ "type": "INTEGER"
309
+ },
310
+ {
311
+ "mode": "NULLABLE",
312
+ "name": "model",
313
+ "type": "STRING"
314
+ },
315
+ {
316
+ "mode": "NULLABLE",
317
+ "name": "cuda_version",
318
+ "type": "STRING"
319
+ }
320
+ ]
321
+ },
322
+ {
323
+ "description": "The cloud instance inforation if the benchmark run is executed on cloud",
324
+ "mode": "NULLABLE",
325
+ "name": "cloud_info",
326
+ "type": "RECORD",
327
+ "fields": [
328
+ {
329
+ "description": "The instance type, E.g. n1-standard-4.",
330
+ "mode": "NULLABLE",
331
+ "name": "instance_type",
332
+ "type": "STRING"
333
+ },
334
+ {
335
+ "description": "The arbitrary attribute of the cloud info.",
336
+ "fields": [
337
+ {
338
+ "description": "The name of the attribute.",
339
+ "mode": "REQUIRED",
340
+ "name": "name",
341
+ "type": "STRING"
342
+ },
343
+ {
344
+ "description": "The value of the attribute.",
345
+ "mode": "NULLABLE",
346
+ "name": "value",
347
+ "type": "STRING"
348
+ }
349
+ ],
350
+ "mode": "REPEATED",
351
+ "name": "attribute",
352
+ "type": "RECORD"
353
+ }
354
+ ]
355
+ },
356
+ {
357
+ "mode": "NULLABLE",
358
+ "name": "memory_total",
359
+ "type": "INTEGER"
360
+ },
361
+ {
362
+ "mode": "NULLABLE",
363
+ "name": "memory_available",
364
+ "type": "STRING"
365
+ }
366
+ ]
367
+ }
368
+ ]
models/official/benchmark/datastore/schema/benchmark_run_status.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "description": "The UUID of the run for the benchmark.",
4
+ "mode": "REQUIRED",
5
+ "name": "run_id",
6
+ "type": "STRING"
7
+ },
8
+ {
9
+ "description": "The status of the run for the benchmark. Eg, running, failed, success",
10
+ "mode": "REQUIRED",
11
+ "name": "status",
12
+ "type": "STRING"
13
+ }
14
+ ]
models/official/benchmark/keras_benchmark.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Executes Keras benchmarks and accuracy tests."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import tensorflow as tf
22
+ from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
23
+ from official.utils.flags import core as flags_core
24
+
25
+
26
+ class KerasBenchmark(PerfZeroBenchmark):
27
+ """Base benchmark class with methods to simplify testing."""
28
+
29
+ def __init__(self,
30
+ output_dir=None,
31
+ default_flags=None,
32
+ flag_methods=None,
33
+ tpu=None):
34
+ super(KerasBenchmark, self).__init__(
35
+ output_dir=output_dir,
36
+ default_flags=default_flags,
37
+ flag_methods=flag_methods,
38
+ tpu=tpu)
39
+
40
+ def _report_benchmark(self,
41
+ stats,
42
+ wall_time_sec,
43
+ top_1_max=None,
44
+ top_1_min=None,
45
+ log_steps=None,
46
+ total_batch_size=None,
47
+ warmup=1,
48
+ start_time_sec=None):
49
+ """Report benchmark results by writing to local protobuf file.
50
+
51
+ Args:
52
+ stats: dict returned from keras models with known entries.
53
+ wall_time_sec: the during of the benchmark execution in seconds
54
+ top_1_max: highest passing level for top_1 accuracy.
55
+ top_1_min: lowest passing level for top_1 accuracy.
56
+ log_steps: How often the log was created for stats['step_timestamp_log'].
57
+ total_batch_size: Global batch-size.
58
+ warmup: number of entries in stats['step_timestamp_log'] to ignore.
59
+ start_time_sec: the start time of the program in seconds since epoch
60
+ """
61
+
62
+ metrics = []
63
+ if 'accuracy_top_1' in stats:
64
+ metrics.append({'name': 'accuracy_top_1',
65
+ 'value': stats['accuracy_top_1'],
66
+ 'min_value': top_1_min,
67
+ 'max_value': top_1_max})
68
+ metrics.append({'name': 'top_1_train_accuracy',
69
+ 'value': stats['training_accuracy_top_1']})
70
+
71
+ if (warmup and 'step_timestamp_log' in stats and
72
+ len(stats['step_timestamp_log']) > warmup):
73
+ # first entry in the time_log is start of step 1. The rest of the
74
+ # entries are the end of each step recorded
75
+ time_log = stats['step_timestamp_log']
76
+ elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
77
+ num_examples = (
78
+ total_batch_size * log_steps * (len(time_log) - warmup - 1))
79
+ examples_per_sec = num_examples / elapsed
80
+ metrics.append({'name': 'exp_per_second',
81
+ 'value': examples_per_sec})
82
+
83
+ if 'avg_exp_per_second' in stats:
84
+ metrics.append({'name': 'avg_exp_per_second',
85
+ 'value': stats['avg_exp_per_second']})
86
+
87
+ if start_time_sec and 'step_timestamp_log' in stats:
88
+ time_log = stats['step_timestamp_log']
89
+ # time_log[0] is recorded at the beginning of the first step.
90
+ startup_time = time_log[0].timestamp - start_time_sec
91
+ metrics.append({'name': 'startup_time', 'value': startup_time})
92
+
93
+ flags_str = flags_core.get_nondefault_flags_as_str()
94
+ self.report_benchmark(
95
+ iters=-1,
96
+ wall_time=wall_time_sec,
97
+ metrics=metrics,
98
+ extras={'flags': flags_str})
models/official/benchmark/keras_cifar_benchmark.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Executes Keras benchmarks and accuracy tests."""
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ import os
21
+ import time
22
+ from absl import flags
23
+ import tensorflow as tf # pylint: disable=g-bad-import-order
24
+
25
+ from official.benchmark import keras_benchmark
26
+ from official.benchmark import benchmark_wrappers
27
+ from official.benchmark.models import resnet_cifar_main
28
+
29
+ MIN_TOP_1_ACCURACY = 0.929
30
+ MAX_TOP_1_ACCURACY = 0.938
31
+
32
+ FLAGS = flags.FLAGS
33
+ CIFAR_DATA_DIR_NAME = 'cifar-10-batches-bin'
34
+
35
+
36
+ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
37
+ """Accuracy tests for ResNet56 Keras CIFAR-10."""
38
+
39
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
40
+ """A benchmark class.
41
+
42
+ Args:
43
+ output_dir: directory where to output e.g. log files
44
+ root_data_dir: directory under which to look for dataset
45
+ **kwargs: arbitrary named arguments. This is needed to make the
46
+ constructor forward compatible in case PerfZero provides more
47
+ named arguments before updating the constructor.
48
+ """
49
+
50
+ self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
51
+ flag_methods = [resnet_cifar_main.define_cifar_flags]
52
+
53
+ super(Resnet56KerasAccuracy, self).__init__(
54
+ output_dir=output_dir, flag_methods=flag_methods)
55
+
56
+ def _setup(self):
57
+ super(Resnet56KerasAccuracy, self)._setup()
58
+ FLAGS.use_tensor_lr = False
59
+
60
+ def benchmark_graph_1_gpu(self):
61
+ """Test keras based model with Keras fit and distribution strategies."""
62
+ self._setup()
63
+ FLAGS.num_gpus = 1
64
+ FLAGS.data_dir = self.data_dir
65
+ FLAGS.batch_size = 128
66
+ FLAGS.train_epochs = 182
67
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
68
+ FLAGS.dtype = 'fp32'
69
+ self._run_and_report_benchmark()
70
+
71
+ def benchmark_1_gpu(self):
72
+ """Test keras based model with eager and distribution strategies."""
73
+ self._setup()
74
+ FLAGS.num_gpus = 1
75
+ FLAGS.data_dir = self.data_dir
76
+ FLAGS.batch_size = 128
77
+ FLAGS.train_epochs = 182
78
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
79
+ FLAGS.dtype = 'fp32'
80
+ FLAGS.enable_eager = True
81
+ self._run_and_report_benchmark()
82
+
83
+ def benchmark_cpu(self):
84
+ """Test keras based model on CPU."""
85
+ self._setup()
86
+ FLAGS.num_gpus = 0
87
+ FLAGS.data_dir = self.data_dir
88
+ FLAGS.batch_size = 128
89
+ FLAGS.train_epochs = 182
90
+ FLAGS.model_dir = self._get_model_dir('benchmark_cpu')
91
+ FLAGS.dtype = 'fp32'
92
+ FLAGS.enable_eager = True
93
+ FLAGS.data_format = 'channels_last'
94
+ self._run_and_report_benchmark()
95
+
96
+ def benchmark_cpu_no_dist_strat(self):
97
+ """Test keras based model on CPU without distribution strategies."""
98
+ self._setup()
99
+ FLAGS.num_gpus = 0
100
+ FLAGS.data_dir = self.data_dir
101
+ FLAGS.batch_size = 128
102
+ FLAGS.train_epochs = 182
103
+ FLAGS.model_dir = self._get_model_dir('benchmark_cpu_no_dist_strat')
104
+ FLAGS.dtype = 'fp32'
105
+ FLAGS.enable_eager = True
106
+ FLAGS.distribution_strategy = 'off'
107
+ FLAGS.data_format = 'channels_last'
108
+ self._run_and_report_benchmark()
109
+
110
+ def benchmark_cpu_no_dist_strat_run_eagerly(self):
111
+ """Test keras based model on CPU w/forced eager and no dist_strat."""
112
+ self._setup()
113
+ FLAGS.num_gpus = 0
114
+ FLAGS.data_dir = self.data_dir
115
+ FLAGS.batch_size = 128
116
+ FLAGS.train_epochs = 182
117
+ FLAGS.model_dir = self._get_model_dir(
118
+ 'benchmark_cpu_no_dist_strat_run_eagerly')
119
+ FLAGS.dtype = 'fp32'
120
+ FLAGS.enable_eager = True
121
+ FLAGS.run_eagerly = True
122
+ FLAGS.distribution_strategy = 'off'
123
+ FLAGS.data_format = 'channels_last'
124
+ self._run_and_report_benchmark()
125
+
126
+ def benchmark_1_gpu_no_dist_strat(self):
127
+ """Test keras based model with eager and no dist strat."""
128
+ self._setup()
129
+ FLAGS.num_gpus = 1
130
+ FLAGS.data_dir = self.data_dir
131
+ FLAGS.batch_size = 128
132
+ FLAGS.train_epochs = 182
133
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
134
+ FLAGS.dtype = 'fp32'
135
+ FLAGS.enable_eager = True
136
+ FLAGS.distribution_strategy = 'off'
137
+ self._run_and_report_benchmark()
138
+
139
+ def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
140
+ """Test keras based model w/forced eager and no dist_strat."""
141
+ self._setup()
142
+ FLAGS.num_gpus = 1
143
+ FLAGS.data_dir = self.data_dir
144
+ FLAGS.batch_size = 128
145
+ FLAGS.train_epochs = 182
146
+ FLAGS.model_dir = self._get_model_dir(
147
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly')
148
+ FLAGS.dtype = 'fp32'
149
+ FLAGS.enable_eager = True
150
+ FLAGS.run_eagerly = True
151
+ FLAGS.distribution_strategy = 'off'
152
+ self._run_and_report_benchmark()
153
+
154
+ def benchmark_graph_1_gpu_no_dist_strat(self):
155
+ """Test keras based model with Keras fit but not distribution strategies."""
156
+ self._setup()
157
+ FLAGS.distribution_strategy = 'off'
158
+ FLAGS.num_gpus = 1
159
+ FLAGS.data_dir = self.data_dir
160
+ FLAGS.batch_size = 128
161
+ FLAGS.train_epochs = 182
162
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu_no_dist_strat')
163
+ FLAGS.dtype = 'fp32'
164
+ self._run_and_report_benchmark()
165
+
166
+ def benchmark_2_gpu(self):
167
+ """Test keras based model with eager and distribution strategies."""
168
+ self._setup()
169
+ FLAGS.num_gpus = 2
170
+ FLAGS.data_dir = self.data_dir
171
+ FLAGS.batch_size = 128
172
+ FLAGS.train_epochs = 182
173
+ FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu')
174
+ FLAGS.dtype = 'fp32'
175
+ FLAGS.enable_eager = True
176
+ self._run_and_report_benchmark()
177
+
178
+ def benchmark_graph_2_gpu(self):
179
+ """Test keras based model with Keras fit and distribution strategies."""
180
+ self._setup()
181
+ FLAGS.num_gpus = 2
182
+ FLAGS.data_dir = self.data_dir
183
+ FLAGS.batch_size = 128
184
+ FLAGS.train_epochs = 182
185
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
186
+ FLAGS.dtype = 'fp32'
187
+ self._run_and_report_benchmark()
188
+
189
+ @benchmark_wrappers.enable_runtime_flags
190
+ def _run_and_report_benchmark(self):
191
+ start_time_sec = time.time()
192
+ stats = resnet_cifar_main.run(FLAGS)
193
+ wall_time_sec = time.time() - start_time_sec
194
+
195
+ super(Resnet56KerasAccuracy, self)._report_benchmark(
196
+ stats,
197
+ wall_time_sec,
198
+ top_1_min=MIN_TOP_1_ACCURACY,
199
+ top_1_max=MAX_TOP_1_ACCURACY,
200
+ total_batch_size=FLAGS.batch_size,
201
+ log_steps=100)
202
+
203
+
204
+ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
205
+ """Short performance tests for ResNet56 via Keras and CIFAR-10."""
206
+
207
+ def __init__(self, output_dir=None, default_flags=None):
208
+ flag_methods = [resnet_cifar_main.define_cifar_flags]
209
+
210
+ super(Resnet56KerasBenchmarkBase, self).__init__(
211
+ output_dir=output_dir,
212
+ flag_methods=flag_methods,
213
+ default_flags=default_flags)
214
+
215
+ @benchmark_wrappers.enable_runtime_flags
216
+ def _run_and_report_benchmark(self):
217
+ start_time_sec = time.time()
218
+ stats = resnet_cifar_main.run(FLAGS)
219
+ wall_time_sec = time.time() - start_time_sec
220
+
221
+ super(Resnet56KerasBenchmarkBase, self)._report_benchmark(
222
+ stats,
223
+ wall_time_sec,
224
+ total_batch_size=FLAGS.batch_size,
225
+ log_steps=FLAGS.log_steps)
226
+
227
+ def benchmark_1_gpu(self):
228
+ """Test 1 gpu."""
229
+ self._setup()
230
+ FLAGS.num_gpus = 1
231
+ FLAGS.enable_eager = True
232
+ FLAGS.distribution_strategy = 'one_device'
233
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
234
+ FLAGS.batch_size = 128
235
+ self._run_and_report_benchmark()
236
+
237
+ def benchmark_1_gpu_xla(self):
238
+ """Test 1 gpu with xla enabled."""
239
+ self._setup()
240
+ FLAGS.num_gpus = 1
241
+ FLAGS.enable_eager = True
242
+ FLAGS.run_eagerly = False
243
+ FLAGS.enable_xla = True
244
+ FLAGS.distribution_strategy = 'one_device'
245
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla')
246
+ FLAGS.batch_size = 128
247
+ self._run_and_report_benchmark()
248
+
249
+ def benchmark_graph_1_gpu(self):
250
+ """Test 1 gpu graph."""
251
+ self._setup()
252
+ FLAGS.num_gpus = 1
253
+ FLAGS.enable_eager = False
254
+ FLAGS.run_eagerly = False
255
+ FLAGS.distribution_strategy = 'one_device'
256
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
257
+ FLAGS.batch_size = 128
258
+ self._run_and_report_benchmark()
259
+
260
+ def benchmark_1_gpu_no_dist_strat(self):
261
+ """Test 1 gpu without distribution strategies."""
262
+ self._setup()
263
+ FLAGS.num_gpus = 1
264
+ FLAGS.enable_eager = True
265
+ FLAGS.distribution_strategy = 'off'
266
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
267
+ FLAGS.batch_size = 128
268
+ self._run_and_report_benchmark()
269
+
270
+ def benchmark_graph_1_gpu_no_dist_strat(self):
271
+ """Test 1 gpu graph mode without distribution strategies."""
272
+ self._setup()
273
+ FLAGS.num_gpus = 1
274
+ FLAGS.enable_eager = False
275
+ FLAGS.distribution_strategy = 'off'
276
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu_no_dist_strat')
277
+ FLAGS.batch_size = 128
278
+ self._run_and_report_benchmark()
279
+
280
+ def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
281
+ """Test 1 gpu without distribution strategy and forced eager."""
282
+ self._setup()
283
+ FLAGS.num_gpus = 1
284
+ FLAGS.batch_size = 128
285
+ FLAGS.model_dir = self._get_model_dir(
286
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly')
287
+ FLAGS.dtype = 'fp32'
288
+ FLAGS.enable_eager = True
289
+ FLAGS.run_eagerly = True
290
+ FLAGS.distribution_strategy = 'off'
291
+ self._run_and_report_benchmark()
292
+
293
+ def benchmark_2_gpu(self):
294
+ """Test 2 gpu."""
295
+ self._setup()
296
+ FLAGS.num_gpus = 2
297
+ FLAGS.enable_eager = True
298
+ FLAGS.run_eagerly = False
299
+ FLAGS.distribution_strategy = 'mirrored'
300
+ FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu')
301
+ FLAGS.batch_size = 128 * 2 # 2 GPUs
302
+ self._run_and_report_benchmark()
303
+
304
+ def benchmark_graph_2_gpu(self):
305
+ """Test 2 gpu graph mode."""
306
+ self._setup()
307
+ FLAGS.num_gpus = 2
308
+ FLAGS.enable_eager = False
309
+ FLAGS.run_eagerly = False
310
+ FLAGS.distribution_strategy = 'mirrored'
311
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
312
+ FLAGS.batch_size = 128 * 2 # 2 GPUs
313
+ self._run_and_report_benchmark()
314
+
315
+ def benchmark_cpu(self):
316
+ """Test cpu."""
317
+ self._setup()
318
+ FLAGS.num_gpus = 0
319
+ FLAGS.enable_eager = True
320
+ FLAGS.model_dir = self._get_model_dir('benchmark_cpu')
321
+ FLAGS.batch_size = 128
322
+ FLAGS.data_format = 'channels_last'
323
+ self._run_and_report_benchmark()
324
+
325
+ def benchmark_graph_cpu(self):
326
+ """Test cpu graph mode."""
327
+ self._setup()
328
+ FLAGS.num_gpus = 0
329
+ FLAGS.enable_eager = False
330
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_cpu')
331
+ FLAGS.batch_size = 128
332
+ FLAGS.data_format = 'channels_last'
333
+ self._run_and_report_benchmark()
334
+
335
+ def benchmark_cpu_no_dist_strat_run_eagerly(self):
336
+ """Test cpu without distribution strategy and forced eager."""
337
+ self._setup()
338
+ FLAGS.num_gpus = 0
339
+ FLAGS.distribution_strategy = 'off'
340
+ FLAGS.enable_eager = True
341
+ FLAGS.run_eagerly = True
342
+ FLAGS.model_dir = self._get_model_dir(
343
+ 'benchmark_cpu_no_dist_strat_run_eagerly')
344
+ FLAGS.batch_size = 128
345
+ FLAGS.data_format = 'channels_last'
346
+ self._run_and_report_benchmark()
347
+
348
+ def benchmark_cpu_no_dist_strat(self):
349
+ """Test cpu without distribution strategies."""
350
+ self._setup()
351
+ FLAGS.num_gpus = 0
352
+ FLAGS.enable_eager = True
353
+ FLAGS.distribution_strategy = 'off'
354
+ FLAGS.model_dir = self._get_model_dir('benchmark_cpu_no_dist_strat')
355
+ FLAGS.batch_size = 128
356
+ FLAGS.data_format = 'channels_last'
357
+ self._run_and_report_benchmark()
358
+
359
+ def benchmark_graph_cpu_no_dist_strat(self):
360
+ """Test cpu graph mode without distribution strategies."""
361
+ self._setup()
362
+ FLAGS.num_gpus = 0
363
+ FLAGS.enable_eager = False
364
+ FLAGS.distribution_strategy = 'off'
365
+ FLAGS.model_dir = self._get_model_dir('benchmark_graph_cpu_no_dist_strat')
366
+ FLAGS.batch_size = 128
367
+ FLAGS.data_format = 'channels_last'
368
+ self._run_and_report_benchmark()
369
+
370
+
371
+ class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
372
+ """Synthetic benchmarks for ResNet56 and Keras."""
373
+
374
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
375
+ default_flags = {}
376
+ default_flags['skip_eval'] = True
377
+ default_flags['use_synthetic_data'] = True
378
+ default_flags['train_steps'] = 110
379
+ default_flags['log_steps'] = 10
380
+ default_flags['use_tensor_lr'] = False
381
+
382
+ super(Resnet56KerasBenchmarkSynth, self).__init__(
383
+ output_dir=output_dir, default_flags=default_flags)
384
+
385
+
386
+ class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase):
387
+ """Real data benchmarks for ResNet56 and Keras."""
388
+
389
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
390
+ default_flags = {}
391
+ default_flags['skip_eval'] = True
392
+ default_flags['data_dir'] = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
393
+ default_flags['train_steps'] = 110
394
+ default_flags['log_steps'] = 10
395
+ default_flags['use_tensor_lr'] = False
396
+
397
+ super(Resnet56KerasBenchmarkReal, self).__init__(
398
+ output_dir=output_dir, default_flags=default_flags)
399
+
400
+
401
+ if __name__ == '__main__':
402
+ tf.test.main()
models/official/benchmark/keras_imagenet_benchmark.py ADDED
@@ -0,0 +1,1724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Lint as: python3
2
+ # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """Executes Keras benchmarks and accuracy tests."""
17
+ # pylint: disable=line-too-long
18
+ from __future__ import print_function
19
+
20
+ import json
21
+ import os
22
+ import time
23
+
24
+ from typing import Any, MutableMapping, Optional
25
+
26
+ from absl import flags
27
+ import tensorflow as tf # pylint: disable=g-bad-import-order
28
+
29
+ from official.benchmark import benchmark_wrappers
30
+ from official.benchmark import keras_benchmark
31
+ from official.benchmark.models import resnet_imagenet_main
32
+ from official.vision.image_classification import classifier_trainer
33
+
34
+ MIN_TOP_1_ACCURACY = 0.76
35
+ MAX_TOP_1_ACCURACY = 0.77
36
+
37
+ MOBILENET_V1_MIN_TOP_1_ACCURACY = 0.65
38
+ MOBILENET_V1_MAX_TOP_1_ACCURACY = 0.68
39
+
40
+ # Range of top-1 accracies for model optimization techniques.
41
+ # Each item indicates (MIN_TOP_1_ACCURACY, MAX_TOP_1_ACCURACY).
42
+ MODEL_OPTIMIZATION_TOP_1_ACCURACY = {
43
+ 'RESNET50_FINETUNE_PRUNING': (0.76, 0.77),
44
+ 'MOBILENET_V1_FINETUNE_PRUNING': (0.67, 0.68),
45
+ }
46
+
47
+ FLAGS = flags.FLAGS
48
+
49
+
50
+ def _get_classifier_parameters(
51
+ num_gpus: int = 0,
52
+ builder: str = 'records',
53
+ skip_eval: bool = False,
54
+ distribution_strategy: str = 'mirrored',
55
+ per_replica_batch_size: int = 128,
56
+ epochs: int = 90,
57
+ steps: int = 0,
58
+ epochs_between_evals: int = 1,
59
+ dtype: str = 'float32',
60
+ enable_xla: bool = False,
61
+ run_eagerly: bool = False,
62
+ gpu_thread_mode: Optional[str] = None,
63
+ dataset_num_private_threads: Optional[int] = None,
64
+ loss_scale: Optional[str] = None,
65
+ report_metrics: bool = True,
66
+ batchnorm_spatial_persistent: bool = False) -> MutableMapping[str, Any]:
67
+ """Gets classifier trainer's ResNet parameters."""
68
+ return {
69
+ 'runtime': {
70
+ 'num_gpus': num_gpus,
71
+ 'distribution_strategy': distribution_strategy,
72
+ 'run_eagerly': run_eagerly,
73
+ 'enable_xla': enable_xla,
74
+ 'dataset_num_private_threads': dataset_num_private_threads,
75
+ 'gpu_thread_mode': gpu_thread_mode,
76
+ 'loss_scale': loss_scale,
77
+ 'batchnorm_spatial_persistent': batchnorm_spatial_persistent,
78
+ },
79
+ 'train_dataset': {
80
+ 'builder': builder,
81
+ 'use_per_replica_batch_size': True,
82
+ 'batch_size': per_replica_batch_size,
83
+ 'image_size': 224,
84
+ 'dtype': dtype,
85
+ },
86
+ 'validation_dataset': {
87
+ 'builder': builder,
88
+ 'batch_size': per_replica_batch_size,
89
+ 'use_per_replica_batch_size': True,
90
+ 'image_size': 224,
91
+ 'dtype': dtype,
92
+ },
93
+ 'train': {
94
+ 'epochs': epochs,
95
+ 'steps': steps,
96
+ 'callbacks': {
97
+ 'enable_tensorboard': False,
98
+ 'enable_checkpoint_and_export': False,
99
+ 'enable_time_history': True,
100
+ },
101
+ 'metrics': ['accuracy'] if report_metrics else [],
102
+ },
103
+ 'model': {
104
+ 'loss': {
105
+ 'label_smoothing': 0.1,
106
+ },
107
+ },
108
+ 'evaluation': {
109
+ 'epochs_between_evals': epochs_between_evals,
110
+ 'skip_eval': skip_eval,
111
+ },
112
+ }
113
+
114
+
115
+ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
116
+ """Benchmark accuracy tests for ResNet50 in Keras."""
117
+
118
+ def __init__(self,
119
+ output_dir: Optional[str] = None,
120
+ root_data_dir: Optional[str] = None,
121
+ **kwargs):
122
+ """A benchmark class.
123
+
124
+ Args:
125
+ output_dir: directory where to output e.g. log files
126
+ root_data_dir: directory under which to look for dataset
127
+ **kwargs: arbitrary named arguments. This is needed to make the
128
+ constructor forward compatible in case PerfZero provides more
129
+ named arguments before updating the constructor.
130
+ """
131
+
132
+ flag_methods = [classifier_trainer.define_classifier_flags]
133
+
134
+ self.data_dir = os.path.join(root_data_dir, 'imagenet')
135
+ super(Resnet50KerasAccuracy, self).__init__(
136
+ output_dir=output_dir, flag_methods=flag_methods)
137
+
138
+ @benchmark_wrappers.enable_runtime_flags
139
+ def _run_and_report_benchmark(
140
+ self,
141
+ experiment_name: str,
142
+ top_1_min: float = MIN_TOP_1_ACCURACY,
143
+ top_1_max: float = MAX_TOP_1_ACCURACY,
144
+ num_gpus: int = 0,
145
+ distribution_strategy: str = 'mirrored',
146
+ per_replica_batch_size: int = 128,
147
+ epochs: int = 90,
148
+ steps: int = 0,
149
+ epochs_between_evals: int = 1,
150
+ dtype: str = 'float32',
151
+ enable_xla: bool = False,
152
+ run_eagerly: bool = False,
153
+ gpu_thread_mode: Optional[str] = None,
154
+ dataset_num_private_threads: Optional[int] = None,
155
+ loss_scale: Optional[str] = None):
156
+ """Runs and reports the benchmark given the provided configuration."""
157
+ FLAGS.model_type = 'resnet'
158
+ FLAGS.dataset = 'imagenet'
159
+ FLAGS.mode = 'train_and_eval'
160
+ FLAGS.data_dir = self.data_dir
161
+ FLAGS.model_dir = self._get_model_dir(experiment_name)
162
+ parameters = _get_classifier_parameters(
163
+ num_gpus=num_gpus,
164
+ distribution_strategy=distribution_strategy,
165
+ per_replica_batch_size=per_replica_batch_size,
166
+ epochs=epochs,
167
+ steps=steps,
168
+ epochs_between_evals=epochs_between_evals,
169
+ dtype=dtype,
170
+ enable_xla=enable_xla,
171
+ run_eagerly=run_eagerly,
172
+ gpu_thread_mode=gpu_thread_mode,
173
+ dataset_num_private_threads=dataset_num_private_threads,
174
+ report_metrics=True,
175
+ loss_scale=loss_scale,
176
+ batchnorm_spatial_persistent=True)
177
+ FLAGS.params_override = json.dumps(parameters)
178
+ total_batch_size = num_gpus * per_replica_batch_size
179
+
180
+ start_time_sec = time.time()
181
+ stats = classifier_trainer.run(flags.FLAGS)
182
+ wall_time_sec = time.time() - start_time_sec
183
+
184
+ super(Resnet50KerasAccuracy, self)._report_benchmark(
185
+ stats,
186
+ wall_time_sec,
187
+ top_1_min=top_1_min,
188
+ top_1_max=top_1_max,
189
+ total_batch_size=total_batch_size,
190
+ log_steps=100)
191
+
192
+ def benchmark_8_gpu(self):
193
+ """Tests Keras model with eager, dist_strat and 8 GPUs."""
194
+ self._setup()
195
+ self._run_and_report_benchmark(
196
+ experiment_name='benchmark_8_gpu',
197
+ num_gpus=8,
198
+ per_replica_batch_size=128,
199
+ epochs=90,
200
+ epochs_between_evals=10,
201
+ dtype='float32')
202
+
203
+ def benchmark_8_gpu_fp16(self):
204
+ """Tests Keras model with eager, dist_strat, 8 GPUs, and fp16."""
205
+ self._setup()
206
+ self._run_and_report_benchmark(
207
+ experiment_name='benchmark_8_gpu_fp16',
208
+ num_gpus=8,
209
+ per_replica_batch_size=256,
210
+ epochs=90,
211
+ epochs_between_evals=10,
212
+ dtype='float16')
213
+
214
+ def benchmark_xla_8_gpu_fp16(self):
215
+ """Tests Keras model with XLA, eager, dist_strat, 8 GPUs and fp16."""
216
+ self._setup()
217
+ self._run_and_report_benchmark(
218
+ experiment_name='benchmark_xla_8_gpu_fp16',
219
+ num_gpus=8,
220
+ per_replica_batch_size=256,
221
+ epochs=90,
222
+ epochs_between_evals=10,
223
+ dtype='float16',
224
+ enable_xla=True)
225
+
226
+ def benchmark_xla_8_gpu_fp16_dynamic(self):
227
+ """Tests Keras model with XLA, eager, dist_strat, 8 GPUs, dynamic fp16."""
228
+ self._setup()
229
+ self._run_and_report_benchmark(
230
+ experiment_name='benchmark_xla_8_gpu_fp16_dynamic',
231
+ top_1_min=0.736,
232
+ num_gpus=8,
233
+ per_replica_batch_size=256,
234
+ epochs=90,
235
+ epochs_between_evals=10,
236
+ dtype='float16',
237
+ loss_scale='dynamic')
238
+
239
+ def _get_model_dir(self, folder_name):
240
+ return os.path.join(self.output_dir, folder_name)
241
+
242
+
243
+ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark):
244
+ """Benchmark accuracy tests for MobilenetV1 in Keras."""
245
+
246
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
247
+ """A benchmark class.
248
+
249
+ Args:
250
+ output_dir: directory where to output e.g. log files
251
+ root_data_dir: directory under which to look for dataset
252
+ **kwargs: arbitrary named arguments. This is needed to make the
253
+ constructor forward compatible in case PerfZero provides more
254
+ named arguments before updating the constructor.
255
+ """
256
+
257
+ flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
258
+
259
+ self.data_dir = os.path.join(root_data_dir, 'imagenet')
260
+ super(MobilenetV1KerasAccuracy, self).__init__(
261
+ output_dir=output_dir,
262
+ flag_methods=flag_methods,
263
+ default_flags={
264
+ 'model': 'mobilenet',
265
+ 'optimizer': 'mobilenet_default',
266
+ 'initial_learning_rate_per_sample': 0.00039,
267
+ })
268
+
269
+ def benchmark_8_gpu(self):
270
+ """Test Keras model with eager, dist_strat and 8 GPUs."""
271
+ self._setup()
272
+ FLAGS.num_gpus = 8
273
+ FLAGS.data_dir = self.data_dir
274
+ FLAGS.batch_size = 128 * 8
275
+ FLAGS.train_epochs = 90
276
+ FLAGS.epochs_between_evals = 10
277
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
278
+ FLAGS.dtype = 'fp32'
279
+ FLAGS.enable_eager = True
280
+ self._run_and_report_benchmark()
281
+
282
+ @benchmark_wrappers.enable_runtime_flags
283
+ def _run_and_report_benchmark(self,
284
+ top_1_min=MOBILENET_V1_MIN_TOP_1_ACCURACY,
285
+ top_1_max=MOBILENET_V1_MAX_TOP_1_ACCURACY):
286
+ start_time_sec = time.time()
287
+ stats = resnet_imagenet_main.run(flags.FLAGS)
288
+ wall_time_sec = time.time() - start_time_sec
289
+
290
+ super(MobilenetV1KerasAccuracy, self)._report_benchmark(
291
+ stats,
292
+ wall_time_sec,
293
+ top_1_min=top_1_min,
294
+ top_1_max=top_1_max,
295
+ total_batch_size=FLAGS.batch_size,
296
+ log_steps=100)
297
+
298
+ def _get_model_dir(self, folder_name):
299
+ return os.path.join(self.output_dir, folder_name)
300
+
301
+
302
+ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
303
+ """Resnet50 (classifier_trainer) benchmarks."""
304
+
305
+ def __init__(self, output_dir=None, default_flags=None,
306
+ tpu=None, dataset_builder='records', train_epochs=1,
307
+ train_steps=110, data_dir=None):
308
+ flag_methods = [classifier_trainer.define_classifier_flags]
309
+
310
+ self.dataset_builder = dataset_builder
311
+ self.train_epochs = train_epochs
312
+ self.train_steps = train_steps
313
+ self.data_dir = data_dir
314
+
315
+ super(Resnet50KerasClassifierBenchmarkBase, self).__init__(
316
+ output_dir=output_dir,
317
+ flag_methods=flag_methods,
318
+ default_flags=default_flags,
319
+ tpu=tpu)
320
+
321
+ @benchmark_wrappers.enable_runtime_flags
322
+ def _run_and_report_benchmark(
323
+ self,
324
+ experiment_name: str,
325
+ skip_steps: Optional[int] = None,
326
+ top_1_min: float = MIN_TOP_1_ACCURACY,
327
+ top_1_max: float = MAX_TOP_1_ACCURACY,
328
+ num_gpus: int = 0,
329
+ num_tpus: int = 0,
330
+ distribution_strategy: str = 'mirrored',
331
+ per_replica_batch_size: int = 128,
332
+ epochs_between_evals: int = 1,
333
+ dtype: str = 'float32',
334
+ enable_xla: bool = False,
335
+ run_eagerly: bool = False,
336
+ gpu_thread_mode: Optional[str] = None,
337
+ dataset_num_private_threads: Optional[int] = None,
338
+ loss_scale: Optional[str] = None):
339
+ """Runs and reports the benchmark given the provided configuration."""
340
+ FLAGS.model_type = 'resnet'
341
+ FLAGS.dataset = 'imagenet'
342
+ FLAGS.mode = 'train_and_eval'
343
+ FLAGS.data_dir = self.data_dir
344
+ FLAGS.model_dir = self._get_model_dir(experiment_name)
345
+ parameters = _get_classifier_parameters(
346
+ builder=self.dataset_builder,
347
+ skip_eval=True,
348
+ num_gpus=num_gpus,
349
+ distribution_strategy=distribution_strategy,
350
+ per_replica_batch_size=per_replica_batch_size,
351
+ epochs=self.train_epochs,
352
+ steps=self.train_steps,
353
+ epochs_between_evals=epochs_between_evals,
354
+ dtype=dtype,
355
+ enable_xla=enable_xla,
356
+ gpu_thread_mode=gpu_thread_mode,
357
+ dataset_num_private_threads=dataset_num_private_threads,
358
+ loss_scale=loss_scale,
359
+ report_metrics=False,
360
+ batchnorm_spatial_persistent=True)
361
+ FLAGS.params_override = json.dumps(parameters)
362
+ if distribution_strategy == 'tpu':
363
+ total_batch_size = num_tpus * per_replica_batch_size
364
+ else:
365
+ total_batch_size = num_gpus * per_replica_batch_size
366
+
367
+ start_time_sec = time.time()
368
+ stats = classifier_trainer.run(flags.FLAGS)
369
+ wall_time_sec = time.time() - start_time_sec
370
+ # Number of logged step time entries that are excluded in performance
371
+ # report. We keep results from last 100 batches, or skip the steps based on
372
+ # input skip_steps.
373
+ warmup = (skip_steps or (self.train_steps - 100)) // FLAGS.log_steps
374
+
375
+ super(Resnet50KerasClassifierBenchmarkBase, self)._report_benchmark(
376
+ stats,
377
+ wall_time_sec,
378
+ total_batch_size=total_batch_size,
379
+ log_steps=FLAGS.log_steps,
380
+ warmup=warmup,
381
+ start_time_sec=start_time_sec)
382
+
383
+ def benchmark_1_gpu_no_dist_strat(self):
384
+ """Tests Keras model with 1 GPU, no distribution strategy."""
385
+ self._setup()
386
+ self._run_and_report_benchmark(
387
+ experiment_name='benchmark_1_gpu_no_dist_strat',
388
+ num_gpus=1,
389
+ distribution_strategy='off',
390
+ per_replica_batch_size=128)
391
+
392
+ def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
393
+ """Tests Keras model with 1 GPU, no distribution strategy, run eagerly."""
394
+ self._setup()
395
+ self._run_and_report_benchmark(
396
+ experiment_name='benchmark_1_gpu_no_dist_strat_run_eagerly',
397
+ num_gpus=1,
398
+ run_eagerly=True,
399
+ distribution_strategy='off',
400
+ per_replica_batch_size=64)
401
+
402
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16(self):
403
+ """Tests with 1 GPU, no distribution strategy, fp16, run eagerly."""
404
+ self._setup()
405
+ self._run_and_report_benchmark(
406
+ experiment_name='benchmark_1_gpu_no_dist_strat_run_eagerly_fp16',
407
+ num_gpus=1,
408
+ run_eagerly=True,
409
+ distribution_strategy='off',
410
+ dtype='float16',
411
+ per_replica_batch_size=128)
412
+
413
+ def benchmark_1_gpu(self):
414
+ """Tests Keras model with 1 GPU."""
415
+ self._setup()
416
+ self._run_and_report_benchmark(
417
+ experiment_name='benchmark_1_gpu',
418
+ num_gpus=1,
419
+ distribution_strategy='one_device',
420
+ per_replica_batch_size=128)
421
+
422
+ def benchmark_xla_1_gpu(self):
423
+ """Tests Keras model with XLA and 1 GPU."""
424
+ self._setup()
425
+ self._run_and_report_benchmark(
426
+ experiment_name='benchmark_xla_1_gpu',
427
+ num_gpus=1,
428
+ enable_xla=True,
429
+ distribution_strategy='one_device',
430
+ per_replica_batch_size=128)
431
+
432
+ def benchmark_1_gpu_fp16(self):
433
+ """Tests Keras model with 1 GPU and fp16."""
434
+ self._setup()
435
+ self._run_and_report_benchmark(
436
+ experiment_name='benchmark_1_gpu_fp16',
437
+ num_gpus=1,
438
+ distribution_strategy='one_device',
439
+ dtype='float16',
440
+ per_replica_batch_size=256)
441
+
442
+ def benchmark_1_gpu_fp16_dynamic(self):
443
+ """Tests Keras model with 1 GPU, fp16, and dynamic loss scaling."""
444
+ self._setup()
445
+ self._run_and_report_benchmark(
446
+ experiment_name='benchmark_1_gpu_fp16_dynamic',
447
+ num_gpus=1,
448
+ distribution_strategy='one_device',
449
+ dtype='float16',
450
+ per_replica_batch_size=256,
451
+ loss_scale='dynamic')
452
+
453
+ def benchmark_xla_1_gpu_fp16(self):
454
+ """Tests Keras model with XLA, 1 GPU and fp16."""
455
+ self._setup()
456
+ self._run_and_report_benchmark(
457
+ experiment_name='benchmark_xla_1_gpu_fp16',
458
+ num_gpus=1,
459
+ enable_xla=True,
460
+ distribution_strategy='one_device',
461
+ dtype='float16',
462
+ per_replica_batch_size=256)
463
+
464
+ def benchmark_xla_1_gpu_fp16_tweaked(self):
465
+ """Tests Keras model with XLA, 1 GPU, fp16, and manual config tuning."""
466
+ self._setup()
467
+ self._run_and_report_benchmark(
468
+ experiment_name='benchmark_xla_1_gpu_fp16_tweaked',
469
+ num_gpus=1,
470
+ enable_xla=True,
471
+ distribution_strategy='one_device',
472
+ dtype='float16',
473
+ per_replica_batch_size=256,
474
+ gpu_thread_mode='gpu_private')
475
+
476
+ def benchmark_xla_1_gpu_fp16_dynamic(self):
477
+ """Tests Keras model with XLA, 1 GPU, fp16, and dynamic loss scaling."""
478
+ self._setup()
479
+ self._run_and_report_benchmark(
480
+ experiment_name='benchmark_xla_1_gpu_fp16_dynamic',
481
+ num_gpus=1,
482
+ enable_xla=True,
483
+ distribution_strategy='one_device',
484
+ dtype='float16',
485
+ per_replica_batch_size=256,
486
+ loss_scale='dynamic')
487
+
488
+ def benchmark_8_gpu(self):
489
+ """Tests Keras model with 8 GPUs."""
490
+ self._setup()
491
+ self._run_and_report_benchmark(
492
+ experiment_name='benchmark_8_gpu',
493
+ num_gpus=8,
494
+ distribution_strategy='mirrored',
495
+ per_replica_batch_size=128)
496
+
497
+ def benchmark_8_gpu_tweaked(self):
498
+ """Tests Keras model with manual config tuning and 8 GPUs."""
499
+ self._setup()
500
+ self._run_and_report_benchmark(
501
+ experiment_name='benchmark_8_gpu_tweaked',
502
+ num_gpus=8,
503
+ distribution_strategy='mirrored',
504
+ per_replica_batch_size=128,
505
+ dataset_num_private_threads=14)
506
+
507
+ def benchmark_xla_8_gpu(self):
508
+ """Tests Keras model with XLA and 8 GPUs."""
509
+ self._setup()
510
+ self._run_and_report_benchmark(
511
+ experiment_name='benchmark_xla_8_gpu',
512
+ num_gpus=8,
513
+ enable_xla=True,
514
+ distribution_strategy='mirrored',
515
+ per_replica_batch_size=128)
516
+
517
+ def benchmark_xla_8_gpu_tweaked(self):
518
+ """Tests Keras model with manual config tuning, 8 GPUs, and XLA."""
519
+ self._setup()
520
+ self._run_and_report_benchmark(
521
+ experiment_name='benchmark_xla_8_gpu_tweaked',
522
+ num_gpus=8,
523
+ enable_xla=True,
524
+ distribution_strategy='mirrored',
525
+ per_replica_batch_size=128,
526
+ gpu_thread_mode='gpu_private',
527
+ dataset_num_private_threads=24)
528
+
529
+ def benchmark_8_gpu_fp16(self):
530
+ """Tests Keras model with 8 GPUs and fp16."""
531
+ self._setup()
532
+ self._run_and_report_benchmark(
533
+ experiment_name='benchmark_8_gpu_fp16',
534
+ num_gpus=8,
535
+ dtype='float16',
536
+ distribution_strategy='mirrored',
537
+ per_replica_batch_size=256)
538
+
539
+ def benchmark_8_gpu_fp16_tweaked(self):
540
+ """Tests Keras model with 8 GPUs, fp16, and manual config tuning."""
541
+ self._setup()
542
+ self._run_and_report_benchmark(
543
+ experiment_name='benchmark_8_gpu_fp16_tweaked',
544
+ num_gpus=8,
545
+ dtype='float16',
546
+ distribution_strategy='mirrored',
547
+ per_replica_batch_size=256,
548
+ gpu_thread_mode='gpu_private',
549
+ dataset_num_private_threads=40)
550
+
551
+ def benchmark_8_gpu_fp16_dynamic_tweaked(self):
552
+ """Tests Keras model with 8 GPUs, fp16, dynamic loss scaling, and tuned."""
553
+ self._setup()
554
+ self._run_and_report_benchmark(
555
+ experiment_name='benchmark_8_gpu_fp16_dynamic_tweaked',
556
+ num_gpus=8,
557
+ dtype='float16',
558
+ distribution_strategy='mirrored',
559
+ per_replica_batch_size=256,
560
+ loss_scale='dynamic',
561
+ gpu_thread_mode='gpu_private',
562
+ dataset_num_private_threads=40)
563
+
564
+ def benchmark_xla_8_gpu_fp16(self):
565
+ """Tests Keras model with XLA, 8 GPUs and fp16."""
566
+ self._setup()
567
+ self._run_and_report_benchmark(
568
+ experiment_name='benchmark_xla_8_gpu_fp16',
569
+ dtype='float16',
570
+ num_gpus=8,
571
+ enable_xla=True,
572
+ distribution_strategy='mirrored',
573
+ per_replica_batch_size=256)
574
+
575
+ def benchmark_xla_8_gpu_fp16_tweaked(self):
576
+ """Test Keras model with manual config tuning, XLA, 8 GPUs and fp16."""
577
+ self._setup()
578
+ self._run_and_report_benchmark(
579
+ experiment_name='benchmark_xla_8_gpu_fp16_tweaked',
580
+ dtype='float16',
581
+ num_gpus=8,
582
+ enable_xla=True,
583
+ distribution_strategy='mirrored',
584
+ per_replica_batch_size=256,
585
+ gpu_thread_mode='gpu_private',
586
+ dataset_num_private_threads=48)
587
+
588
+ def benchmark_xla_8_gpu_fp16_tweaked_delay_measure(self):
589
+ """Tests with manual config tuning, XLA, 8 GPUs and fp16.
590
+
591
+ Delay performance measurement for stable performance on 96 vCPU platforms.
592
+ """
593
+ self._setup()
594
+ self._run_and_report_benchmark(
595
+ experiment_name='benchmark_xla_8_gpu_fp16_tweaked_delay_measure',
596
+ dtype='float16',
597
+ num_gpus=8,
598
+ enable_xla=True,
599
+ distribution_strategy='mirrored',
600
+ per_replica_batch_size=256,
601
+ gpu_thread_mode='gpu_private',
602
+ dataset_num_private_threads=48,
603
+ steps=310)
604
+
605
+ def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
606
+ """Tests Keras model with config tuning, XLA, 8 GPUs and dynamic fp16."""
607
+ self._setup()
608
+ self._run_and_report_benchmark(
609
+ experiment_name='benchmark_xla_8_gpu_fp16_dynamic_tweaked',
610
+ dtype='float16',
611
+ num_gpus=8,
612
+ enable_xla=True,
613
+ distribution_strategy='mirrored',
614
+ per_replica_batch_size=256,
615
+ gpu_thread_mode='gpu_private',
616
+ loss_scale='dynamic',
617
+ dataset_num_private_threads=48)
618
+
619
+ def benchmark_2x2_tpu_bf16(self):
620
+ """Test Keras model with 2x2 TPU, bf16."""
621
+ self._setup()
622
+ self._run_and_report_benchmark(
623
+ experiment_name='benchmark_2x2_tpu_bf16',
624
+ dtype='bfloat16',
625
+ num_tpus=8,
626
+ distribution_strategy='tpu',
627
+ per_replica_batch_size=128)
628
+
629
+ def benchmark_4x4_tpu_bf16(self):
630
+ """Test Keras model with 4x4 TPU, bf16."""
631
+ self._setup()
632
+ self._run_and_report_benchmark(
633
+ experiment_name='benchmark_4x4_tpu_bf16',
634
+ dtype='bfloat16',
635
+ num_tpus=32,
636
+ distribution_strategy='tpu',
637
+ per_replica_batch_size=128)
638
+
639
+ def benchmark_8x8_tpu_bf16(self):
640
+ """Test Keras model with 8x8 TPU, bf16."""
641
+ self._setup()
642
+ self._run_and_report_benchmark(
643
+ experiment_name='benchmark_8x8_tpu_bf16',
644
+ dtype='bfloat16',
645
+ num_tpus=128,
646
+ distribution_strategy='tpu',
647
+ per_replica_batch_size=64)
648
+
649
+ def fill_report_object(self, stats):
650
+ super(Resnet50KerasClassifierBenchmarkBase, self).fill_report_object(
651
+ stats,
652
+ total_batch_size=FLAGS.batch_size,
653
+ log_steps=FLAGS.log_steps)
654
+
655
+
656
+ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
657
+ """Resnet50 benchmarks."""
658
+
659
+ def __init__(self, output_dir=None, default_flags=None, tpu=None):
660
+ flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
661
+
662
+ super(Resnet50KerasBenchmarkBase, self).__init__(
663
+ output_dir=output_dir,
664
+ flag_methods=flag_methods,
665
+ default_flags=default_flags,
666
+ tpu=tpu)
667
+
668
+ @benchmark_wrappers.enable_runtime_flags
669
+ def _run_and_report_benchmark(self, skip_steps=None):
670
+ start_time_sec = time.time()
671
+ stats = resnet_imagenet_main.run(FLAGS)
672
+ wall_time_sec = time.time() - start_time_sec
673
+ # Number of logged step time entries that are excluded in performance
674
+ # report. We keep results from last 100 batches, or skip the steps based on
675
+ # input skip_steps.
676
+ warmup = (skip_steps or (FLAGS.train_steps - 100)) // FLAGS.log_steps
677
+
678
+ super(Resnet50KerasBenchmarkBase, self)._report_benchmark(
679
+ stats,
680
+ wall_time_sec,
681
+ total_batch_size=FLAGS.batch_size,
682
+ log_steps=FLAGS.log_steps,
683
+ warmup=warmup,
684
+ start_time_sec=start_time_sec)
685
+
686
+ def benchmark_1_gpu_no_dist_strat(self):
687
+ """Test Keras model with 1 GPU, no distribution strategy."""
688
+ self._setup()
689
+
690
+ FLAGS.num_gpus = 1
691
+ FLAGS.enable_eager = True
692
+ FLAGS.distribution_strategy = 'off'
693
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
694
+ FLAGS.batch_size = 128
695
+ self._run_and_report_benchmark()
696
+
697
+ def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
698
+ """Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
699
+ self._setup()
700
+
701
+ FLAGS.num_gpus = 1
702
+ FLAGS.enable_eager = True
703
+ FLAGS.run_eagerly = True
704
+ FLAGS.distribution_strategy = 'off'
705
+ FLAGS.model_dir = self._get_model_dir(
706
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly')
707
+ FLAGS.batch_size = 64
708
+ self._run_and_report_benchmark()
709
+
710
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_tweaked(self):
711
+ """Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
712
+ self._setup()
713
+
714
+ FLAGS.num_gpus = 1
715
+ FLAGS.enable_eager = True
716
+ FLAGS.run_eagerly = True
717
+ FLAGS.explicit_gpu_placement = True
718
+ FLAGS.distribution_strategy = 'off'
719
+ FLAGS.model_dir = self._get_model_dir(
720
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly_tweaked')
721
+ FLAGS.batch_size = 64
722
+ self._run_and_report_benchmark()
723
+
724
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16(self):
725
+ """Test with 1 GPU, no distribution strategy, fp16, run eagerly."""
726
+ self._setup()
727
+
728
+ FLAGS.num_gpus = 1
729
+ FLAGS.enable_eager = True
730
+ FLAGS.run_eagerly = True
731
+ FLAGS.distribution_strategy = 'off'
732
+ FLAGS.model_dir = self._get_model_dir(
733
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly_fp16')
734
+ FLAGS.dtype = 'fp16'
735
+ FLAGS.batch_size = 128
736
+ self._run_and_report_benchmark()
737
+
738
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16_tweaked(self):
739
+ """Test with 1 GPU, no distribution strategy, fp16, run eagerly."""
740
+ self._setup()
741
+
742
+ FLAGS.num_gpus = 1
743
+ FLAGS.enable_eager = True
744
+ FLAGS.run_eagerly = True
745
+ FLAGS.explicit_gpu_placement = True
746
+ FLAGS.distribution_strategy = 'off'
747
+ FLAGS.model_dir = self._get_model_dir(
748
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly_fp16_tweaked')
749
+ FLAGS.dtype = 'fp16'
750
+ FLAGS.batch_size = 128
751
+ self._run_and_report_benchmark()
752
+
753
+ def benchmark_1_gpu(self):
754
+ """Test Keras model with 1 GPU."""
755
+ self._setup()
756
+
757
+ FLAGS.num_gpus = 1
758
+ FLAGS.enable_eager = True
759
+ FLAGS.distribution_strategy = 'one_device'
760
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
761
+ FLAGS.batch_size = 128
762
+ self._run_and_report_benchmark()
763
+
764
+ def benchmark_1_gpu_amp(self):
765
+ """Test Keras model with 1 GPU with automatic mixed precision."""
766
+ self._setup()
767
+
768
+ FLAGS.num_gpus = 1
769
+ FLAGS.enable_eager = True
770
+ FLAGS.dtype = 'fp16'
771
+ FLAGS.fp16_implementation = 'graph_rewrite'
772
+ FLAGS.distribution_strategy = 'one_device'
773
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
774
+ FLAGS.batch_size = 256
775
+ self._run_and_report_benchmark()
776
+
777
+ def benchmark_xla_1_gpu(self):
778
+ """Test Keras model with XLA and 1 GPU."""
779
+ self._setup()
780
+
781
+ FLAGS.num_gpus = 1
782
+ FLAGS.enable_eager = True
783
+ FLAGS.enable_xla = True
784
+ FLAGS.distribution_strategy = 'one_device'
785
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu')
786
+ FLAGS.batch_size = 128
787
+ self._run_and_report_benchmark()
788
+
789
+ def benchmark_xla_1_gpu_amp(self):
790
+ """Test Keras model with XLA and 1 GPU with automatic mixed precision."""
791
+ self._setup()
792
+
793
+ FLAGS.num_gpus = 1
794
+ FLAGS.enable_eager = True
795
+ FLAGS.dtype = 'fp16'
796
+ FLAGS.fp16_implementation = 'graph_rewrite'
797
+ FLAGS.enable_xla = True
798
+ FLAGS.distribution_strategy = 'one_device'
799
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
800
+ FLAGS.batch_size = 256
801
+ self._run_and_report_benchmark()
802
+
803
+ def benchmark_1_gpu_fp16(self):
804
+ """Test Keras model with 1 GPU and fp16."""
805
+ self._setup()
806
+
807
+ FLAGS.num_gpus = 1
808
+ FLAGS.enable_eager = True
809
+ FLAGS.distribution_strategy = 'one_device'
810
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16')
811
+ FLAGS.dtype = 'fp16'
812
+ FLAGS.batch_size = 256
813
+ self._run_and_report_benchmark()
814
+
815
+ def benchmark_1_gpu_fp16_dynamic(self):
816
+ """Test Keras model with 1 GPU, fp16, and dynamic loss scaling."""
817
+ self._setup()
818
+
819
+ FLAGS.num_gpus = 1
820
+ FLAGS.enable_eager = True
821
+ FLAGS.distribution_strategy = 'one_device'
822
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16_dynamic')
823
+ FLAGS.dtype = 'fp16'
824
+ FLAGS.batch_size = 256
825
+ FLAGS.loss_scale = 'dynamic'
826
+ self._run_and_report_benchmark()
827
+
828
+ def benchmark_xla_1_gpu_fp16(self):
829
+ """Test Keras model with XLA, 1 GPU and fp16."""
830
+ self._setup()
831
+
832
+ FLAGS.num_gpus = 1
833
+ FLAGS.enable_eager = True
834
+ FLAGS.enable_xla = True
835
+ FLAGS.distribution_strategy = 'one_device'
836
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16')
837
+ FLAGS.dtype = 'fp16'
838
+ FLAGS.batch_size = 256
839
+ self._run_and_report_benchmark()
840
+
841
+ def benchmark_xla_1_gpu_fp16_tweaked(self):
842
+ """Test Keras model with XLA, 1 GPU, fp16, and manual config tuning."""
843
+ self._setup()
844
+
845
+ FLAGS.num_gpus = 1
846
+ FLAGS.enable_eager = True
847
+ FLAGS.enable_xla = True
848
+ FLAGS.distribution_strategy = 'one_device'
849
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked')
850
+ FLAGS.dtype = 'fp16'
851
+ FLAGS.batch_size = 256
852
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
853
+ self._run_and_report_benchmark()
854
+
855
+ def benchmark_xla_1_gpu_fp16_dynamic(self):
856
+ """Test Keras model with XLA, 1 GPU, fp16, and dynamic loss scaling."""
857
+ self._setup()
858
+
859
+ FLAGS.num_gpus = 1
860
+ FLAGS.enable_eager = True
861
+ FLAGS.enable_xla = True
862
+ FLAGS.distribution_strategy = 'one_device'
863
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_dynamic')
864
+ FLAGS.dtype = 'fp16'
865
+ FLAGS.batch_size = 256
866
+ FLAGS.loss_scale = 'dynamic'
867
+ self._run_and_report_benchmark()
868
+
869
+ def benchmark_8_gpu(self):
870
+ """Test Keras model with 8 GPUs."""
871
+ self._setup()
872
+
873
+ FLAGS.num_gpus = 8
874
+ FLAGS.enable_eager = True
875
+ FLAGS.distribution_strategy = 'mirrored'
876
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
877
+ FLAGS.batch_size = 128 * 8 # 8 GPUs
878
+ self._run_and_report_benchmark()
879
+
880
+ def benchmark_8_gpu_amp(self):
881
+ """Test Keras model with 8 GPUs with automatic mixed precision."""
882
+ self._setup()
883
+
884
+ FLAGS.num_gpus = 8
885
+ FLAGS.enable_eager = True
886
+ FLAGS.dtype = 'fp16'
887
+ FLAGS.fp16_implementation = 'graph_rewrite'
888
+ FLAGS.distribution_strategy = 'mirrored'
889
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
890
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
891
+ self._run_and_report_benchmark()
892
+
893
+ def benchmark_8_gpu_tweaked(self):
894
+ """Test Keras model with manual config tuning and 8 GPUs."""
895
+ self._setup()
896
+
897
+ FLAGS.num_gpus = 8
898
+ FLAGS.enable_eager = True
899
+ FLAGS.distribution_strategy = 'mirrored'
900
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_tweaked')
901
+ FLAGS.batch_size = 128 * 8 # 8 GPUs
902
+ FLAGS.datasets_num_private_threads = 14
903
+ self._run_and_report_benchmark()
904
+
905
+ def benchmark_xla_8_gpu(self):
906
+ """Test Keras model with XLA and 8 GPUs."""
907
+ self._setup()
908
+
909
+ FLAGS.num_gpus = 8
910
+ FLAGS.enable_eager = True
911
+ FLAGS.enable_xla = True
912
+ FLAGS.distribution_strategy = 'mirrored'
913
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu')
914
+ FLAGS.batch_size = 128 * 8 # 8 GPUs
915
+ self._run_and_report_benchmark()
916
+
917
+ def benchmark_xla_8_gpu_amp(self):
918
+ """Test Keras model with XLA and 8 GPUs with automatic mixed precision."""
919
+ self._setup()
920
+
921
+ FLAGS.num_gpus = 8
922
+ FLAGS.enable_eager = True
923
+ FLAGS.dtype = 'fp16'
924
+ FLAGS.fp16_implementation = 'graph_rewrite'
925
+ FLAGS.enable_xla = True
926
+ FLAGS.distribution_strategy = 'mirrored'
927
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
928
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
929
+ self._run_and_report_benchmark()
930
+
931
+ def benchmark_xla_8_gpu_tweaked(self):
932
+ """Test Keras model with manual config tuning, 8 GPUs, and XLA."""
933
+ self._setup()
934
+
935
+ FLAGS.num_gpus = 8
936
+ FLAGS.enable_eager = True
937
+ FLAGS.enable_xla = True
938
+ FLAGS.distribution_strategy = 'mirrored'
939
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_tweaked')
940
+ FLAGS.batch_size = 128 * 8
941
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
942
+ FLAGS.datasets_num_private_threads = 24
943
+ self._run_and_report_benchmark()
944
+
945
+ def benchmark_8_gpu_fp16(self):
946
+ """Test Keras model with 8 GPUs and fp16."""
947
+ self._setup()
948
+
949
+ FLAGS.num_gpus = 8
950
+ FLAGS.dtype = 'fp16'
951
+ FLAGS.enable_eager = True
952
+ FLAGS.distribution_strategy = 'mirrored'
953
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
954
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
955
+ self._run_and_report_benchmark()
956
+
957
+ def benchmark_8_gpu_fp16_tweaked(self):
958
+ """Test Keras model with 8 GPUs, fp16, and manual config tuning."""
959
+ self._setup()
960
+
961
+ FLAGS.num_gpus = 8
962
+ FLAGS.dtype = 'fp16'
963
+ FLAGS.enable_eager = True
964
+ FLAGS.distribution_strategy = 'mirrored'
965
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked')
966
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
967
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
968
+ FLAGS.dataset_num_private_threads = 40
969
+ self._run_and_report_benchmark()
970
+
971
+ def benchmark_8_gpu_fp16_dynamic_tweaked(self):
972
+ """Test Keras model with 8 GPUs, fp16, dynamic loss scaling, and tuned."""
973
+ self._setup()
974
+
975
+ FLAGS.num_gpus = 8
976
+ FLAGS.dtype = 'fp16'
977
+ FLAGS.enable_eager = True
978
+ FLAGS.distribution_strategy = 'mirrored'
979
+ FLAGS.model_dir = self._get_model_dir(
980
+ 'benchmark_8_gpu_fp16_dynamic_tweaked')
981
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
982
+ FLAGS.loss_scale = 'dynamic'
983
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
984
+ FLAGS.dataset_num_private_threads = 40
985
+ self._run_and_report_benchmark()
986
+
987
+ def benchmark_xla_8_gpu_fp16(self):
988
+ """Test Keras model with XLA, 8 GPUs and fp16."""
989
+ self._setup()
990
+
991
+ FLAGS.num_gpus = 8
992
+ FLAGS.dtype = 'fp16'
993
+ FLAGS.enable_eager = True
994
+ FLAGS.enable_xla = True
995
+ FLAGS.distribution_strategy = 'mirrored'
996
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16')
997
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
998
+ self._run_and_report_benchmark()
999
+
1000
+ def benchmark_xla_8_gpu_fp16_tweaked(self):
1001
+ """Test Keras model with manual config tuning, XLA, 8 GPUs and fp16."""
1002
+ self._setup()
1003
+
1004
+ FLAGS.num_gpus = 8
1005
+ FLAGS.dtype = 'fp16'
1006
+ FLAGS.enable_eager = True
1007
+ FLAGS.enable_xla = True
1008
+ FLAGS.distribution_strategy = 'mirrored'
1009
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked')
1010
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
1011
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
1012
+ FLAGS.datasets_num_private_threads = 48
1013
+ self._run_and_report_benchmark()
1014
+
1015
+ def benchmark_xla_8_gpu_fp16_tweaked_delay_measure(self):
1016
+ """Test with manual config tuning, XLA, 8 GPUs and fp16.
1017
+
1018
+ Delay performance measurement for stable performance on 96 vCPU platforms.
1019
+ """
1020
+ self._setup()
1021
+
1022
+ FLAGS.num_gpus = 8
1023
+ FLAGS.dtype = 'fp16'
1024
+ FLAGS.enable_eager = True
1025
+ FLAGS.enable_xla = True
1026
+ FLAGS.distribution_strategy = 'mirrored'
1027
+ FLAGS.model_dir = self._get_model_dir(
1028
+ 'benchmark_xla_8_gpu_fp16_tweaked_delay_measure')
1029
+ FLAGS.batch_size = 256 * 8
1030
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
1031
+ FLAGS.datasets_num_private_threads = 48
1032
+ FLAGS.train_steps = 310
1033
+ self._run_and_report_benchmark()
1034
+
1035
+ def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
1036
+ """Test Keras model with config tuning, XLA, 8 GPUs and dynamic fp16."""
1037
+ self._setup()
1038
+
1039
+ FLAGS.num_gpus = 8
1040
+ FLAGS.dtype = 'fp16'
1041
+ FLAGS.enable_eager = True
1042
+ FLAGS.enable_xla = True
1043
+ FLAGS.distribution_strategy = 'mirrored'
1044
+ FLAGS.model_dir = self._get_model_dir(
1045
+ 'benchmark_xla_8_gpu_fp16_dynamic_tweaked')
1046
+ FLAGS.batch_size = 256 * 8 # 8 GPUs
1047
+ FLAGS.loss_scale = 'dynamic'
1048
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
1049
+ FLAGS.datasets_num_private_threads = 48
1050
+ self._run_and_report_benchmark()
1051
+
1052
+ def benchmark_2x2_tpu_bf16(self):
1053
+ """Test Keras model with 2x2 TPU, bf16."""
1054
+ self._setup()
1055
+
1056
+ FLAGS.dtype = 'bf16'
1057
+ FLAGS.distribution_strategy = 'tpu'
1058
+ FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16')
1059
+ FLAGS.batch_size = 1024
1060
+ self._run_and_report_benchmark()
1061
+
1062
+ def benchmark_4x4_tpu_bf16(self):
1063
+ """Test Keras model with 4x4 TPU, bf16."""
1064
+ self._setup()
1065
+
1066
+ FLAGS.dtype = 'bf16'
1067
+ FLAGS.distribution_strategy = 'tpu'
1068
+ FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16')
1069
+ FLAGS.batch_size = 4096
1070
+ self._run_and_report_benchmark()
1071
+
1072
+ def benchmark_8x8_tpu_bf16(self):
1073
+ """Test Keras model with 8x8 TPU, bf16."""
1074
+ self._setup()
1075
+
1076
+ FLAGS.dtype = 'bf16'
1077
+ FLAGS.distribution_strategy = 'tpu'
1078
+ FLAGS.model_dir = self._get_model_dir('benchmark_8x8_tpu_bf16')
1079
+ FLAGS.batch_size = 8192
1080
+ self._run_and_report_benchmark()
1081
+
1082
+ def fill_report_object(self, stats):
1083
+ super(Resnet50KerasBenchmarkBase, self).fill_report_object(
1084
+ stats,
1085
+ total_batch_size=FLAGS.batch_size,
1086
+ log_steps=FLAGS.log_steps)
1087
+
1088
+
1089
+ class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase):
1090
+ """Resnet50 synthetic benchmark tests."""
1091
+
1092
+ def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
1093
+ def_flags = {}
1094
+ def_flags['log_steps'] = 10
1095
+
1096
+ super(Resnet50KerasBenchmarkSynth, self).__init__(
1097
+ output_dir=output_dir, default_flags=def_flags, tpu=tpu,
1098
+ dataset_builder='synthetic', train_epochs=1, train_steps=110)
1099
+
1100
+
1101
+ class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase):
1102
+ """Resnet50 real data benchmark tests."""
1103
+
1104
+ def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
1105
+ data_dir = os.path.join(root_data_dir, 'imagenet')
1106
+ def_flags = {}
1107
+ def_flags['log_steps'] = 10
1108
+
1109
+ super(Resnet50KerasBenchmarkReal, self).__init__(
1110
+ output_dir=output_dir, default_flags=def_flags, tpu=tpu,
1111
+ dataset_builder='records', train_epochs=1, train_steps=110,
1112
+ data_dir=data_dir)
1113
+
1114
+
1115
+ class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase):
1116
+ """Resnet50 real data (stored in remote storage) benchmark tests."""
1117
+
1118
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
1119
+ def_flags = {}
1120
+ def_flags['skip_eval'] = True
1121
+ def_flags['report_accuracy_metrics'] = False
1122
+ def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
1123
+ # Defining multiple epochs overrides the train_steps setting in benchmarks.
1124
+ def_flags['train_epochs'] = 2
1125
+ # Cache dataset so performance is stable after the first epoch.
1126
+ def_flags['training_dataset_cache'] = True
1127
+ def_flags['log_steps'] = 100
1128
+ # Note that for single GPU and pure eager tests which are less likely to be
1129
+ # input bound and more stable, these tests will run for shorter time by
1130
+ # overriding FLAGS.train_epochs, train_seteps, log_steps in benchmark
1131
+ # methods, and skip_steps in _run_and_report_benchmark().
1132
+
1133
+ super(Resnet50KerasBenchmarkRemoteData, self).__init__(
1134
+ output_dir=output_dir, default_flags=def_flags)
1135
+
1136
+ def _override_flags_to_run_test_shorter(self):
1137
+ FLAGS.train_epochs = 1
1138
+ FLAGS.train_steps = 300
1139
+ FLAGS.log_steps = 10
1140
+
1141
+ def benchmark_1_gpu_no_dist_strat(self):
1142
+ """Test Keras model with 1 GPU, no distribution strategy."""
1143
+ self._setup()
1144
+
1145
+ FLAGS.num_gpus = 1
1146
+ FLAGS.enable_eager = True
1147
+ FLAGS.distribution_strategy = 'off'
1148
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat')
1149
+ FLAGS.batch_size = 128
1150
+ self._override_flags_to_run_test_shorter()
1151
+ self._run_and_report_benchmark()
1152
+
1153
+ def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
1154
+ """Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
1155
+ self._setup()
1156
+
1157
+ FLAGS.num_gpus = 1
1158
+ FLAGS.enable_eager = True
1159
+ FLAGS.run_eagerly = True
1160
+ FLAGS.distribution_strategy = 'off'
1161
+ FLAGS.model_dir = self._get_model_dir(
1162
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly')
1163
+ FLAGS.batch_size = 64
1164
+ self._override_flags_to_run_test_shorter()
1165
+ self._run_and_report_benchmark()
1166
+
1167
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_tweaked(self):
1168
+ """Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
1169
+ self._setup()
1170
+
1171
+ FLAGS.num_gpus = 1
1172
+ FLAGS.enable_eager = True
1173
+ FLAGS.run_eagerly = True
1174
+ FLAGS.explicit_gpu_placement = True
1175
+ FLAGS.distribution_strategy = 'off'
1176
+ FLAGS.model_dir = self._get_model_dir(
1177
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly_tweaked')
1178
+ FLAGS.batch_size = 64
1179
+ self._override_flags_to_run_test_shorter()
1180
+ self._run_and_report_benchmark()
1181
+
1182
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16(self):
1183
+ """Test with 1 GPU, no distribution strategy, fp16, run eagerly."""
1184
+ self._setup()
1185
+
1186
+ FLAGS.num_gpus = 1
1187
+ FLAGS.enable_eager = True
1188
+ FLAGS.run_eagerly = True
1189
+ FLAGS.distribution_strategy = 'off'
1190
+ FLAGS.model_dir = self._get_model_dir(
1191
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly_fp16')
1192
+ FLAGS.dtype = 'fp16'
1193
+ FLAGS.batch_size = 128
1194
+ self._override_flags_to_run_test_shorter()
1195
+ self._run_and_report_benchmark()
1196
+
1197
+ def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16_tweaked(self):
1198
+ """Test with 1 GPU, no distribution strategy, fp16, run eagerly."""
1199
+ self._setup()
1200
+
1201
+ FLAGS.num_gpus = 1
1202
+ FLAGS.enable_eager = True
1203
+ FLAGS.run_eagerly = True
1204
+ FLAGS.explicit_gpu_placement = True
1205
+ FLAGS.distribution_strategy = 'off'
1206
+ FLAGS.model_dir = self._get_model_dir(
1207
+ 'benchmark_1_gpu_no_dist_strat_run_eagerly_fp16_tweaked')
1208
+ FLAGS.dtype = 'fp16'
1209
+ FLAGS.batch_size = 128
1210
+ self._override_flags_to_run_test_shorter()
1211
+ self._run_and_report_benchmark()
1212
+
1213
+ def benchmark_1_gpu(self):
1214
+ """Test Keras model with 1 GPU."""
1215
+ self._setup()
1216
+
1217
+ FLAGS.num_gpus = 1
1218
+ FLAGS.enable_eager = True
1219
+ FLAGS.distribution_strategy = 'one_device'
1220
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
1221
+ FLAGS.batch_size = 128
1222
+ self._override_flags_to_run_test_shorter()
1223
+ self._run_and_report_benchmark()
1224
+
1225
+ def benchmark_1_gpu_amp(self):
1226
+ """Test Keras model with 1 GPU with automatic mixed precision."""
1227
+ self._setup()
1228
+
1229
+ FLAGS.num_gpus = 1
1230
+ FLAGS.enable_eager = True
1231
+ FLAGS.dtype = 'fp16'
1232
+ FLAGS.fp16_implementation = 'graph_rewrite'
1233
+ FLAGS.distribution_strategy = 'one_device'
1234
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_amp')
1235
+ FLAGS.batch_size = 256
1236
+ self._override_flags_to_run_test_shorter()
1237
+ self._run_and_report_benchmark()
1238
+
1239
+ def benchmark_xla_1_gpu(self):
1240
+ """Test Keras model with XLA and 1 GPU."""
1241
+ self._setup()
1242
+
1243
+ FLAGS.num_gpus = 1
1244
+ FLAGS.enable_eager = True
1245
+ FLAGS.enable_xla = True
1246
+ FLAGS.distribution_strategy = 'one_device'
1247
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu')
1248
+ FLAGS.batch_size = 128
1249
+ self._override_flags_to_run_test_shorter()
1250
+ self._run_and_report_benchmark()
1251
+
1252
+ def benchmark_xla_1_gpu_amp(self):
1253
+ """Test Keras model with XLA and 1 GPU with automatic mixed precision."""
1254
+ self._setup()
1255
+
1256
+ FLAGS.num_gpus = 1
1257
+ FLAGS.enable_eager = True
1258
+ FLAGS.dtype = 'fp16'
1259
+ FLAGS.fp16_implementation = 'graph_rewrite'
1260
+ FLAGS.enable_xla = True
1261
+ FLAGS.distribution_strategy = 'one_device'
1262
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_amp')
1263
+ FLAGS.batch_size = 256
1264
+ self._override_flags_to_run_test_shorter()
1265
+ self._run_and_report_benchmark()
1266
+
1267
+ def benchmark_1_gpu_fp16(self):
1268
+ """Test Keras model with 1 GPU and fp16."""
1269
+ self._setup()
1270
+
1271
+ FLAGS.num_gpus = 1
1272
+ FLAGS.enable_eager = True
1273
+ FLAGS.distribution_strategy = 'one_device'
1274
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16')
1275
+ FLAGS.dtype = 'fp16'
1276
+ FLAGS.batch_size = 256
1277
+ self._override_flags_to_run_test_shorter()
1278
+ self._run_and_report_benchmark()
1279
+
1280
+ def benchmark_1_gpu_fp16_dynamic(self):
1281
+ """Test Keras model with 1 GPU, fp16, and dynamic loss scaling."""
1282
+ self._setup()
1283
+
1284
+ FLAGS.num_gpus = 1
1285
+ FLAGS.enable_eager = True
1286
+ FLAGS.distribution_strategy = 'one_device'
1287
+ FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_fp16_dynamic')
1288
+ FLAGS.dtype = 'fp16'
1289
+ FLAGS.batch_size = 256
1290
+ FLAGS.loss_scale = 'dynamic'
1291
+ self._override_flags_to_run_test_shorter()
1292
+ self._run_and_report_benchmark()
1293
+
1294
+ def benchmark_xla_1_gpu_fp16(self):
1295
+ """Test Keras model with XLA, 1 GPU and fp16."""
1296
+ self._setup()
1297
+
1298
+ FLAGS.num_gpus = 1
1299
+ FLAGS.enable_eager = True
1300
+ FLAGS.enable_xla = True
1301
+ FLAGS.distribution_strategy = 'one_device'
1302
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16')
1303
+ FLAGS.dtype = 'fp16'
1304
+ FLAGS.batch_size = 256
1305
+ self._override_flags_to_run_test_shorter()
1306
+ self._run_and_report_benchmark()
1307
+
1308
+ def benchmark_xla_1_gpu_fp16_tweaked(self):
1309
+ """Test Keras model with XLA, 1 GPU, fp16, and manual config tuning."""
1310
+ self._setup()
1311
+
1312
+ FLAGS.num_gpus = 1
1313
+ FLAGS.enable_eager = True
1314
+ FLAGS.enable_xla = True
1315
+ FLAGS.distribution_strategy = 'one_device'
1316
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked')
1317
+ FLAGS.dtype = 'fp16'
1318
+ FLAGS.batch_size = 256
1319
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
1320
+ self._override_flags_to_run_test_shorter()
1321
+ self._run_and_report_benchmark()
1322
+
1323
+ def benchmark_xla_1_gpu_fp16_dynamic(self):
1324
+ """Test Keras model with XLA, 1 GPU, fp16, and dynamic loss scaling."""
1325
+ self._setup()
1326
+
1327
+ FLAGS.num_gpus = 1
1328
+ FLAGS.enable_eager = True
1329
+ FLAGS.enable_xla = True
1330
+ FLAGS.distribution_strategy = 'one_device'
1331
+ FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_dynamic')
1332
+ FLAGS.dtype = 'fp16'
1333
+ FLAGS.batch_size = 256
1334
+ FLAGS.loss_scale = 'dynamic'
1335
+ self._override_flags_to_run_test_shorter()
1336
+ self._run_and_report_benchmark()
1337
+
1338
+ @benchmark_wrappers.enable_runtime_flags
1339
+ def _run_and_report_benchmark(self):
1340
+ if FLAGS.num_gpus == 1 or FLAGS.run_eagerly:
1341
+ # For single GPU and pure eager tests which are less likely to be input
1342
+ # bound and more stable, run for shorter time and use the default
1343
+ # skip_steps.
1344
+ skip_steps = None
1345
+ else:
1346
+ # skip the first epoch for performance measurement.
1347
+ skip_steps = 600
1348
+ super(Resnet50KerasBenchmarkRemoteData,
1349
+ self)._run_and_report_benchmark(skip_steps=skip_steps)
1350
+
1351
+
1352
+ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
1353
+ """Trivial model with real data benchmark tests."""
1354
+
1355
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
1356
+ flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
1357
+
1358
+ def_flags = {}
1359
+ def_flags['use_trivial_model'] = True
1360
+ def_flags['skip_eval'] = True
1361
+ def_flags['report_accuracy_metrics'] = False
1362
+ def_flags['dtype'] = 'fp16'
1363
+ def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
1364
+ def_flags['train_steps'] = 600
1365
+ def_flags['log_steps'] = 100
1366
+ def_flags['distribution_strategy'] = 'mirrored'
1367
+
1368
+ super(TrivialKerasBenchmarkReal, self).__init__(
1369
+ output_dir=output_dir,
1370
+ flag_methods=flag_methods,
1371
+ default_flags=def_flags)
1372
+
1373
+ @benchmark_wrappers.enable_runtime_flags
1374
+ def _run_and_report_benchmark(self):
1375
+ start_time_sec = time.time()
1376
+ stats = resnet_imagenet_main.run(FLAGS)
1377
+ wall_time_sec = time.time() - start_time_sec
1378
+
1379
+ super(TrivialKerasBenchmarkReal, self)._report_benchmark(
1380
+ stats,
1381
+ wall_time_sec,
1382
+ total_batch_size=FLAGS.batch_size,
1383
+ log_steps=FLAGS.log_steps)
1384
+
1385
+ def benchmark_8_gpu_warmup(self):
1386
+ """Dummy test that runs over an epoch to warmup the machine."""
1387
+ self._setup()
1388
+
1389
+ FLAGS.num_gpus = 8
1390
+ FLAGS.enable_eager = True
1391
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_warmup')
1392
+ FLAGS.batch_size = 256 * 8
1393
+ FLAGS.train_steps = 700
1394
+ self._run_and_report_benchmark()
1395
+
1396
+ def fill_report_object(self, stats):
1397
+ super(TrivialKerasBenchmarkReal, self).fill_report_object(
1398
+ stats,
1399
+ total_batch_size=FLAGS.batch_size,
1400
+ log_steps=FLAGS.log_steps)
1401
+
1402
+
1403
+ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark):
1404
+ """Resnet50 distributed accuracy tests with multiple workers."""
1405
+
1406
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
1407
+ flag_methods = [classifier_trainer.define_imagenet_keras_flags]
1408
+ self.data_dir = os.path.join(root_data_dir, 'imagenet')
1409
+ super(Resnet50MultiWorkerKerasAccuracy, self).__init__(
1410
+ output_dir=output_dir, flag_methods=flag_methods)
1411
+
1412
+ def _benchmark_common(self, eager, num_workers, all_reduce_alg):
1413
+ """Common to all benchmarks in this class."""
1414
+ self._setup()
1415
+
1416
+ num_gpus = 8
1417
+ FLAGS.num_gpus = num_gpus
1418
+ FLAGS.data_dir = self.data_dir
1419
+ FLAGS.train_epochs = 90
1420
+ FLAGS.epochs_between_evals = 10
1421
+ FLAGS.dtype = 'fp16'
1422
+ FLAGS.enable_eager = eager
1423
+ FLAGS.enable_xla = False
1424
+ FLAGS.distribution_strategy = 'multi_worker_mirrored'
1425
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
1426
+ FLAGS.datasets_num_private_threads = 32
1427
+ FLAGS.model_dir = self._get_model_dir(
1428
+ 'benchmark_{}_8_gpu_{}_worker_fp16_{}_tweaked'.format(
1429
+ 'eager' if eager else 'graph', num_workers, all_reduce_alg))
1430
+ FLAGS.batch_size = 256 * num_gpus * num_workers
1431
+ FLAGS.all_reduce_alg = all_reduce_alg
1432
+
1433
+ self._run_and_report_benchmark()
1434
+
1435
+ @benchmark_wrappers.enable_runtime_flags
1436
+ def _run_and_report_benchmark(self,
1437
+ top_1_min=MIN_TOP_1_ACCURACY,
1438
+ top_1_max=MAX_TOP_1_ACCURACY):
1439
+ start_time_sec = time.time()
1440
+ stats = classifier_trainer.run(flags.FLAGS)
1441
+ wall_time_sec = time.time() - start_time_sec
1442
+
1443
+ super(Resnet50MultiWorkerKerasAccuracy, self)._report_benchmark(
1444
+ stats,
1445
+ wall_time_sec,
1446
+ top_1_min=top_1_min,
1447
+ top_1_max=top_1_max,
1448
+ total_batch_size=FLAGS.batch_size,
1449
+ log_steps=100)
1450
+
1451
+ def _get_model_dir(self, folder_name):
1452
+ return os.path.join(self.output_dir, folder_name)
1453
+
1454
+ def benchmark_eager_8_gpu_2_workers_fp16_ring_tweaked(self):
1455
+ """Eager, 8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
1456
+ self._benchmark_common(eager=True, num_workers=2, all_reduce_alg='ring')
1457
+
1458
+ def benchmark_eager_8_gpu_2_workers_fp16_nccl_tweaked(self):
1459
+ """Eager, 8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
1460
+ self._benchmark_common(eager=True, num_workers=2, all_reduce_alg='nccl')
1461
+
1462
+ def benchmark_eager_8_gpu_8_workers_fp16_ring_tweaked(self):
1463
+ """Eager, 8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
1464
+ self._benchmark_common(eager=True, num_workers=8, all_reduce_alg='ring')
1465
+
1466
+ def benchmark_eager_8_gpu_8_workers_fp16_nccl_tweaked(self):
1467
+ """Eager, 8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
1468
+ self._benchmark_common(eager=True, num_workers=8, all_reduce_alg='nccl')
1469
+
1470
+
1471
+ class Resnet50MultiWorkerKerasBenchmark(Resnet50KerasBenchmarkBase):
1472
+ """Resnet50 distributed benchmark tests with multiple workers."""
1473
+
1474
+ def __init__(self, output_dir=None, default_flags=None):
1475
+ super(Resnet50MultiWorkerKerasBenchmark, self).__init__(
1476
+ output_dir=output_dir, default_flags=default_flags)
1477
+
1478
+ def _benchmark_common(self, eager, num_workers, all_reduce_alg):
1479
+ """Common to all benchmarks in this class."""
1480
+ self._setup()
1481
+
1482
+ num_gpus = 8
1483
+ FLAGS.num_gpus = num_gpus
1484
+ FLAGS.dtype = 'fp16'
1485
+ FLAGS.enable_eager = eager
1486
+ FLAGS.enable_xla = False
1487
+ FLAGS.distribution_strategy = 'multi_worker_mirrored'
1488
+ FLAGS.tf_gpu_thread_mode = 'gpu_private'
1489
+ FLAGS.datasets_num_private_threads = 32
1490
+ FLAGS.model_dir = self._get_model_dir(
1491
+ 'benchmark_{}_8_gpu_{}_worker_fp16_{}_tweaked'.format(
1492
+ 'eager' if eager else 'graph', num_workers, all_reduce_alg))
1493
+ FLAGS.batch_size = 256 * num_gpus * num_workers
1494
+ FLAGS.all_reduce_alg = all_reduce_alg
1495
+
1496
+ self._run_and_report_benchmark()
1497
+
1498
+ def benchmark_eager_8_gpu_1_worker_fp16_ring_tweaked(self):
1499
+ """Eager, 8 GPUs per worker, 1 worker, fp16, ring all-reduce."""
1500
+ self._benchmark_common(eager=True, num_workers=1, all_reduce_alg='ring')
1501
+
1502
+ def benchmark_eager_8_gpu_1_worker_fp16_nccl_tweaked(self):
1503
+ """Eager, 8 GPUs per worker, 1 worker, fp16, nccl all-reduce."""
1504
+ self._benchmark_common(eager=True, num_workers=1, all_reduce_alg='nccl')
1505
+
1506
+ def benchmark_eager_8_gpu_2_workers_fp16_ring_tweaked(self):
1507
+ """Eager, 8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
1508
+ self._benchmark_common(eager=True, num_workers=2, all_reduce_alg='ring')
1509
+
1510
+ def benchmark_eager_8_gpu_2_workers_fp16_nccl_tweaked(self):
1511
+ """Eager, 8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
1512
+ self._benchmark_common(eager=True, num_workers=2, all_reduce_alg='nccl')
1513
+
1514
+ def benchmark_eager_8_gpu_8_workers_fp16_ring_tweaked(self):
1515
+ """Eager, 8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
1516
+ self._benchmark_common(eager=True, num_workers=8, all_reduce_alg='ring')
1517
+
1518
+ def benchmark_eager_8_gpu_8_workers_fp16_nccl_tweaked(self):
1519
+ """Eager, 8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
1520
+ self._benchmark_common(eager=True, num_workers=8, all_reduce_alg='nccl')
1521
+
1522
+
1523
+ class Resnet50MultiWorkerKerasBenchmarkSynth(Resnet50MultiWorkerKerasBenchmark):
1524
+ """Resnet50 multi-worker synthetic data benchmark tests."""
1525
+
1526
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
1527
+ def_flags = {}
1528
+ def_flags['skip_eval'] = True
1529
+ def_flags['report_accuracy_metrics'] = False
1530
+ def_flags['use_synthetic_data'] = True
1531
+ def_flags['train_steps'] = 110
1532
+ def_flags['log_steps'] = 10
1533
+
1534
+ super(Resnet50MultiWorkerKerasBenchmarkSynth, self).__init__(
1535
+ output_dir=output_dir, default_flags=def_flags)
1536
+
1537
+
1538
+ class Resnet50MultiWorkerKerasBenchmarkReal(Resnet50MultiWorkerKerasBenchmark):
1539
+ """Resnet50 multi-worker real data benchmark tests."""
1540
+
1541
+ def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
1542
+ def_flags = {}
1543
+ def_flags['skip_eval'] = True
1544
+ def_flags['report_accuracy_metrics'] = False
1545
+ def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
1546
+ def_flags['train_steps'] = 110
1547
+ def_flags['log_steps'] = 10
1548
+
1549
+ super(Resnet50MultiWorkerKerasBenchmarkReal, self).__init__(
1550
+ output_dir=output_dir, default_flags=def_flags)
1551
+
1552
+
1553
+ # TODO(kimjaehong): It also should be also cover other metheods of model
1554
+ # optimization techniques. In that time, this class will change to something
1555
+ # like 'KerasModelOptimizationAccuracyBase'.
1556
+ class KerasPruningAccuracyBase(keras_benchmark.KerasBenchmark):
1557
+ """Benchmark accuracy tests for pruning method."""
1558
+
1559
+ def __init__(self,
1560
+ output_dir=None,
1561
+ root_data_dir=None,
1562
+ default_flags=None,
1563
+ **kwargs):
1564
+ """A accuracy benchmark class for pruning method.
1565
+
1566
+ Args:
1567
+ output_dir: directory where to output e.g. log files
1568
+ root_data_dir: directory under which to look for dataset
1569
+ default_flags: default flags
1570
+ **kwargs: arbitrary named arguments. This is needed to make the
1571
+ constructor forward compatible in case PerfZero provides more
1572
+ named arguments before updating the constructor.
1573
+ """
1574
+ if default_flags is None:
1575
+ default_flags = {}
1576
+ default_flags['pruning_method'] = 'polynomial_decay'
1577
+ default_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
1578
+
1579
+ flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
1580
+
1581
+ super(KerasPruningAccuracyBase, self).__init__(
1582
+ output_dir=output_dir,
1583
+ flag_methods=flag_methods,
1584
+ default_flags=default_flags,
1585
+ **kwargs)
1586
+
1587
+ def benchmark_8_gpu(self):
1588
+ """Test Keras model with eager, dist_strat and 8 GPUs."""
1589
+ self._setup()
1590
+ FLAGS.num_gpus = 8
1591
+ FLAGS.batch_size = 32 * 8
1592
+ FLAGS.train_epochs = 90
1593
+ FLAGS.epochs_between_evals = 10
1594
+ FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
1595
+ FLAGS.dtype = 'fp32'
1596
+ FLAGS.enable_eager = True
1597
+ self._run_and_report_benchmark()
1598
+
1599
+ @benchmark_wrappers.enable_runtime_flags
1600
+ def _run_and_report_benchmark(self,
1601
+ top_1_min=MODEL_OPTIMIZATION_TOP_1_ACCURACY[
1602
+ 'RESNET50_FINETUNE_PRUNING'][0],
1603
+ top_1_max=MODEL_OPTIMIZATION_TOP_1_ACCURACY[
1604
+ 'RESNET50_FINETUNE_PRUNING'][1]):
1605
+ start_time_sec = time.time()
1606
+ stats = resnet_imagenet_main.run(flags.FLAGS)
1607
+ wall_time_sec = time.time() - start_time_sec
1608
+
1609
+ super(KerasPruningAccuracyBase, self)._report_benchmark(
1610
+ stats,
1611
+ wall_time_sec,
1612
+ top_1_min=top_1_min,
1613
+ top_1_max=top_1_max,
1614
+ total_batch_size=FLAGS.batch_size,
1615
+ log_steps=100)
1616
+
1617
+
1618
+ class MobilenetV1KerasPruningAccuracy(KerasPruningAccuracyBase):
1619
+ """Benchmark accuracy tests for MobilenetV1 with pruning method."""
1620
+
1621
+ def __init__(self, root_data_dir=None, **kwargs):
1622
+ default_flags = {
1623
+ 'model': 'mobilenet',
1624
+ 'optimizer': 'mobilenet_default',
1625
+ 'initial_learning_rate_per_sample': 0.00007,
1626
+ 'pretrained_filepath': tf.train.latest_checkpoint(
1627
+ os.path.join(root_data_dir, 'mobilenet_v1')),
1628
+ 'pruning_begin_step': 0,
1629
+ 'pruning_end_step': 100000,
1630
+ 'pruning_initial_sparsity': 0.0,
1631
+ 'pruning_final_sparsity': 0.5,
1632
+ 'pruning_frequency': 100,
1633
+ }
1634
+ super(MobilenetV1KerasPruningAccuracy, self).__init__(
1635
+ root_data_dir=root_data_dir,
1636
+ default_flags=default_flags,
1637
+ **kwargs)
1638
+
1639
+ def _run_and_report_benchmark(self):
1640
+ super(MobilenetV1KerasPruningAccuracy, self)._run_and_report_benchmark(
1641
+ top_1_min=\
1642
+ MODEL_OPTIMIZATION_TOP_1_ACCURACY['MOBILENET_V1_FINETUNE_PRUNING'][0],
1643
+ top_1_max=\
1644
+ MODEL_OPTIMIZATION_TOP_1_ACCURACY['MOBILENET_V1_FINETUNE_PRUNING'][1])
1645
+
1646
+
1647
+ class Resnet50KerasPruningAccuracy(KerasPruningAccuracyBase):
1648
+ """Benchmark accuracy tests for resnet50 with pruning method."""
1649
+
1650
+ def __init__(self, root_data_dir=None, **kwargs):
1651
+ default_flags = {
1652
+ 'model': 'resnet50_v1.5',
1653
+ 'optimizer': 'mobilenet_default',
1654
+ 'initial_learning_rate_per_sample': 0.0000039,
1655
+ 'pretrained_filepath': tf.train.latest_checkpoint(
1656
+ os.path.join(root_data_dir, 'resnet50')),
1657
+ 'pruning_begin_step': 0,
1658
+ 'pruning_end_step': 50000,
1659
+ 'pruning_initial_sparsity': 0.0,
1660
+ 'pruning_final_sparsity': 0.5,
1661
+ 'pruning_frequency': 100,
1662
+ }
1663
+ super(Resnet50KerasPruningAccuracy, self).__init__(
1664
+ root_data_dir=root_data_dir,
1665
+ default_flags=default_flags,
1666
+ **kwargs)
1667
+
1668
+ def _run_and_report_benchmark(self):
1669
+ super(Resnet50KerasPruningAccuracy, self)._run_and_report_benchmark(
1670
+ top_1_min=\
1671
+ MODEL_OPTIMIZATION_TOP_1_ACCURACY['RESNET50_FINETUNE_PRUNING'][0],
1672
+ top_1_max=\
1673
+ MODEL_OPTIMIZATION_TOP_1_ACCURACY['RESNET50_FINETUNE_PRUNING'][1])
1674
+
1675
+
1676
+ class KerasPruningBenchmarkRealBase(Resnet50KerasBenchmarkBase):
1677
+ """Pruning method benchmarks."""
1678
+
1679
+ def __init__(self, root_data_dir=None, default_flags=None, **kwargs):
1680
+ if default_flags is None:
1681
+ default_flags = {}
1682
+ default_flags.update({
1683
+ 'skip_eval': True,
1684
+ 'report_accuracy_metrics': False,
1685
+ 'data_dir': os.path.join(root_data_dir, 'imagenet'),
1686
+ 'train_steps': 110,
1687
+ 'log_steps': 10,
1688
+ 'pruning_method': 'polynomial_decay',
1689
+ 'pruning_begin_step': 0,
1690
+ 'pruning_end_step': 50000,
1691
+ 'pruning_initial_sparsity': 0,
1692
+ 'pruning_final_sparsity': 0.5,
1693
+ 'pruning_frequency': 100,
1694
+ })
1695
+ super(KerasPruningBenchmarkRealBase, self).__init__(
1696
+ default_flags=default_flags, **kwargs)
1697
+
1698
+
1699
+ class MobilenetV1KerasPruningBenchmarkReal(KerasPruningBenchmarkRealBase):
1700
+ """Pruning method benchmarks for MobilenetV1."""
1701
+
1702
+ def __init__(self, **kwargs):
1703
+ default_flags = {
1704
+ 'model': 'mobilenet',
1705
+ 'optimizer': 'mobilenet_default',
1706
+ }
1707
+ super(MobilenetV1KerasPruningBenchmarkReal, self).__init__(
1708
+ default_flags=default_flags, **kwargs)
1709
+
1710
+
1711
+ class Resnet50KerasPruningBenchmarkReal(KerasPruningBenchmarkRealBase):
1712
+ """Pruning method benchmarks for resnet50."""
1713
+
1714
+ def __init__(self, **kwargs):
1715
+ default_flags = {
1716
+ 'model': 'resnet50_v1.5',
1717
+ 'optimizer': 'mobilenet_default',
1718
+ }
1719
+ super(Resnet50KerasPruningBenchmarkReal, self).__init__(
1720
+ default_flags=default_flags, **kwargs)
1721
+
1722
+
1723
+ if __name__ == '__main__':
1724
+ tf.test.main()
models/official/benchmark/models/__init__.py ADDED
File without changes
models/official/benchmark/models/cifar_preprocessing.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides utilities to Cifar-10 dataset."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import os
22
+ from absl import logging
23
+ import tensorflow as tf
24
+
25
+ from official.vision.image_classification.resnet import imagenet_preprocessing
26
+
27
+ HEIGHT = 32
28
+ WIDTH = 32
29
+ NUM_CHANNELS = 3
30
+ _DEFAULT_IMAGE_BYTES = HEIGHT * WIDTH * NUM_CHANNELS
31
+ # The record is the image plus a one-byte label
32
+ _RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
33
+
34
+ # TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
35
+ NUM_IMAGES = {
36
+ 'train': 50000,
37
+ 'validation': 10000,
38
+ }
39
+ _NUM_DATA_FILES = 5
40
+ NUM_CLASSES = 10
41
+
42
+
43
+ def parse_record(raw_record, is_training, dtype):
44
+ """Parses a record containing a training example of an image.
45
+
46
+ The input record is parsed into a label and image, and the image is passed
47
+ through preprocessing steps (cropping, flipping, and so on).
48
+
49
+ This method converts the label to one hot to fit the loss function.
50
+
51
+ Args:
52
+ raw_record: scalar Tensor tf.string containing a serialized
53
+ Example protocol buffer.
54
+ is_training: A boolean denoting whether the input is for training.
55
+ dtype: Data type to use for input images.
56
+
57
+ Returns:
58
+ Tuple with processed image tensor and one-hot-encoded label tensor.
59
+ """
60
+ # Convert bytes to a vector of uint8 that is record_bytes long.
61
+ record_vector = tf.io.decode_raw(raw_record, tf.uint8)
62
+
63
+ # The first byte represents the label, which we convert from uint8 to int32
64
+ # and then to one-hot.
65
+ label = tf.cast(record_vector[0], tf.int32)
66
+
67
+ # The remaining bytes after the label represent the image, which we reshape
68
+ # from [depth * height * width] to [depth, height, width].
69
+ depth_major = tf.reshape(record_vector[1:_RECORD_BYTES],
70
+ [NUM_CHANNELS, HEIGHT, WIDTH])
71
+
72
+ # Convert from [depth, height, width] to [height, width, depth], and cast as
73
+ # float32.
74
+ image = tf.cast(tf.transpose(a=depth_major, perm=[1, 2, 0]), tf.float32)
75
+
76
+ image = preprocess_image(image, is_training)
77
+ image = tf.cast(image, dtype)
78
+
79
+ return image, label
80
+
81
+
82
+ def preprocess_image(image, is_training):
83
+ """Preprocess a single image of layout [height, width, depth]."""
84
+ if is_training:
85
+ # Resize the image to add four extra pixels on each side.
86
+ image = tf.image.resize_with_crop_or_pad(
87
+ image, HEIGHT + 8, WIDTH + 8)
88
+
89
+ # Randomly crop a [HEIGHT, WIDTH] section of the image.
90
+ image = tf.image.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS])
91
+
92
+ # Randomly flip the image horizontally.
93
+ image = tf.image.random_flip_left_right(image)
94
+
95
+ # Subtract off the mean and divide by the variance of the pixels.
96
+ image = tf.image.per_image_standardization(image)
97
+ return image
98
+
99
+
100
+ def get_filenames(is_training, data_dir):
101
+ """Returns a list of filenames."""
102
+ assert tf.io.gfile.exists(data_dir), (
103
+ 'Run cifar10_download_and_extract.py first to download and extract the '
104
+ 'CIFAR-10 data.')
105
+
106
+ if is_training:
107
+ return [
108
+ os.path.join(data_dir, 'data_batch_%d.bin' % i)
109
+ for i in range(1, _NUM_DATA_FILES + 1)
110
+ ]
111
+ else:
112
+ return [os.path.join(data_dir, 'test_batch.bin')]
113
+
114
+
115
+ def input_fn(is_training,
116
+ data_dir,
117
+ batch_size,
118
+ dtype=tf.float32,
119
+ datasets_num_private_threads=None,
120
+ parse_record_fn=parse_record,
121
+ input_context=None,
122
+ drop_remainder=False):
123
+ """Input function which provides batches for train or eval.
124
+
125
+ Args:
126
+ is_training: A boolean denoting whether the input is for training.
127
+ data_dir: The directory containing the input data.
128
+ batch_size: The number of samples per batch.
129
+ dtype: Data type to use for images/features
130
+ datasets_num_private_threads: Number of private threads for tf.data.
131
+ parse_record_fn: Function to use for parsing the records.
132
+ input_context: A `tf.distribute.InputContext` object passed in by
133
+ `tf.distribute.Strategy`.
134
+ drop_remainder: A boolean indicates whether to drop the remainder of the
135
+ batches. If True, the batch dimension will be static.
136
+
137
+ Returns:
138
+ A dataset that can be used for iteration.
139
+ """
140
+ filenames = get_filenames(is_training, data_dir)
141
+ dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
142
+
143
+ if input_context:
144
+ logging.info(
145
+ 'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
146
+ input_context.input_pipeline_id, input_context.num_input_pipelines)
147
+ dataset = dataset.shard(input_context.num_input_pipelines,
148
+ input_context.input_pipeline_id)
149
+
150
+ return imagenet_preprocessing.process_record_dataset(
151
+ dataset=dataset,
152
+ is_training=is_training,
153
+ batch_size=batch_size,
154
+ shuffle_buffer=NUM_IMAGES['train'],
155
+ parse_record_fn=parse_record_fn,
156
+ dtype=dtype,
157
+ datasets_num_private_threads=datasets_num_private_threads,
158
+ drop_remainder=drop_remainder
159
+ )