Spaces:
Sleeping
Sleeping
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
==============================================================================*/ | |
namespace tensorflow { | |
namespace { | |
Status FinalizeOpDef(const OpDefBuilder& b, OpDef* op_def) { | |
OpRegistrationData op_reg_data; | |
const Status s = b.Finalize(&op_reg_data); | |
*op_def = op_reg_data.op_def; | |
return s; | |
} | |
// Producer and consumer have default for an attr -> graph unchanged. | |
TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) { | |
OpList op_list; | |
TF_ASSERT_OK( | |
FinalizeOpDef(OpDefBuilder("NoChangeWithDefault").Attr("a: int = 12"), | |
op_list.add_op())); | |
OpListOpRegistry registry(&op_list); | |
GraphDef graph_def; | |
TF_ASSERT_OK(NodeDefBuilder("ncwd", "NoChangeWithDefault", ®istry) | |
.Finalize(graph_def.add_node())); | |
GraphDef expected_graph_def = graph_def; | |
std::set<std::pair<string, string>> op_attr_removed; | |
TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry, | |
&op_attr_removed)); | |
TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def); | |
EXPECT_TRUE(op_attr_removed.empty()); | |
} | |
// Producer and consumer both have an attr -> graph unchanged. | |
TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) { | |
OpList op_list; | |
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("NoChangeNoDefault").Attr("a: int"), | |
op_list.add_op())); | |
OpListOpRegistry registry(&op_list); | |
GraphDef graph_def; | |
TF_ASSERT_OK(NodeDefBuilder("ncnd", "NoChangeNoDefault", ®istry) | |
.Attr("a", 42) | |
.Finalize(graph_def.add_node())); | |
GraphDef expected_graph_def = graph_def; | |
std::set<std::pair<string, string>> op_attr_removed; | |
TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry, | |
&op_attr_removed)); | |
TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def); | |
EXPECT_TRUE(op_attr_removed.empty()); | |
} | |
// Producer has default for an attr that the consumer does not know | |
// about, and the produced graph has the default value for the attr -> | |
// attr removed from graph (and so able to be consumed). | |
TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) { | |
OpList consumer_op_list; | |
TF_ASSERT_OK( | |
FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op())); | |
OpListOpRegistry consumer_registry(&consumer_op_list); | |
OpList producer_op_list; | |
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"), | |
producer_op_list.add_op())); | |
OpListOpRegistry producer_registry(&producer_op_list); | |
GraphDef produced_graph_def; | |
TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &producer_registry) | |
.Finalize(produced_graph_def.add_node())); | |
std::set<std::pair<string, string>> op_attr_removed; | |
TF_ASSERT_OK( | |
RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, | |
producer_registry, &op_attr_removed)); | |
GraphDef expected_graph_def; | |
TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &consumer_registry) | |
.Finalize(expected_graph_def.add_node())); | |
TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def); | |
std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}}); | |
EXPECT_EQ(expected_removed, op_attr_removed); | |
} | |
// Producer has default for an attr that the consumer does not know | |
// about, graph sets the attr to a value different from the default -> | |
// graph unchanged (but not able to be consumed by consumer). | |
TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) { | |
OpList consumer_op_list; | |
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"), | |
consumer_op_list.add_op())); | |
OpListOpRegistry consumer_registry(&consumer_op_list); | |
OpList producer_op_list; | |
TF_ASSERT_OK( | |
FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"), | |
producer_op_list.add_op())); | |
OpListOpRegistry producer_registry(&producer_op_list); | |
GraphDef produced_graph_def; | |
TF_ASSERT_OK(NodeDefBuilder("changed_from_default", "ChangedFromDefault", | |
&producer_registry) | |
.Attr("a", 9) | |
.Finalize(produced_graph_def.add_node())); | |
GraphDef expected_graph_def = produced_graph_def; | |
std::set<std::pair<string, string>> op_attr_removed; | |
TF_ASSERT_OK( | |
RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, | |
producer_registry, &op_attr_removed)); | |
TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def); | |
EXPECT_TRUE(op_attr_removed.empty()); | |
} | |
// Attrs starting with underscores should not be removed. | |
TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) { | |
OpList consumer_op_list; | |
TF_ASSERT_OK( | |
FinalizeOpDef(OpDefBuilder("Underscore"), consumer_op_list.add_op())); | |
OpListOpRegistry consumer_registry(&consumer_op_list); | |
OpList producer_op_list; | |
TF_ASSERT_OK( | |
FinalizeOpDef(OpDefBuilder("Underscore"), producer_op_list.add_op())); | |
// Add the _underscore attr manually since OpDefBuilder would complain | |
OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr(); | |
attr->set_name("_underscore"); | |
attr->set_type("int"); | |
attr->mutable_default_value()->set_i(17); | |
OpListOpRegistry producer_registry(&producer_op_list); | |
GraphDef produced_graph_def; | |
TF_ASSERT_OK(NodeDefBuilder("node", "Underscore", &producer_registry) | |
.Attr("_underscore", 17) | |
.Finalize(produced_graph_def.add_node())); | |
GraphDef expected_graph_def = produced_graph_def; | |
std::set<std::pair<string, string>> op_attr_removed; | |
TF_ASSERT_OK( | |
RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, | |
producer_registry, &op_attr_removed)); | |
TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def); | |
EXPECT_EQ(op_attr_removed.size(), 0); | |
} | |
TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) { | |
OpList consumer_op_list; | |
TF_ASSERT_OK( | |
FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op())); | |
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"), | |
consumer_op_list.add_op())); | |
OpListOpRegistry consumer_registry(&consumer_op_list); | |
OpList producer_op_list; | |
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"), | |
producer_op_list.add_op())); | |
TF_ASSERT_OK( | |
FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"), | |
producer_op_list.add_op())); | |
OpListOpRegistry producer_registry(&producer_op_list); | |
GraphDef produced_graph_def; | |
*produced_graph_def.mutable_library()->add_function() = | |
FunctionDefHelper::Create( | |
"my_func", {}, {}, {}, | |
{{{"x"}, "UsesDefault", {}, {{"a", 17}}}, | |
{{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}}, | |
{}); | |
OpList function_op_list; | |
*function_op_list.add_op() = | |
produced_graph_def.library().function(0).signature(); | |
OpListOpRegistry function_registry(&function_op_list); | |
TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry) | |
.Finalize(produced_graph_def.add_node())); | |
std::set<std::pair<string, string>> op_attr_removed; | |
TF_ASSERT_OK( | |
RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, | |
producer_registry, &op_attr_removed)); | |
GraphDef expected_graph_def; | |
*expected_graph_def.mutable_library()->add_function() = | |
FunctionDefHelper::Create( | |
"my_func", {}, {}, {}, | |
{{{"x"}, "UsesDefault", {}, {}}, | |
{{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}}, | |
{}); | |
TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry) | |
.Finalize(expected_graph_def.add_node())); | |
TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def); | |
EXPECT_EQ(expected_graph_def.library().DebugString(), | |
produced_graph_def.library().DebugString()); | |
std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}}); | |
EXPECT_EQ(expected_removed, op_attr_removed); | |
} | |
TEST(StrippedOpListForGraphTest, FlatTest) { | |
// Make four ops | |
OpList op_list; | |
for (const string& op : {"A", "B", "C", "D"}) { | |
OpDef* op_def = op_list.add_op(); | |
op_def->set_name(op); | |
op_def->set_summary("summary"); | |
op_def->set_description("description"); | |
op_def->set_is_commutative(op == "B"); | |
} | |
// Make a graph which uses two ops once and twice, respectively. | |
// The result should be independent of the ordering. | |
const string graph_ops[4][3] = { | |
{"C", "B", "B"}, {"B", "C", "B"}, {"B", "B", "C"}, {"C", "C", "B"}}; | |
for (const bool use_function : {false, true}) { | |
for (int order = 0; order < 4; order++) { | |
GraphDef graph_def; | |
if (use_function) { | |
FunctionDef* function_def = graph_def.mutable_library()->add_function(); | |
function_def->mutable_signature()->set_name("F"); | |
for (const string& op : graph_ops[order]) { | |
function_def->add_node_def()->set_op(op); | |
} | |
graph_def.add_node()->set_op("F"); | |
} else { | |
for (const string& op : graph_ops[order]) { | |
string name = strings::StrCat("name", graph_def.node_size()); | |
NodeDef* node = graph_def.add_node(); | |
node->set_name(name); | |
node->set_op(op); | |
} | |
} | |
// Strip the op list | |
OpList stripped_op_list; | |
TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list), | |
&stripped_op_list)); | |
// We should have exactly two ops: B and C. | |
ASSERT_EQ(stripped_op_list.op_size(), 2); | |
for (int i = 0; i < 2; i++) { | |
const OpDef& op = stripped_op_list.op(i); | |
EXPECT_EQ(op.name(), i ? "C" : "B"); | |
EXPECT_EQ(op.summary(), ""); | |
EXPECT_EQ(op.description(), ""); | |
EXPECT_EQ(op.is_commutative(), !i); | |
} | |
// Should get the same result using OpsUsedByGraph(). | |
std::set<string> used_ops; | |
OpsUsedByGraph(graph_def, &used_ops); | |
ASSERT_EQ(std::set<string>({"B", "C"}), used_ops); | |
} | |
} | |
} | |
TEST(StrippedOpListForGraphTest, NestedFunctionTest) { | |
// Make a primitive op A. | |
OpList op_list; | |
op_list.add_op()->set_name("A"); | |
for (const bool recursive : {false, true}) { | |
// Call A from function B, and B from function C. | |
GraphDef graph_def; | |
FunctionDef* b = graph_def.mutable_library()->add_function(); | |
FunctionDef* c = graph_def.mutable_library()->add_function(); | |
b->mutable_signature()->set_name("B"); | |
c->mutable_signature()->set_name("C"); | |
b->add_node_def()->set_op("A"); | |
c->add_node_def()->set_op("B"); | |
if (recursive) { | |
b->add_node_def()->set_op("B"); | |
c->add_node_def()->set_op("C"); | |
} | |
// Use C in the graph. | |
graph_def.add_node()->set_op("C"); | |
// The stripped op list should contain just A. | |
OpList stripped_op_list; | |
TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list), | |
&stripped_op_list)); | |
ASSERT_EQ(stripped_op_list.op_size(), 1); | |
ASSERT_EQ(stripped_op_list.op(0).name(), "A"); | |
// Should get the same result using OpsUsedByGraph(). | |
std::set<string> used_ops; | |
OpsUsedByGraph(graph_def, &used_ops); | |
ASSERT_EQ(std::set<string>({"A"}), used_ops); | |
} | |
} | |
} // namespace | |
} // namespace tensorflow | |