|
|
|
|
|
|
|
from .utils import to_json |
|
from .runtime.env_vars import trial_env_vars |
|
from .runtime import platform |
|
|
|
|
|
__all__ = [ |
|
'get_next_parameter', |
|
'get_current_parameter', |
|
'report_intermediate_result', |
|
'report_final_result', |
|
'get_experiment_id', |
|
'get_trial_id', |
|
'get_sequence_id' |
|
] |
|
|
|
|
|
_params = None |
|
_experiment_id = platform.get_experiment_id() |
|
_trial_id = platform.get_trial_id() |
|
_sequence_id = platform.get_sequence_id() |
|
|
|
|
|
def get_next_parameter(): |
|
""" |
|
Get the hyper paremeters generated by tuner. For a multiphase experiment, it returns a new group of hyper |
|
parameters at each call of get_next_parameter. For a non-multiphase (multiPhase is not configured or set to False) |
|
experiment, it returns hyper parameters only on the first call for each trial job, it returns None since second call. |
|
This API should be called only once in each trial job of an experiment which is not specified as multiphase. |
|
|
|
Returns |
|
------- |
|
dict |
|
A dict object contains the hyper parameters generated by tuner, the keys of the dict are defined in |
|
search space. Returns None if no more hyper parameters can be generated by tuner. |
|
""" |
|
global _params |
|
_params = platform.get_next_parameter() |
|
if _params is None: |
|
return None |
|
return _params['parameters'] |
|
|
|
def get_current_parameter(tag=None): |
|
""" |
|
Get current hyper parameters generated by tuner. It returns the same group of hyper parameters as the last |
|
call of get_next_parameter returns. |
|
|
|
Parameters |
|
---------- |
|
tag: str |
|
hyper parameter key |
|
""" |
|
global _params |
|
if _params is None: |
|
return None |
|
if tag is None: |
|
return _params['parameters'] |
|
return _params['parameters'][tag] |
|
|
|
def get_experiment_id(): |
|
""" |
|
Get experiment ID. |
|
|
|
Returns |
|
------- |
|
str |
|
Identifier of current experiment |
|
""" |
|
return _experiment_id |
|
|
|
def get_trial_id(): |
|
""" |
|
Get trial job ID which is string identifier of a trial job, for example 'MoXrp'. In one experiment, each trial |
|
job has an unique string ID. |
|
|
|
Returns |
|
------- |
|
str |
|
Identifier of current trial job which is calling this API. |
|
""" |
|
return _trial_id |
|
|
|
def get_sequence_id(): |
|
""" |
|
Get trial job sequence nubmer. A sequence number is an integer value assigned to each trial job base on the |
|
order they are submitted, incremental starting from 0. In one experiment, both trial job ID and sequence number |
|
are unique for each trial job, they are of different data types. |
|
|
|
Returns |
|
------- |
|
int |
|
Sequence number of current trial job which is calling this API. |
|
""" |
|
return _sequence_id |
|
|
|
_intermediate_seq = 0 |
|
|
|
def report_intermediate_result(metric): |
|
""" |
|
Reports intermediate result to NNI. |
|
|
|
Parameters |
|
---------- |
|
metric: |
|
serializable object. |
|
""" |
|
global _intermediate_seq |
|
assert _params or trial_env_vars.NNI_PLATFORM is None, \ |
|
'nni.get_next_parameter() needs to be called before report_intermediate_result' |
|
metric = to_json({ |
|
'parameter_id': _params['parameter_id'] if _params else None, |
|
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, |
|
'type': 'PERIODICAL', |
|
'sequence': _intermediate_seq, |
|
'value': to_json(metric) |
|
}) |
|
_intermediate_seq += 1 |
|
platform.send_metric(metric) |
|
|
|
def report_final_result(metric): |
|
""" |
|
Reports final result to NNI. |
|
|
|
Parameters |
|
---------- |
|
metric: serializable object |
|
Usually (for built-in tuners to work), it should be a number, or |
|
a dict with key "default" (a number), and any other extra keys. |
|
""" |
|
assert _params or trial_env_vars.NNI_PLATFORM is None, \ |
|
'nni.get_next_parameter() needs to be called before report_final_result' |
|
metric = to_json({ |
|
'parameter_id': _params['parameter_id'] if _params else None, |
|
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, |
|
'type': 'FINAL', |
|
'sequence': 0, |
|
'value': to_json(metric) |
|
}) |
|
platform.send_metric(metric) |
|
|