sidphbot's picture
spaces init
a8d4e3d
raw history blame
No virus
6.35 kB
"""
tf_hub.py
Find text embeddings using pre-trained TensorFlow Hub models
"""
import os
import pickle
import numpy as np
from arxiv_public_data.config import DIR_OUTPUT, LOGGER
from arxiv_public_data.embeddings.util import batch_fulltext
logger = LOGGER.getChild('embds')
try:
import tensorflow as tf
import tensorflow_hub as hub
import sentencepiece as spm
except ImportError as e:
logger.warn("This module requires 'tensorflow', 'tensorflow-hub', and"
"'sentencepiece'\n"
'Please install these modules to use tf_hub.py')
UNIV_SENTENCE_ENCODER_URL = ('https://tfhub.dev/google/'
'universal-sentence-encoder/2')
ELMO_URL = "https://tfhub.dev/google/elmo/2"
ELMO_KWARGS = dict(signature='default', as_dict=True)
ELMO_MODULE_KWARGS = dict(trainable=True)
ELMO_DICTKEY = 'default'
DIR_EMBEDDING = os.path.join(DIR_OUTPUT, 'embeddings')
if not os.path.exists(DIR_EMBEDDING):
os.mkdir(DIR_EMBEDDING)
def elmo_strings(batches, filename, batchsize=32):
"""
Compute and save vector embeddings of lists of strings in batches
Parameters
----------
batches : iterable of strings to be embedded
filename : str
filename to store embeddings
(optional)
batchsize : int
size of batches
"""
g = tf.Graph()
with g.as_default():
module = hub.Module(ELMO_URL, **ELMO_MODULE_KWARGS)
text_input = tf.placeholder(dtype=tf.string, shape=[None])
embeddings = module(text_input, **ELMO_KWARGS)
init_op = tf.group([tf.global_variables_initializer(),
tf.tables_initializer()])
g.finalize()
with tf.Session(graph=g) as sess:
sess.run(init_op)
for i, batch in enumerate(batches):
# grab mean-pooling of contextualized word reps
logger.info("Computing/saving batch {}".format(i))
with open(filename, 'ab') as fout:
pickle.dump(sess.run(
embeddings, feed_dict={text_input: batch}
)[ELMO_DICTKEY], fout)
UNIV_SENTENCE_LITE = "https://tfhub.dev/google/universal-sentence-encoder-lite/2"
def get_sentence_piece_model():
with tf.Session() as sess:
module = hub.Module(UNIV_SENTENCE_LITE)
return sess.run(module(signature="spm_path"))
def process_to_IDs_in_sparse_format(sp, sentences):
"""
An utility method that processes sentences with the sentence piece
processor
'sp' and returns the results in tf.SparseTensor-similar format:
(values, indices, dense_shape)
"""
ids = [sp.EncodeAsIds(x) for x in sentences]
max_len = max(len(x) for x in ids)
dense_shape=(len(ids), max_len)
values=[item for sublist in ids for item in sublist]
indices=[[row,col] for row in range(len(ids)) for col in range(len(ids[row]))]
return (values, indices, dense_shape)
def universal_sentence_encoder_lite(batches, filename, spm_path, batchsize=32):
"""
Compute and save vector embeddings of lists of strings in batches
Parameters
----------
batches : iterable of strings to be embedded
filename : str
filename to store embeddings
spm_path : str
path to sentencepiece model from `get_sentence_piece_model`
(optional)
batchsize : int
size of batches
"""
sp = spm.SentencePieceProcessor()
sp.Load(spm_path)
g = tf.Graph()
with g.as_default():
module = hub.Module(UNIV_SENTENCE_LITE)
input_placeholder = tf.sparse_placeholder(
tf.int64, shape=(None, None)
)
embeddings = module(
inputs=dict(
values=input_placeholder.values, indices=input_placeholder.indices,
dense_shape=input_placeholder.dense_shape
)
)
init_op = tf.group([tf.global_variables_initializer(),
tf.tables_initializer()])
g.finalize()
with tf.Session(graph=g) as sess:
sess.run(init_op)
for i, batch in enumerate(batches):
values, indices, dense_shape = process_to_IDs_in_sparse_format(sp, batch)
logger.info("Computing/saving batch {}".format(i))
emb = sess.run(
embeddings,
feed_dict={
input_placeholder.values: values,
input_placeholder.indices: indices,
input_placeholder.dense_shape: dense_shape
}
)
with open(filename, 'ab') as fout:
pickle.dump(emb, fout)
def create_save_embeddings(batches, filename, encoder, headers=[], encoder_args=(),
encoder_kwargs={}, savedir=DIR_EMBEDDING):
"""
Create vector embeddings of strings and save them to filename
Parameters
----------
batches : iterator of strings
filename: str
embeddings will be saved in DIR_EMBEDDING/embeddings/filename
encoder : function(batches, savename, *args, **kwargs)
encodes strings in batches into vectors and saves them
(optional)
headers : list of things to save in embeddings file first
Examples
--------
# For list of strings, create batched numpy array of objects
batches = np.array_split(
np.array(strings, dtype='object'), len(strings)//batchsize
)
headers = []
# For the fulltext which cannot fit in memory, use `util.batch_fulltext`
md_index, all_ids, batch_gen = batch_fulltext()
headers = [md_index, all_ids]
# Universal Sentence Encoder Lite:
spm_path = get_sentence_piece_model()
create_save_embeddings(batches, filename, universal_sentence_encoder_lite,
headers=headers, encoder_args=(spm_path,))
# ELMO:
create_save_embeddings(strings, filename, elmo_strings, headers=headers)
"""
if not os.path.exists(savedir):
os.makedirs(savedir)
savename = os.path.join(savedir, filename)
with open(savename, 'ab') as fout:
for h in headers:
pickle.dump(h, fout)
logger.info("Saving embeddings to {}".format(savename))
encoder(batches, savename, *encoder_args,
**encoder_kwargs)