myTest01 / models /util /generation.py
meng2003's picture
Upload 85 files
bc32eea
import torch
#TODO: implement option to include the conditioning bit of input in the output
def autoregressive_generation_multimodal(features, model, autoreg_mods=[], teacher_forcing=False, ground_truth=False):
inputs_ = []
for i,mod in enumerate(model.input_mods):
input_ = features["in_"+mod]
if model.input_types[i] == "c":
input_ = torch.from_numpy(input_).float().to(model.device)
else:
input_ = torch.from_numpy(input_).long().to(model.device)
inputs_.append(input_)
output_time_offsets = model.output_time_offsets
input_time_offsets = model.input_time_offsets
input_lengths = model.input_lengths
output_lengths = model.output_lengths
input_mods = model.input_mods
output_mods = model.output_mods
# predicted_inputs = model.predicted_inputs
for mod in autoreg_mods:
assert mod in output_mods
input_tmp = []
for i,mod in enumerate(input_mods):
input_tmp.append(inputs_[i].clone()[input_time_offsets[i]:input_time_offsets[i]+input_lengths[i]])
#TODO: append the initial conditioning bit to the output too
model.eval()
output_seq = []
#sequence_length = inputs_[0].shape[0]
#TODO: make this less ad-hoc
sequence_length = inputs_[-1].shape[0]
print(sequence_length)
#import pdb;pdb.set_trace()
with torch.no_grad():
# for t in range(min(512, sequence_length-max(input_lengths)-1)):
import time
start_time = time.time()
for t in range(sequence_length-max(input_lengths)+1):
#for t in range(512):
print(t)
inputs = [x.clone().to(model.device) for x in input_tmp]
# import pdb;pdb.set_trace()
if not ground_truth:
outputs = model.forward(inputs)
#outputs[0][:,0,-4] = 0.0
#outputs[0][:,0,-6] = 0.0
if t == 0:
for i, mod in enumerate(output_mods):
# output[:,0,:-3] = torch.clamp(output[:,0,:-3],-3,3)
if not ground_truth:
output = outputs[i]
else:
j = input_mods.index(mod)
output = inputs_[j][t+output_time_offsets[i]+output_lengths[i]:t+output_time_offsets[i]+output_lengths[i]+1]
output_seq.append(output[:1].detach().clone())
#output_seq.append(inputs_[i][t+input_time_offsets[i]+input_lengths[i]:t+input_time_offsets[i]+input_lengths[i]+1]+0.15*torch.randn(1,219).to(model.device))
else:
for i, mod in enumerate(output_mods):
#output_seq[i] = torch.cat([output_seq[i], inputs_[i][t+input_time_offsets[i]+input_lengths[i]:t+input_time_offsets[i]+input_lengths[i]+1]+0.15*torch.randn(1,219).to(model.device)])
if not ground_truth:
output = outputs[i]
else:
j = input_mods.index(mod)
output = inputs_[j][t+output_time_offsets[i]+output_lengths[i]:t+output_time_offsets[i]+output_lengths[i]+1]
output_seq[i] = torch.cat([output_seq[i], output[:1].detach().clone()])
# output[:,0,:-3] = torch.clamp(output[:,0,:-3],-3,3)
# print(outputs[i][:1])
if t < sequence_length-1:
for i, mod in enumerate(input_mods):
if mod in autoreg_mods:
j = output_mods.index(mod)
if not ground_truth:
output = outputs[j]
else:
output = inputs_[i][t+input_time_offsets[j]+input_lengths[j]:t+input_time_offsets[j]+input_lengths[j]+1]
if teacher_forcing:
input_tmp[i] = torch.cat([input_tmp[i][1:],inputs_[i][t+input_time_offsets[i]+input_lengths[i]:t+input_time_offsets[i]+input_lengths[i]+1]],0)
else:
# import pdb;pdb.set_trace()
input_tmp[i] = torch.cat([input_tmp[i][1:],output[:1].detach().clone()],0)
# print(torch.mean((inputs_[i][t+input_time_offsets[i]+input_lengths[i]+1:t+input_time_offsets[i]+input_lengths[i]+1+1]-outputs[j][:1].detach().clone())**2))
if not ground_truth:
print(torch.mean((inputs_[i][t+output_time_offsets[j]:t+output_time_offsets[j]+1]-outputs[j][:1].detach().clone())**2))
else:
if model.input_fix_length_types[i] == "single":
#input_tmp[i] = torch.cat([input_tmp[i][1:],inputs_[i][input_time_offsets[i]+input_lengths[i]+t:input_time_offsets[i]+input_lengths[i]+t+1]],0)
pass
else:
input_tmp[i] = torch.cat([input_tmp[i][1:],inputs_[i][input_time_offsets[i]+input_lengths[i]+t:input_time_offsets[i]+input_lengths[i]+t+1]],0)
print("--- %s seconds ---" % (time.time() - start_time))
return output_seq