Spaces:
Build error
Build error
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
==============================================================================*/ | |
namespace tensorflow { | |
string SummarizeGraphDef(const GraphDef& graph_def) { | |
string ret; | |
strings::StrAppend(&ret, "versions = ", | |
ProtoShortDebugString(graph_def.versions()), ";\n"); | |
for (const NodeDef& node : graph_def.node()) { | |
strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n"); | |
} | |
return ret; | |
} | |
Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) { | |
for (const NodeDef& node : graph_def.node()) { | |
TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node)); | |
} | |
return Status::OK(); | |
} | |
Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, | |
const OpRegistryInterface& op_registry, | |
int node_offset) { | |
if (node_offset > graph_def->node_size()) { | |
return errors::InvalidArgument( | |
"Tried to add default attrs to GraphDef " | |
"starting at offset ", | |
node_offset, " with total nodes in graph: ", graph_def->node_size()); | |
} | |
for (int i = node_offset; i < graph_def->node_size(); ++i) { | |
NodeDef* node_def = graph_def->mutable_node(i); | |
const OpDef* op_def; | |
TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(node_def->op(), &op_def)); | |
AddDefaultsToNodeDef(*op_def, node_def); | |
} | |
return Status::OK(); | |
} | |
static Status RemoveNewDefaultAttrsFromNodeDef( | |
NodeDef* node_def, const OpRegistryInterface& consumer_op_registry, | |
const OpRegistryInterface& producer_op_registry, | |
std::set<std::pair<string, string>>* op_attr_removed) { | |
const OpDef* producer_op_def; | |
const OpDef* consumer_op_def; | |
TF_RETURN_IF_ERROR( | |
producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def)); | |
TF_RETURN_IF_ERROR( | |
consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def)); | |
std::vector<string> to_remove; | |
for (const auto& attr : node_def->attr()) { | |
// If the attr is not in consumer_op_def and doesn't start with '_'... | |
if (!StringPiece(attr.first).starts_with("_") && | |
FindAttr(attr.first, *consumer_op_def) == nullptr) { | |
const OpDef::AttrDef* producer_attr_def = | |
FindAttr(attr.first, *producer_op_def); | |
if (producer_attr_def == nullptr) { | |
return errors::InvalidArgument( | |
"Attr '", attr.first, "' missing in producer's OpDef: ", | |
SummarizeOpDef(*producer_op_def), " but found in node: ", | |
SummarizeNodeDef(*node_def)); | |
} | |
// ...and it has the same value as the default in producer, | |
if (producer_attr_def->has_default_value() && | |
AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) { | |
// then we will remove it below. | |
to_remove.emplace_back(attr.first); | |
} | |
} | |
} | |
// We separate identifying which attrs should be removed from | |
// actually removing them to avoid invalidating the loop iterators | |
// above. | |
for (const string& attr_name : to_remove) { | |
node_def->mutable_attr()->erase(attr_name); | |
if (op_attr_removed != nullptr) { | |
op_attr_removed->insert(std::make_pair(node_def->op(), attr_name)); | |
} | |
} | |
return Status::OK(); | |
} | |
static bool IsFunction(const GraphDef& graph_def, const string& op_name) { | |
for (const auto& func_def : graph_def.library().function()) { | |
if (op_name == func_def.signature().name()) return true; | |
} | |
return false; | |
} | |
Status RemoveNewDefaultAttrsFromGraphDef( | |
GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry, | |
const OpRegistryInterface& producer_op_registry, | |
std::set<std::pair<string, string>>* op_attr_removed) { | |
// TODO(joshL): Make IsFunction() faster by collecting the names of | |
// all functions as a preprocessing step. | |
for (int n = 0; n < graph_def->node_size(); ++n) { | |
NodeDef* node_def = graph_def->mutable_node(n); | |
if (!IsFunction(*graph_def, node_def->op())) { | |
TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef( | |
node_def, consumer_op_registry, producer_op_registry, | |
op_attr_removed)); | |
} | |
} | |
for (int f = 0; f < graph_def->library().function_size(); ++f) { | |
FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f); | |
for (int n = 0; n < func_def->node_def_size(); ++n) { | |
NodeDef* node_def = func_def->mutable_node_def(n); | |
if (!IsFunction(*graph_def, node_def->op())) { | |
// TODO(josh11b): Better handling of attrs with placeholder values. | |
TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef( | |
node_def, consumer_op_registry, producer_op_registry, | |
op_attr_removed)); | |
} | |
} | |
} | |
return Status::OK(); | |
} | |
void OpsUsedByGraph(const GraphDef& graph_def, | |
std::set<string>* ops_used_in_graph) { | |
// Map function names to definitions. | |
std::unordered_map<string, const FunctionDef*> name_to_function; | |
for (const auto& function : graph_def.library().function()) { | |
name_to_function.insert( | |
std::make_pair(function.signature().name(), &function)); | |
} | |
// Collect the sorted list of op names. Since functions can reference | |
// functions, we need a recursive traversal. | |
std::set<string> used_ops; // Includes both primitive ops and functions | |
std::vector<const FunctionDef*> functions_to_process; // A subset of used_ops | |
// Collect the logic to mark an op in a lambda; it'll be used twice below. | |
const auto mark_op_as_used = [&used_ops, &functions_to_process, | |
&name_to_function](const string& op) { | |
if (used_ops.insert(op).second) { | |
// If it's a function, we'll need to process further | |
const auto it = name_to_function.find(op); | |
if (it != name_to_function.end()) { | |
functions_to_process.push_back(it->second); | |
} | |
} | |
}; | |
for (const auto& node : graph_def.node()) { | |
mark_op_as_used(node.op()); | |
} | |
while (!functions_to_process.empty()) { | |
const FunctionDef* fun = functions_to_process.back(); | |
functions_to_process.pop_back(); | |
for (const auto& node : fun->node_def()) { | |
mark_op_as_used(node.op()); | |
} | |
} | |
// Filter out function names to produce output. | |
// TODO(josh11b): Change the above code to produce this directly. | |
ops_used_in_graph->clear(); | |
for (const string& op_name : used_ops) { | |
if (name_to_function.find(op_name) == name_to_function.end()) { | |
ops_used_in_graph->insert(op_name); | |
} | |
} | |
} | |
Status StrippedOpListForGraph(const GraphDef& graph_def, | |
const OpRegistryInterface& op_registry, | |
OpList* stripped_op_list) { | |
std::set<string> used_ops; | |
OpsUsedByGraph(graph_def, &used_ops); | |
// Build the stripped op list in sorted order, ignoring functions. | |
stripped_op_list->clear_op(); | |
for (const string& op_name : used_ops) { | |
const OpDef* op_def; | |
TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def)); | |
OpDef* stripped_op = stripped_op_list->add_op(); | |
stripped_op->CopyFrom(*op_def); | |
RemoveDescriptionsFromOpDef(stripped_op); | |
} | |
return Status::OK(); | |
} | |
} // namespace tensorflow | |