# Copied from rut5compressed/util.py of rut5compressed repository. import logging import re from functools import wraps from re import Pattern from typing import Callable, Dict, Optional, Tuple import numpy as np import torch as T from .modules import TTCompressedLinear def map_module(root: T.nn.Module, func: Callable[[T.nn.Module, str], T.nn.Module], patt: Optional[str] = None) -> T.nn.Module: """Function ``map_module`` applies a function to each leaf of module tree which matches to a specified pattern. Parameters ---------- root : torch.nn.Module Module to modify. func : callable Function to be applied to every module (or matched to pattern) in module tree. patt : str, optional Pattern to filter modules by path in module tree. Returns ------- torch.nn.Module Module modified in-place. """ @wraps(func) def func_safe(*args, **kwargs): node = func(*args, **kwargs) if not isinstance(node, T.nn.Module): raise ValueError('Mapped result must be toch.nn.Module type ' f'but given {type(node)}.') return node return _map_module(root, func_safe, re.compile(patt or r'.*'), '') def _map_module(root: T.nn.Module, func: Callable[[T.nn.Module, str], T.nn.Module], patt: Pattern, path: str) -> T.nn.Module: for name, child in root.named_children(): node = _map_module(child, func, patt, f'{path}/{name}') if node != child: setattr(root, name, node) if patt.match(path or '/'): root = func(root, path or '/') return root def convert_linear(module: T.nn.Linear, ctor, **kwargs) -> T.nn.Module: """Function convert_linear takes module and returns linear module with approximate matmul. Non-linear modules are returned intact. """ if not isinstance(module, T.nn.Linear): return module raise NotImplementedError def numel(module: T.nn.Module): value = sum(x.numel() for x in module.parameters()) + \ sum(x.numel() for x in module.buffers()) def account_prunned(module: T.nn.Module, path: str): nonlocal value for name, attr in vars(module).items(): if not name.endswith('_mask') or not isinstance(attr, T.Tensor): continue weight_name = name[:-5] if not hasattr(module, weight_name): continue weight = getattr(module, weight_name) value -= weight.numel() - attr.sum() value += attr.numel() return module def account_quantized(module: T.nn.Module, path: str): nonlocal value if isinstance(module, T.nn.quantized.Linear): value += module.weight().numel() if module.bias() is not None: value += module.bias().numel() return module def account_rest(module: T.nn.Module, path: str): account_prunned(module, path) account_quantized(module, path) return module map_module(module, account_rest) return value def sizeof(module: T.nn.Module): value = sum(x.numel() * x.element_size() for x in module.parameters()) + \ sum(x.numel() * x.element_size() for x in module.buffers()) def account_prunned(module: T.nn.Module, path: str): nonlocal value for name, attr in vars(module).items(): if not name.endswith('_mask') or not isinstance(attr, T.Tensor): continue weight_name = name[:-5] if not hasattr(module, weight_name): continue weight = getattr(module, weight_name) value -= (weight.numel() - attr.sum()) * weight.element_size() value += attr.numel() * attr.element_size() return module def account_quantized(module: T.nn.Module, path: str): nonlocal value if isinstance(module, T.nn.quantized.Linear): value += module.weight().numel() * module.weight().element_size() if (bias := module.bias()) is not None: value += bias.numel() * bias.element_size() return module def account_rest(module: T.nn.Module, path: str): account_prunned(module, path) account_quantized(module, path) return module map_module(module, account_rest) return value def flatten_module(module: T.nn.Module, regexp=None) -> Dict[str, T.nn.Module]: modules = {} map_module(module, lambda x, y: modules.update(**{y: x}) or x, regexp) return modules def print_flatten(module: T.nn.Module): paths = [] path_len = 0 names = [] name_len = 0 indx_len = 0 def func(module, path): nonlocal path_len, name_len, indx_len paths.append(path) path_len = max(path_len, len(path)) name = module.__class__.__name__ names.append(name) name_len = max(name_len, len(name)) indx_len += 1 return module map_module(module, func) indx_len = int(np.ceil(np.log10(indx_len))) fmt = f'{{indx:>{indx_len}s}} {{path:{path_len}s}} {{name:{name_len}s}}' print(fmt.format(indx='#', path='Path', name='Layer')) print('-' * (indx_len + path_len + name_len + 2)) for i, (path, name) in enumerate(zip(paths, names)): print(fmt.format(indx=str(i), path=path, name=name)) def compress_linear_tt(module: T.nn.Module, path: str, shape: Tuple[Tuple[int], Tuple[int]], rank: int) -> T.nn.Module: if not isinstance(module, T.nn.Linear): return module # TODO(@not-found): We need propper compression config. inp_size = np.prod(shape[0]) out_size = np.prod(shape[1]) if inp_size == module.in_features and out_size == module.out_features: pass elif inp_size == module.out_features and out_size == module.in_features: shape = (shape[1], shape[0]) else: raise ValueError( 'Input and output features does not match to compression shape: ' f'{shape[0]} vs {module.in_features} and {shape[1]} vs ' f'{module.out_features}.') logging.info('apply tt compression to layer %s', path) return TTCompressedLinear.from_linear(module, shape, rank) # noqa: F821