marinone94
commited on
Commit
·
21f22fe
1
Parent(s):
8f1a9b5
final swedish training
Browse files- run.sh +9 -8
- run_speech_recognition_seq2seq_streaming.py +197 -23
run.sh
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
python run_speech_recognition_seq2seq_streaming.py \
|
2 |
--model_name_or_path="marinone94/whisper-medium-nordic" \
|
3 |
-
--
|
4 |
-
--
|
5 |
--language="swedish" \
|
6 |
-
--train_split_name="train+validation" \
|
|
|
|
|
7 |
--eval_split_name="test" \
|
8 |
--model_index_name="Whisper Medium Swedish" \
|
9 |
-
--max_steps="
|
10 |
--output_dir="./" \
|
11 |
--per_device_train_batch_size="32" \
|
12 |
--per_device_eval_batch_size="16" \
|
@@ -20,9 +22,9 @@ python run_speech_recognition_seq2seq_streaming.py \
|
|
20 |
--generation_max_length="225" \
|
21 |
--length_column_name="input_length" \
|
22 |
--max_duration_in_seconds="30" \
|
23 |
-
--text_column_name="sentence" \
|
24 |
--freeze_feature_encoder="False" \
|
25 |
-
--report_to="
|
26 |
--metric_for_best_model="wer" \
|
27 |
--greater_is_better="False" \
|
28 |
--load_best_model_at_end \
|
@@ -34,5 +36,4 @@ python run_speech_recognition_seq2seq_streaming.py \
|
|
34 |
--predict_with_generate \
|
35 |
--do_normalize_eval \
|
36 |
--streaming \
|
37 |
-
--use_auth_token
|
38 |
-
--push_to_hub
|
|
|
1 |
python run_speech_recognition_seq2seq_streaming.py \
|
2 |
--model_name_or_path="marinone94/whisper-medium-nordic" \
|
3 |
+
--dataset_train_name="mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,google/fleurs" \
|
4 |
+
--dataset_train_config_name="sv-SE,nst,sv_se" \
|
5 |
--language="swedish" \
|
6 |
+
--train_split_name="train+validation,train,train+validation+test" \
|
7 |
+
--dataset_eval_name="mozilla-foundation/common_voice_11_0" \
|
8 |
+
--dataset_eval_config_name="sv-SE" \
|
9 |
--eval_split_name="test" \
|
10 |
--model_index_name="Whisper Medium Swedish" \
|
11 |
+
--max_steps="5000" \
|
12 |
--output_dir="./" \
|
13 |
--per_device_train_batch_size="32" \
|
14 |
--per_device_eval_batch_size="16" \
|
|
|
22 |
--generation_max_length="225" \
|
23 |
--length_column_name="input_length" \
|
24 |
--max_duration_in_seconds="30" \
|
25 |
+
--text_column_name="sentence,raw_transcription" \
|
26 |
--freeze_feature_encoder="False" \
|
27 |
+
--report_to="wandb" \
|
28 |
--metric_for_best_model="wer" \
|
29 |
--greater_is_better="False" \
|
30 |
--load_best_model_at_end \
|
|
|
36 |
--predict_with_generate \
|
37 |
--do_normalize_eval \
|
38 |
--streaming \
|
39 |
+
--use_auth_token
|
|
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -20,6 +20,7 @@ with 🤗 Datasets' streaming mode.
|
|
20 |
# You can also adapt this script for your own sequence to sequence speech
|
21 |
# recognition task. Pointers for this are left as comments.
|
22 |
|
|
|
23 |
import logging
|
24 |
import os
|
25 |
import sys
|
@@ -28,6 +29,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
28 |
|
29 |
import datasets
|
30 |
import torch
|
|
|
31 |
from datasets import DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
|
32 |
from torch.utils.data import IterableDataset
|
33 |
|
@@ -60,6 +62,42 @@ require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/spe
|
|
60 |
logger = logging.getLogger(__name__)
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
@dataclass
|
64 |
class ModelArguments:
|
65 |
"""
|
@@ -265,27 +303,131 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
265 |
return batch
|
266 |
|
267 |
|
268 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
"""
|
270 |
Utility function to load a dataset in streaming mode. For datasets with multiple splits,
|
271 |
each split is loaded individually and then splits combined by taking alternating examples from
|
272 |
each (interleaving).
|
273 |
"""
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
275 |
# load multiple splits separated by the `+` symbol with streaming mode
|
276 |
-
dataset_splits = [
|
277 |
-
|
278 |
-
|
279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
# interleave multiple splits to form one dataset
|
281 |
-
interleaved_dataset = interleave_datasets(dataset_splits)
|
282 |
return interleaved_dataset
|
283 |
else:
|
284 |
# load a single split *with* streaming mode
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
return dataset
|
287 |
|
288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
def main():
|
290 |
# 1. Parse input arguments
|
291 |
# See all possible arguments in src/transformers/training_args.py
|
@@ -349,25 +491,41 @@ def main():
|
|
349 |
# Set seed before initializing model.
|
350 |
set_seed(training_args.seed)
|
351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
# 4. Load dataset
|
353 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
354 |
|
355 |
if training_args.do_train:
|
356 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
357 |
-
data_args.
|
358 |
-
data_args.
|
359 |
split=data_args.train_split_name,
|
360 |
-
use_auth_token=
|
361 |
streaming=data_args.streaming,
|
|
|
|
|
|
|
|
|
362 |
)
|
363 |
|
364 |
if training_args.do_eval:
|
365 |
raw_datasets["eval"] = load_maybe_streaming_dataset(
|
366 |
-
data_args.
|
367 |
-
data_args.
|
368 |
split=data_args.eval_split_name,
|
369 |
-
use_auth_token=
|
370 |
streaming=data_args.streaming,
|
|
|
|
|
|
|
|
|
371 |
)
|
372 |
|
373 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
@@ -394,7 +552,7 @@ def main():
|
|
394 |
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
395 |
cache_dir=model_args.cache_dir,
|
396 |
revision=model_args.model_revision,
|
397 |
-
use_auth_token=
|
398 |
)
|
399 |
|
400 |
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
|
@@ -402,25 +560,19 @@ def main():
|
|
402 |
if training_args.gradient_checkpointing:
|
403 |
config.update({"use_cache": False})
|
404 |
|
405 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
406 |
-
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
|
407 |
-
cache_dir=model_args.cache_dir,
|
408 |
-
revision=model_args.model_revision,
|
409 |
-
use_auth_token=True if model_args.use_auth_token else None,
|
410 |
-
)
|
411 |
tokenizer = AutoTokenizer.from_pretrained(
|
412 |
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
413 |
cache_dir=model_args.cache_dir,
|
414 |
use_fast=model_args.use_fast_tokenizer,
|
415 |
revision=model_args.model_revision,
|
416 |
-
use_auth_token=
|
417 |
)
|
418 |
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
419 |
model_args.model_name_or_path,
|
420 |
config=config,
|
421 |
cache_dir=model_args.cache_dir,
|
422 |
revision=model_args.model_revision,
|
423 |
-
use_auth_token=
|
424 |
)
|
425 |
|
426 |
if model.config.decoder_start_token_id is None:
|
@@ -568,6 +720,9 @@ def main():
|
|
568 |
callbacks=[ShuffleCallback()] if data_args.streaming else None,
|
569 |
)
|
570 |
|
|
|
|
|
|
|
571 |
# 12. Training
|
572 |
if training_args.do_train:
|
573 |
checkpoint = None
|
@@ -617,10 +772,29 @@ def main():
|
|
617 |
if model_args.model_index_name is not None:
|
618 |
kwargs["model_name"] = model_args.model_index_name
|
619 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
if training_args.push_to_hub:
|
621 |
trainer.push_to_hub(**kwargs)
|
622 |
else:
|
623 |
trainer.create_model_card(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
624 |
|
625 |
return results
|
626 |
|
|
|
20 |
# You can also adapt this script for your own sequence to sequence speech
|
21 |
# recognition task. Pointers for this are left as comments.
|
22 |
|
23 |
+
import json
|
24 |
import logging
|
25 |
import os
|
26 |
import sys
|
|
|
29 |
|
30 |
import datasets
|
31 |
import torch
|
32 |
+
import wandb
|
33 |
from datasets import DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
|
34 |
from torch.utils.data import IterableDataset
|
35 |
|
|
|
62 |
logger = logging.getLogger(__name__)
|
63 |
|
64 |
|
65 |
+
SENDING_NOTIFICATION = "*** Sending notification to email ***"
|
66 |
+
RECIPIENT_ADDRESS = "marinone94@gmail.com"
|
67 |
+
|
68 |
+
wandb_token = os.environ.get("WANDB_TOKEN", "None")
|
69 |
+
hf_token = os.environ.get("HF_TOKEN", None)
|
70 |
+
if (hf_token is None or wandb_token == "None") and os.path.exists("./creds.txt"):
|
71 |
+
with open("./creds.txt", "r") as f:
|
72 |
+
lines = f.readlines()
|
73 |
+
for line in lines:
|
74 |
+
key, value = line.split("=")
|
75 |
+
if key == "HF_TOKEN":
|
76 |
+
hf_token = value.strip()
|
77 |
+
if key == "WANDB_TOKEN":
|
78 |
+
wandb_token = value.strip()
|
79 |
+
if key == "EMAIL_ADDRESS":
|
80 |
+
os.environ["EMAIL_ADDRESS"] = value.strip()
|
81 |
+
if key == "EMAIL_PASSWORD":
|
82 |
+
os.environ["EMAIL_PASSWORD"] = value.strip()
|
83 |
+
|
84 |
+
if hf_token is not None:
|
85 |
+
try:
|
86 |
+
os.makedirs("/root/.huggingface", exist_ok=True)
|
87 |
+
with open("/root/.huggingface/token", "w") as f:
|
88 |
+
f.write(hf_token)
|
89 |
+
logger.info("Huggingface API key set")
|
90 |
+
except (PermissionError, OSError):
|
91 |
+
logger.warning("Huggingface API key not set, relying on ~/.huggingface/token")
|
92 |
+
else:
|
93 |
+
logger.warning("Huggingface API key not set, relying on ~/.huggingface/token")
|
94 |
+
|
95 |
+
wandb.login(key=wandb_token, relogin=True, timeout=5)
|
96 |
+
wandb.init(project="whisper", entity="pn-aa")
|
97 |
+
|
98 |
+
logger.info("Wandb API key set, logging to wandb")
|
99 |
+
|
100 |
+
|
101 |
@dataclass
|
102 |
class ModelArguments:
|
103 |
"""
|
|
|
303 |
return batch
|
304 |
|
305 |
|
306 |
+
def rename_col_and_resample(dataset, dataset_name, text_column_names, text_col_name_ref, audio_column_name, sampling_rate):
|
307 |
+
raw_datasets_features = list(dataset.features.keys())
|
308 |
+
logger.info(f"Dataset {dataset_name} - Features: {raw_datasets_features}")
|
309 |
+
|
310 |
+
if text_col_name_ref not in raw_datasets_features:
|
311 |
+
if len(text_column_names) == 1:
|
312 |
+
raise ValueError("None of the text column names provided found in dataset."
|
313 |
+
f"Text columns: {text_column_names}"
|
314 |
+
f"Dataset columns: {raw_datasets_features}")
|
315 |
+
flag = False
|
316 |
+
for text_column_name in text_column_names:
|
317 |
+
if text_column_name in raw_datasets_features:
|
318 |
+
logger.info(f"Renaming text column {text_column_name} to {text_col_name_ref}")
|
319 |
+
dataset = dataset.rename_column(text_column_name, text_col_name_ref)
|
320 |
+
flag = True
|
321 |
+
break
|
322 |
+
if flag is False:
|
323 |
+
raise ValueError("None of the text column names provided found in dataset."
|
324 |
+
f"Text columns: {text_column_names}"
|
325 |
+
f"Dataset columns: {raw_datasets_features}")
|
326 |
+
if audio_column_name is not None and sampling_rate is not None:
|
327 |
+
ds_sr = int(dataset.features[audio_column_name].sampling_rate)
|
328 |
+
if ds_sr != sampling_rate:
|
329 |
+
dataset = dataset.cast_column(
|
330 |
+
audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
|
331 |
+
)
|
332 |
+
|
333 |
+
raw_datasets_features = list(dataset.features.keys())
|
334 |
+
raw_datasets_features.remove(audio_column_name)
|
335 |
+
raw_datasets_features.remove(text_col_name_ref)
|
336 |
+
# Keep only audio and sentence
|
337 |
+
dataset = dataset.remove_columns(column_names=raw_datasets_features)
|
338 |
+
return dataset
|
339 |
+
|
340 |
+
|
341 |
+
def load_maybe_streaming_dataset(
|
342 |
+
dataset_names,
|
343 |
+
dataset_config_names,
|
344 |
+
split="train",
|
345 |
+
streaming=True,
|
346 |
+
audio_column_name=None,
|
347 |
+
sampling_rate=None,
|
348 |
+
**kwargs
|
349 |
+
):
|
350 |
"""
|
351 |
Utility function to load a dataset in streaming mode. For datasets with multiple splits,
|
352 |
each split is loaded individually and then splits combined by taking alternating examples from
|
353 |
each (interleaving).
|
354 |
"""
|
355 |
+
text_column_names = None
|
356 |
+
if "text_column_name" in kwargs:
|
357 |
+
text_column_names = kwargs.pop("text_column_name").split(",")
|
358 |
+
text_col_name_ref = text_column_names[0]
|
359 |
+
|
360 |
+
if "," in dataset_names or "+" in split:
|
361 |
# load multiple splits separated by the `+` symbol with streaming mode
|
362 |
+
dataset_splits = []
|
363 |
+
for dataset_name, dataset_config_name, split_names in zip(
|
364 |
+
dataset_names.split(","), dataset_config_names.split(","), split.split(",")
|
365 |
+
):
|
366 |
+
for split_name in split_names.split("+"):
|
367 |
+
if dataset_config_name:
|
368 |
+
dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
|
369 |
+
else:
|
370 |
+
dataset = load_dataset(dataset_name, split=split_name, streaming=streaming, **kwargs)
|
371 |
+
|
372 |
+
dataset = rename_col_and_resample(
|
373 |
+
dataset,
|
374 |
+
dataset_name,
|
375 |
+
text_column_names,
|
376 |
+
text_col_name_ref,
|
377 |
+
audio_column_name,
|
378 |
+
sampling_rate
|
379 |
+
)
|
380 |
+
|
381 |
+
dataset_splits.append(dataset)
|
382 |
+
|
383 |
# interleave multiple splits to form one dataset
|
384 |
+
interleaved_dataset = interleave_datasets(dataset_splits, stopping_strategy="all_exhausted")
|
385 |
return interleaved_dataset
|
386 |
else:
|
387 |
# load a single split *with* streaming mode
|
388 |
+
|
389 |
+
dataset = load_dataset(dataset_names, dataset_config_names, split=split, streaming=streaming, **kwargs)
|
390 |
+
dataset = rename_col_and_resample(
|
391 |
+
dataset,
|
392 |
+
dataset_names,
|
393 |
+
text_column_names,
|
394 |
+
text_col_name_ref,
|
395 |
+
audio_column_name,
|
396 |
+
sampling_rate
|
397 |
+
)
|
398 |
return dataset
|
399 |
|
400 |
|
401 |
+
def notify_me(recipient, message=None):
|
402 |
+
"""
|
403 |
+
Send an email to the specified address with the specified message
|
404 |
+
"""
|
405 |
+
sender = os.environ.get("EMAIL_ADDRESS", None)
|
406 |
+
password = os.environ.get("EMAIL_PASSWORD", None)
|
407 |
+
if sender is None:
|
408 |
+
logging.warning("No email address specified, not sending notification")
|
409 |
+
if password is None:
|
410 |
+
logging.warning("No email password specified, not sending notification")
|
411 |
+
if message is None:
|
412 |
+
message = "Training is finished!"
|
413 |
+
|
414 |
+
if sender is not None:
|
415 |
+
import smtplib
|
416 |
+
from email.mime.text import MIMEText
|
417 |
+
|
418 |
+
msg = MIMEText(message)
|
419 |
+
msg["Subject"] = "Training updates..."
|
420 |
+
msg["From"] = "marinone.auto@gmail.com"
|
421 |
+
msg["To"] = recipient
|
422 |
+
|
423 |
+
# send the email
|
424 |
+
smtp_obj = smtplib.SMTP("smtp.gmail.com", 587)
|
425 |
+
smtp_obj.starttls()
|
426 |
+
smtp_obj.login(sender, password)
|
427 |
+
smtp_obj.sendmail(sender, recipient, msg.as_string())
|
428 |
+
smtp_obj.quit()
|
429 |
+
|
430 |
+
|
431 |
def main():
|
432 |
# 1. Parse input arguments
|
433 |
# See all possible arguments in src/transformers/training_args.py
|
|
|
491 |
# Set seed before initializing model.
|
492 |
set_seed(training_args.seed)
|
493 |
|
494 |
+
# Load feature extractor
|
495 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
496 |
+
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
|
497 |
+
cache_dir=model_args.cache_dir,
|
498 |
+
revision=model_args.model_revision,
|
499 |
+
use_auth_token=hf_token if model_args.use_auth_token else None,
|
500 |
+
)
|
501 |
+
|
502 |
# 4. Load dataset
|
503 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
504 |
|
505 |
if training_args.do_train:
|
506 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
507 |
+
data_args.dataset_train_name,
|
508 |
+
data_args.dataset_train_config_name,
|
509 |
split=data_args.train_split_name,
|
510 |
+
use_auth_token=hf_token if model_args.use_auth_token else None,
|
511 |
streaming=data_args.streaming,
|
512 |
+
text_column_name=data_args.text_column_name,
|
513 |
+
audio_column_name=data_args.audio_column_name,
|
514 |
+
sampling_rate=int(feature_extractor.sampling_rate),
|
515 |
+
# language=data_args.language_train
|
516 |
)
|
517 |
|
518 |
if training_args.do_eval:
|
519 |
raw_datasets["eval"] = load_maybe_streaming_dataset(
|
520 |
+
data_args.dataset_eval_name,
|
521 |
+
data_args.dataset_eval_config_name,
|
522 |
split=data_args.eval_split_name,
|
523 |
+
use_auth_token=hf_token if model_args.use_auth_token else None,
|
524 |
streaming=data_args.streaming,
|
525 |
+
text_column_name=data_args.text_column_name,
|
526 |
+
audio_column_name=data_args.audio_column_name,
|
527 |
+
sampling_rate=int(feature_extractor.sampling_rate),
|
528 |
+
# language=data_args.language_eval
|
529 |
)
|
530 |
|
531 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
|
|
552 |
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
553 |
cache_dir=model_args.cache_dir,
|
554 |
revision=model_args.model_revision,
|
555 |
+
use_auth_token=hf_token if model_args.use_auth_token else None,
|
556 |
)
|
557 |
|
558 |
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
|
|
|
560 |
if training_args.gradient_checkpointing:
|
561 |
config.update({"use_cache": False})
|
562 |
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
tokenizer = AutoTokenizer.from_pretrained(
|
564 |
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
565 |
cache_dir=model_args.cache_dir,
|
566 |
use_fast=model_args.use_fast_tokenizer,
|
567 |
revision=model_args.model_revision,
|
568 |
+
use_auth_token=hf_token if model_args.use_auth_token else None,
|
569 |
)
|
570 |
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
571 |
model_args.model_name_or_path,
|
572 |
config=config,
|
573 |
cache_dir=model_args.cache_dir,
|
574 |
revision=model_args.model_revision,
|
575 |
+
use_auth_token=hf_token if model_args.use_auth_token else None,
|
576 |
)
|
577 |
|
578 |
if model.config.decoder_start_token_id is None:
|
|
|
720 |
callbacks=[ShuffleCallback()] if data_args.streaming else None,
|
721 |
)
|
722 |
|
723 |
+
orig_push_to_hub = trainer.args.push_to_hub
|
724 |
+
trainer.args.push_to_hub = False
|
725 |
+
|
726 |
# 12. Training
|
727 |
if training_args.do_train:
|
728 |
checkpoint = None
|
|
|
772 |
if model_args.model_index_name is not None:
|
773 |
kwargs["model_name"] = model_args.model_index_name
|
774 |
|
775 |
+
logger.info("*** Training stats written ***")
|
776 |
+
logger.info(json.dumps(kwargs, indent=4))
|
777 |
+
|
778 |
+
# Training complete notification
|
779 |
+
logger.info("*** Training and eval complete ***")
|
780 |
+
logger.info(SENDING_NOTIFICATION)
|
781 |
+
with open(os.path.join(training_args.output_dir, "train_results.json"), "r") as f:
|
782 |
+
train_results = json.load(f)
|
783 |
+
with open(os.path.join(training_args.output_dir, "eval_results.json"), "r") as f:
|
784 |
+
eval_results = json.load(f)
|
785 |
+
notify_me(recipient=RECIPIENT_ADDRESS,
|
786 |
+
message=f"Training complete! {train_results = } {eval_results = }")
|
787 |
+
|
788 |
+
trainer.args.push_to_hub = orig_push_to_hub
|
789 |
if training_args.push_to_hub:
|
790 |
trainer.push_to_hub(**kwargs)
|
791 |
else:
|
792 |
trainer.create_model_card(**kwargs)
|
793 |
+
|
794 |
+
with open(os.path.join(training_args.output_dir, "README.md"), "r") as f:
|
795 |
+
readme = f.read()
|
796 |
+
notify_me(recipient=RECIPIENT_ADDRESS,
|
797 |
+
message=f"Model pushed to hub! {readme = }")
|
798 |
|
799 |
return results
|
800 |
|