LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame
6.03 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import os
from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url, import_data_url
from .config_utils import Config
from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import get_experiment_port, get_config_filename, detect_process
from .launcher_utils import parse_time
from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA
def validate_digit(value, start, end):
'''validate if a digit is valid'''
if not str(value).isdigit() or int(value) < start or int(value) > end:
raise ValueError('value (%s) must be a digit from %s to %s' % (value, start, end))
def validate_file(path):
'''validate if a file exist'''
if not os.path.exists(path):
raise FileNotFoundError('%s is not a valid file path' % path)
def validate_dispatcher(args):
'''validate if the dispatcher of the experiment supports importing data'''
nni_config = Config(get_config_filename(args)).get_config('experimentConfig')
if nni_config.get('tuner') and nni_config['tuner'].get('builtinTunerName'):
dispatcher_name = nni_config['tuner']['builtinTunerName']
elif nni_config.get('advisor') and nni_config['advisor'].get('builtinAdvisorName'):
dispatcher_name = nni_config['advisor']['builtinAdvisorName']
else: # otherwise it should be a customized one
return
if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA:
if dispatcher_name in TUNERS_NO_NEED_TO_IMPORT_DATA:
print_warning("There is no need to import data for %s" % dispatcher_name)
exit(0)
else:
print_error("%s does not support importing addtional data" % dispatcher_name)
exit(1)
def load_search_space(path):
'''load search space content'''
content = json.dumps(get_json_content(path))
if not content:
raise ValueError('searchSpace file should not be empty')
return content
def get_query_type(key):
'''get update query type'''
if key == 'trialConcurrency':
return '?update_type=TRIAL_CONCURRENCY'
if key == 'maxExecDuration':
return '?update_type=MAX_EXEC_DURATION'
if key == 'searchSpace':
return '?update_type=SEARCH_SPACE'
if key == 'maxTrialNum':
return '?update_type=MAX_TRIAL_NUM'
def update_experiment_profile(args, key, value):
'''call restful server to update experiment profile'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
if response and check_response(response):
experiment_profile = json.loads(response.text)
experiment_profile['params'][key] = value
response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), REST_TIME_OUT)
if response and check_response(response):
return response
else:
print_error('Restful server is not running...')
return None
def update_searchspace(args):
validate_file(args.filename)
content = load_search_space(args.filename)
args.port = get_experiment_port(args)
if args.port is not None:
if update_experiment_profile(args, 'searchSpace', content):
print_normal('Update %s success!' % 'searchSpace')
else:
print_error('Update %s failed!' % 'searchSpace')
def update_concurrency(args):
validate_digit(args.value, 1, 1000)
args.port = get_experiment_port(args)
if args.port is not None:
if update_experiment_profile(args, 'trialConcurrency', int(args.value)):
print_normal('Update %s success!' % 'concurrency')
else:
print_error('Update %s failed!' % 'concurrency')
def update_duration(args):
#parse time, change time unit to seconds
args.value = parse_time(args.value)
args.port = get_experiment_port(args)
if args.port is not None:
if update_experiment_profile(args, 'maxExecDuration', int(args.value)):
print_normal('Update %s success!' % 'duration')
else:
print_error('Update %s failed!' % 'duration')
def update_trialnum(args):
validate_digit(args.value, 1, 999999999)
if update_experiment_profile(args, 'maxTrialNum', int(args.value)):
print_normal('Update %s success!' % 'trialnum')
else:
print_error('Update %s failed!' % 'trialnum')
def import_data(args):
'''import additional data to the experiment'''
validate_file(args.filename)
validate_dispatcher(args)
content = load_search_space(args.filename)
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
running, _ = check_rest_server_quick(rest_port)
if not running:
print_error('Restful server is not running')
return
args.port = rest_port
if args.port is not None:
if import_data_to_restful_server(args, content):
pass
else:
print_error('Import data failed!')
def import_data_to_restful_server(args, content):
'''call restful server to import data to the experiment'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT)
if response and check_response(response):
return response
else:
print_error('Restful server is not running...')
return None