LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame
1.19 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from typing import NewType, Any
import nni
from .serializer import json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
_advisor: 'RetiariiAdvisor' = None
def get_advisor() -> 'RetiariiAdvisor':
global _advisor
assert _advisor is not None
return _advisor
def register_advisor(advisor: 'RetiariiAdvisor'):
global _advisor
assert _advisor is None
_advisor = advisor
def send_trial(parameters: dict) -> int:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return get_advisor().send_trial(parameters)
def receive_trial_parameters() -> dict:
"""
Received a new trial. Executed on trial end.
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
"""
params = nni.get_next_parameter()
params = json_loads(json.dumps(params))
return params
def get_experiment_id() -> str:
return nni.get_experiment_id()