|
import atexit |
|
import sacred |
|
import argparse |
|
import time |
|
import math |
|
import subprocess |
|
import shutil |
|
import os |
|
import json |
|
import threading |
|
import requests |
|
import glob |
|
from configs import fetch_model_params |
|
import socket |
|
import subprocess |
|
import queue |
|
import sys |
|
import signal |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--tpu', type=str, required=True) |
|
parser.add_argument('--model', type=str, required=True) |
|
parser.add_argument('--experiment_name', type=str, required=True) |
|
parser.add_argument('--steps_per_checkpoint', type=int, default=5000) |
|
parser.add_argument('--autostack', action="store_false") |
|
parser.add_argument('--auto_layout', action="store_true") |
|
parser.add_argument('--auto_layout_and_mesh_shape', action="store_true") |
|
parser.add_argument('--new', action='store_true') |
|
parser.add_argument('--test', action='store_true') |
|
parser.add_argument('--eval', action='store_true') |
|
parser.add_argument('--predict', action='store_true') |
|
parser.add_argument('--no_delete_tpu', action='store_true') |
|
parser.add_argument('--initial_heartbeat_timeout', type=int, default=7200) |
|
parser.add_argument('--heartbeat_timeout', type=int, default=1800) |
|
args = parser.parse_args() |
|
|
|
params = fetch_model_params(args.model) |
|
|
|
ex = sacred.Experiment(args.experiment_name) |
|
ex.observers.append(sacred.observers.QueuedMongoObserver(url='127.0.0.1:27017', db_name='db', username='user', password='password')) |
|
|
|
|
|
def get_open_port(lo=8000, hi=8100): |
|
for i in range(lo, hi): |
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
|
if s.connect_ex(('localhost', i)) != 0: |
|
return i |
|
|
|
|
|
def train_thread(args, tpu, id, q): |
|
print('starting training on', tpu) |
|
|
|
|
|
opts = '' |
|
for flag in ['auto_layout', 'auto_layout_and_mesh_shape', 'new', 'test', 'predict', 'eval', ]: |
|
if args.__getattribute__(flag): |
|
opts += ' --' + flag |
|
|
|
for flag in ['autostack', ]: |
|
if not args.__getattribute__(flag): |
|
opts += ' --' + flag |
|
|
|
cmd = "python3 main.py --tpu {tpu} --model run_configs/config_{id}.json --steps_per_checkpoint {steps_per_checkpoint} {opts} --sacred_id {run_id}".format(tpu=tpu, id=id, steps_per_checkpoint=args.steps_per_checkpoint, opts=opts, run_id=id) |
|
print('Running:', cmd) |
|
proc = subprocess.Popen(cmd, shell=True) |
|
|
|
|
|
while proc.poll() is None: |
|
time.sleep(60) |
|
try: |
|
nq, *nargs = q.get_nowait() |
|
if nq == 'kill': |
|
print('train thread recieved kill signal from logging thread') |
|
|
|
proc.terminate() |
|
|
|
time.sleep(60) |
|
|
|
|
|
if proc.poll() is None: |
|
print('SIGTERM not successful, sending SIGKILL') |
|
proc.kill() |
|
|
|
except queue.Empty: |
|
pass |
|
|
|
print('exited training!') |
|
if proc.returncode == 0: |
|
print('exited gracefully') |
|
os.kill(os.getpid(), signal.SIGINT) |
|
return |
|
|
|
if args.no_delete_tpu: |
|
print('recreate done, exiting train_thread - not killing tpu!') |
|
return |
|
print("Recreating {} in 60sec...".format(tpu)) |
|
time.sleep(60) |
|
os.system("pu recreate {} --yes --retry 3600 --retry-randomness 1.5".format(tpu)) |
|
print('recreate done, exiting train_thread') |
|
|
|
|
|
while True: |
|
try: |
|
q.get_nowait() |
|
print('dropped request in queue after pu recreate') |
|
except queue.Empty: |
|
break |
|
|
|
|
|
def get_json(uri, params=None, timeout=15): |
|
resp = requests.get(uri, params=params, timeout=timeout) |
|
resp.raise_for_status() |
|
return resp.json() |
|
|
|
|
|
def get_tag_sets(base_uri): |
|
j = get_json(f'{base_uri}/data/plugin/scalars/tags', {'experiment': ''}) |
|
assert isinstance(j, dict) |
|
return { |
|
run: j[run].keys() |
|
for run in j.keys() |
|
} |
|
|
|
|
|
def get_scalar_data(base_uri, run, tag): |
|
j = get_json(f'{base_uri}/data/plugin/scalars/scalars', {'experiment': '', 'run': run, 'tag': tag}) |
|
assert isinstance(j, list) |
|
return j |
|
|
|
|
|
def get_run_data(port): |
|
base_uri = f'http://localhost:{port}/' |
|
r = {} |
|
try: |
|
tag_sets = get_tag_sets(base_uri) |
|
runs = tag_sets.keys() |
|
if '.' in runs: |
|
if 'loss' in tag_sets['.']: |
|
r['loss'] = get_scalar_data(base_uri, '.', 'loss') |
|
if 'eval' in runs: |
|
if 'loss' in tag_sets['eval']: |
|
r['val_loss'] = get_scalar_data(base_uri, 'eval', 'loss') |
|
if 'eval_lambada' in runs: |
|
if 'lambada_acc' in tag_sets['eval_lambada']: |
|
r['lambada_acc'] = get_scalar_data(base_uri, 'eval_lambada', 'lambada_acc') |
|
if 'lambada_log_ppl' in tag_sets['eval_lambada']: |
|
r['lambada_ppl'] = [ |
|
[t, s, math.exp(lp)] |
|
for [t, s, lp] in get_scalar_data(base_uri, 'eval_lambada', 'lambada_log_ppl') |
|
] |
|
except: |
|
import traceback |
|
traceback.print_exc() |
|
return r |
|
|
|
|
|
@ex.main |
|
def main(_run): |
|
print('Starting run', _run._id) |
|
print('experiment main invoked with argv:', " ".join(sys.argv)) |
|
print('WARNING: please remember to remove old metric log files from the model directory.') |
|
|
|
os.makedirs('run_configs', exist_ok=True) |
|
shutil.copy(args.model if args.model.endswith('.json') else 'configs/{}.json'.format(args.model), 'run_configs/config_{}.json'.format(_run._id)) |
|
|
|
tensorboard_port = get_open_port() |
|
print('Tensorboard at port:', tensorboard_port) |
|
print('Tensorboard url: ', 'http://eleutherai.bmk.sh:'+ str(tensorboard_port)) |
|
os.system("screen -S tensorboard_{} -d -m bash -c 'tensorboard --logdir {} --port {} --bind_all --reload_multifile=true || tensorboard --logdir {} --port {} --reload_multifile=true'".format(_run._id, params["model_path"], tensorboard_port,params["model_path"], tensorboard_port,)) |
|
atexit.register(goodbye, _run._id) |
|
|
|
curr_step = {} |
|
seen_predictions = set() |
|
|
|
heartbeat_timeout = args.initial_heartbeat_timeout * 2 |
|
while True: |
|
last_tb_log_time = time.time() |
|
start_time = time.time() |
|
q = queue.Queue() |
|
trainthd = threading.Thread(target=train_thread, args=(args, args.tpu, _run._id, q)) |
|
trainthd.start() |
|
|
|
while trainthd.is_alive(): |
|
time.sleep(60) |
|
|
|
if start_time + args.initial_heartbeat_timeout < time.time(): |
|
|
|
heartbeat_timeout = args.heartbeat_timeout |
|
|
|
print('Polling tensorboard for metrics...') |
|
data = get_run_data(tensorboard_port) |
|
for k in data.keys(): |
|
for ts, step, val in data[k]: |
|
if step <= curr_step.get(k, -1): |
|
continue |
|
_run.log_scalar(k, val, step) |
|
if k == 'loss': |
|
_run.log_scalar('tb_ts', ts, step) |
|
print('Logged to sacred: step={},loss={},tb_ts={}'.format(step, val, ts)) |
|
|
|
|
|
last_tb_log_time = time.time() |
|
|
|
curr_step[k] = step |
|
|
|
for f in glob.glob('predictions_{}_*'.format(_run._id)): |
|
if f in seen_predictions: |
|
continue |
|
print('collecting prediction file', f) |
|
ex.add_artifact(f) |
|
|
|
seen_predictions.add(f) |
|
|
|
|
|
if os.path.exists(f'eval_{_run._id}.jsonl'): |
|
with open(f'eval_{_run._id}.jsonl') as fh: |
|
for line in fh: |
|
ob = json.loads(line) |
|
val_step = ob['global_step'] |
|
val_task = ob['task'] |
|
for metr in ob.keys(): |
|
k = 'fs.' + val_task + '.' + metr |
|
if metr in ['task', 'global_step']: continue |
|
if val_step <= curr_step.get(k, -1): continue |
|
_run.log_scalar(k, ob[metr], val_step) |
|
curr_step[k] = val_step |
|
|
|
if time.time() - last_tb_log_time > heartbeat_timeout: |
|
|
|
q.put(('kill',)) |
|
|
|
|
|
while trainthd.is_alive(): |
|
print('logging thread waiting for killing stalled run and for tpu recreate to finish') |
|
time.sleep(60) |
|
|
|
|
|
heartbeat_timeout = args.initial_heartbeat_timeout |
|
last_tb_log_time = time.time() |
|
|
|
|
|
if args.no_delete_tpu: |
|
break |
|
|
|
|
|
def goodbye(id): |
|
print("You are now leaving the Python sector.") |
|
print("Sie verlassen den pythonischen Sektor.") |
|
|
|
os.system("screen -S tensorboard_{} -X quit".format(id)) |
|
|
|
|
|
if __name__ == '__main__': |
|
for file in glob.glob("**/*", recursive=True): |
|
if file.split('.')[-1] in ['py']: |
|
print('Adding', file, 'to sacred') |
|
ex.add_source_file(file) |
|
|
|
ex.add_config({ |
|
'tpu_name': args.tpu, |
|
**params |
|
}) |
|
|
|
ex.run() |
|
|