Spaces:
Runtime error
Runtime error
import torch | |
#TODO: implement option to include the conditioning bit of input in the output | |
def autoregressive_generation_multimodal(features, model, autoreg_mods=[], teacher_forcing=False, ground_truth=False): | |
inputs_ = [] | |
for i,mod in enumerate(model.input_mods): | |
input_ = features["in_"+mod] | |
if model.input_types[i] == "c": | |
input_ = torch.from_numpy(input_).float().to(model.device) | |
else: | |
input_ = torch.from_numpy(input_).long().to(model.device) | |
inputs_.append(input_) | |
output_time_offsets = model.output_time_offsets | |
input_time_offsets = model.input_time_offsets | |
input_lengths = model.input_lengths | |
output_lengths = model.output_lengths | |
input_mods = model.input_mods | |
output_mods = model.output_mods | |
# predicted_inputs = model.predicted_inputs | |
for mod in autoreg_mods: | |
assert mod in output_mods | |
input_tmp = [] | |
for i,mod in enumerate(input_mods): | |
input_tmp.append(inputs_[i].clone()[input_time_offsets[i]:input_time_offsets[i]+input_lengths[i]]) | |
#TODO: append the initial conditioning bit to the output too | |
model.eval() | |
output_seq = [] | |
#sequence_length = inputs_[0].shape[0] | |
#TODO: make this less ad-hoc | |
sequence_length = inputs_[-1].shape[0] | |
print(sequence_length) | |
#import pdb;pdb.set_trace() | |
with torch.no_grad(): | |
# for t in range(min(512, sequence_length-max(input_lengths)-1)): | |
import time | |
start_time = time.time() | |
for t in range(sequence_length-max(input_lengths)+1): | |
#for t in range(512): | |
print(t) | |
inputs = [x.clone().to(model.device) for x in input_tmp] | |
# import pdb;pdb.set_trace() | |
if not ground_truth: | |
outputs = model.forward(inputs) | |
#outputs[0][:,0,-4] = 0.0 | |
#outputs[0][:,0,-6] = 0.0 | |
if t == 0: | |
for i, mod in enumerate(output_mods): | |
# output[:,0,:-3] = torch.clamp(output[:,0,:-3],-3,3) | |
if not ground_truth: | |
output = outputs[i] | |
else: | |
j = input_mods.index(mod) | |
output = inputs_[j][t+output_time_offsets[i]+output_lengths[i]:t+output_time_offsets[i]+output_lengths[i]+1] | |
output_seq.append(output[:1].detach().clone()) | |
#output_seq.append(inputs_[i][t+input_time_offsets[i]+input_lengths[i]:t+input_time_offsets[i]+input_lengths[i]+1]+0.15*torch.randn(1,219).to(model.device)) | |
else: | |
for i, mod in enumerate(output_mods): | |
#output_seq[i] = torch.cat([output_seq[i], inputs_[i][t+input_time_offsets[i]+input_lengths[i]:t+input_time_offsets[i]+input_lengths[i]+1]+0.15*torch.randn(1,219).to(model.device)]) | |
if not ground_truth: | |
output = outputs[i] | |
else: | |
j = input_mods.index(mod) | |
output = inputs_[j][t+output_time_offsets[i]+output_lengths[i]:t+output_time_offsets[i]+output_lengths[i]+1] | |
output_seq[i] = torch.cat([output_seq[i], output[:1].detach().clone()]) | |
# output[:,0,:-3] = torch.clamp(output[:,0,:-3],-3,3) | |
# print(outputs[i][:1]) | |
if t < sequence_length-1: | |
for i, mod in enumerate(input_mods): | |
if mod in autoreg_mods: | |
j = output_mods.index(mod) | |
if not ground_truth: | |
output = outputs[j] | |
else: | |
output = inputs_[i][t+input_time_offsets[j]+input_lengths[j]:t+input_time_offsets[j]+input_lengths[j]+1] | |
if teacher_forcing: | |
input_tmp[i] = torch.cat([input_tmp[i][1:],inputs_[i][t+input_time_offsets[i]+input_lengths[i]:t+input_time_offsets[i]+input_lengths[i]+1]],0) | |
else: | |
# import pdb;pdb.set_trace() | |
input_tmp[i] = torch.cat([input_tmp[i][1:],output[:1].detach().clone()],0) | |
# print(torch.mean((inputs_[i][t+input_time_offsets[i]+input_lengths[i]+1:t+input_time_offsets[i]+input_lengths[i]+1+1]-outputs[j][:1].detach().clone())**2)) | |
if not ground_truth: | |
print(torch.mean((inputs_[i][t+output_time_offsets[j]:t+output_time_offsets[j]+1]-outputs[j][:1].detach().clone())**2)) | |
else: | |
if model.input_fix_length_types[i] == "single": | |
#input_tmp[i] = torch.cat([input_tmp[i][1:],inputs_[i][input_time_offsets[i]+input_lengths[i]+t:input_time_offsets[i]+input_lengths[i]+t+1]],0) | |
pass | |
else: | |
input_tmp[i] = torch.cat([input_tmp[i][1:],inputs_[i][input_time_offsets[i]+input_lengths[i]+t:input_time_offsets[i]+input_lengths[i]+t+1]],0) | |
print("--- %s seconds ---" % (time.time() - start_time)) | |
return output_seq | |