#pragma once #include #include namespace at { namespace functorch { // NOTE [functorch TLS in pytorch/pytorch] // // functorch lives out-of-tree. However, it has some TLS that needs to be // propagated. The solution for that is we store a pointer to the TLS // inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to // include whatever functorch needs. // // We need to store a pointer due to the indirection: // inside functorch, we will create a subclass of FunctorchTLSBase called // FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack. // FuncTorchTLSBase doesn't have any metadata because it hasn't been defined // yet. // // Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside // functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*. // We can't directly pass around FunctorchTLSBase (without a pointer) because // FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having // more elements. struct TORCH_API FuncTorchTLSBase { virtual ~FuncTorchTLSBase() = default; virtual std::unique_ptr deepcopy() const = 0; // functorch doesn't always work with autograd.Function. // This is a hook to get into functorch -- functorch will determine // if it should raise an error message virtual int64_t checkSupportsAutogradFunction() const = 0; virtual void checkSupportsInplaceRequiresGrad() const = 0; virtual void checkSupportsRetainGrad() const = 0; }; // returns deepcopy of the functorch tls TORCH_API std::unique_ptr getCopyOfFuncTorchTLS(); // sets the functorch tls. always does a deep copy. TORCH_API void setFuncTorchTLS( const std::shared_ptr& state); // get a mutable reference to the functorch tls TORCH_API std::unique_ptr& functorchTLSAccessor(); } // namespace functorch } // namespace at