File size: 3,548 Bytes
1c300e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Script to run sagemaker training jobs for whisper finetuning jobs."""

import logging
import os
from pprint import pprint

import boto3
import sagemaker
from sagemaker.huggingface import HuggingFace


TEST = True


test_sm_instances = {
    "ml.g4dn.2xlarge":
        {
            "num_instances": 1,
            "num_gpus": 1
        }
}

full_sm_instances = {
    "ml.g4dn.xlarge":
        {
            "num_instances": 1,
            "num_gpus": 1
        }
}

sm_instances = test_sm_instances if TEST else full_sm_instances

ENTRY_POINT = "run_speech_recognition_seq2seq_streaming.py"
RUN_SCRIPT = "test_run.sh" if TEST else "run.sh"
IMAGE_URI = "116817510867.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:whisper-finetuning-0223e276db78adf4ea4dc5f874793cb2"
if IMAGE_URI is None:
    raise ValueError("IMAGE_URI variable not set, please update script.")

iam = boto3.client("iam")
os.environ["AWS_DEFAULT_REGION"] = "eu-west-1"
role = iam.get_role(RoleName="whisper-sagemaker-role")["Role"]["Arn"]
_ = sagemaker.Session()  # not sure if this is necessary
sm_client = boto3.client("sagemaker")


def set_creds():
    with open("creds.txt") as f:
        creds = f.readlines()
        for line in creds:
            key, value = line.split("=")
            os.environ[key] = value.replace("\n", "")


def parse_run_script():
    """Parse the run script to get the hyperparameters."""
    hyperparameters = {}
    with open(RUN_SCRIPT, "r") as f:
        for line in f.readlines():
            if line.startswith("python"):
                continue
            line = line \
                .replace("\\", "") \
                .replace("\t", "") \
                .replace("--", "") \
                .replace(" \n", "") \
                .replace("\n", "") \
                .replace('"', "")
            line = line.split("=")
            # remove '\t--'
            key = str(line[0])
            assert 0 < len(key) < 256, f"Key {key} is not allowed, len must be between 0 and 256"
            try:
                value = line[1]
            except IndexError:
                value = "True"
            hyperparameters[key] = value
    hyperparameters["model_index_name"] = f'"{hyperparameters["model_index_name"]}"'
    return hyperparameters


set_creds()
hyperparameters = parse_run_script()
pprint(hyperparameters)

hf_token = os.environ.get("HF_TOKEN")
if hf_token is None:
    raise ValueError("HF_TOKEN environment variable not set")

env_vars = {
    "HF_TOKEN": hf_token,
    "EMAIL_ADDRESS": os.environ.get("EMAIL_ADDRESS"),
    "EMAIL_PASSWORD": os.environ.get("EMAIL_PASSWORD"),
    "WANDB_TOKEN": os.environ.get("WANDB_TOKEN")
}
pprint(env_vars)

for sm_instance_name, sm_instance_values in sm_instances.items():
        num_instances: int = \
            int(sm_instance_values["num_instances"])
        num_gpus: int = \
            int(sm_instance_values["num_gpus"])
        try:
            # instantiate and fit the sm Estimator
            hf_estimator = HuggingFace(
                entry_point=ENTRY_POINT,
                instance_type=sm_instance_name,
                instance_count=num_instances,
                role=role,
                py_version="py38",
                image_uri=IMAGE_URI,
                hyperparameters=hyperparameters,
                environment=env_vars,
            )
            hf_estimator.fit()
            break
        except sm_client.exceptions.ResourceLimitExceeded as e_0:
            logging.warning(f"Instance error {e_0}\nRetrying with new instance")