|
|
from functools import wraps |
|
|
from typing import Callable, Union, Tuple, Any |
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
from torch import distributed as dist |
|
|
|
|
|
from .context_managers import RandContext |
|
|
|
|
|
|
|
|
def cached(func: Callable[..., Tensor]): |
|
|
""" |
|
|
A decorator that takes a pytorch call function into a cached compatible version. |
|
|
:param func: A function that calls the pytorch and return representation tensor. |
|
|
:return: A function that returns 1) representation leaf tensors for cache construction, 2) a closure function for |
|
|
the 2nd forward and the cached backward. Call 2) with 1) as argument after calling backward on the loss Tensor. |
|
|
""" |
|
|
@wraps(func) |
|
|
def cache_func(*args, **kwargs): |
|
|
rnd_state = RandContext() |
|
|
with torch.no_grad(): |
|
|
reps_no_grad = func(*args, **kwargs) |
|
|
if isinstance(reps_no_grad, Tensor): |
|
|
reps_no_grad = (reps_no_grad, ) |
|
|
else: |
|
|
assert all(isinstance(v, Tensor) for v in reps_no_grad) |
|
|
leaf_reps = tuple(t.detach().requires_grad_() for t in reps_no_grad) |
|
|
|
|
|
@wraps(func) |
|
|
def forward_backward_func(cache_reps: Union[Tensor, Tuple[Tensor]]): |
|
|
with rnd_state: |
|
|
reps = func(*args, **kwargs) |
|
|
if isinstance(reps, Tensor): |
|
|
reps = (reps,) |
|
|
if isinstance(cache_reps, Tensor): |
|
|
cache_reps = (cache_reps,) |
|
|
assert len(reps) == len(cache_reps) |
|
|
|
|
|
surrogate = sum(map(lambda u, v: torch.dot(u.flatten(), v.grad.flatten()), reps, cache_reps), 0) |
|
|
surrogate.backward() |
|
|
|
|
|
return leaf_reps + (forward_backward_func,) |
|
|
return cache_func |
|
|
|
|
|
|
|
|
def _cat_tensor_list(xx): |
|
|
if isinstance(xx, list) and len(xx) > 0 and all(isinstance(x, Tensor) for x in xx): |
|
|
return torch.cat(xx) |
|
|
else: |
|
|
return xx |
|
|
|
|
|
|
|
|
def cat_input_tensor(func: Callable[..., Tensor]): |
|
|
""" |
|
|
A decorator that concatenates positional and keyword arguments of type List[Tensor] into a single Tensor |
|
|
on the 0 dimension. This can come in handy dealing with results of representation tensors from multiple |
|
|
cached forward. |
|
|
:param func: A loss function |
|
|
:return: Decorated loss function for cached results. |
|
|
""" |
|
|
@wraps(func) |
|
|
def cat_f(*args, **kwargs): |
|
|
args_cat = [_cat_tensor_list(x) for x in args] |
|
|
kwargs_cat = dict((k, _cat_tensor_list(v)) for k, v in kwargs.values()) |
|
|
return func(*args_cat, **kwargs_cat) |
|
|
return cat_f |
|
|
|
|
|
|
|
|
def _maybe_gather_tensor(t: Any, axis: int): |
|
|
if not isinstance(t, Tensor): |
|
|
return t |
|
|
gathered = [torch.empty_like(t) for _ in range(dist.get_world_size())] |
|
|
dist.all_gather(gathered, t) |
|
|
gathered[dist.get_rank()] = t |
|
|
return torch.cat(gathered, dim=axis) |
|
|
|
|
|
|
|
|
def gather_input_tensor(func: Callable[..., Tensor], axis=0): |
|
|
""" |
|
|
A decorator that all-gather positional and keyword arguments of type Tensor and concatenate them on axis. |
|
|
Intended to be used with distributed contrastive learning loss. |
|
|
:param func: A loss function |
|
|
:param axis: The axis the gathered tensors are concatenated. |
|
|
:return: Decorated loss function for distributed training. |
|
|
""" |
|
|
@wraps(func) |
|
|
def f(*args, **kwargs): |
|
|
args_gathered = [_maybe_gather_tensor(x, axis=axis) for x in args] |
|
|
kwargs_gathered = dict((k, _maybe_gather_tensor(v, axis=axis)) for k, v in kwargs.values()) |
|
|
return func(*args_gathered, **kwargs_gathered) |
|
|
return f |
|
|
|