Spaces:
Build error
Build error
""" | |
tf_hub.py | |
Find text embeddings using pre-trained TensorFlow Hub models | |
""" | |
import os | |
import pickle | |
import numpy as np | |
from arxiv_public_data.config import DIR_OUTPUT, LOGGER | |
from arxiv_public_data.embeddings.util import batch_fulltext | |
logger = LOGGER.getChild('embds') | |
try: | |
import tensorflow as tf | |
import tensorflow_hub as hub | |
import sentencepiece as spm | |
except ImportError as e: | |
logger.warn("This module requires 'tensorflow', 'tensorflow-hub', and" | |
"'sentencepiece'\n" | |
'Please install these modules to use tf_hub.py') | |
UNIV_SENTENCE_ENCODER_URL = ('https://tfhub.dev/google/' | |
'universal-sentence-encoder/2') | |
ELMO_URL = "https://tfhub.dev/google/elmo/2" | |
ELMO_KWARGS = dict(signature='default', as_dict=True) | |
ELMO_MODULE_KWARGS = dict(trainable=True) | |
ELMO_DICTKEY = 'default' | |
DIR_EMBEDDING = os.path.join(DIR_OUTPUT, 'embeddings') | |
if not os.path.exists(DIR_EMBEDDING): | |
os.mkdir(DIR_EMBEDDING) | |
def elmo_strings(batches, filename, batchsize=32): | |
""" | |
Compute and save vector embeddings of lists of strings in batches | |
Parameters | |
---------- | |
batches : iterable of strings to be embedded | |
filename : str | |
filename to store embeddings | |
(optional) | |
batchsize : int | |
size of batches | |
""" | |
g = tf.Graph() | |
with g.as_default(): | |
module = hub.Module(ELMO_URL, **ELMO_MODULE_KWARGS) | |
text_input = tf.placeholder(dtype=tf.string, shape=[None]) | |
embeddings = module(text_input, **ELMO_KWARGS) | |
init_op = tf.group([tf.global_variables_initializer(), | |
tf.tables_initializer()]) | |
g.finalize() | |
with tf.Session(graph=g) as sess: | |
sess.run(init_op) | |
for i, batch in enumerate(batches): | |
# grab mean-pooling of contextualized word reps | |
logger.info("Computing/saving batch {}".format(i)) | |
with open(filename, 'ab') as fout: | |
pickle.dump(sess.run( | |
embeddings, feed_dict={text_input: batch} | |
)[ELMO_DICTKEY], fout) | |
UNIV_SENTENCE_LITE = "https://tfhub.dev/google/universal-sentence-encoder-lite/2" | |
def get_sentence_piece_model(): | |
with tf.Session() as sess: | |
module = hub.Module(UNIV_SENTENCE_LITE) | |
return sess.run(module(signature="spm_path")) | |
def process_to_IDs_in_sparse_format(sp, sentences): | |
""" | |
An utility method that processes sentences with the sentence piece | |
processor | |
'sp' and returns the results in tf.SparseTensor-similar format: | |
(values, indices, dense_shape) | |
""" | |
ids = [sp.EncodeAsIds(x) for x in sentences] | |
max_len = max(len(x) for x in ids) | |
dense_shape=(len(ids), max_len) | |
values=[item for sublist in ids for item in sublist] | |
indices=[[row,col] for row in range(len(ids)) for col in range(len(ids[row]))] | |
return (values, indices, dense_shape) | |
def universal_sentence_encoder_lite(batches, filename, spm_path, batchsize=32): | |
""" | |
Compute and save vector embeddings of lists of strings in batches | |
Parameters | |
---------- | |
batches : iterable of strings to be embedded | |
filename : str | |
filename to store embeddings | |
spm_path : str | |
path to sentencepiece model from `get_sentence_piece_model` | |
(optional) | |
batchsize : int | |
size of batches | |
""" | |
sp = spm.SentencePieceProcessor() | |
sp.Load(spm_path) | |
g = tf.Graph() | |
with g.as_default(): | |
module = hub.Module(UNIV_SENTENCE_LITE) | |
input_placeholder = tf.sparse_placeholder( | |
tf.int64, shape=(None, None) | |
) | |
embeddings = module( | |
inputs=dict( | |
values=input_placeholder.values, indices=input_placeholder.indices, | |
dense_shape=input_placeholder.dense_shape | |
) | |
) | |
init_op = tf.group([tf.global_variables_initializer(), | |
tf.tables_initializer()]) | |
g.finalize() | |
with tf.Session(graph=g) as sess: | |
sess.run(init_op) | |
for i, batch in enumerate(batches): | |
values, indices, dense_shape = process_to_IDs_in_sparse_format(sp, batch) | |
logger.info("Computing/saving batch {}".format(i)) | |
emb = sess.run( | |
embeddings, | |
feed_dict={ | |
input_placeholder.values: values, | |
input_placeholder.indices: indices, | |
input_placeholder.dense_shape: dense_shape | |
} | |
) | |
with open(filename, 'ab') as fout: | |
pickle.dump(emb, fout) | |
def create_save_embeddings(batches, filename, encoder, headers=[], encoder_args=(), | |
encoder_kwargs={}, savedir=DIR_EMBEDDING): | |
""" | |
Create vector embeddings of strings and save them to filename | |
Parameters | |
---------- | |
batches : iterator of strings | |
filename: str | |
embeddings will be saved in DIR_EMBEDDING/embeddings/filename | |
encoder : function(batches, savename, *args, **kwargs) | |
encodes strings in batches into vectors and saves them | |
(optional) | |
headers : list of things to save in embeddings file first | |
Examples | |
-------- | |
# For list of strings, create batched numpy array of objects | |
batches = np.array_split( | |
np.array(strings, dtype='object'), len(strings)//batchsize | |
) | |
headers = [] | |
# For the fulltext which cannot fit in memory, use `util.batch_fulltext` | |
md_index, all_ids, batch_gen = batch_fulltext() | |
headers = [md_index, all_ids] | |
# Universal Sentence Encoder Lite: | |
spm_path = get_sentence_piece_model() | |
create_save_embeddings(batches, filename, universal_sentence_encoder_lite, | |
headers=headers, encoder_args=(spm_path,)) | |
# ELMO: | |
create_save_embeddings(strings, filename, elmo_strings, headers=headers) | |
""" | |
if not os.path.exists(savedir): | |
os.makedirs(savedir) | |
savename = os.path.join(savedir, filename) | |
with open(savename, 'ab') as fout: | |
for h in headers: | |
pickle.dump(h, fout) | |
logger.info("Saving embeddings to {}".format(savename)) | |
encoder(batches, savename, *encoder_args, | |
**encoder_kwargs) | |