LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame
4.41 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .utils import to_json
from .runtime.env_vars import trial_env_vars
from .runtime import platform
__all__ = [
'get_next_parameter',
'get_current_parameter',
'report_intermediate_result',
'report_final_result',
'get_experiment_id',
'get_trial_id',
'get_sequence_id'
]
_params = None
_experiment_id = platform.get_experiment_id()
_trial_id = platform.get_trial_id()
_sequence_id = platform.get_sequence_id()
def get_next_parameter():
"""
Get the hyper paremeters generated by tuner. For a multiphase experiment, it returns a new group of hyper
parameters at each call of get_next_parameter. For a non-multiphase (multiPhase is not configured or set to False)
experiment, it returns hyper parameters only on the first call for each trial job, it returns None since second call.
This API should be called only once in each trial job of an experiment which is not specified as multiphase.
Returns
-------
dict
A dict object contains the hyper parameters generated by tuner, the keys of the dict are defined in
search space. Returns None if no more hyper parameters can be generated by tuner.
"""
global _params
_params = platform.get_next_parameter()
if _params is None:
return None
return _params['parameters']
def get_current_parameter(tag=None):
"""
Get current hyper parameters generated by tuner. It returns the same group of hyper parameters as the last
call of get_next_parameter returns.
Parameters
----------
tag: str
hyper parameter key
"""
global _params
if _params is None:
return None
if tag is None:
return _params['parameters']
return _params['parameters'][tag]
def get_experiment_id():
"""
Get experiment ID.
Returns
-------
str
Identifier of current experiment
"""
return _experiment_id
def get_trial_id():
"""
Get trial job ID which is string identifier of a trial job, for example 'MoXrp'. In one experiment, each trial
job has an unique string ID.
Returns
-------
str
Identifier of current trial job which is calling this API.
"""
return _trial_id
def get_sequence_id():
"""
Get trial job sequence nubmer. A sequence number is an integer value assigned to each trial job base on the
order they are submitted, incremental starting from 0. In one experiment, both trial job ID and sequence number
are unique for each trial job, they are of different data types.
Returns
-------
int
Sequence number of current trial job which is calling this API.
"""
return _sequence_id
_intermediate_seq = 0
def overwrite_intermediate_seq(value):
"""
Overwrite intermediate sequence value.
Parameters
----------
value:
int
"""
assert isinstance(value, int)
global _intermediate_seq
_intermediate_seq = value
def report_intermediate_result(metric):
"""
Reports intermediate result to NNI.
Parameters
----------
metric:
serializable object.
"""
global _intermediate_seq
assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric = to_json({
'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL',
'sequence': _intermediate_seq,
'value': to_json(metric)
})
_intermediate_seq += 1
platform.send_metric(metric)
def report_final_result(metric):
"""
Reports final result to NNI.
Parameters
----------
metric: serializable object
Usually (for built-in tuners to work), it should be a number, or
a dict with key "default" (a number), and any other extra keys.
"""
assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_final_result'
metric = to_json({
'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL',
'sequence': 0,
'value': to_json(metric)
})
platform.send_metric(metric)