File size: 1,786 Bytes
d737ecd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import numpy as np
from functools import reduce, partial
from operator import mul
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations


def chain_functions(*functions):
    return lambda initial: reduce(lambda x, f: f(x), functions, initial)


def remove_fx_parametrisation(fx):
    def remover(m):
        if not is_parametrized(m):
            return
        for k in list(m.parametrizations.keys()):
            remove_parametrizations(m, k)

    fx.apply(remover)
    return fx


def get_chunks(keys, original_shapes):
    (position, _), *_ = filter(lambda i_k: "U.original" in i_k[1], enumerate(keys))
    original_chunks = list(map(partial(reduce, mul), original_shapes))
    U_matrix_shape = original_shapes[position]

    dimensions_not_need = np.ravel_multi_index(
        np.tril_indices(**dict(zip(("n", "m"), U_matrix_shape))), U_matrix_shape
    ) + sum(original_chunks[:position])

    selected_chunks = (
        original_chunks[:position]
        + [original_chunks[position] - dimensions_not_need.size]
        + original_chunks[position + 1 :]
    )
    return selected_chunks, position, U_matrix_shape, dimensions_not_need


def vec2statedict(
    x: torch.Tensor,
    keys,
    original_shapes,
    selected_chunks,
    position,
    U_matrix_shape,
):
    chunks = list(torch.split(x, selected_chunks))
    U = x.new_zeros(reduce(mul, U_matrix_shape))
    U[
        np.ravel_multi_index(
            np.triu_indices(n=U_matrix_shape[0], k=1, m=U_matrix_shape[1]),
            U_matrix_shape,
        )
    ] = chunks[position]
    chunks[position] = U

    state_dict = dict(
        zip(
            keys,
            map(lambda x, shape: x.reshape(*shape), chunks, original_shapes),
        )
    )
    return state_dict