EMAGE / models /utils /build_vocab.py
H-Liu1997's picture
Upload folder using huggingface_hub
2d47d90 verified
raw
history blame contribute delete
No virus
5.19 kB
import numpy as np
import glob
import os
import pickle
import lmdb
import pyarrow
import fasttext
from loguru import logger
from scipy import linalg
class Vocab:
PAD_token = 0
SOS_token = 1
EOS_token = 2
UNK_token = 3
def __init__(self, name, insert_default_tokens=True):
self.name = name
self.trimmed = False
self.word_embedding_weights = None
self.reset_dictionary(insert_default_tokens)
def reset_dictionary(self, insert_default_tokens=True):
self.word2index = {}
self.word2count = {}
if insert_default_tokens:
self.index2word = {self.PAD_token: "<PAD>", self.SOS_token: "<SOS>",
self.EOS_token: "<EOS>", self.UNK_token: "<UNK>"}
else:
self.index2word = {self.UNK_token: "<UNK>"}
self.n_words = len(self.index2word) # count default tokens
def index_word(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
def add_vocab(self, other_vocab):
for word, _ in other_vocab.word2count.items():
self.index_word(word)
# remove words below a certain count threshold
def trim(self, min_count):
if self.trimmed:
return
self.trimmed = True
keep_words = []
for k, v in self.word2count.items():
if v >= min_count:
keep_words.append(k)
print(' word trimming, kept %s / %s = %.4f' % (
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
))
# reinitialize dictionary
self.reset_dictionary()
for word in keep_words:
self.index_word(word)
def get_word_index(self, word):
if word in self.word2index:
return self.word2index[word]
else:
return self.UNK_token
def load_word_vectors(self, pretrained_path, embedding_dim=300):
print(" loading word vectors from '{}'...".format(pretrained_path))
# initialize embeddings to random values for special words
init_sd = 1 / np.sqrt(embedding_dim)
weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim])
weights = weights.astype(np.float32)
# read word vectors
word_model = fasttext.load_model(pretrained_path)
for word, id in self.word2index.items():
vec = word_model.get_word_vector(word)
weights[id] = vec
self.word_embedding_weights = weights
def build_vocab(name, data_path, cache_path, word_vec_path=None, feat_dim=None):
print(' building a language model...')
lang_model = Vocab(name)
print(' indexing words from {}'.format(data_path))
index_words_from_textgrid(lang_model, data_path)
if word_vec_path is not None:
lang_model.load_word_vectors(word_vec_path, feat_dim)
else:
print(' loaded from {}'.format(cache_path))
with open(cache_path, 'rb') as f:
lang_model = pickle.load(f)
if word_vec_path is None:
lang_model.word_embedding_weights = None
elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words:
logging.warning(' failed to load word embedding weights. check this')
assert False
with open(cache_path, 'wb') as f:
pickle.dump(lang_model, f)
return lang_model
def index_words(lang_model, data_path):
#index words form text
with open(data_path, "r") as f:
for line in f.readlines():
line = line.replace(",", " ")
line = line.replace(".", " ")
line = line.replace("?", " ")
line = line.replace("!", " ")
for word in line.split():
lang_model.index_word(word)
print(' indexed %d words' % lang_model.n_words)
def index_words_from_textgrid(lang_model, data_path):
import textgrid as tg
trainvaltest=os.listdir(data_path)
for loadtype in trainvaltest:
if "." in loadtype: continue #ignore .ipynb_checkpoints
texts = os.listdir(data_path+loadtype+"/text/")
for textfile in texts:
tgrid = tg.TextGrid.fromFile(data_path+loadtype+"/text/"+textfile)
for word in tgrid[0]:
word_n, word_s, word_e = word.mark, word.minTime, word.maxTime
word_n = word_n.replace(",", " ")
word_n = word_n.replace(".", " ")
word_n = word_n.replace("?", " ")
word_n = word_n.replace("!", " ")
#print(word_n)
lang_model.index_word(word_n)
print(' indexed %d words' % lang_model.n_words)
if __name__ == "__main__":
#11195 for all, 5793 for 4 speakers
build_vocab("beat_english_15_141", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/vocab.pkl", "/home/ma-user/work/datasets/cc.en.300.bin", 300)