Spaces:
Running
Running
import torch | |
import numpy as np | |
from functools import reduce, partial | |
from operator import mul | |
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations | |
def chain_functions(*functions): | |
return lambda initial: reduce(lambda x, f: f(x), functions, initial) | |
def remove_fx_parametrisation(fx): | |
def remover(m): | |
if not is_parametrized(m): | |
return | |
for k in list(m.parametrizations.keys()): | |
remove_parametrizations(m, k) | |
fx.apply(remover) | |
return fx | |
def get_chunks(keys, original_shapes): | |
(position, _), *_ = filter(lambda i_k: "U.original" in i_k[1], enumerate(keys)) | |
original_chunks = list(map(partial(reduce, mul), original_shapes)) | |
U_matrix_shape = original_shapes[position] | |
dimensions_not_need = np.ravel_multi_index( | |
np.tril_indices(**dict(zip(("n", "m"), U_matrix_shape))), U_matrix_shape | |
) + sum(original_chunks[:position]) | |
selected_chunks = ( | |
original_chunks[:position] | |
+ [original_chunks[position] - dimensions_not_need.size] | |
+ original_chunks[position + 1 :] | |
) | |
return selected_chunks, position, U_matrix_shape, dimensions_not_need | |
def vec2statedict( | |
x: torch.Tensor, | |
keys, | |
original_shapes, | |
selected_chunks, | |
position, | |
U_matrix_shape, | |
): | |
chunks = list(torch.split(x, selected_chunks)) | |
U = x.new_zeros(reduce(mul, U_matrix_shape)) | |
U[ | |
np.ravel_multi_index( | |
np.triu_indices(n=U_matrix_shape[0], k=1, m=U_matrix_shape[1]), | |
U_matrix_shape, | |
) | |
] = chunks[position] | |
chunks[position] = U | |
state_dict = dict( | |
zip( | |
keys, | |
map(lambda x, shape: x.reshape(*shape), chunks, original_shapes), | |
) | |
) | |
return state_dict | |