chriskasparaws commited on
Commit
c5aa5a6
·
verified ·
1 Parent(s): 8988d02

Upload 10 files

Browse files
assets.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"class_name": "MultiModalPredictor", "column_types": {"image": "image_path", "label": "categorical"}, "label_column": "label", "problem_type": "multiclass", "eval_metric_name": "accuracy", "validation_metric_name": "accuracy", "output_shape": 4, "classes": null, "save_path": "/opt/ml/model", "pretrained_path": null, "version": "0.6.1"}
code/constants.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################################
2
+ # ENVIRONMENT VARIABLES #
3
+ ########################################################################################################################
4
+ NUM_GPU = "NUM_GPU"
5
+ SAGEMAKER_INFERENCE_OUTPUT = "SAGEMAKER_INFERENCE_OUTPUT"
6
+
7
+ ########################################################################################################################
8
+ # OUTPUT CONSTANTS #
9
+ ########################################################################################################################
10
+ PROBABILITY = "probability"
11
+ PROBABILITIES = "probabilities"
12
+ PREDICTED_LABEL = "predicted_label"
13
+ LABELS = "labels"
14
+
15
+ ########################################################################################################################
16
+ # DATA FORMAT CONSTANTS #
17
+ ########################################################################################################################
18
+ BYTE_ARRAY_FORMAT = "application/x-image"
19
+ JPEG_FORMAT = "image/jpeg"
20
+ PNG_FORMAT = "image/png"
21
+ JSON_FORMAT = "application/json"
22
+ CSV_FORMAT = "text/csv"
23
+ COMMA_DELIMITER = ","
24
+ BRACKET_FORMATTER = '"{}"'
25
+ ALLOWED_INPUT_FORMATS = [BYTE_ARRAY_FORMAT, JPEG_FORMAT, PNG_FORMAT]
26
+ ALLOWED_OUTPUT_FORMATS = [JSON_FORMAT, CSV_FORMAT]
27
+
28
+ ########################################################################################################################
29
+ # INFERENCE DATA CONSTANTS #
30
+ ########################################################################################################################
31
+ IMAGE_COLUMN_NAME = "image"
code/multimodal_serve.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+ import os
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from autogluon.multimodal import MultiModalPredictor
8
+ from constants import (
9
+ ALLOWED_INPUT_FORMATS,
10
+ ALLOWED_OUTPUT_FORMATS,
11
+ BRACKET_FORMATTER,
12
+ COMMA_DELIMITER,
13
+ IMAGE_COLUMN_NAME,
14
+ JSON_FORMAT,
15
+ LABELS,
16
+ NUM_GPU,
17
+ PREDICTED_LABEL,
18
+ PROBABILITIES,
19
+ PROBABILITY,
20
+ SAGEMAKER_INFERENCE_OUTPUT,
21
+ )
22
+ from utils import infer_type_and_cast_value
23
+
24
+ INFERENCE_OUTPUT = (
25
+ infer_type_and_cast_value(os.getenv(SAGEMAKER_INFERENCE_OUTPUT))
26
+ if SAGEMAKER_INFERENCE_OUTPUT in os.environ
27
+ else [PREDICTED_LABEL]
28
+ )
29
+ NUM_GPUS = infer_type_and_cast_value(os.getenv(NUM_GPU))
30
+
31
+
32
+ def generate_single_csv_line_inference_selection(data):
33
+ """Generate a single csv line response.
34
+
35
+ :param data: list of output generated from the model
36
+ :return: csv line for the predictions
37
+ """
38
+ contents: str
39
+ for single_prediction in data:
40
+ contents = (
41
+ BRACKET_FORMATTER.format(single_prediction)
42
+ if isinstance(single_prediction, list)
43
+ else str(single_prediction)
44
+ )
45
+ return contents
46
+
47
+
48
+ def model_fn(model_dir):
49
+ """Load model from previously saved artifact.
50
+
51
+ :param model_dir: local path to the model directory
52
+ :return: loaded model
53
+ """
54
+ predictor = MultiModalPredictor.load(model_dir)
55
+ if NUM_GPUS is not None:
56
+ predictor._config.env.num_gpus = NUM_GPUS
57
+
58
+ return predictor
59
+
60
+
61
+ def convert_to_json_compatible_type(value):
62
+ """Convert the input value to a JSON compatible type.
63
+
64
+ :param value: input value
65
+ :return: JSON compatible value
66
+ """
67
+ string_value = "{}".format(value)
68
+ try:
69
+ return ast.literal_eval(string_value)
70
+ except Exception:
71
+ return string_value
72
+
73
+
74
+ def transform_fn(model, request_body, input_content_type, output_content_type):
75
+ """Transform function for serving inference requests.
76
+
77
+ If INFERENCE_OUTPUT is provided, then the predictions are generated in the requested format and concatenated in the
78
+ same order. Otherwise, prediction_labels are generated by default.
79
+
80
+ :param model: loaded model
81
+ :param request_body: request body
82
+ :param input_content_type: content type of the input
83
+ :param output_content_type: content type of the response
84
+ :return: prediction response
85
+ """
86
+ if input_content_type.lower() not in ALLOWED_INPUT_FORMATS:
87
+ raise Exception(
88
+ f"{input_content_type} input content type not supported. Supported formats are {ALLOWED_INPUT_FORMATS}"
89
+ )
90
+
91
+ if output_content_type.lower() not in ALLOWED_OUTPUT_FORMATS:
92
+ raise Exception(
93
+ f"{output_content_type} output content type not supported. Supported formats are {ALLOWED_OUTPUT_FORMATS}"
94
+ )
95
+
96
+ data = pd.DataFrame({IMAGE_COLUMN_NAME: [request_body]})
97
+ result_dict = dict()
98
+
99
+ result = []
100
+ inference_output_list = (
101
+ INFERENCE_OUTPUT if isinstance(INFERENCE_OUTPUT, list) else [INFERENCE_OUTPUT]
102
+ )
103
+ for output_type in inference_output_list:
104
+ if output_type == PREDICTED_LABEL:
105
+ prediction = model.predict(data)
106
+ result_dict[PREDICTED_LABEL] = convert_to_json_compatible_type(prediction.squeeze())
107
+ elif output_type == PROBABILITIES:
108
+ predict_probs = model.predict_proba(data)
109
+ prediction = predict_probs.to_numpy()
110
+ result_dict[PROBABILITIES] = predict_probs.squeeze().tolist()
111
+ elif output_type == LABELS:
112
+ labels = model.class_labels
113
+ prediction = np.array([labels]).astype("str")
114
+ result_dict[LABELS] = labels.tolist()
115
+ else:
116
+ predict_probabilities = model.predict_proba(data).to_numpy()
117
+ prediction = np.max(predict_probabilities, axis=1)
118
+ result_dict[PROBABILITY] = prediction.squeeze().tolist()
119
+ result.append(generate_single_csv_line_inference_selection(prediction.tolist()))
120
+ response = COMMA_DELIMITER.join(result)
121
+
122
+ if output_content_type == JSON_FORMAT:
123
+ response = json.dumps(result_dict)
124
+
125
+ return response, output_content_type
code/utils.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Optional
4
+
5
+
6
+ def is_float(value: str):
7
+ """Check if the input value is float.
8
+
9
+ :param value: value
10
+ :return: True / False based on whether the value is float or not
11
+ """
12
+ try:
13
+ float(value)
14
+ except (TypeError, ValueError):
15
+ return False
16
+ else:
17
+ return True
18
+
19
+
20
+ def is_int(value: str):
21
+ """Check if the input value is int.
22
+
23
+ :param value: value
24
+ :return: True / False based on whether the value is int or not
25
+ """
26
+ try:
27
+ float_value = float(value)
28
+ int_value = int(value)
29
+ except (TypeError, ValueError):
30
+ return False
31
+ else:
32
+ return float_value == int_value
33
+
34
+
35
+ def is_list(value: str):
36
+ """Check if the input value is list.
37
+
38
+ Currently, we support list in the following format -
39
+
40
+ 1. "a,b,c,d"
41
+ 2. "1,2,3,4"
42
+
43
+ :param value: value
44
+ :return: True / False based on whether the value is list or not
45
+ """
46
+ return "," in value
47
+
48
+
49
+ def is_boolean(value: str):
50
+ """Check if the input value is boolean.
51
+
52
+ :param value: value
53
+ :return: True / False based on whether the value is boolean or not
54
+ """
55
+ return value.lower() in ["true", "false"]
56
+
57
+
58
+ def parse_boolean(value: str):
59
+ """Parse the boolean value.
60
+
61
+ :param value: value
62
+ :return: Parsed boolean values
63
+ """
64
+ return True if value.lower() == "true" else False
65
+
66
+
67
+ def parse_list(value: str):
68
+ """Parse the list value.
69
+
70
+ Currently, we support list in the following format -
71
+
72
+ 1. "a,b,c,d"
73
+ 2. "1,2,3,4"
74
+
75
+ :param value: value
76
+ :return: Parsed list
77
+ """
78
+ values = value.split(",")
79
+ clean_values = [v.strip(" \"'") for v in values]
80
+ return [infer_type_and_cast_value(v) for v in clean_values]
81
+
82
+
83
+ def infer_type_and_cast_value(value: Optional[str]):
84
+ """Infer the type of value and casts it accordingly.
85
+
86
+ :param value: value
87
+ :return: casted value
88
+ """
89
+ if value is None:
90
+ return value
91
+ elif is_int(value):
92
+ return int(value)
93
+ elif is_float(value):
94
+ return float(value)
95
+ elif is_boolean(value):
96
+ return parse_boolean(value)
97
+ elif is_list(value):
98
+ return parse_list(value)
99
+ else:
100
+ return value
101
+
102
+
103
+ def __setup_fault_handler(file_path: str = None):
104
+ """Set up fault handler.
105
+
106
+ :param file_path: path to the error file
107
+ :return:
108
+ """
109
+ try:
110
+ import faulthandler
111
+
112
+ if not faulthandler.is_enabled():
113
+ if file_path is not None:
114
+ faulthandler.enable(os.open(file_path, os.O_APPEND), all_threads=True)
115
+ else:
116
+ faulthandler.enable()
117
+ except ImportError:
118
+ logging.warn("No faulthandler found")
119
+
120
+
121
+ def get_error_logger():
122
+ """Return the logger from logging for id ERROR_LOGGER_ID ."""
123
+ return logging.getLogger("error")
124
+
125
+
126
+ def setup_trusted_log(error_volume: str, error_file_path: str):
127
+ """Set up trusted logs for the script.
128
+
129
+ :param error_volume: volume where the errors should be written
130
+ :param error_file_path: path to the error_file
131
+ :return: trusted logger
132
+ """
133
+ trusted_log_formatter = logging.Formatter(
134
+ "[%(asctime)s %(levelname)s %(thread)d %(filename)s:%(lineno)d] %(message)s",
135
+ datefmt="%m/%d/%Y %H:%M:%S",
136
+ )
137
+ os.makedirs(error_volume, exist_ok=True)
138
+ trusted_log_handler = logging.FileHandler(error_file_path)
139
+ __setup_fault_handler(file_path=error_file_path)
140
+ trusted_log_handler.setFormatter(trusted_log_formatter)
141
+ trusted_log_handler.setLevel(logging.INFO)
142
+
143
+ error_logger = get_error_logger()
144
+ error_logger.addHandler(trusted_log_handler)
145
+ error_logger.propagate = False
146
+
147
+
148
+ def write_trusted_log_info(private_info_message):
149
+ """Write private info message to the trusted log channel.
150
+
151
+ :param private_info_message: private trusted log message
152
+ :return:
153
+ """
154
+ trusted_logger = get_error_logger()
155
+ trusted_logger.info(private_info_message)
156
+
157
+
158
+ def write_failure_reason(failure_reason_text, file_path):
159
+ """Write failure reason to failure file.
160
+
161
+ :param failure_reason_text: reason for failure
162
+ :param file_path: path to the failure file
163
+ :return:
164
+ """
165
+ if not os.path.exists(os.path.dirname(file_path)):
166
+ os.makedirs(os.path.dirname(file_path))
167
+ with open(file_path, "w") as f:
168
+ f.write(failure_reason_text)
169
+
170
+
171
+ def write_trusted_log_exception(
172
+ error_message, caused_by, failure_file_path, failure_prefix="Algorithm Error"
173
+ ):
174
+ """Write private exception message to the trusted error channel.
175
+
176
+ :param error_message: error_message
177
+ :param caused_by: cause for the error
178
+ :param failure_file_path: failure file path. Usually /opt/ml/output/failure
179
+ :param failure_prefix: prefix to attach to the error message
180
+ :return:
181
+ """
182
+ message = "{}: {}".format(failure_prefix, error_message)
183
+ error_detail = "Caused by: {}".format(caused_by)
184
+ message += "\n\n{}".format(error_detail)
185
+ err_logger = get_error_logger()
186
+ err_logger.exception(message)
187
+ write_failure_reason(message, failure_file_path)
188
+ return message
config.yaml ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ names:
3
+ - timm_image
4
+ categorical_transformer:
5
+ out_features: 192
6
+ d_token: 192
7
+ ffn_d_hidden: 192
8
+ num_trans_blocks: 0
9
+ num_attn_heads: 8
10
+ residual_dropout: 0.0
11
+ attention_dropout: 0.2
12
+ ffn_dropout: 0.1
13
+ normalization: layer_norm
14
+ ffn_activation: reglu
15
+ head_activation: relu
16
+ data_types:
17
+ - categorical
18
+ additive_attention: false
19
+ share_qv_weights: false
20
+ numerical_transformer:
21
+ out_features: 192
22
+ d_token: 192
23
+ ffn_d_hidden: 192
24
+ num_trans_blocks: 0
25
+ num_attn_heads: 8
26
+ residual_dropout: 0.0
27
+ attention_dropout: 0.2
28
+ ffn_dropout: 0.1
29
+ normalization: layer_norm
30
+ ffn_activation: reglu
31
+ head_activation: relu
32
+ data_types:
33
+ - numerical
34
+ embedding_arch:
35
+ - linear
36
+ - relu
37
+ merge: concat
38
+ additive_attention: false
39
+ share_qv_weights: false
40
+ ner_text:
41
+ checkpoint_name: bert-base-cased
42
+ max_text_len: 512
43
+ gradient_checkpointing: false
44
+ low_cpu_mem_usage: false
45
+ data_types:
46
+ - text
47
+ tokenizer_name: hf_auto
48
+ insert_sep: false
49
+ text_segment_num: 2
50
+ stochastic_chunk: false
51
+ special_tags:
52
+ - X
53
+ - O
54
+ t_few:
55
+ checkpoint_name: t5-small
56
+ gradient_checkpointing: false
57
+ data_types:
58
+ - text
59
+ tokenizer_name: hf_auto
60
+ length_norm: 1.0
61
+ unlikely_loss: 1.0
62
+ mc_loss: 1.0
63
+ max_text_len: 512
64
+ text_segment_num: 2
65
+ insert_sep: true
66
+ low_cpu_mem_usage: false
67
+ stochastic_chunk: false
68
+ text_aug_detect_length: 10
69
+ text_trivial_aug_maxscale: 0.0
70
+ timm_image:
71
+ checkpoint_name: swin_base_patch4_window7_224
72
+ mix_choice: all_logits
73
+ data_types:
74
+ - image
75
+ train_transform_types:
76
+ - resize_shorter_side
77
+ - center_crop
78
+ - trivial_augment
79
+ val_transform_types:
80
+ - resize_shorter_side
81
+ - center_crop
82
+ image_norm: imagenet
83
+ image_size: 224
84
+ max_img_num_per_col: 2
85
+ mmdet_image:
86
+ checkpoint_name: yolov3_mobilenetv2_320_300e_coco
87
+ data_types:
88
+ - image
89
+ train_transform_types:
90
+ - resize_shorter_side
91
+ - center_crop
92
+ - trivial_augment
93
+ val_transform_types:
94
+ - resize_shorter_side
95
+ - center_crop
96
+ image_norm: imagenet
97
+ image_size: 224
98
+ max_img_num_per_col: 2
99
+ mmocr_text_detection:
100
+ checkpoint_name: TextSnake
101
+ data_types:
102
+ - image
103
+ train_transform_types:
104
+ - resize_shorter_side
105
+ - center_crop
106
+ - trivial_augment
107
+ val_transform_types:
108
+ - resize_shorter_side
109
+ - center_crop
110
+ image_norm: imagenet
111
+ image_size: 224
112
+ max_img_num_per_col: 2
113
+ mmocr_text_recognition:
114
+ checkpoint_name: ABINet
115
+ data_types:
116
+ - image
117
+ train_transform_types:
118
+ - resize_shorter_side
119
+ - center_crop
120
+ - trivial_augment
121
+ val_transform_types:
122
+ - resize_shorter_side
123
+ - center_crop
124
+ image_norm: imagenet
125
+ image_size: 224
126
+ max_img_num_per_col: 2
127
+ clip:
128
+ checkpoint_name: openai/clip-vit-base-patch32
129
+ data_types:
130
+ - image
131
+ - text
132
+ train_transform_types:
133
+ - resize_shorter_side
134
+ - center_crop
135
+ - trivial_augment
136
+ val_transform_types:
137
+ - resize_shorter_side
138
+ - center_crop
139
+ image_norm: clip
140
+ image_size: 224
141
+ max_img_num_per_col: 2
142
+ tokenizer_name: clip
143
+ max_text_len: 77
144
+ insert_sep: false
145
+ text_segment_num: 1
146
+ stochastic_chunk: false
147
+ text_aug_detect_length: 10
148
+ text_trivial_aug_maxscale: 0.0
149
+ text_train_augment_types: null
150
+ fusion_transformer:
151
+ hidden_size: 192
152
+ n_blocks: 3
153
+ attention_n_heads: 8
154
+ adapt_in_features: max
155
+ attention_dropout: 0.2
156
+ residual_dropout: 0.0
157
+ ffn_dropout: 0.1
158
+ ffn_d_hidden: 192
159
+ normalization: layer_norm
160
+ ffn_activation: geglu
161
+ head_activation: relu
162
+ data_types: null
163
+ additive_attention: false
164
+ share_qv_weights: false
165
+ data:
166
+ image:
167
+ missing_value_strategy: skip
168
+ text:
169
+ normalize_text: false
170
+ categorical:
171
+ minimum_cat_count: 100
172
+ maximum_num_cat: 20
173
+ convert_to_text: true
174
+ numerical:
175
+ convert_to_text: false
176
+ scaler_with_mean: true
177
+ scaler_with_std: true
178
+ label:
179
+ numerical_label_preprocessing: standardscaler
180
+ pos_label: null
181
+ mixup:
182
+ turn_on: false
183
+ mixup_alpha: 0.8
184
+ cutmix_alpha: 1.0
185
+ cutmix_minmax: null
186
+ prob: 1.0
187
+ switch_prob: 0.5
188
+ mode: batch
189
+ turn_off_epoch: 5
190
+ label_smoothing: 0.1
191
+ templates:
192
+ turn_on: false
193
+ num_templates: 30
194
+ template_length: 2048
195
+ preset_templates:
196
+ - super_glue
197
+ - rte
198
+ custom_templates: null
199
+ optimization:
200
+ optim_type: adamw
201
+ learning_rate: 0.001
202
+ weight_decay: 0.001
203
+ lr_choice: layerwise_decay
204
+ lr_decay: 0.9
205
+ lr_schedule: cosine_decay
206
+ max_epochs: 10
207
+ max_steps: -1
208
+ warmup_steps: 0.1
209
+ end_lr: 0
210
+ lr_mult: 1
211
+ patience: 10
212
+ val_check_interval: 0.5
213
+ check_val_every_n_epoch: 1
214
+ gradient_clip_val: 1
215
+ gradient_clip_algorithm: norm
216
+ track_grad_norm: -1
217
+ log_every_n_steps: 10
218
+ val_metric: null
219
+ top_k: 3
220
+ top_k_average_method: best
221
+ efficient_finetune: null
222
+ lora:
223
+ module_filter: null
224
+ filter:
225
+ - query
226
+ - value
227
+ - ^q$
228
+ - ^v$
229
+ - ^k$
230
+ - ^o$
231
+ r: 8
232
+ alpha: 8
233
+ loss_function: auto
234
+ env:
235
+ num_gpus: 4
236
+ num_nodes: 1
237
+ batch_size: 128
238
+ per_gpu_batch_size: 32
239
+ eval_batch_size_ratio: 4
240
+ per_gpu_batch_size_evaluation: null
241
+ precision: 16
242
+ num_workers: 2
243
+ num_workers_evaluation: 2
244
+ fast_dev_run: false
245
+ deterministic: false
246
+ auto_select_gpus: true
247
+ strategy: ddp
248
+ deepspeed_allgather_size: 1000000000.0
249
+ deepspeed_allreduce_size: 1000000000.0
data_processors.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4de012a54d1f0b1e4087f8d8fe96649416ce438a359da29b8aeb68a6eb7b7e03
3
+ size 348790580
df_preprocessor.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9e5001bf2241ae6b6fdc72edb06ddef63842d8bec8f79b66f0ad58315b1f403
3
+ size 22548
events.out.tfevents.1706665637.algo-1.21.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61b5cf65a2950a863dfa5d46f2c253e4d9e7729ba00c9c6b22b5a12df4c03eda
3
+ size 8085
hparams.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ optim_type: adamw
2
+ lr_choice: layerwise_decay
3
+ lr_schedule: cosine_decay
4
+ lr: 0.001
5
+ lr_decay: 0.9
6
+ end_lr: 0
7
+ lr_mult: 1
8
+ weight_decay: 0.001
9
+ warmup_steps: 0.1
10
+ validation_metric_name: accuracy
11
+ custom_metric_func: null
12
+ efficient_finetune: null
13
+ trainable_param_names: []
14
+ mixup_fn: null
15
+ mixup_off_epoch: 5
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e7891ddbec5fc09cc894c993664dab03244b780705c508fe04889e990948652
3
+ size 348680105