QiyuWu's picture
Upload 100 files
1fd7780 verified
STYLESPACE_DIMENSIONS = [512 for _ in range(15)] + [256, 256, 256] + [128, 128, 128] + [64, 64, 64] + [32, 32]
def aggregate_loss_dict(agg_loss_dict):
mean_vals = {}
for output in agg_loss_dict:
for key in output:
mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]]
for key in mean_vals:
if len(mean_vals[key]) > 0:
mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key])
else:
print('{} has no value'.format(key))
mean_vals[key] = 0
return mean_vals
def convert_s_tensor_to_list(batch):
s_list = []
for i in range(len(STYLESPACE_DIMENSIONS)):
s_list.append(batch[:, :, 512 * i: 512 * i + STYLESPACE_DIMENSIONS[i]])
return s_list