wav2vec2_tiny_random / create_model_files.py
patrickvonplaten's picture
upload
919b2aa
raw history blame
No virus
1.04 kB
#!/usr/bin/env python3
from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Config, Wav2Vec2Model
import json
conf = Wav2Vec2Config()
conf.conv_dim = [64, 64]
conf.conv_kernel = [40, 40]
conf.conv_stride = [30, 30]
conf.num_feat_extract_layers = 2
conf.hidden_size = 64
conf.num_conv_pos_embeddings = 64
conf.num_hidden_layers = 2
conf.vocab_size = 12
conf.intermediate_size = 128
conf.num_conv_pos_embedding_groups = 4
conf.num_attention_heads = 2
model = Wav2Vec2Model(conf)
vocab = {
"a": 0,
"b": 1,
"c": 2,
"d": 3,
"e": 4,
"f": 5,
"g": 6,
"<s>": 7,
"</s>": 8,
"<unk>": 9,
"<pad>": 10,
"|": 11
}
with open("vocab.json", "w") as f:
f.write(json.dumps(vocab, ensure_ascii=False))
tok = Wav2Vec2CTCTokenizer("./vocab.json")
extract = Wav2Vec2FeatureExtractor()
processor = Wav2Vec2Processor(tokenizer=tok, feature_extractor=extract)
processor.save_pretrained("wav2vec2_tiny_random")
model.save_pretrained("wav2vec2_tiny_random")