Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# Copyright 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""DrQA Document Reader model""" | |
import torch | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import numpy as np | |
import logging | |
import copy | |
from .config import override_model_args | |
from .rnn_reader import RnnDocReader | |
logger = logging.getLogger(__name__) | |
class DocReader(object): | |
"""High level model that handles intializing the underlying network | |
architecture, saving, updating examples, and predicting examples. | |
""" | |
# -------------------------------------------------------------------------- | |
# Initialization | |
# -------------------------------------------------------------------------- | |
def __init__(self, args, word_dict, feature_dict, | |
state_dict=None, normalize=True): | |
# Book-keeping. | |
self.args = args | |
self.word_dict = word_dict | |
self.args.vocab_size = len(word_dict) | |
self.feature_dict = feature_dict | |
self.args.num_features = len(feature_dict) | |
self.updates = 0 | |
self.use_cuda = False | |
self.parallel = False | |
# Building network. If normalize if false, scores are not normalized | |
# 0-1 per paragraph (no softmax). | |
if args.model_type == 'rnn': | |
self.network = RnnDocReader(args, normalize) | |
else: | |
raise RuntimeError('Unsupported model: %s' % args.model_type) | |
# Load saved state | |
if state_dict: | |
# Load buffer separately | |
if 'fixed_embedding' in state_dict: | |
fixed_embedding = state_dict.pop('fixed_embedding') | |
self.network.load_state_dict(state_dict) | |
self.network.register_buffer('fixed_embedding', fixed_embedding) | |
else: | |
self.network.load_state_dict(state_dict) | |
def expand_dictionary(self, words): | |
"""Add words to the DocReader dictionary if they do not exist. The | |
underlying embedding matrix is also expanded (with random embeddings). | |
Args: | |
words: iterable of tokens to add to the dictionary. | |
Output: | |
added: set of tokens that were added. | |
""" | |
to_add = {self.word_dict.normalize(w) for w in words | |
if w not in self.word_dict} | |
# Add words to dictionary and expand embedding layer | |
if len(to_add) > 0: | |
logger.info('Adding %d new words to dictionary...' % len(to_add)) | |
for w in to_add: | |
self.word_dict.add(w) | |
self.args.vocab_size = len(self.word_dict) | |
logger.info('New vocab size: %d' % len(self.word_dict)) | |
old_embedding = self.network.embedding.weight.data | |
self.network.embedding = torch.nn.Embedding(self.args.vocab_size, | |
self.args.embedding_dim, | |
padding_idx=0) | |
new_embedding = self.network.embedding.weight.data | |
new_embedding[:old_embedding.size(0)] = old_embedding | |
# Return added words | |
return to_add | |
def load_embeddings(self, words, embedding_file): | |
"""Load pretrained embeddings for a given list of words, if they exist. | |
Args: | |
words: iterable of tokens. Only those that are indexed in the | |
dictionary are kept. | |
embedding_file: path to text file of embeddings, space separated. | |
""" | |
words = {w for w in words if w in self.word_dict} | |
logger.info('Loading pre-trained embeddings for %d words from %s' % | |
(len(words), embedding_file)) | |
embedding = self.network.embedding.weight.data | |
# When normalized, some words are duplicated. (Average the embeddings). | |
vec_counts = {} | |
with open(embedding_file) as f: | |
# Skip first line if of form count/dim. | |
line = f.readline().rstrip().split(' ') | |
if len(line) != 2: | |
f.seek(0) | |
for line in f: | |
parsed = line.rstrip().split(' ') | |
assert(len(parsed) == embedding.size(1) + 1) | |
w = self.word_dict.normalize(parsed[0]) | |
if w in words: | |
vec = torch.Tensor([float(i) for i in parsed[1:]]) | |
if w not in vec_counts: | |
vec_counts[w] = 1 | |
embedding[self.word_dict[w]].copy_(vec) | |
else: | |
logging.warning( | |
'WARN: Duplicate embedding found for %s' % w | |
) | |
vec_counts[w] = vec_counts[w] + 1 | |
embedding[self.word_dict[w]].add_(vec) | |
for w, c in vec_counts.items(): | |
embedding[self.word_dict[w]].div_(c) | |
logger.info('Loaded %d embeddings (%.2f%%)' % | |
(len(vec_counts), 100 * len(vec_counts) / len(words))) | |
def tune_embeddings(self, words): | |
"""Unfix the embeddings of a list of words. This is only relevant if | |
only some of the embeddings are being tuned (tune_partial = N). | |
Shuffles the N specified words to the front of the dictionary, and saves | |
the original vectors of the other N + 1:vocab words in a fixed buffer. | |
Args: | |
words: iterable of tokens contained in dictionary. | |
""" | |
words = {w for w in words if w in self.word_dict} | |
if len(words) == 0: | |
logger.warning('Tried to tune embeddings, but no words given!') | |
return | |
if len(words) == len(self.word_dict): | |
logger.warning('Tuning ALL embeddings in dictionary') | |
return | |
# Shuffle words and vectors | |
embedding = self.network.embedding.weight.data | |
for idx, swap_word in enumerate(words, self.word_dict.START): | |
# Get current word + embedding for this index | |
curr_word = self.word_dict[idx] | |
curr_emb = embedding[idx].clone() | |
old_idx = self.word_dict[swap_word] | |
# Swap embeddings + dictionary indices | |
embedding[idx].copy_(embedding[old_idx]) | |
embedding[old_idx].copy_(curr_emb) | |
self.word_dict[swap_word] = idx | |
self.word_dict[idx] = swap_word | |
self.word_dict[curr_word] = old_idx | |
self.word_dict[old_idx] = curr_word | |
# Save the original, fixed embeddings | |
self.network.register_buffer( | |
'fixed_embedding', embedding[idx + 1:].clone() | |
) | |
def init_optimizer(self, state_dict=None): | |
"""Initialize an optimizer for the free parameters of the network. | |
Args: | |
state_dict: network parameters | |
""" | |
if self.args.fix_embeddings: | |
for p in self.network.embedding.parameters(): | |
p.requires_grad = False | |
parameters = [p for p in self.network.parameters() if p.requires_grad] | |
if self.args.optimizer == 'sgd': | |
self.optimizer = optim.SGD(parameters, self.args.learning_rate, | |
momentum=self.args.momentum, | |
weight_decay=self.args.weight_decay) | |
elif self.args.optimizer == 'adamax': | |
self.optimizer = optim.Adamax(parameters, | |
weight_decay=self.args.weight_decay) | |
else: | |
raise RuntimeError('Unsupported optimizer: %s' % | |
self.args.optimizer) | |
# -------------------------------------------------------------------------- | |
# Learning | |
# -------------------------------------------------------------------------- | |
def update(self, ex): | |
"""Forward a batch of examples; step the optimizer to update weights.""" | |
if not self.optimizer: | |
raise RuntimeError('No optimizer set.') | |
# Train mode | |
self.network.train() | |
# Transfer to GPU | |
if self.use_cuda: | |
inputs = [e if e is None else e.cuda(non_blocking=True) | |
for e in ex[:5]] | |
target_s = ex[5].cuda(non_blocking=True) | |
target_e = ex[6].cuda(non_blocking=True) | |
else: | |
inputs = [e if e is None else e for e in ex[:5]] | |
target_s = ex[5] | |
target_e = ex[6] | |
# Run forward | |
score_s, score_e = self.network(*inputs) | |
# Compute loss and accuracies | |
loss = F.nll_loss(score_s, target_s) + F.nll_loss(score_e, target_e) | |
# Clear gradients and run backward | |
self.optimizer.zero_grad() | |
loss.backward() | |
# Clip gradients | |
torch.nn.utils.clip_grad_norm_(self.network.parameters(), | |
self.args.grad_clipping) | |
# Update parameters | |
self.optimizer.step() | |
self.updates += 1 | |
# Reset any partially fixed parameters (e.g. rare words) | |
self.reset_parameters() | |
return loss.item(), ex[0].size(0) | |
def reset_parameters(self): | |
"""Reset any partially fixed parameters to original states.""" | |
# Reset fixed embeddings to original value | |
if self.args.tune_partial > 0: | |
if self.parallel: | |
embedding = self.network.module.embedding.weight.data | |
fixed_embedding = self.network.module.fixed_embedding | |
else: | |
embedding = self.network.embedding.weight.data | |
fixed_embedding = self.network.fixed_embedding | |
# Embeddings to fix are the last indices | |
offset = embedding.size(0) - fixed_embedding.size(0) | |
if offset >= 0: | |
embedding[offset:] = fixed_embedding | |
# -------------------------------------------------------------------------- | |
# Prediction | |
# -------------------------------------------------------------------------- | |
def predict(self, ex, candidates=None, top_n=1, async_pool=None): | |
"""Forward a batch of examples only to get predictions. | |
Args: | |
ex: the batch | |
candidates: batch * variable length list of string answer options. | |
The model will only consider exact spans contained in this list. | |
top_n: Number of predictions to return per batch element. | |
async_pool: If provided, non-gpu post-processing will be offloaded | |
to this CPU process pool. | |
Output: | |
pred_s: batch * top_n predicted start indices | |
pred_e: batch * top_n predicted end indices | |
pred_score: batch * top_n prediction scores | |
If async_pool is given, these will be AsyncResult handles. | |
""" | |
# Eval mode | |
self.network.eval() | |
# Transfer to GPU | |
if self.use_cuda: | |
inputs = [e if e is None else e.cuda(non_blocking=True) | |
for e in ex[:5]] | |
else: | |
inputs = [e for e in ex[:5]] | |
# Run forward | |
with torch.no_grad(): | |
score_s, score_e = self.network(*inputs) | |
# Decode predictions | |
score_s = score_s.data.cpu() | |
score_e = score_e.data.cpu() | |
if candidates: | |
args = (score_s, score_e, candidates, top_n, self.args.max_len) | |
if async_pool: | |
return async_pool.apply_async(self.decode_candidates, args) | |
else: | |
return self.decode_candidates(*args) | |
else: | |
args = (score_s, score_e, top_n, self.args.max_len) | |
if async_pool: | |
return async_pool.apply_async(self.decode, args) | |
else: | |
return self.decode(*args) | |
def decode(score_s, score_e, top_n=1, max_len=None): | |
"""Take argmax of constrained score_s * score_e. | |
Args: | |
score_s: independent start predictions | |
score_e: independent end predictions | |
top_n: number of top scored pairs to take | |
max_len: max span length to consider | |
""" | |
pred_s = [] | |
pred_e = [] | |
pred_score = [] | |
max_len = max_len or score_s.size(1) | |
for i in range(score_s.size(0)): | |
# Outer product of scores to get full p_s * p_e matrix | |
scores = torch.ger(score_s[i], score_e[i]) | |
# Zero out negative length and over-length span scores | |
scores.triu_().tril_(max_len - 1) | |
# Take argmax or top n | |
scores = scores.numpy() | |
scores_flat = scores.flatten() | |
if top_n == 1: | |
idx_sort = [np.argmax(scores_flat)] | |
elif len(scores_flat) < top_n: | |
idx_sort = np.argsort(-scores_flat) | |
else: | |
idx = np.argpartition(-scores_flat, top_n)[0:top_n] | |
idx_sort = idx[np.argsort(-scores_flat[idx])] | |
s_idx, e_idx = np.unravel_index(idx_sort, scores.shape) | |
pred_s.append(s_idx) | |
pred_e.append(e_idx) | |
pred_score.append(scores_flat[idx_sort]) | |
return pred_s, pred_e, pred_score | |
def decode_candidates(score_s, score_e, candidates, top_n=1, max_len=None): | |
"""Take argmax of constrained score_s * score_e. Except only consider | |
spans that are in the candidates list. | |
""" | |
pred_s = [] | |
pred_e = [] | |
pred_score = [] | |
for i in range(score_s.size(0)): | |
# Extract original tokens stored with candidates | |
tokens = candidates[i]['input'] | |
cands = candidates[i]['cands'] | |
if not cands: | |
# try getting from globals? (multiprocessing in pipeline mode) | |
from ..pipeline.drqa import PROCESS_CANDS | |
cands = PROCESS_CANDS | |
if not cands: | |
raise RuntimeError('No candidates given.') | |
# Score all valid candidates found in text. | |
# Brute force get all ngrams and compare against the candidate list. | |
max_len = max_len or len(tokens) | |
scores, s_idx, e_idx = [], [], [] | |
for s, e in tokens.ngrams(n=max_len, as_strings=False): | |
span = tokens.slice(s, e).untokenize() | |
if span in cands or span.lower() in cands: | |
# Match! Record its score. | |
scores.append(score_s[i][s] * score_e[i][e - 1]) | |
s_idx.append(s) | |
e_idx.append(e - 1) | |
if len(scores) == 0: | |
# No candidates present | |
pred_s.append([]) | |
pred_e.append([]) | |
pred_score.append([]) | |
else: | |
# Rank found candidates | |
scores = np.array(scores) | |
s_idx = np.array(s_idx) | |
e_idx = np.array(e_idx) | |
idx_sort = np.argsort(-scores)[0:top_n] | |
pred_s.append(s_idx[idx_sort]) | |
pred_e.append(e_idx[idx_sort]) | |
pred_score.append(scores[idx_sort]) | |
return pred_s, pred_e, pred_score | |
# -------------------------------------------------------------------------- | |
# Saving and loading | |
# -------------------------------------------------------------------------- | |
def save(self, filename): | |
if self.parallel: | |
network = self.network.module | |
else: | |
network = self.network | |
state_dict = copy.copy(network.state_dict()) | |
if 'fixed_embedding' in state_dict: | |
state_dict.pop('fixed_embedding') | |
params = { | |
'state_dict': state_dict, | |
'word_dict': self.word_dict, | |
'feature_dict': self.feature_dict, | |
'args': self.args, | |
} | |
try: | |
torch.save(params, filename) | |
except BaseException: | |
logger.warning('WARN: Saving failed... continuing anyway.') | |
def checkpoint(self, filename, epoch): | |
if self.parallel: | |
network = self.network.module | |
else: | |
network = self.network | |
params = { | |
'state_dict': network.state_dict(), | |
'word_dict': self.word_dict, | |
'feature_dict': self.feature_dict, | |
'args': self.args, | |
'epoch': epoch, | |
'optimizer': self.optimizer.state_dict(), | |
} | |
try: | |
torch.save(params, filename) | |
except BaseException: | |
logger.warning('WARN: Saving failed... continuing anyway.') | |
def load(filename, new_args=None, normalize=True): | |
logger.info('Loading model %s' % filename) | |
saved_params = torch.load( | |
filename, map_location=lambda storage, loc: storage | |
) | |
word_dict = saved_params['word_dict'] | |
feature_dict = saved_params['feature_dict'] | |
state_dict = saved_params['state_dict'] | |
args = saved_params['args'] | |
if new_args: | |
args = override_model_args(args, new_args) | |
return DocReader(args, word_dict, feature_dict, state_dict, normalize) | |
def load_checkpoint(filename, normalize=True): | |
logger.info('Loading model %s' % filename) | |
saved_params = torch.load( | |
filename, map_location=lambda storage, loc: storage | |
) | |
word_dict = saved_params['word_dict'] | |
feature_dict = saved_params['feature_dict'] | |
state_dict = saved_params['state_dict'] | |
epoch = saved_params['epoch'] | |
optimizer = saved_params['optimizer'] | |
args = saved_params['args'] | |
model = DocReader(args, word_dict, feature_dict, state_dict, normalize) | |
model.init_optimizer(optimizer) | |
return model, epoch | |
# -------------------------------------------------------------------------- | |
# Runtime | |
# -------------------------------------------------------------------------- | |
def cuda(self): | |
self.use_cuda = True | |
self.network = self.network.cuda() | |
def cpu(self): | |
self.use_cuda = False | |
self.network = self.network.cpu() | |
def parallelize(self): | |
"""Use data parallel to copy the model across several gpus. | |
This will take all gpus visible with CUDA_VISIBLE_DEVICES. | |
""" | |
self.parallel = True | |
self.network = torch.nn.DataParallel(self.network) | |