Spaces:
Sleeping
Sleeping
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
==============================================================================*/ | |
namespace tensorflow { | |
class CancellationManager; | |
class GraphDef; | |
class OpKernel; | |
class ResourceMgr; | |
class Rendezvous; | |
class ScopedStepContainer; | |
class StepStatsCollector; | |
class Node; | |
// FunctionDefHelper::Create is a convenient helper to construct a | |
// FunctionDef proto. | |
// E.g., | |
// FunctionDef my_func = FunctionDefHelper::Create( | |
// "my_func_name", | |
// {"x:T", "y:T" /* one string per argument */}, | |
// {"z:T" /* one string per return value */}, | |
// {"T: {float, double}" /* one string per attribute */}, | |
// { | |
// {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}} | |
// /* one entry per function node */ | |
// }, | |
// /* Mapping between function returns and function node outputs. */ | |
// {{"z", "o:z"}}); | |
// | |
// For the old Function::Node approach, use FunctionDefHelper::Define() | |
// E.g., | |
// FunctionDef my_func = FunctionDefHelper::Define( | |
// "my_func_name", | |
// {"x:T", "y:T" /* one string per argument */}, | |
// {"z:T" /* one string per return value */}, | |
// {"T: {float, double}" /* one string per attribute */}, | |
// { | |
// {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}} | |
// /* one entry per function node */ | |
// }); | |
class FunctionDefHelper { | |
public: | |
// AttrValueWrapper has copy constructors for the type T so that | |
// it's easy to construct a simple AttrValue proto. | |
// | |
// If T is a string type (const char*, string, or StringPiece), and | |
// it starts with "$", we construct a AttrValue of "placeholder". | |
// | |
// E.g., | |
// std::<string, AttrValueWrapper> x = {"T", "$T"} | |
// is a named attr value placeholder. | |
struct AttrValueWrapper { | |
AttrValue proto; | |
AttrValueWrapper() {} | |
template <typename T> | |
AttrValueWrapper(T val) { // NOLINT(runtime/explicit) | |
SetAttrValue(val, &proto); | |
} | |
private: | |
void InitFromString(StringPiece val); | |
}; | |
// Constructs an AttrValue.func given the "name" and "attrs". | |
static AttrValueWrapper FunctionRef( | |
const string& name, | |
gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs); | |
static AttrValueWrapper FunctionRef(const string& name) { | |
return FunctionRef(name, {}); | |
} | |
// Node is used to construct FunctionDef.Node using initialization | |
// lists. E.g., | |
// Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y | |
struct Node { | |
// When constructing a NodeDef, the first entry in ret is used as | |
// the node name, the remaining values are ignored. | |
std::vector<string> ret; | |
string op; | |
std::vector<string> arg; | |
std::vector<std::pair<string, AttrValueWrapper>> attr; | |
std::vector<string> dep; | |
NodeDef ToNodeDef() const; | |
}; | |
// The Create() function uses the new NodeDef field. `ret_def` | |
// holds a mapping from the function output names from `out_def` to | |
// the node outputs from `node_def`. | |
static FunctionDef Create(const string& function_name, | |
gtl::ArraySlice<string> in_def, | |
gtl::ArraySlice<string> out_def, | |
gtl::ArraySlice<string> attr_def, | |
gtl::ArraySlice<Node> node_def, | |
gtl::ArraySlice<std::pair<string, string>> ret_def); | |
// The two Define() functions use the old FunctionDef::Node field. | |
// TODO(josh11b): Get rid of these and transition to the one above. | |
static FunctionDef Define(const string& function_name, | |
gtl::ArraySlice<string> arg_def, | |
gtl::ArraySlice<string> ret_def, | |
gtl::ArraySlice<string> attr_def, | |
gtl::ArraySlice<Node> node_def); | |
// Defines an anonymous function. I.e., its name is not relevant. | |
static FunctionDef Define(gtl::ArraySlice<string> arg_def, | |
gtl::ArraySlice<string> ret_def, | |
gtl::ArraySlice<string> attr_def, | |
gtl::ArraySlice<Node> node_def); | |
// Helpers to construct a constant scalar. | |
template <typename T> | |
static Node Const(const string& name, const T& val) { | |
Node n = {{name}, "Const"}; | |
const DataType dtype = DataTypeToEnum<T>::value; | |
n.attr.push_back({"dtype", dtype}); | |
Tensor t(dtype, TensorShape({})); | |
t.scalar<T>()() = val; | |
n.attr.push_back({"value", t}); | |
return n; | |
} | |
template <typename T> | |
static Node Const(const string& name, gtl::ArraySlice<T> vals) { | |
Node n = {{name}, "Const"}; | |
const DataType dtype = DataTypeToEnum<T>::value; | |
n.attr.push_back({"dtype", dtype}); | |
int64 num = vals.size(); | |
Tensor t(dtype, TensorShape({num})); | |
for (size_t i = 0; i < vals.size(); ++i) { | |
t.flat<T>()(i) = vals[i]; | |
} | |
n.attr.push_back({"value", t}); | |
return n; | |
} | |
}; | |
template <> | |
inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) { | |
InitFromString(val); | |
} | |
template <> | |
inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper( | |
const string& val) { | |
InitFromString(val); | |
} | |
template <> | |
inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { | |
InitFromString(val); | |
} | |
// Instantiate a function. | |
// | |
// "fdef" encodes a TF function with some attrs in fdef.signature.attr | |
// containing placeholders. InstantiateFunction binds these | |
// placeholders and produces an instantiated function encoded in | |
// "result.gdef". The value to substitute a placeholder is given by | |
// "attr_values", which is a map from a placeholder name to an attr | |
// value. | |
// | |
// InstantiateFunction calls "get_function" to find signatures of other | |
// functions and primitive ops. | |
// GetFunctionSignature(func name, opdef) returns OK if the func name is found | |
// and opdef is filled with a pointer to the corresponding signature | |
// (a OpDef proto). Otherwise, returns an error. | |
typedef std::function<Status(const string&, const OpDef**)> | |
GetFunctionSignature; | |
struct InstantiationResult { | |
DataTypeVector arg_types; | |
DataTypeVector ret_types; | |
std::vector<NodeDef> nodes; | |
}; | |
Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, | |
GetFunctionSignature get_function, | |
InstantiationResult* result); | |
// Returns a debug string for a function definition. | |
// | |
// The returned text is multiple-line. It is intended to be | |
// human-readable rather than being friendly to parsers. It is _NOT_ | |
// intended to be the canonical string representation of "func_def". | |
// Particularly, it may not include all information presented in | |
// "func_def" (e.g., comments, description of the function arguments, | |
// etc.) | |
string DebugString(const FunctionDef& func_def); | |
string DebugString(const GraphDef& instantiated_func_def); | |
string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes); | |
// Returns a debug string for a top level graph (the main program and | |
// its supporting functions defined in its library). | |
string DebugStringWhole(const GraphDef& gdef); | |
// Returns true if f1 == f2. Compares all fields, including descriptions. Order | |
// of NodeDefs doesn't matter. | |
bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2); | |
// Return a hash of `fdef` that is consistent with FunctionDefsEqual method. | |
// In other words, if two fdefs compare equal, their hash values will be the | |
// same. | |
uint64 FunctionDefHash(const FunctionDef& fdef); | |
// Returns a canonicalized string for the instantiation of the | |
// function of the given "name" and attributes "attrs". | |
// | |
// The returned string is guaranteed to be stable within one address | |
// space. But it may be change as the implementation | |
// evolves. Therefore, it should not be persisted or compared across | |
// address spaces. | |
string Canonicalize(const string& funcname, AttrSlice attrs); | |
class CallFrameInterface { | |
public: | |
virtual ~CallFrameInterface() {} | |
virtual size_t num_args() const = 0; | |
virtual size_t num_retvals() const = 0; | |
virtual Status GetArg(int index, Tensor* val) const = 0; | |
virtual Status SetRetval(int index, const Tensor& val) = 0; | |
}; | |
// Represents a function call frame. I.e., the data structure used to | |
// pass arguments to a function and retrieve its results. | |
// | |
// Runtime must arrange accesses to one FunctionCallFrame s.t. | |
// 1. SetArgs() happens before any GetArg(); | |
// 2. GetRetvals happens after all SetRetval(); | |
class FunctionCallFrame : public CallFrameInterface { | |
public: | |
FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types); | |
~FunctionCallFrame(); | |
// Caller methods. | |
Status SetArgs(gtl::ArraySlice<Tensor> args); | |
Status GetRetvals(std::vector<Tensor>* rets) const; | |
Status ConsumeRetvals(std::vector<Tensor>* rets); | |
size_t num_args() const override { return arg_types_.size(); } | |
size_t num_retvals() const override { return ret_types_.size(); } | |
// Callee methods. | |
Status GetArg(int index, Tensor* val) const override; | |
Status SetRetval(int index, const Tensor& val) override; | |
private: | |
DataTypeVector arg_types_; | |
DataTypeVector ret_types_; | |
gtl::InlinedVector<Tensor, 4> args_; | |
struct Retval { | |
bool has_val = false; | |
Tensor val; | |
}; | |
gtl::InlinedVector<Retval, 4> rets_; | |
TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame); | |
}; | |
// Helper to maintain a map between function names in a given | |
// FunctionDefLibrary and function definitions. | |
class FunctionLibraryDefinition : public OpRegistryInterface { | |
public: | |
explicit FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def); | |
FunctionLibraryDefinition(const OpRegistryInterface* default_registry, | |
const FunctionDefLibrary& lib_def); | |
~FunctionLibraryDefinition() override; | |
FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) = | |
delete; | |
// Returns nullptr if "func" is not defined in "lib_def". Otherwise, | |
// returns its definition proto. | |
const FunctionDef* Find(const string& func) const; | |
// Adds function definition 'fdef' to this function library. | |
// Returns status 'ok' on success, or error otherwise. This is a no-op if | |
// 'fdef' already exists in this function library. | |
// If 'fdef' is successfully added to the library, it will be accessible | |
// from 'LookUp' and included in the proto returned by 'ToProto'. | |
// This operation is atomic. | |
Status AddFunctionDef(const FunctionDef& fdef); | |
// Adds gradient definition 'grad' to this function library. | |
// This is a no-op if 'grad' already exists in this function library. | |
// If 'grad' is successfully added, it will be accessible via 'FindGradient' | |
// and included in the proto returned by 'ToProto'. | |
// This operation is atomic. | |
Status AddGradientDef(const GradientDef& grad); | |
// Adds the functions and gradients in 'other' to this function library. | |
// Duplicate functions and gradients are ignored. | |
// This operation is atomic. | |
Status AddLibrary(const FunctionLibraryDefinition& other); | |
// Adds the functions and gradients in 'lib_def' to this function library. | |
// Duplicate functions and gradients are ignored. | |
// This operation is atomic. | |
Status AddLibrary(const FunctionDefLibrary& lib_def); | |
// If the gradient function for 'func' is specified explicitly in | |
// the library, returns the gradient function name. Otherwise, | |
// returns an empty string. | |
string FindGradient(const string& func) const; | |
// OpRegistryInterface method. Useful for constructing a Graph. | |
// | |
// If "op" is defined in the library, returns its signature. | |
// Otherwise, assume "op" is a primitive op and returns its op | |
// signature and shape inference function. | |
Status LookUp(const string& op_type_name, | |
const OpRegistrationData** op_reg_data) const override; | |
static constexpr const char* const kGradientOp = "SymbolicGradient"; | |
static constexpr const char* const kFuncAttr = "f"; | |
// Given a node def 'ndef', inspects attributes of the callee | |
// function to derive the attribute 'value' for 'attr'. Returns OK | |
// iff the attribute is given by the function's definition. | |
// TODO(irving): Remove; keep only the const Node& version. | |
template <typename T> | |
Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const; | |
// Given a node, inspects attributes of the callee function to derive the | |
// attribute 'value' for 'attr'. Returns OK iff the attribute is given by the | |
// function's definition. | |
template <typename T> | |
Status GetAttr(const Node& node, const string& attr, T* value) const; | |
// Returns a proto representation of the state of this function library. | |
FunctionDefLibrary ToProto() const; | |
const OpRegistryInterface* default_registry() const { | |
return default_registry_; | |
} | |
private: | |
// Shape inference for functions is handled separately by ShapeRefiner. | |
struct FunctionDefAndOpRegistration { | |
FunctionDefAndOpRegistration(const FunctionDef& fdef_in); | |
FunctionDef fdef; | |
OpRegistrationData op_registration_data; | |
}; | |
// Same as AddFunctionDef/AddGradientDef except these methods set | |
// `added` to true if the `fdef`/`grad` were actually added to this. | |
Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added); | |
Status AddGradientDefHelper(const GradientDef& grad, bool* added); | |
const OpRegistryInterface* const default_registry_; | |
gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>> | |
function_defs_; | |
gtl::FlatMap<string, string> func_grad_; | |
// Helper function for GetAttr. Returns the FunctionDef* to get the | |
// attr from. | |
const FunctionDef* GetAttrImpl(const NodeDef& ndef) const; | |
// Remove function `func` from the library. `func` must be in the library. | |
void RemoveFunction(const string& func); | |
// Remove gradient of function `func` from the library. `func` must have | |
// a gradient. | |
void RemoveGradient(const string& func); | |
// Remove all functions in `funcs` and all gradients of | |
// functions in `funcs_with_grads` from this library. | |
void Remove(const std::vector<string>& funcs, | |
const std::vector<string>& funcs_with_grads); | |
}; | |
// Forward declare. Defined in common_runtime/function.h | |
struct FunctionBody; | |
// Forward declare. Defined in common_runtime/device.h | |
class Device; | |
class FunctionLibraryRuntime { | |
public: | |
virtual ~FunctionLibraryRuntime() {} | |
// Instantiate a function with the given "attrs". | |
// | |
// Returns OK and fills in "handle" if the instantiation succeeds. | |
// Otherwise returns an error and "handle" is undefined. | |
typedef uint64 Handle; | |
virtual Status Instantiate(const string& function_name, AttrSlice attrs, | |
Handle* handle) = 0; | |
// Releases state associated with the handle. | |
virtual Status ReleaseHandle(Handle handle) = 0; | |
// Returns the function body for the instantiated function given its | |
// handle 'h'. Returns nullptr if "h" is not found. | |
// | |
// *this keeps the ownership of the returned object, which remains alive | |
// as long as *this. | |
virtual const FunctionBody* GetFunctionBody(Handle h) = 0; | |
// Asynchronously invokes the instantiated function identified by | |
// "handle". | |
// | |
// If function execution succeeds, "done" is called with OK and | |
// "*rets" is filled with the function's return values. Otheriwse, | |
// "done" is called with an error status. | |
// | |
// Does not take ownership of "rets". | |
// In the cross-process scenario, runner isn't used for making the Async | |
// RPC calls. | |
struct Options { | |
// The id of the step that is calling this function. | |
int64 step_id = 0; | |
Rendezvous* rendezvous = nullptr; | |
CancellationManager* cancellation_manager = nullptr; | |
ScopedStepContainer* step_container = nullptr; | |
StepStatsCollector* stats_collector = nullptr; | |
std::function<void(std::function<void()>)>* runner = nullptr; | |
// Parameters for remote function execution. | |
bool remote_execution = false; | |
string source_device = ""; // Fully specified device name. | |
// Allocator attributes specifying where the args are / rets should be put. | |
// These should either be {} or match the length of args / retvals. If {}, | |
// the default allocator attributes will be assumed for all args / retvals. | |
std::vector<AllocatorAttributes> args_alloc_attrs; | |
std::vector<AllocatorAttributes> rets_alloc_attrs; | |
// If true, we create a new IntraProcessRendezvous, else use the existing | |
// one. | |
bool create_rendezvous = false; | |
}; | |
typedef std::function<void(const Status&)> DoneCallback; | |
virtual void Run(const Options& opts, Handle handle, | |
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, | |
DoneCallback done) = 0; | |
virtual void Run(const Options& opts, Handle handle, | |
CallFrameInterface* call_frame, DoneCallback done) = 0; | |
// Creates a "kernel" for the given node def "ndef". | |
// | |
// If succeeds, returns OK and the caller takes the ownership of the | |
// returned "*kernel". Otherwise, returns an error. | |
virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0; | |
// Returns true iff 'function' is stateful. | |
virtual bool IsStateful(const string& function_name) = 0; | |
// Returns the device on which the function executes. | |
virtual Device* device() = 0; | |
// Returns the function library definition that backs this runtime. | |
virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition() | |
const = 0; | |
// Returns the environment on which the function executes. | |
virtual Env* env() = 0; | |
// Returns a debug string showing the definition of the function of | |
// 'handle'. | |
virtual string DebugString(Handle handle) = 0; | |
// Returns the graph version number. | |
virtual int graph_def_version() = 0; | |
typedef uint64 LocalHandle; | |
}; | |
const FunctionLibraryRuntime::Handle kInvalidHandle = -1; | |
const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1; | |
typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&, | |
std::unique_ptr<OpKernel>*)> | |
CustomKernelCreator; | |
// Used to instantiate and run functions in a distributed system. | |
class DistributedFunctionLibraryRuntime { | |
public: | |
virtual ~DistributedFunctionLibraryRuntime() {} | |
// The _target attr in attrs determines where the function is instantiated. | |
virtual Status Instantiate(const string& function_name, | |
const FunctionLibraryDefinition& lib_def, | |
AttrSlice attrs, | |
FunctionLibraryRuntime::LocalHandle* handle) = 0; | |
// opts.runner isn't used for execution. | |
virtual void Run(const FunctionLibraryRuntime::Options& opts, | |
FunctionLibraryRuntime::LocalHandle handle, | |
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, | |
FunctionLibraryRuntime::DoneCallback done) = 0; | |
}; | |
// Extracts the actual type from "attr_values" based on its definition | |
// "arg_def". | |
// | |
// If "arg_def" is a N*T type, *is_type_list is set to false, and | |
// *dtypes is set to be a vector of size N and each element is T. | |
// | |
// If "arg_def" is a list(type), *is_type_list is set to true, and | |
// *dtypes is set to be a vector of types specified in attrs for | |
// arg_def. | |
// | |
// Otherwise (arg_def is a simple type T), *is_type_list is set to | |
// false, and *dtypes is set to a single element vector, whose only | |
// element is T. | |
Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, | |
bool* is_type_list, DataTypeVector* dtypes); | |
// To register a gradient function for a builtin op, one should use | |
// REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>); | |
// | |
// Typically, the c++ grad factory is a plan function that can be | |
// converted into ::tensorflow::gradient::Creator, which is | |
// std::function<Status(const AttrSlice&, FunctionDef*)>. | |
// | |
// A ::tensorflow::gradient::Creator should populate in FunctionDef* with a | |
// definition of a brain function which compute the gradient for the | |
// <op_name> when the <op_name> is instantiated with the given attrs. | |
// | |
// E.g., | |
// | |
// Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { | |
// bool transpose_a; | |
// TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a)); | |
// bool transpose_b; | |
// TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b)); | |
// DataType dtype; | |
// TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype)); | |
// if (!transpose_a && !transpose_b) { | |
// *g = FunctionDefHelper::Define( | |
// "MatMulGrad", | |
// {"x:T ", "y:T", "dz:T"}, // Inputs to this function | |
// {"dx:T", "dy:T"}, // Outputs from this function | |
// {"T: {float, double}"}, // Attributes needed by this function | |
// { | |
// {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}}, | |
// {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}}, | |
// {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}}, | |
// {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}}, | |
// }); | |
// } else { | |
// ... ... | |
// } | |
// return Status::OK(); | |
// } | |
// | |
// NOTE: $T is substituted with the type variable "T" when the | |
// gradient function MatMul is instantiated. | |
// | |
// TODO(zhifengc): Better documentation somewhere. | |
// Macros to define a gradient function factory for a primitive | |
// operation. | |
namespace gradient { | |
// Register a gradient creator for the "op". | |
typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator; | |
bool RegisterOp(const string& op, Creator func); | |
// Returns OK the gradient creator for the "op" is found (may be | |
// nullptr if REGISTER_OP_NO_GRADIENT is used. | |
Status GetOpGradientCreator(const string& op, Creator* creator); | |
}; | |
// Declare explicit instantiations of GetAttr | |
GET_ATTR(string) | |
GET_ATTR(bool) | |
} // end namespace tensorflow | |