Spaces:
Runtime error
Runtime error
# Copyright 2019 Kakao Brain | |
# | |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Arbitrary dependency between two autograd lanes.""" | |
from typing import List, Tuple | |
import torch | |
from torch import Tensor | |
from .phony import get_phony | |
__all__: List[str] = ["fork", "Fork", "join", "Join"] | |
def fork(input: Tensor) -> Tuple[Tensor, Tensor]: | |
"""Branches out from an autograd lane of the given tensor.""" | |
if torch.is_grad_enabled() and input.requires_grad: | |
input, phony = Fork.apply(input) | |
else: | |
phony = get_phony(input.device, requires_grad=False) | |
return input, phony | |
class Fork(torch.autograd.Function): | |
def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore[override] | |
phony = get_phony(input.device, requires_grad=False) | |
return input.detach(), phony.detach() | |
def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore[override] | |
return grad_input | |
def join(input: Tensor, phony: Tensor) -> Tensor: | |
"""Merge two autograd lanes.""" | |
if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): | |
input = Join.apply(input, phony) | |
return input | |
class Join(torch.autograd.Function): | |
def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore[override] | |
return input.detach() | |
def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore[override] | |
return grad_input, None | |