Spaces:
Running
on
T4
Running
on
T4
import numpy as np | |
import glob | |
import os | |
import pickle | |
import lmdb | |
import pyarrow | |
import fasttext | |
from loguru import logger | |
from scipy import linalg | |
class Vocab: | |
PAD_token = 0 | |
SOS_token = 1 | |
EOS_token = 2 | |
UNK_token = 3 | |
def __init__(self, name, insert_default_tokens=True): | |
self.name = name | |
self.trimmed = False | |
self.word_embedding_weights = None | |
self.reset_dictionary(insert_default_tokens) | |
def reset_dictionary(self, insert_default_tokens=True): | |
self.word2index = {} | |
self.word2count = {} | |
if insert_default_tokens: | |
self.index2word = {self.PAD_token: "<PAD>", self.SOS_token: "<SOS>", | |
self.EOS_token: "<EOS>", self.UNK_token: "<UNK>"} | |
else: | |
self.index2word = {self.UNK_token: "<UNK>"} | |
self.n_words = len(self.index2word) # count default tokens | |
def index_word(self, word): | |
if word not in self.word2index: | |
self.word2index[word] = self.n_words | |
self.word2count[word] = 1 | |
self.index2word[self.n_words] = word | |
self.n_words += 1 | |
else: | |
self.word2count[word] += 1 | |
def add_vocab(self, other_vocab): | |
for word, _ in other_vocab.word2count.items(): | |
self.index_word(word) | |
# remove words below a certain count threshold | |
def trim(self, min_count): | |
if self.trimmed: | |
return | |
self.trimmed = True | |
keep_words = [] | |
for k, v in self.word2count.items(): | |
if v >= min_count: | |
keep_words.append(k) | |
print(' word trimming, kept %s / %s = %.4f' % ( | |
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) | |
)) | |
# reinitialize dictionary | |
self.reset_dictionary() | |
for word in keep_words: | |
self.index_word(word) | |
def get_word_index(self, word): | |
if word in self.word2index: | |
return self.word2index[word] | |
else: | |
return self.UNK_token | |
def load_word_vectors(self, pretrained_path, embedding_dim=300): | |
print(" loading word vectors from '{}'...".format(pretrained_path)) | |
# initialize embeddings to random values for special words | |
init_sd = 1 / np.sqrt(embedding_dim) | |
weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) | |
weights = weights.astype(np.float32) | |
# read word vectors | |
word_model = fasttext.load_model(pretrained_path) | |
for word, id in self.word2index.items(): | |
vec = word_model.get_word_vector(word) | |
weights[id] = vec | |
self.word_embedding_weights = weights | |
def __get_embedding_weight(self, pretrained_path, embedding_dim=300): | |
""" function modified from http://ronny.rest/blog/post_2017_08_04_glove/ """ | |
print("Loading word embedding '{}'...".format(pretrained_path)) | |
cache_path = pretrained_path | |
weights = None | |
# use cached file if it exists | |
if os.path.exists(cache_path): # | |
with open(cache_path, 'rb') as f: | |
print(' using cached result from {}'.format(cache_path)) | |
weights = pickle.load(f) | |
if weights.shape != (self.n_words, embedding_dim): | |
logging.warning(' failed to load word embedding weights. reinitializing...') | |
weights = None | |
if weights is None: | |
# initialize embeddings to random values for special and OOV words | |
init_sd = 1 / np.sqrt(embedding_dim) | |
weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) | |
weights = weights.astype(np.float32) | |
with open(pretrained_path, encoding="utf-8", mode="r") as textFile: | |
num_embedded_words = 0 | |
for line_raw in textFile: | |
# extract the word, and embeddings vector | |
line = line_raw.split() | |
try: | |
word, vector = (line[0], np.array(line[1:], dtype=np.float32)) | |
# if word == 'love': # debugging | |
# print(word, vector) | |
# if it is in our vocab, then update the corresponding weights | |
id = self.word2index.get(word, None) | |
if id is not None: | |
weights[id] = vector | |
num_embedded_words += 1 | |
except ValueError: | |
print(' parsing error at {}...'.format(line_raw[:50])) | |
continue | |
print(' {} / {} word vectors are found in the embedding'.format(num_embedded_words, len(self.word2index))) | |
with open(cache_path, 'wb') as f: | |
pickle.dump(weights, f) | |
return weights | |
def build_vocab(name, data_path, cache_path, word_vec_path=None, feat_dim=None): | |
print(' building a language model...') | |
#if not os.path.exists(cache_path): | |
lang_model = Vocab(name) | |
print(' indexing words from {}'.format(data_path)) | |
index_words_from_textgrid(lang_model, data_path) | |
if word_vec_path is not None: | |
lang_model.load_word_vectors(word_vec_path, feat_dim) | |
else: | |
print(' loaded from {}'.format(cache_path)) | |
with open(cache_path, 'rb') as f: | |
lang_model = pickle.load(f) | |
if word_vec_path is None: | |
lang_model.word_embedding_weights = None | |
elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words: | |
logging.warning(' failed to load word embedding weights. check this') | |
assert False | |
with open(cache_path, 'wb') as f: | |
pickle.dump(lang_model, f) | |
return lang_model | |
def index_words(lang_model, data_path): | |
#index words form text | |
with open(data_path, "r") as f: | |
for line in f.readlines(): | |
line = line.replace(",", " ") | |
line = line.replace(".", " ") | |
line = line.replace("?", " ") | |
line = line.replace("!", " ") | |
for word in line.split(): | |
lang_model.index_word(word) | |
print(' indexed %d words' % lang_model.n_words) | |
def index_words_from_textgrid(lang_model, data_path): | |
import textgrid as tg | |
from tqdm import tqdm | |
#trainvaltest=os.listdir(data_path) | |
# for loadtype in trainvaltest: | |
# if "." in loadtype: continue #ignore .ipynb_checkpoints | |
texts = os.listdir(data_path+"/textgrid/") | |
#print(texts) | |
for textfile in tqdm(texts): | |
tgrid = tg.TextGrid.fromFile(data_path+"/textgrid/"+textfile) | |
for word in tgrid[0]: | |
word_n, word_s, word_e = word.mark, word.minTime, word.maxTime | |
word_n = word_n.replace(",", " ") | |
word_n = word_n.replace(".", " ") | |
word_n = word_n.replace("?", " ") | |
word_n = word_n.replace("!", " ") | |
#print(word_n) | |
lang_model.index_word(word_n) | |
print(' indexed %d words' % lang_model.n_words) | |
print(lang_model.word2index, lang_model.word2count) | |
if __name__ == "__main__": | |
# 11195 for all, 5793 for 4 speakers | |
# build_vocab("beat_english_15_141", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/vocab.pkl", "/home/ma-user/work/datasets/cc.en.300.bin", 300) | |
build_vocab("beat_chinese_v1.0.0", "/data/datasets/beat_chinese_v1.0.0/", "/data/datasets/beat_chinese_v1.0.0/weights/vocab.pkl", "/home/ma-user/work/cc.zh.300.bin", 300) | |