"""Script to run sagemaker training jobs for whisper finetuning jobs.""" import logging import os from pprint import pprint import boto3 import sagemaker from sagemaker.huggingface import HuggingFace TEST = True test_sm_instances = { "ml.g4dn.xlarge": { "num_instances": 1, "num_gpus": 1 } } full_sm_instances = { "ml.g4dn.xlarge": { "num_instances": 1, "num_gpus": 1 } } sm_instances = test_sm_instances if TEST else full_sm_instances ENTRY_POINT = "run_sm.py" RUN_SCRIPT = "test_run.sh" if TEST else "run.sh" IMAGE_URI = "116817510867.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:whisper-finetuning-0223e276db78adf4ea4dc5f874793cb2" if IMAGE_URI is None: raise ValueError("IMAGE_URI variable not set, please update script.") iam = boto3.client("iam") os.environ["AWS_DEFAULT_REGION"] = "eu-west-1" role = iam.get_role(RoleName="whisper-sagemaker-role")["Role"]["Arn"] _ = sagemaker.Session() # not sure if this is necessary sm_client = boto3.client("sagemaker") def set_creds(): with open("creds.txt") as f: creds = f.readlines() for line in creds: key, value = line.split("=") os.environ[key] = value.replace("\n", "") def parse_run_script(): """Parse the run script to get the hyperparameters.""" hyperparameters = {} with open(RUN_SCRIPT, "r") as f: for line in f.readlines(): if line.startswith("python"): continue line = line \ .replace("\\", "") \ .replace("\t", "") \ .replace("--", "") \ .replace(" \n", "") \ .replace("\n", "") \ .replace('"', "") line = line.split("=") key = str(line[0]) try: value = line[1] except IndexError: value = "True" hyperparameters[key] = value hyperparameters["model_index_name"] = f'"{hyperparameters["model_index_name"]}"' return hyperparameters set_creds() # hyperparameters = parse_run_script() # pprint(hyperparameters) hf_token = os.environ.get("HF_TOKEN") if hf_token is None: raise ValueError("HF_TOKEN environment variable not set") env_vars = { "HF_TOKEN": hf_token, "EMAIL_ADDRESS": os.environ.get("EMAIL_ADDRESS"), "EMAIL_PASSWORD": os.environ.get("EMAIL_PASSWORD"), "WANDB_TOKEN": os.environ.get("WANDB_TOKEN") } pprint(env_vars) repo = f"https://huggingface.co/marinone94/{os.getcwd().split('/')[-1]}" hyperparameters = { "repo": repo, "entrypoint": RUN_SCRIPT } for sm_instance_name, sm_instance_values in sm_instances.items(): num_instances: int = \ int(sm_instance_values["num_instances"]) num_gpus: int = \ int(sm_instance_values["num_gpus"]) try: # instantiate and fit the sm Estimator hf_estimator = HuggingFace( entry_point=ENTRY_POINT, instance_type=sm_instance_name, instance_count=num_instances, role=role, py_version="py38", image_uri=IMAGE_URI, hyperparameters=hyperparameters, environment=env_vars, git_config={"repo": repo, "branch": "main"}, ) hf_estimator.fit() break except sm_client.exceptions.ResourceLimitExceeded as e_0: logging.warning(f"Instance error {e_0}\nRetrying with new instance")