File size: 1,186 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from typing import NewType, Any
import nni
from .serializer import json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
_advisor: 'RetiariiAdvisor' = None
def get_advisor() -> 'RetiariiAdvisor':
global _advisor
assert _advisor is not None
return _advisor
def register_advisor(advisor: 'RetiariiAdvisor'):
global _advisor
assert _advisor is None
_advisor = advisor
def send_trial(parameters: dict) -> int:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return get_advisor().send_trial(parameters)
def receive_trial_parameters() -> dict:
"""
Received a new trial. Executed on trial end.
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
"""
params = nni.get_next_parameter()
params = json_loads(json.dumps(params))
return params
def get_experiment_id() -> str:
return nni.get_experiment_id()
|