more small changes
Browse files- corpus/angry_tweets/jsonlines2tsv.py +66 -0
- corpus/angry_tweets/test.csv +0 -0
- corpus/angry_tweets/test.tsv +0 -0
- corpus/angry_tweets/test2.tsv +0 -0
- corpus/angry_tweets/train.csv +0 -0
- corpus/angry_tweets/train.tsv +0 -0
- corpus/dane/jsonlines2tsv.py +69 -0
- corpus/dane/test.jsonl +0 -0
- corpus/dane/test.tsv +0 -0
- corpus/dane/test_tokens.tsv +0 -0
- corpus/dane/train.jsonl +0 -0
- corpus/dane/train.tsv +0 -0
- corpus/dane/train_tokens.tsv +0 -0
- finetune_categorisation_large.gin +14 -7
- finetune_large.sh +1 -1
- log/angry_tweets-1705000.jsonl +0 -0
- log/angry_tweets-metrics.jsonl +3 -0
- my_metrics.py +1 -1
- my_preprocessors.py +1 -0
- tasks.py +44 -25
corpus/angry_tweets/jsonlines2tsv.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import ujson as json
|
3 |
+
import gzip
|
4 |
+
import sys
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def validate_to_set(x):
|
9 |
+
if x is None:
|
10 |
+
return set()
|
11 |
+
elif isinstance(x, (tuple, list)):
|
12 |
+
return set(x)
|
13 |
+
elif isinstance(x, str):
|
14 |
+
return set([x])
|
15 |
+
return -1
|
16 |
+
|
17 |
+
|
18 |
+
def main(in_path, out_path, delim='\t', keep_fields=None, skip_fields=None):
|
19 |
+
"""
|
20 |
+
|
21 |
+
:param str in_path:
|
22 |
+
:param str out_path:
|
23 |
+
:param str delim:
|
24 |
+
:param list|str keep_fields:
|
25 |
+
:param list|str skip_fields:
|
26 |
+
"""
|
27 |
+
keep_fields = validate_to_set(keep_fields)
|
28 |
+
if keep_fields == -1:
|
29 |
+
return
|
30 |
+
skip_fields = validate_to_set(skip_fields)
|
31 |
+
if skip_fields == -1:
|
32 |
+
return
|
33 |
+
|
34 |
+
fmt = in_path.split('.')[-1]
|
35 |
+
if fmt == 'gz':
|
36 |
+
open_to_use = gzip.open
|
37 |
+
else:
|
38 |
+
open_to_use = open
|
39 |
+
|
40 |
+
# Read the file once to get a list of all keep fields
|
41 |
+
# skip if a set list of keep fields is defined
|
42 |
+
line_count = None
|
43 |
+
if len(keep_fields) == 0:
|
44 |
+
line_count = 0
|
45 |
+
for line in tqdm(open_to_use(in_path)):
|
46 |
+
keep_fields.update(list(json.loads(line).keys()))
|
47 |
+
line_count += 1
|
48 |
+
|
49 |
+
keep_fields.difference_update(skip_fields)
|
50 |
+
|
51 |
+
# force alphabetization
|
52 |
+
keep_list = sorted(keep_fields)
|
53 |
+
|
54 |
+
with open(out_path, 'w') as outfile:
|
55 |
+
writer = csv.writer(outfile, delimiter=delim)
|
56 |
+
#writer.writerow(keep_list)
|
57 |
+
for line in tqdm(open_to_use(in_path), total=line_count):
|
58 |
+
jsn = json.loads(line)
|
59 |
+
writer.writerow([jsn[x].replace("\n"," ").replace("\t"," ") if x in jsn else '' for x in keep_list])
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
if len(sys.argv) < 2:
|
64 |
+
print('Usage: python jsonlines2ctsv.py <in_file> <out_file>]')
|
65 |
+
sys.exit(0)
|
66 |
+
main(sys.argv[1], sys.argv[2], skip_fields=['content'])
|
corpus/angry_tweets/test.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
corpus/angry_tweets/test.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
corpus/angry_tweets/test2.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
corpus/angry_tweets/train.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
corpus/angry_tweets/train.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
corpus/dane/jsonlines2tsv.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import ujson as json
|
3 |
+
import gzip
|
4 |
+
import sys
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def validate_to_set(x):
|
9 |
+
if x is None:
|
10 |
+
return set()
|
11 |
+
elif isinstance(x, (tuple, list)):
|
12 |
+
return set(x)
|
13 |
+
elif isinstance(x, str):
|
14 |
+
return set([x])
|
15 |
+
return -1
|
16 |
+
|
17 |
+
|
18 |
+
def main(in_path, out_path, delim='\t', keep_fields=None, skip_fields=None):
|
19 |
+
"""
|
20 |
+
|
21 |
+
:param str in_path:
|
22 |
+
:param str out_path:
|
23 |
+
:param str delim:
|
24 |
+
:param list|str keep_fields:
|
25 |
+
:param list|str skip_fields:
|
26 |
+
"""
|
27 |
+
keep_fields = validate_to_set(keep_fields)
|
28 |
+
if keep_fields == -1:
|
29 |
+
return
|
30 |
+
skip_fields = validate_to_set(skip_fields)
|
31 |
+
if skip_fields == -1:
|
32 |
+
return
|
33 |
+
|
34 |
+
fmt = in_path.split('.')[-1]
|
35 |
+
if fmt == 'gz':
|
36 |
+
open_to_use = gzip.open
|
37 |
+
else:
|
38 |
+
open_to_use = open
|
39 |
+
|
40 |
+
# Read the file once to get a list of all keep fields
|
41 |
+
# skip if a set list of keep fields is defined
|
42 |
+
line_count = None
|
43 |
+
if len(keep_fields) == 0:
|
44 |
+
line_count = 0
|
45 |
+
for line in tqdm(open_to_use(in_path)):
|
46 |
+
keep_fields.update(list(json.loads(line).keys()))
|
47 |
+
line_count += 1
|
48 |
+
|
49 |
+
keep_fields.difference_update(skip_fields)
|
50 |
+
|
51 |
+
# force alphabetization
|
52 |
+
keep_list = sorted(keep_fields)
|
53 |
+
keep_list.append("combined")
|
54 |
+
|
55 |
+
with open(out_path, 'w') as outfile:
|
56 |
+
writer = csv.writer(outfile, delimiter=delim)
|
57 |
+
#writer.writerow(keep_list)
|
58 |
+
for line in tqdm(open_to_use(in_path), total=line_count):
|
59 |
+
jsn = json.loads(line)
|
60 |
+
jsn['combined'] = dict(zip(jsn['tokens'], jsn['ner_tags']))
|
61 |
+
#writer.writerow([jsn[x].replace("\n"," ").replace("\t"," ") if x in jsn else '' for x in keep_list])
|
62 |
+
writer.writerow([jsn[x] if x in jsn else '' for x in keep_list])
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
if len(sys.argv) < 2:
|
67 |
+
print('Usage: python jsonlines2ctsv.py <in_file> <out_file>]')
|
68 |
+
sys.exit(0)
|
69 |
+
main(sys.argv[1], sys.argv[2], keep_fields=['doc','tokens','ner_tags'])
|
corpus/dane/test.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
corpus/dane/test.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
corpus/dane/test_tokens.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
corpus/dane/train.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
corpus/dane/train.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
corpus/dane/train_tokens.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
finetune_categorisation_large.gin
CHANGED
@@ -10,9 +10,9 @@ from t5x import utils
|
|
10 |
include "t5x/examples/t5/mt5/large.gin"
|
11 |
include "t5x/configs/runs/finetune.gin"
|
12 |
|
13 |
-
MIXTURE_OR_TASK_NAME = "
|
14 |
-
TASK_FEATURE_LENGTHS = {"inputs":
|
15 |
-
TRAIN_STEPS =
|
16 |
USE_CACHED_TASKS = False
|
17 |
DROPOUT_RATE = 0.1
|
18 |
RANDOM_SEED = 0
|
@@ -24,17 +24,24 @@ RANDOM_SEED = 0
|
|
24 |
# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`.
|
25 |
#LOSS_NORMALIZING_FACTOR = 234496
|
26 |
|
27 |
-
#INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_base/checkpoint_1000000"
|
28 |
-
#INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_large/checkpoint_1000000"
|
29 |
-
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_run1/checkpoint_1100000"
|
30 |
INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_NCC_plus_English_pluss200k_scandinavian_t5x_large/checkpoint_1700000"
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
#train_script.train:
|
35 |
# eval_period = 500
|
36 |
# partitioner = @partitioning.ModelBasedPjitPartitioner()
|
37 |
-
|
38 |
|
39 |
# `num_decodes` is equivalent to a beam size in a beam search decoding.
|
40 |
# models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
|
|
|
10 |
include "t5x/examples/t5/mt5/large.gin"
|
11 |
include "t5x/configs/runs/finetune.gin"
|
12 |
|
13 |
+
MIXTURE_OR_TASK_NAME = "dane"
|
14 |
+
TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 512}
|
15 |
+
TRAIN_STEPS = 1_720_000 # 1700000 pre-trained steps + 20000 fine-tuning steps.
|
16 |
USE_CACHED_TASKS = False
|
17 |
DROPOUT_RATE = 0.1
|
18 |
RANDOM_SEED = 0
|
|
|
24 |
# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`.
|
25 |
#LOSS_NORMALIZING_FACTOR = 234496
|
26 |
|
|
|
|
|
|
|
27 |
INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_NCC_plus_English_pluss200k_scandinavian_t5x_large/checkpoint_1700000"
|
28 |
|
29 |
|
30 |
+
#Fixing a small error
|
31 |
+
infer_eval/utils.DatasetConfig:
|
32 |
+
task_feature_lengths = %TASK_FEATURE_LENGTHS
|
33 |
+
|
34 |
+
#Saving every 2000 steps
|
35 |
+
utils.SaveCheckpointConfig:
|
36 |
+
period = 2000
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
|
41 |
#train_script.train:
|
42 |
# eval_period = 500
|
43 |
# partitioner = @partitioning.ModelBasedPjitPartitioner()
|
44 |
+
partitioning.PjitPartitioner.num_partitions = 1
|
45 |
|
46 |
# `num_decodes` is equivalent to a beam size in a beam search decoding.
|
47 |
# models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
|
finetune_large.sh
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
|
2 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
3 |
#Needs to be updated when moving to tpu-v4 it should then be in another zone
|
4 |
-
MODEL_DIR="gs://nb-t5x-us-central2/finetuned/
|
5 |
export PYTHONPATH=${PROJECT_DIR}
|
6 |
|
7 |
python3 ${T5X_DIR}/t5x/train.py \
|
|
|
1 |
PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
|
2 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
3 |
#Needs to be updated when moving to tpu-v4 it should then be in another zone
|
4 |
+
MODEL_DIR="gs://nb-t5x-us-central2/finetuned/v3_eval_dane_scandinavian_large"
|
5 |
export PYTHONPATH=${PROJECT_DIR}
|
6 |
|
7 |
python3 ${T5X_DIR}/t5x/train.py \
|
log/angry_tweets-1705000.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
log/angry_tweets-metrics.jsonl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{"step": 1703000, "accuracy": 70.96466093600765, "f1_macro": 70.6282173005562}
|
2 |
+
{"step": 1704000, "accuracy": 71.82425978987584, "f1_macro": 71.53409993558694}
|
3 |
+
{"step": 1705000, "accuracy": 71.25119388729703, "f1_macro": 71.00190678021416}
|
my_metrics.py
CHANGED
@@ -2,6 +2,6 @@ import sklearn.metrics
|
|
2 |
import numpy as np
|
3 |
|
4 |
def f1_macro(targets, predictions):
|
5 |
-
targets, predictions = np.asarray(targets).astype(
|
6 |
return {"f1_macro": 100*sklearn.metrics.f1_score(targets, predictions, average='macro')}
|
7 |
|
|
|
2 |
import numpy as np
|
3 |
|
4 |
def f1_macro(targets, predictions):
|
5 |
+
targets, predictions = np.asarray(targets).astype(str), np.asarray(predictions).astype(str)
|
6 |
return {"f1_macro": 100*sklearn.metrics.f1_score(targets, predictions, average='macro')}
|
7 |
|
my_preprocessors.py
CHANGED
@@ -36,6 +36,7 @@ def parse_tsv(line, field_names=None, field_delim='\t'):
|
|
36 |
Returns:
|
37 |
A feature dict mapping field name to string value.
|
38 |
"""
|
|
|
39 |
field_names = field_names or ['inputs', 'targets']
|
40 |
return dict(
|
41 |
zip(field_names,
|
|
|
36 |
Returns:
|
37 |
A feature dict mapping field name to string value.
|
38 |
"""
|
39 |
+
breakpoint()
|
40 |
field_names = field_names or ['inputs', 'targets']
|
41 |
return dict(
|
42 |
zip(field_names,
|
tasks.py
CHANGED
@@ -5,8 +5,8 @@ import seqio
|
|
5 |
import my_metrics
|
6 |
import tensorflow_datasets as tfds
|
7 |
from t5.evaluation import metrics
|
8 |
-
|
9 |
-
import my_preprocessors
|
10 |
import t5
|
11 |
import tensorflow.compat.v1 as tf
|
12 |
|
@@ -28,6 +28,25 @@ json_angry_tweets_path = {
|
|
28 |
"test": "gs://notram-public/finetune_datasets/angry_tweets/test.jsonl"
|
29 |
}
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
vocabulary = seqio.SentencePieceVocabulary(
|
33 |
'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
|
@@ -44,7 +63,8 @@ DEFAULT_OUTPUT_FEATURES = {
|
|
44 |
def categorise_preprocessor(ds):
|
45 |
def normalize_text(text):
|
46 |
"""Lowercase and remove quotes from a TensorFlow string."""
|
47 |
-
text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
|
|
|
48 |
return text
|
49 |
|
50 |
def to_inputs_and_targets(ex):
|
@@ -60,25 +80,6 @@ def categorise_preprocessor(ds):
|
|
60 |
return ds.map(to_inputs_and_targets,
|
61 |
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
62 |
|
63 |
-
def scandeval_preprocessor(ds):
|
64 |
-
def normalize_text(text):
|
65 |
-
"""Lowercase and remove quotes from a TensorFlow string."""
|
66 |
-
text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
|
67 |
-
return text
|
68 |
-
|
69 |
-
def to_inputs_and_targets(ex):
|
70 |
-
"""Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
|
71 |
-
return {
|
72 |
-
"inputs":
|
73 |
-
tf.strings.join(
|
74 |
-
[normalize_text(ex["text"])]),
|
75 |
-
"targets":
|
76 |
-
tf.strings.join(
|
77 |
-
[normalize_text(ex["label"])]),
|
78 |
-
}
|
79 |
-
return ds.map(to_inputs_and_targets,
|
80 |
-
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
81 |
-
|
82 |
|
83 |
seqio.TaskRegistry.add(
|
84 |
"parliament",
|
@@ -117,13 +118,31 @@ seqio.TaskRegistry.add(
|
|
117 |
seqio.TaskRegistry.add(
|
118 |
"angry_tweets",
|
119 |
source=seqio.TextLineDataSource(
|
120 |
-
split_to_filepattern=
|
121 |
#num_input_examples=num_nq_examples
|
122 |
),
|
123 |
preprocessors=[
|
124 |
functools.partial(
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
seqio.preprocessors.tokenize_and_append_eos,
|
128 |
],
|
129 |
metric_fns=[metrics.accuracy,my_metrics.f1_macro],
|
|
|
5 |
import my_metrics
|
6 |
import tensorflow_datasets as tfds
|
7 |
from t5.evaluation import metrics
|
8 |
+
from t5.data import preprocessors
|
9 |
+
#import my_preprocessors
|
10 |
import t5
|
11 |
import tensorflow.compat.v1 as tf
|
12 |
|
|
|
28 |
"test": "gs://notram-public/finetune_datasets/angry_tweets/test.jsonl"
|
29 |
}
|
30 |
|
31 |
+
tsv_angry_tweets_path = {
|
32 |
+
"train": "gs://notram-public/finetune_datasets/angry_tweets/train.tsv",
|
33 |
+
"validation": "gs://notram-public/finetune_datasets/angry_tweets/test.tsv",
|
34 |
+
"test": "gs://notram-public/finetune_datasets/angry_tweets/test.tsv"
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
tsv_dane_path = {
|
39 |
+
"train": "gs://notram-public/finetune_datasets/dane/train.tsv",
|
40 |
+
"validation": "gs://notram-public/finetune_datasets/dane/test.tsv",
|
41 |
+
"test": "gs://notram-public/finetune_datasets/dane/test.tsv"
|
42 |
+
}
|
43 |
+
|
44 |
+
tsv_dane_tokens_path = {
|
45 |
+
"train": "gs://notram-public/finetune_datasets/dane/train_tokens.tsv",
|
46 |
+
"validation": "gs://notram-public/finetune_datasets/dane/test_tokens.tsv",
|
47 |
+
"test": "gs://notram-public/finetune_datasets/dane/test_tokens.tsv"
|
48 |
+
}
|
49 |
+
|
50 |
|
51 |
vocabulary = seqio.SentencePieceVocabulary(
|
52 |
'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
|
|
|
63 |
def categorise_preprocessor(ds):
|
64 |
def normalize_text(text):
|
65 |
"""Lowercase and remove quotes from a TensorFlow string."""
|
66 |
+
#text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
|
67 |
+
...
|
68 |
return text
|
69 |
|
70 |
def to_inputs_and_targets(ex):
|
|
|
80 |
return ds.map(to_inputs_and_targets,
|
81 |
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
seqio.TaskRegistry.add(
|
85 |
"parliament",
|
|
|
118 |
seqio.TaskRegistry.add(
|
119 |
"angry_tweets",
|
120 |
source=seqio.TextLineDataSource(
|
121 |
+
split_to_filepattern=tsv_angry_tweets_path,
|
122 |
#num_input_examples=num_nq_examples
|
123 |
),
|
124 |
preprocessors=[
|
125 |
functools.partial(
|
126 |
+
t5.data.preprocessors.parse_tsv,
|
127 |
+
field_names=["target","source"]),
|
128 |
+
categorise_preprocessor,
|
129 |
+
seqio.preprocessors.tokenize_and_append_eos,
|
130 |
+
],
|
131 |
+
metric_fns=[metrics.accuracy,my_metrics.f1_macro],
|
132 |
+
output_features=DEFAULT_OUTPUT_FEATURES,
|
133 |
+
)
|
134 |
+
|
135 |
+
seqio.TaskRegistry.add(
|
136 |
+
"dane",
|
137 |
+
source=seqio.TextLineDataSource(
|
138 |
+
split_to_filepattern=tsv_dane_tokens_path,
|
139 |
+
#num_input_examples=num_nq_examples
|
140 |
+
),
|
141 |
+
preprocessors=[
|
142 |
+
functools.partial(
|
143 |
+
t5.data.preprocessors.parse_tsv,
|
144 |
+
field_names=["source","placeholder1","placeholder2","target"]),
|
145 |
+
categorise_preprocessor,
|
146 |
seqio.preprocessors.tokenize_and_append_eos,
|
147 |
],
|
148 |
metric_fns=[metrics.accuracy,my_metrics.f1_macro],
|