Fraser-Greenlee
add dreamcoder codebase
e1c1753
"""
Deprecated network.py module. This file only exists to support backwards-compatibility
with old pickle files. See lib/__init__.py for more information.
"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
# UPGRADING TO INPUT -> OUTPUT -> TARGET
# Todo:
# [X] Output attending to input
# [X] Target attending to output
# [ ] check passing hidden state between encoders/decoder (+ pass c?)
# [ ] add v_output
def choose(matrix, idxs):
if isinstance(idxs, Variable):
idxs = idxs.data
assert(matrix.ndimension() == 2)
unrolled_idxs = idxs + \
torch.arange(0, matrix.size(0)).type_as(idxs) * matrix.size(1)
return matrix.view(matrix.nelement())[unrolled_idxs]
class Network(nn.Module):
"""
Todo:
- Beam search
- check if this is right? attend during P->FC rather than during softmax->P?
- allow length 0 inputs/targets
- give n_examples as input to FC
- Initialise new weights randomly, rather than as zeroes
"""
def __init__(
self,
input_vocabulary,
target_vocabulary,
hidden_size=512,
embedding_size=128,
cell_type="LSTM"):
"""
:param list input_vocabulary: list of possible inputs
:param list target_vocabulary: list of possible targets
"""
super(Network, self).__init__()
self.h_input_encoder_size = hidden_size
self.h_output_encoder_size = hidden_size
self.h_decoder_size = hidden_size
self.embedding_size = embedding_size
self.input_vocabulary = input_vocabulary
self.target_vocabulary = target_vocabulary
# Number of tokens in input vocabulary
self.v_input = len(input_vocabulary)
# Number of tokens in target vocabulary
self.v_target = len(target_vocabulary)
self.cell_type = cell_type
if cell_type == 'GRU':
self.input_encoder_cell = nn.GRUCell(
input_size=self.v_input + 1,
hidden_size=self.h_input_encoder_size,
bias=True)
self.input_encoder_init = Parameter(
torch.rand(1, self.h_input_encoder_size))
self.output_encoder_cell = nn.GRUCell(
input_size=self.v_input +
1 +
self.h_input_encoder_size,
hidden_size=self.h_output_encoder_size,
bias=True)
self.decoder_cell = nn.GRUCell(
input_size=self.v_target + 1,
hidden_size=self.h_decoder_size,
bias=True)
if cell_type == 'LSTM':
self.input_encoder_cell = nn.LSTMCell(
input_size=self.v_input + 1,
hidden_size=self.h_input_encoder_size,
bias=True)
self.input_encoder_init = nn.ParameterList([Parameter(torch.rand(
1, self.h_input_encoder_size)), Parameter(torch.rand(1, self.h_input_encoder_size))])
self.output_encoder_cell = nn.LSTMCell(
input_size=self.v_input +
1 +
self.h_input_encoder_size,
hidden_size=self.h_output_encoder_size,
bias=True)
self.output_encoder_init_c = Parameter(
torch.rand(1, self.h_output_encoder_size))
self.decoder_cell = nn.LSTMCell(
input_size=self.v_target + 1,
hidden_size=self.h_decoder_size,
bias=True)
self.decoder_init_c = Parameter(torch.rand(1, self.h_decoder_size))
self.W = nn.Linear(
self.h_output_encoder_size +
self.h_decoder_size,
self.embedding_size)
self.V = nn.Linear(self.embedding_size, self.v_target + 1)
self.input_A = nn.Bilinear(
self.h_input_encoder_size,
self.h_output_encoder_size,
1,
bias=False)
self.output_A = nn.Bilinear(
self.h_output_encoder_size,
self.h_decoder_size,
1,
bias=False)
self.input_EOS = torch.zeros(1, self.v_input + 1)
self.input_EOS[:, -1] = 1
self.input_EOS = Parameter(self.input_EOS)
self.output_EOS = torch.zeros(1, self.v_input + 1)
self.output_EOS[:, -1] = 1
self.output_EOS = Parameter(self.output_EOS)
self.target_EOS = torch.zeros(1, self.v_target + 1)
self.target_EOS[:, -1] = 1
self.target_EOS = Parameter(self.target_EOS)
def __getstate__(self):
if hasattr(self, 'opt'):
return dict([(k, v) for k, v in self.__dict__.items(
) if k is not 'opt'] + [('optstate', self.opt.state_dict())])
# return {**{k:v for k,v in self.__dict__.items() if k is not 'opt'},
# 'optstate': self.opt.state_dict()}
else:
return self.__dict__
def __setstate__(self, state):
self.__dict__.update(state)
# Legacy:
if isinstance(self.input_encoder_init, tuple):
self.input_encoder_init = nn.ParameterList(
list(self.input_encoder_init))
def clear_optimiser(self):
if hasattr(self, 'opt'):
del self.opt
if hasattr(self, 'optstate'):
del self.optstate
def get_optimiser(self):
self.opt = torch.optim.Adam(self.parameters(), lr=0.001)
if hasattr(self, 'optstate'):
self.opt.load_state_dict(self.optstate)
def optimiser_step(self, inputs, outputs, target):
if not hasattr(self, 'opt'):
self.get_optimiser()
score = self.score(inputs, outputs, target, autograd=True).mean()
(-score).backward()
self.opt.step()
self.opt.zero_grad()
return score.data[0]
def set_target_vocabulary(self, target_vocabulary):
if target_vocabulary == self.target_vocabulary:
return
V_weight = []
V_bias = []
decoder_ih = []
for i in range(len(target_vocabulary)):
if target_vocabulary[i] in self.target_vocabulary:
j = self.target_vocabulary.index(target_vocabulary[i])
V_weight.append(self.V.weight.data[j:j + 1])
V_bias.append(self.V.bias.data[j:j + 1])
decoder_ih.append(self.decoder_cell.weight_ih.data[:, j:j + 1])
else:
V_weight.append(torch.zeros(1, self.V.weight.size(1)))
V_bias.append(torch.ones(1) * -10)
decoder_ih.append(
torch.zeros(
self.decoder_cell.weight_ih.data.size(0), 1))
V_weight.append(self.V.weight.data[-1:])
V_bias.append(self.V.bias.data[-1:])
decoder_ih.append(self.decoder_cell.weight_ih.data[:, -1:])
self.target_vocabulary = target_vocabulary
self.v_target = len(target_vocabulary)
self.target_EOS.data = torch.zeros(1, self.v_target + 1)
self.target_EOS.data[:, -1] = 1
self.V.weight.data = torch.cat(V_weight, dim=0)
self.V.bias.data = torch.cat(V_bias, dim=0)
self.V.out_features = self.V.bias.data.size(0)
self.decoder_cell.weight_ih.data = torch.cat(decoder_ih, dim=1)
self.decoder_cell.input_size = self.decoder_cell.weight_ih.data.size(1)
self.clear_optimiser()
def input_encoder_get_init(self, batch_size):
if self.cell_type == "GRU":
return self.input_encoder_init.repeat(batch_size, 1)
if self.cell_type == "LSTM":
return tuple(x.repeat(batch_size, 1)
for x in self.input_encoder_init)
def output_encoder_get_init(self, input_encoder_h):
if self.cell_type == "GRU":
return input_encoder_h
if self.cell_type == "LSTM":
return (
input_encoder_h,
self.output_encoder_init_c.repeat(
input_encoder_h.size(0),
1))
def decoder_get_init(self, output_encoder_h):
if self.cell_type == "GRU":
return output_encoder_h
if self.cell_type == "LSTM":
return (
output_encoder_h,
self.decoder_init_c.repeat(
output_encoder_h.size(0),
1))
def cell_get_h(self, cell_state):
if self.cell_type == "GRU":
return cell_state
if self.cell_type == "LSTM":
return cell_state[0]
def score(self, inputs, outputs, target, autograd=False):
inputs = self.inputsToTensors(inputs)
outputs = self.inputsToTensors(outputs)
target = self.targetToTensor(target)
target, score = self.run(inputs, outputs, target=target, mode="score")
# target = self.tensorToOutput(target)
if autograd:
return score
else:
return score.data
def sample(self, inputs, outputs):
inputs = self.inputsToTensors(inputs)
outputs = self.inputsToTensors(outputs)
target, score = self.run(inputs, outputs, mode="sample")
target = self.tensorToOutput(target)
return target
def sampleAndScore(self, inputs, outputs, nRepeats=None):
inputs = self.inputsToTensors(inputs)
outputs = self.inputsToTensors(outputs)
if nRepeats is None:
target, score = self.run(inputs, outputs, mode="sample")
target = self.tensorToOutput(target)
return target, score.data
else:
target = []
score = []
for i in range(nRepeats):
# print("repeat %d" % i)
t, s = self.run(inputs, outputs, mode="sample")
t = self.tensorToOutput(t)
target.extend(t)
score.extend(list(s.data))
return target, score
def run(self, inputs, outputs, target=None, mode="sample"):
"""
:param mode: "score" returns log p(target|input), "sample" returns target ~ p(-|input)
:param List[LongTensor] inputs: n_examples * (max_length_input * batch_size)
:param List[LongTensor] target: max_length_target * batch_size
"""
assert((mode == "score" and target is not None) or mode == "sample")
n_examples = len(inputs)
max_length_input = [inputs[j].size(0) for j in range(n_examples)]
max_length_output = [outputs[j].size(0) for j in range(n_examples)]
max_length_target = target.size(0) if target is not None else 10
batch_size = inputs[0].size(1)
score = Variable(torch.zeros(batch_size))
inputs_scatter = [Variable(torch.zeros(max_length_input[j], batch_size, self.v_input + 1).scatter_(
2, inputs[j][:, :, None], 1)) for j in range(n_examples)] # n_examples * (max_length_input * batch_size * v_input+1)
outputs_scatter = [Variable(torch.zeros(max_length_output[j], batch_size, self.v_input + 1).scatter_(
2, outputs[j][:, :, None], 1)) for j in range(n_examples)] # n_examples * (max_length_output * batch_size * v_input+1)
if target is not None:
target_scatter = Variable(torch.zeros(max_length_target,
batch_size,
self.v_target + 1).scatter_(2,
target[:,
:,
None],
1)) # max_length_target * batch_size * v_target+1
# -------------- Input Encoder -------------
# n_examples * (max_length_input * batch_size * h_encoder_size)
input_H = []
input_embeddings = [] # h for example at INPUT_EOS
# 0 until (and including) INPUT_EOS, then -inf
input_attention_mask = []
for j in range(n_examples):
active = torch.Tensor(max_length_input[j], batch_size).byte()
active[0, :] = 1
state = self.input_encoder_get_init(batch_size)
hs = []
for i in range(max_length_input[j]):
state = self.input_encoder_cell(
inputs_scatter[j][i, :, :], state)
if i + 1 < max_length_input[j]:
active[i + 1, :] = active[i, :] * \
(inputs[j][i, :] != self.v_input)
h = self.cell_get_h(state)
hs.append(h[None, :, :])
input_H.append(torch.cat(hs, 0))
embedding_idx = active.sum(0).long() - 1
embedding = input_H[j].gather(0, Variable(
embedding_idx[None, :, None].repeat(1, 1, self.h_input_encoder_size)))[0]
input_embeddings.append(embedding)
input_attention_mask.append(Variable(active.float().log()))
# -------------- Output Encoder -------------
def input_attend(j, h_out):
"""
'general' attention from https://arxiv.org/pdf/1508.04025.pdf
:param j: Index of example
:param h_out: batch_size * h_output_encoder_size
"""
scores = self.input_A(
input_H[j].view(
max_length_input[j] * batch_size,
self.h_input_encoder_size),
h_out.view(
batch_size,
self.h_output_encoder_size).repeat(
max_length_input[j],
1)).view(
max_length_input[j],
batch_size) + input_attention_mask[j]
c = (F.softmax(scores[:, :, None], dim=0) * input_H[j]).sum(0)
return c
# n_examples * (max_length_input * batch_size * h_encoder_size)
output_H = []
output_embeddings = [] # h for example at INPUT_EOS
# 0 until (and including) INPUT_EOS, then -inf
output_attention_mask = []
for j in range(n_examples):
active = torch.Tensor(max_length_output[j], batch_size).byte()
active[0, :] = 1
state = self.output_encoder_get_init(input_embeddings[j])
hs = []
h = self.cell_get_h(state)
for i in range(max_length_output[j]):
state = self.output_encoder_cell(torch.cat(
[outputs_scatter[j][i, :, :], input_attend(j, h)], 1), state)
if i + 1 < max_length_output[j]:
active[i + 1, :] = active[i, :] * \
(outputs[j][i, :] != self.v_input)
h = self.cell_get_h(state)
hs.append(h[None, :, :])
output_H.append(torch.cat(hs, 0))
embedding_idx = active.sum(0).long() - 1
embedding = output_H[j].gather(0, Variable(
embedding_idx[None, :, None].repeat(1, 1, self.h_output_encoder_size)))[0]
output_embeddings.append(embedding)
output_attention_mask.append(Variable(active.float().log()))
# ------------------ Decoder -----------------
def output_attend(j, h_dec):
"""
'general' attention from https://arxiv.org/pdf/1508.04025.pdf
:param j: Index of example
:param h_dec: batch_size * h_decoder_size
"""
scores = self.output_A(
output_H[j].view(
max_length_output[j] * batch_size,
self.h_output_encoder_size),
h_dec.view(
batch_size,
self.h_decoder_size).repeat(
max_length_output[j],
1)).view(
max_length_output[j],
batch_size) + output_attention_mask[j]
c = (F.softmax(scores[:, :, None], dim=0) * output_H[j]).sum(0)
return c
# Multi-example pooling: Figure 3, https://arxiv.org/pdf/1703.07469.pdf
target = target if mode == "score" else torch.zeros(
max_length_target, batch_size).long()
decoder_states = [
self.decoder_get_init(
output_embeddings[j]) for j in range(n_examples)] # P
active = torch.ones(batch_size).byte()
for i in range(max_length_target):
FC = []
for j in range(n_examples):
h = self.cell_get_h(decoder_states[j])
p_aug = torch.cat([h, output_attend(j, h)], 1)
FC.append(F.tanh(self.W(p_aug)[None, :, :]))
# batch_size * embedding_size
m = torch.max(torch.cat(FC, 0), 0)[0]
logsoftmax = F.log_softmax(self.V(m), dim=1)
if mode == "sample":
target[i, :] = torch.multinomial(
logsoftmax.data.exp(), 1)[:, 0]
score = score + \
choose(logsoftmax, target[i, :]) * Variable(active.float())
active *= (target[i, :] != self.v_target)
for j in range(n_examples):
if mode == "score":
target_char_scatter = target_scatter[i, :, :]
elif mode == "sample":
target_char_scatter = Variable(torch.zeros(
batch_size, self.v_target + 1).scatter_(1, target[i, :, None], 1))
decoder_states[j] = self.decoder_cell(
target_char_scatter, decoder_states[j])
return target, score
def inputsToTensors(self, inputss):
"""
:param inputss: size = nBatch * nExamples
"""
tensors = []
for j in range(len(inputss[0])):
inputs = [x[j] for x in inputss]
maxlen = max(len(s) for s in inputs)
t = torch.ones(
1 if maxlen == 0 else maxlen + 1,
len(inputs)).long() * self.v_input
for i in range(len(inputs)):
s = inputs[i]
if len(s) > 0:
t[:len(s), i] = torch.LongTensor(
[self.input_vocabulary.index(x) for x in s])
tensors.append(t)
return tensors
def targetToTensor(self, targets):
"""
:param targets:
"""
maxlen = max(len(s) for s in targets)
t = torch.ones(
1 if maxlen == 0 else maxlen + 1,
len(targets)).long() * self.v_target
for i in range(len(targets)):
s = targets[i]
if len(s) > 0:
t[:len(s), i] = torch.LongTensor(
[self.target_vocabulary.index(x) for x in s])
return t
def tensorToOutput(self, tensor):
"""
:param tensor: max_length * batch_size
"""
out = []
for i in range(tensor.size(1)):
l = tensor[:, i].tolist()
if l[0] == self.v_target:
out.append([])
elif self.v_target in l:
final = tensor[:, i].tolist().index(self.v_target)
out.append([self.target_vocabulary[x]
for x in tensor[:final, i]])
else:
out.append([self.target_vocabulary[x] for x in tensor[:, i]])
return out