Spaces:
Running
Running
File size: 1,786 Bytes
d737ecd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
import numpy as np
from functools import reduce, partial
from operator import mul
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
def chain_functions(*functions):
return lambda initial: reduce(lambda x, f: f(x), functions, initial)
def remove_fx_parametrisation(fx):
def remover(m):
if not is_parametrized(m):
return
for k in list(m.parametrizations.keys()):
remove_parametrizations(m, k)
fx.apply(remover)
return fx
def get_chunks(keys, original_shapes):
(position, _), *_ = filter(lambda i_k: "U.original" in i_k[1], enumerate(keys))
original_chunks = list(map(partial(reduce, mul), original_shapes))
U_matrix_shape = original_shapes[position]
dimensions_not_need = np.ravel_multi_index(
np.tril_indices(**dict(zip(("n", "m"), U_matrix_shape))), U_matrix_shape
) + sum(original_chunks[:position])
selected_chunks = (
original_chunks[:position]
+ [original_chunks[position] - dimensions_not_need.size]
+ original_chunks[position + 1 :]
)
return selected_chunks, position, U_matrix_shape, dimensions_not_need
def vec2statedict(
x: torch.Tensor,
keys,
original_shapes,
selected_chunks,
position,
U_matrix_shape,
):
chunks = list(torch.split(x, selected_chunks))
U = x.new_zeros(reduce(mul, U_matrix_shape))
U[
np.ravel_multi_index(
np.triu_indices(n=U_matrix_shape[0], k=1, m=U_matrix_shape[1]),
U_matrix_shape,
)
] = chunks[position]
chunks[position] = U
state_dict = dict(
zip(
keys,
map(lambda x, shape: x.reshape(*shape), chunks, original_shapes),
)
)
return state_dict
|