File size: 3,946 Bytes
c239b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import ast
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from transformers import (
    HfArgumentParser,
    Wav2Vec2Config,
    Wav2Vec2FeatureExtractor
)

logger = logging.getLogger(__name__)


@dataclass
class ConfigArguments:
    """
    Arguments to which config we are going to set up.
    """
    output_dir: str = field(
        default=".",
        metadata={"help": "The output directory where the config will be written."},
    )
    name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "The model checkpoint for weights initialization."
                    "Don't set if you want to train a model from scratch."
        },
    )
    config_params: Optional[str] = field(
        default=None,
        metadata={"help": "Custom configuration for the specific `name_or_path`"}
    )
    feature_extractor_params: Optional[str] = field(
        default=None,
        metadata={"help": "Custom feature extractor configuration for the specific `name_or_path`"}
    )

    def __post_init__(self):
        if self.config_params:
            try:
                self.config_params = ast.literal_eval(self.config_params)
            except Exception as e:
                print(f"Your custom `config` parameters do not acceptable due to {e}")

        if self.feature_extractor_params:
            try:
                self.feature_extractor_params = ast.literal_eval(self.feature_extractor_params)
            except Exception as e:
                print(f"Your custom `feature_extractor` parameters do not acceptable due to {e}")


def main():
    parser = HfArgumentParser([ConfigArguments])
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        config_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
    else:
        config_args = parser.parse_args_into_dataclasses()[0]
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(logging.INFO)
    logger.info(f"Setting up configuration {config_args.name_or_path} with extra params {config_args.config_params}")
    if config_args.config_params and isinstance(config_args.config_params, dict):
        config = Wav2Vec2Config.from_pretrained(
            config_args.name_or_path,
            **config_args.config_params
        )
    else:
        config = Wav2Vec2Config.from_pretrained(
            config_args.name_or_path,
            mask_time_length=10,
            mask_time_prob=0.05,
            diversity_loss_weight=0.1,
            num_negatives=100,
            do_stable_layer_norm=True,
            feat_extract_norm="layer",
            vocab_size=40
        )

    logger.info(f"Setting up feature_extractor {config_args.name_or_path} with extra params "
                f"{config_args.feature_extractor_params}")
    if config_args.feature_extractor_params and isinstance(config_args.feature_extractor_params, dict):
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            config_args.name_or_path,
            **config_args.feature_extractor_params
        )
    else:
        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
            config_args.name_or_path,
            return_attention_mask=True
        )
    logger.info(f"Your `config` saved here {config_args.output_dir}/config.json")
    config.save_pretrained(config_args.output_dir)

    logger.info(f"Your `feature_extractor` saved here {config_args.output_dir}/preprocessor_config.json")
    feature_extractor.save_pretrained(config_args.output_dir)


if __name__ == '__main__':
    main()