File size: 1,823 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Using flags in official models

1. **All common flags must be incorporated in the models.**

   Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions,
   and channeled through `official.utils.flags.core`. For instance to define common supervised
   learning parameters one could use the following code:

   ```$xslt
   from absl import app as absl_app
   from absl import flags

   from official.utils.flags import core as flags_core


   def define_flags():
     flags_core.define_base()
     flags.adopt_key_flags(flags_core)


   def main(_):
     flags_obj = flags.FLAGS
     print(flags_obj)


   if __name__ == "__main__"
     absl_app.run(main)
   ```
2. **Validate flag values.**

   See the [Validators](#validators) section for implementation details.

   Validators in the official model repo should not access the file system, such as verifying
   that files exist, due to the strict ordering requirements.

3. **Flag values should not be mutated.**

   Instead of mutating flag values, use getter functions to return the desired values. An example
   getter function is `get_tf_dtype` function below:

   ```
   # Map string to TensorFlow dtype
   DTYPE_MAP = {
       "fp16": tf.float16,
       "fp32": tf.float32,
   }

   def get_tf_dtype(flags_obj):
     if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite":
       # If the graph_rewrite is used, we build the graph with fp32, and let the
       # graph rewrite change ops to fp16.
       return tf.float32
     return DTYPE_MAP[flags_obj.dtype]


   def main(_):
     flags_obj = flags.FLAGS()

     # Do not mutate flags_obj
     # if flags_obj.fp16_implementation == "graph_rewrite":
     #   flags_obj.dtype = "float32" # Don't do this

     print(get_tf_dtype(flags_obj))
     ...
   ```