Spaces:
Sleeping
Sleeping
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
==============================================================================*/ | |
namespace tensorflow { | |
namespace test { | |
namespace function { | |
typedef FunctionDefHelper FDH; | |
GraphDef GDef(gtl::ArraySlice<NodeDef> nodes, | |
gtl::ArraySlice<FunctionDef> funcs) { | |
GraphDef g; | |
VersionDef* versions = g.mutable_versions(); | |
versions->set_producer(TF_GRAPH_DEF_VERSION); | |
versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER); | |
for (const auto& n : nodes) { | |
*(g.add_node()) = n; | |
} | |
auto lib = g.mutable_library(); | |
for (const auto& f : funcs) { | |
*(lib->add_function()) = f; | |
} | |
return g; | |
} | |
// Helper to construct a NodeDef. | |
NodeDef NDef(const string& name, const string& op, | |
gtl::ArraySlice<string> inputs, | |
gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs, | |
const string& device) { | |
NodeDef n; | |
n.set_name(name); | |
n.set_op(op); | |
for (const auto& in : inputs) n.add_input(in); | |
n.set_device(device); | |
for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto}); | |
return n; | |
} | |
FunctionDef NonZero() { | |
return FDH::Define( | |
// Name | |
"NonZero", | |
// Args | |
{"x:T"}, | |
// Return values | |
{"y:T"}, | |
// Attr def | |
{"T:{float, double, int32, int64, string}"}, | |
// Nodes | |
{ | |
{{"y"}, "Identity", {"x"}, {{"T", "$T"}}}, | |
}); | |
} | |
FunctionDef XTimesTwo() { | |
const Tensor kTwo = test::AsScalar<int64>(2); | |
return FDH::Define( | |
// Name | |
"XTimesTwo", | |
// Args | |
{"x: T"}, | |
// Return values | |
{"y: T"}, | |
// Attr def | |
{"T: {float, double, int32, int64}"}, | |
// Nodes | |
{ | |
{{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, | |
{{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, | |
{{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}}, | |
}); | |
} | |
FunctionDef XTimesTwoInt32() { | |
const Tensor kTwo = test::AsScalar<int64>(2); | |
return FDH::Define( | |
// Name | |
"XTimesTwoInt32", | |
// Args | |
{"x: int32"}, | |
// Return values | |
{"y: int32"}, {}, | |
// Nodes | |
{ | |
{{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, | |
{{"scale"}, | |
"Cast", | |
{"two"}, | |
{{"SrcT", DT_INT64}, {"DstT", DT_INT32}}}, | |
{{"y"}, "Mul", {"x", "scale"}, {{"T", DT_INT32}}}, | |
}); | |
} | |
FunctionDef XTimesFour() { | |
return FDH::Create( | |
// Name | |
"XTimesFour", | |
// Args | |
{"x: T"}, | |
// Return values | |
{"y: T"}, | |
// Attr def | |
{"T: {float, double, int32, int64}"}, | |
// Nodes | |
{ | |
{{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}}, | |
{{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}}, | |
}, | |
{{"y", "y:y:0"}}); | |
} | |
FunctionDef XTimes16() { | |
return FDH::Create( | |
// Name | |
"XTimes16", | |
// Args | |
{"x: T"}, | |
// Return values | |
{"y: T"}, | |
// Attr def | |
{"T: {float, double, int32, int64}"}, | |
// Nodes | |
{ | |
{{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}}, | |
{{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}}, | |
}, | |
{{"y", "y:y:0"}}); | |
} | |
FunctionDef WXPlusB(){return FDH::Define( | |
// Name | |
"WXPlusB", | |
// Args | |
{"w: T", "x: T", "b: T"}, | |
// Return values | |
{"y: T"}, | |
// Attr def | |
{"T: {float, double}"}, | |
// Nodes | |
{ | |
{{"mm"}, | |
"MatMul", | |
{"w", "x"}, | |
{ | |
{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false}, | |
}}, | |
{"_kernel", "eigen"}}}, | |
{ | |
{"y"}, "Add", {"mm", "b"}, { | |
{ "T", "$T" } | |
} | |
} | |
}); | |
} | |
FunctionDef Swap() { | |
return FDH::Define( | |
// Name | |
"Swap", | |
// Args | |
{"i0: T", "i1: T"}, | |
// Return values | |
{"o0: T", "o1: T"}, | |
// Attr def | |
{"T: {float, double}"}, | |
// Nodes | |
{{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}}, | |
{{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}}); | |
} | |
void FunctionTestSchedClosure(std::function<void()> fn) { | |
static thread::ThreadPool* w = | |
new thread::ThreadPool(Env::Default(), "Test", 8); | |
w->Schedule(std::move(fn)); | |
} | |
} // end namespace function | |
} // end namespace test | |
} // end namespace tensorflow | |