|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import logging |
|
from functools import partial |
|
|
|
from demo_utils import download_model_folder |
|
import argparse |
|
import subprocess as sp |
|
|
|
|
|
PROJECT_FOLDER = os.path.dirname(os.path.realpath(__file__)) |
|
PYTHON_EXE = 'python' |
|
MODEL_FOLDER = os.path.join(PROJECT_FOLDER, 'models') |
|
DATA_FOLDER = os.path.join(PROJECT_FOLDER, 'data') |
|
|
|
print(f'PROJECT_FOLDER = {PROJECT_FOLDER}') |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--data', type=str, default='dummy', |
|
help='choose from dummy, small and full') |
|
dargs = parser.parse_args() |
|
|
|
assert dargs.data == 'dummy' or dargs.data == 'small' or dargs.data == 'full' , \ |
|
'The specified data option is not support!' |
|
|
|
|
|
logging.basicConfig( |
|
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
|
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
if os.path.exists(MODEL_FOLDER): |
|
print(f'Found existing models folder at {MODEL_FOLDER}, skip creating a new one!') |
|
os.makedirs(MODEL_FOLDER, exist_ok=True) |
|
else: |
|
os.makedirs(MODEL_FOLDER) |
|
|
|
|
|
|
|
|
|
logger.info('Downloading models...') |
|
download_model = partial(download_model_folder, DATA_FOLDER=MODEL_FOLDER) |
|
|
|
|
|
|
|
|
|
target_folder = download_model(model_size='small', dataset='multiref', from_scratch=False) |
|
logger.info('Done!\n') |
|
|
|
|
|
|
|
|
|
|
|
logger.info('Downloading and Extracting Data...') |
|
if dargs.data == 'dummy': |
|
cmd = 'bash prepare4db.sh' |
|
ret = sp.run(cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=DATA_FOLDER) |
|
elif dargs.data == 'small': |
|
myCmd = os.popen('cd reddit_extractor; SIZE=small make -j 8; cd ..').read() |
|
elif dargs.data == 'full': |
|
myCmd = os.popen('cd reddit_extractor; SIZE=full make -j 8; cd ..').read() |
|
else: |
|
raise ValueError('you need to implement your own data type, or use either dummy, small, or full') |
|
|
|
logger.info('Preparing Data...') |
|
data_path = os.path.join(DATA_FOLDER, 'train.tsv') |
|
MAX_LEN = 128 |
|
data_db = f'{data_path[:-4]}.{MAX_LEN}len.db' |
|
if os.path.isdir(data_db): |
|
print(f'{data_db} exists, skip prepro.py') |
|
else: |
|
cmd = ['prepro.py', '--corpus', data_path, '--max_seq_len', f'{MAX_LEN}'] |
|
cmd = ' '.join(cmd) |
|
print(cmd) |
|
ret = sp.run([PYTHON_EXE] + cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=PROJECT_FOLDER) |
|
if ret.returncode != 0: |
|
print(f'error occurred, {ret.stdout}') |
|
sys.exit(ret.returncode) |
|
logger.info('Done!\n') |
|
|
|
|
|
|
|
|
|
logger.info('Generating training CMD!') |
|
logger.info('If there is any problem, please copy (modify) and run command below') |
|
logger.info('#########################################################################') |
|
train_cmd = 'LSP_train.py' |
|
args = [ |
|
'--model_name_or_path', target_folder, |
|
'--init_checkpoint', os.path.join(target_folder, 'pytorch_model.bin'), |
|
'--train_input_file', data_db , |
|
'--eval_input_file', './data/dummy_data.tsv', |
|
'--output_dir', os.path.join(MODEL_FOLDER, 'output_model'), |
|
'--seed', '42', |
|
'--max_seq_length', '128', |
|
'--train_batch_size', '512', |
|
'--gradient_accumulation_steps', '8', |
|
'--eval_batch_size', '64', |
|
'--learning_rate', '1e-5', |
|
'--num_optim_steps', '10000', |
|
'--valid_step', '5000', |
|
'--warmup_steps', '4000', |
|
'--normalize_data', 'true', |
|
'--fp16', 'true', |
|
'--lr_schedule', 'noam', |
|
'--loss_scale', '0.0', |
|
'--no_token_id', 'true', |
|
'--pbar', 'true' |
|
] |
|
|
|
arg = ' '.join(args) |
|
train_cmd = train_cmd + ' ' + arg |
|
print(PYTHON_EXE + ' ' +train_cmd) |
|
logger.info('#########################################################################') |
|
with open('./output.log', 'wb') as f: |
|
process = sp.Popen([PYTHON_EXE] + train_cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=PROJECT_FOLDER) |
|
for line in iter(process.stdout.readline, b''): |
|
sys.stdout.write(line.decode(sys.stdout.encoding)) |
|
f.write(line) |
|
logger.info('Done!\n') |
|
|