Spaces:
Build error
Build error
Upload 161 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- allocation_description.proto +27 -0
- allocator.cc +130 -0
- allocator.h +394 -0
- allocator_registry.cc +80 -0
- allocator_registry.h +80 -0
- allocator_test.cc +186 -0
- api_def.proto +120 -0
- attr_value.proto +62 -0
- attr_value_util.cc +551 -0
- attr_value_util.h +116 -0
- attr_value_util_test.cc +195 -0
- bfloat16.cc +50 -0
- bfloat16.h +62 -0
- bfloat16_test.cc +158 -0
- cancellation.cc +94 -0
- cancellation.h +137 -0
- cancellation_test.cc +118 -0
- common_shape_fns.cc +1399 -0
- common_shape_fns.h +290 -0
- common_shape_fns_test.cc +1131 -0
- control_flow.h +58 -0
- cost_graph.proto +72 -0
- device_attributes.proto +35 -0
- device_base.cc +30 -0
- device_base.h +243 -0
- fake_input.cc +240 -0
- fake_input.h +40 -0
- function.cc +1322 -0
- function.h +625 -0
- function.proto +101 -0
- function_test.cc +1339 -0
- function_testlib.cc +204 -0
- function_testlib.h +90 -0
- graph.proto +56 -0
- graph_def_util.cc +218 -0
- graph_def_util.h +115 -0
- graph_def_util_test.cc +321 -0
- graph_transfer_info.proto +68 -0
- iterator.proto +17 -0
- kernel_def.proto +36 -0
- kernel_def_builder.cc +75 -0
- kernel_def_builder.h +87 -0
- kernel_def_builder_test.cc +91 -0
- load_library.cc +104 -0
- log_memory.cc +102 -0
- log_memory.h +111 -0
- log_memory.proto +93 -0
- lookup_interface.cc +87 -0
- lookup_interface.h +145 -0
- memory_types.cc +156 -0
allocation_description.proto
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto3";
|
2 |
+
|
3 |
+
package tensorflow;
|
4 |
+
option cc_enable_arenas = true;
|
5 |
+
option java_outer_classname = "AllocationDescriptionProtos";
|
6 |
+
option java_multiple_files = true;
|
7 |
+
option java_package = "org.tensorflow.framework";
|
8 |
+
|
9 |
+
message AllocationDescription {
|
10 |
+
// Total number of bytes requested
|
11 |
+
int64 requested_bytes = 1;
|
12 |
+
|
13 |
+
// Total number of bytes allocated if known
|
14 |
+
int64 allocated_bytes = 2;
|
15 |
+
|
16 |
+
// Name of the allocator used
|
17 |
+
string allocator_name = 3;
|
18 |
+
|
19 |
+
// Identifier of the allocated buffer if known
|
20 |
+
int64 allocation_id = 4;
|
21 |
+
|
22 |
+
// Set if this tensor only has one remaining reference
|
23 |
+
bool has_single_reference = 5;
|
24 |
+
|
25 |
+
// Address of the allocation.
|
26 |
+
uint64 ptr = 6;
|
27 |
+
};
|
allocator.cc
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/allocator.h"
|
17 |
+
|
18 |
+
#include "tensorflow/core/framework/allocator_registry.h"
|
19 |
+
#include "tensorflow/core/framework/log_memory.h"
|
20 |
+
#include "tensorflow/core/framework/tracking_allocator.h"
|
21 |
+
#include "tensorflow/core/lib/strings/stringprintf.h"
|
22 |
+
#include "tensorflow/core/platform/mem.h"
|
23 |
+
#include "tensorflow/core/platform/mutex.h"
|
24 |
+
#include "tensorflow/core/platform/types.h"
|
25 |
+
|
26 |
+
namespace tensorflow {
|
27 |
+
|
28 |
+
void AllocatorStats::Clear() {
|
29 |
+
this->num_allocs = 0;
|
30 |
+
this->bytes_in_use = 0;
|
31 |
+
this->max_bytes_in_use = 0;
|
32 |
+
this->max_alloc_size = 0;
|
33 |
+
this->bytes_limit = 0;
|
34 |
+
}
|
35 |
+
|
36 |
+
string AllocatorStats::DebugString() const {
|
37 |
+
return strings::Printf(
|
38 |
+
"Limit: %20lld\n"
|
39 |
+
"InUse: %20lld\n"
|
40 |
+
"MaxInUse: %20lld\n"
|
41 |
+
"NumAllocs: %20lld\n"
|
42 |
+
"MaxAllocSize: %20lld\n",
|
43 |
+
this->bytes_limit, this->bytes_in_use, this->max_bytes_in_use,
|
44 |
+
this->num_allocs, this->max_alloc_size);
|
45 |
+
}
|
46 |
+
|
47 |
+
constexpr size_t Allocator::kAllocatorAlignment;
|
48 |
+
|
49 |
+
Allocator::~Allocator() {}
|
50 |
+
|
51 |
+
void RunResourceCtor(ResourceHandle* p, size_t n) {
|
52 |
+
for (size_t i = 0; i < n; ++p, ++i) new (p) ResourceHandle();
|
53 |
+
}
|
54 |
+
|
55 |
+
void RunResourceDtor(ResourceHandle* p, size_t n) {
|
56 |
+
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
|
57 |
+
}
|
58 |
+
|
59 |
+
// If true, cpu allocator collects more stats.
|
60 |
+
static bool cpu_allocator_collect_stats = false;
|
61 |
+
// If true, cpu allocator collects full stats.
|
62 |
+
static bool cpu_allocator_collect_full_stats = false;
|
63 |
+
|
64 |
+
void EnableCPUAllocatorStats(bool enable) {
|
65 |
+
cpu_allocator_collect_stats = enable;
|
66 |
+
}
|
67 |
+
void EnableCPUAllocatorFullStats(bool enable) {
|
68 |
+
cpu_allocator_collect_full_stats = enable;
|
69 |
+
}
|
70 |
+
|
71 |
+
class CPUAllocator : public Allocator {
|
72 |
+
public:
|
73 |
+
CPUAllocator() {}
|
74 |
+
|
75 |
+
~CPUAllocator() override {}
|
76 |
+
|
77 |
+
string Name() override { return "cpu"; }
|
78 |
+
|
79 |
+
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
|
80 |
+
void* p = port::AlignedMalloc(num_bytes, alignment);
|
81 |
+
if (cpu_allocator_collect_stats) {
|
82 |
+
const std::size_t alloc_size = port::MallocExtension_GetAllocatedSize(p);
|
83 |
+
mutex_lock l(mu_);
|
84 |
+
++stats_.num_allocs;
|
85 |
+
stats_.bytes_in_use += alloc_size;
|
86 |
+
stats_.max_bytes_in_use =
|
87 |
+
std::max<int64>(stats_.max_bytes_in_use, stats_.bytes_in_use);
|
88 |
+
stats_.max_alloc_size =
|
89 |
+
std::max<int64>(stats_.max_alloc_size, alloc_size);
|
90 |
+
}
|
91 |
+
return p;
|
92 |
+
}
|
93 |
+
|
94 |
+
void DeallocateRaw(void* ptr) override {
|
95 |
+
if (cpu_allocator_collect_stats) {
|
96 |
+
const std::size_t alloc_size =
|
97 |
+
port::MallocExtension_GetAllocatedSize(ptr);
|
98 |
+
mutex_lock l(mu_);
|
99 |
+
stats_.bytes_in_use -= alloc_size;
|
100 |
+
}
|
101 |
+
port::AlignedFree(ptr);
|
102 |
+
}
|
103 |
+
|
104 |
+
void GetStats(AllocatorStats* stats) override {
|
105 |
+
mutex_lock l(mu_);
|
106 |
+
*stats = stats_;
|
107 |
+
}
|
108 |
+
|
109 |
+
size_t AllocatedSizeSlow(void* ptr) override {
|
110 |
+
return port::MallocExtension_GetAllocatedSize(ptr);
|
111 |
+
}
|
112 |
+
|
113 |
+
private:
|
114 |
+
mutex mu_;
|
115 |
+
AllocatorStats stats_ GUARDED_BY(mu_);
|
116 |
+
|
117 |
+
TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator);
|
118 |
+
};
|
119 |
+
|
120 |
+
Allocator* cpu_allocator() {
|
121 |
+
static Allocator* cpu_alloc = AllocatorRegistry::Global()->GetAllocator();
|
122 |
+
if (cpu_allocator_collect_full_stats && !cpu_alloc->TracksAllocationSizes()) {
|
123 |
+
cpu_alloc = new TrackingAllocator(cpu_alloc, true);
|
124 |
+
}
|
125 |
+
return cpu_alloc;
|
126 |
+
}
|
127 |
+
|
128 |
+
REGISTER_MEM_ALLOCATOR("DefaultCPUAllocator", 100, CPUAllocator);
|
129 |
+
|
130 |
+
} // namespace tensorflow
|
allocator.h
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
|
18 |
+
|
19 |
+
#include <stdlib.h>
|
20 |
+
|
21 |
+
#include <limits>
|
22 |
+
|
23 |
+
#include "tensorflow/core/framework/numeric_types.h"
|
24 |
+
#include "tensorflow/core/framework/resource_handle.h"
|
25 |
+
#include "tensorflow/core/framework/type_traits.h"
|
26 |
+
#include "tensorflow/core/framework/variant.h"
|
27 |
+
#include "tensorflow/core/platform/logging.h"
|
28 |
+
#include "tensorflow/core/platform/types.h"
|
29 |
+
|
30 |
+
namespace tensorflow {
|
31 |
+
|
32 |
+
// Attributes for a single allocation call. Different calls to the same
|
33 |
+
// allocator could potentially have different allocation attributes.
|
34 |
+
struct AllocationAttributes {
|
35 |
+
// If the first attempt to allocate the memory fails, the allocation
|
36 |
+
// should return immediately without retrying.
|
37 |
+
// An example use case is optional scratch spaces where a failure
|
38 |
+
// has only performance impact.
|
39 |
+
bool no_retry_on_failure = false;
|
40 |
+
// If a Tensor is allocated without the following set to true, then
|
41 |
+
// it is logged as an unknown allocation. During execution Tensors
|
42 |
+
// should be allocated through the OpKernelContext which records
|
43 |
+
// which Op is performing the allocation, and sets this flag to
|
44 |
+
// true.
|
45 |
+
bool allocation_will_be_logged = false;
|
46 |
+
};
|
47 |
+
|
48 |
+
// Runtime statistics collected by an allocator.
|
49 |
+
struct AllocatorStats {
|
50 |
+
int64 num_allocs; // Number of allocations.
|
51 |
+
int64 bytes_in_use; // Number of bytes in use.
|
52 |
+
int64 max_bytes_in_use; // The maximum bytes in use.
|
53 |
+
int64 max_alloc_size; // The max single allocation seen.
|
54 |
+
|
55 |
+
// The upper limit what the allocator can allocate, if such a limit
|
56 |
+
// is known. Certain allocator may return 0 to indicate the limit is
|
57 |
+
// unknown.
|
58 |
+
int64 bytes_limit;
|
59 |
+
|
60 |
+
AllocatorStats() { Clear(); }
|
61 |
+
|
62 |
+
void Clear();
|
63 |
+
string DebugString() const;
|
64 |
+
};
|
65 |
+
|
66 |
+
// Allocator is an abstract interface for allocating and deallocating
|
67 |
+
// device memory.
|
68 |
+
class Allocator {
|
69 |
+
public:
|
70 |
+
#ifdef EIGEN_VECTORIZE_AVX512
|
71 |
+
// Align to 64 byte boundary.
|
72 |
+
static constexpr size_t kAllocatorAlignment = 64;
|
73 |
+
#else
|
74 |
+
// Align to 32 byte boundary.
|
75 |
+
static constexpr size_t kAllocatorAlignment = 32;
|
76 |
+
#endif
|
77 |
+
|
78 |
+
virtual ~Allocator();
|
79 |
+
|
80 |
+
// Return a string identifying this allocator
|
81 |
+
virtual string Name() = 0;
|
82 |
+
|
83 |
+
// Return an uninitialized block of memory that is "num_bytes" bytes
|
84 |
+
// in size. The returned pointer is guaranteed to be aligned to a
|
85 |
+
// multiple of "alignment" bytes.
|
86 |
+
// REQUIRES: "alignment" is a power of 2.
|
87 |
+
virtual void* AllocateRaw(size_t alignment, size_t num_bytes) = 0;
|
88 |
+
|
89 |
+
// Return an uninitialized block of memory that is "num_bytes" bytes
|
90 |
+
// in size with specified allocation attributes. The returned pointer is
|
91 |
+
// guaranteed to be aligned to a multiple of "alignment" bytes.
|
92 |
+
// REQUIRES: "alignment" is a power of 2.
|
93 |
+
virtual void* AllocateRaw(size_t alignment, size_t num_bytes,
|
94 |
+
const AllocationAttributes& allocation_attr) {
|
95 |
+
// The default behavior is to use the implementation without any allocation
|
96 |
+
// attributes.
|
97 |
+
return AllocateRaw(alignment, num_bytes);
|
98 |
+
}
|
99 |
+
|
100 |
+
// Deallocate a block of memory pointer to by "ptr"
|
101 |
+
// REQUIRES: "ptr" was previously returned by a call to AllocateRaw
|
102 |
+
virtual void DeallocateRaw(void* ptr) = 0;
|
103 |
+
|
104 |
+
// Convenience functions to do typed allocation. C++ constructors
|
105 |
+
// and destructors are invoked for complex types if necessary,
|
106 |
+
// depending on the concrete Allocator implementation. May return
|
107 |
+
// NULL if the tensor has too many elements to represent in a single
|
108 |
+
// allocation.
|
109 |
+
template <typename T>
|
110 |
+
T* Allocate(size_t num_elements) {
|
111 |
+
return Allocate<T>(num_elements, AllocationAttributes());
|
112 |
+
}
|
113 |
+
|
114 |
+
template <typename T>
|
115 |
+
T* Allocate(size_t num_elements,
|
116 |
+
const AllocationAttributes& allocation_attr) {
|
117 |
+
// TODO(jeff): Do we need to allow clients to pass in alignment
|
118 |
+
// requirements?
|
119 |
+
|
120 |
+
if (num_elements > (std::numeric_limits<size_t>::max() / sizeof(T))) {
|
121 |
+
return NULL;
|
122 |
+
}
|
123 |
+
|
124 |
+
void* p = AllocateRaw(kAllocatorAlignment, sizeof(T) * num_elements,
|
125 |
+
allocation_attr);
|
126 |
+
T* typed_p = reinterpret_cast<T*>(p);
|
127 |
+
if (typed_p) RunCtor<T>(typed_p, num_elements);
|
128 |
+
return typed_p;
|
129 |
+
}
|
130 |
+
|
131 |
+
template <typename T>
|
132 |
+
void Deallocate(T* ptr, size_t num_elements) {
|
133 |
+
if (ptr) {
|
134 |
+
RunDtor<T>(ptr, num_elements);
|
135 |
+
DeallocateRaw(ptr);
|
136 |
+
}
|
137 |
+
}
|
138 |
+
|
139 |
+
// Returns true if this allocator tracks the sizes of allocations.
|
140 |
+
// RequestedSize and AllocatedSize must be overridden if
|
141 |
+
// TracksAllocationSizes is overridden to return true.
|
142 |
+
virtual bool TracksAllocationSizes() { return false; }
|
143 |
+
|
144 |
+
// Returns true if this allocator requires tensors with 0 elements
|
145 |
+
// to allocate buffers. This is false for most allocators, but may
|
146 |
+
// be used by special-case allocators that want to track tensor
|
147 |
+
// usage.
|
148 |
+
virtual bool ShouldAllocateEmptyTensors() { return false; }
|
149 |
+
|
150 |
+
// Returns the user-requested size of the data allocated at
|
151 |
+
// 'ptr'. Note that the actual buffer allocated might be larger
|
152 |
+
// than requested, but this function returns the size requested by
|
153 |
+
// the user.
|
154 |
+
//
|
155 |
+
// REQUIRES: TracksAllocationSizes() is true.
|
156 |
+
//
|
157 |
+
// REQUIRES: 'ptr!=nullptr' and points to a buffer previously
|
158 |
+
// allocated by this allocator.
|
159 |
+
virtual size_t RequestedSize(void* ptr) {
|
160 |
+
CHECK(false) << "allocator doesn't track sizes";
|
161 |
+
return size_t(0);
|
162 |
+
}
|
163 |
+
|
164 |
+
// Returns the allocated size of the buffer at 'ptr' if known,
|
165 |
+
// otherwise returns RequestedSize(ptr). AllocatedSize(ptr) is
|
166 |
+
// guaranteed to be >= RequestedSize(ptr).
|
167 |
+
//
|
168 |
+
// REQUIRES: TracksAllocationSizes() is true.
|
169 |
+
//
|
170 |
+
// REQUIRES: 'ptr!=nullptr' and points to a buffer previously
|
171 |
+
// allocated by this allocator.
|
172 |
+
virtual size_t AllocatedSize(void* ptr) { return RequestedSize(ptr); }
|
173 |
+
|
174 |
+
// Returns either 0 or an identifier assigned to the buffer at 'ptr'
|
175 |
+
// when the buffer was returned by AllocateRaw. If non-zero, the
|
176 |
+
// identifier differs from every other ID assigned by this
|
177 |
+
// allocator.
|
178 |
+
//
|
179 |
+
// REQUIRES: TracksAllocationSizes() is true.
|
180 |
+
//
|
181 |
+
// REQUIRES: 'ptr!=nullptr' and points to a buffer previously
|
182 |
+
// allocated by this allocator.
|
183 |
+
virtual int64 AllocationId(void* ptr) { return 0; }
|
184 |
+
|
185 |
+
// Returns the allocated size of the buffer at 'ptr' if known,
|
186 |
+
// otherwise returns 0. This method can be called when
|
187 |
+
// TracksAllocationSizes() is false, but can be extremely slow.
|
188 |
+
//
|
189 |
+
// REQUIRES: 'ptr!=nullptr' and points to a buffer previously
|
190 |
+
// allocated by this allocator.
|
191 |
+
virtual size_t AllocatedSizeSlow(void* ptr) {
|
192 |
+
if (TracksAllocationSizes()) {
|
193 |
+
return AllocatedSize(ptr);
|
194 |
+
}
|
195 |
+
return 0;
|
196 |
+
}
|
197 |
+
|
198 |
+
// Fills in 'stats' with statistics collected by this allocator.
|
199 |
+
virtual void GetStats(AllocatorStats* stats) { stats->Clear(); }
|
200 |
+
|
201 |
+
private:
|
202 |
+
// No constructors or destructors are run for simple types
|
203 |
+
template <typename T>
|
204 |
+
void RunCtor(T* p, size_t n) {
|
205 |
+
static_assert(is_simple_type<T>::value, "T is not a simple type.");
|
206 |
+
}
|
207 |
+
|
208 |
+
template <typename T>
|
209 |
+
void RunDtor(T* p, size_t n) {}
|
210 |
+
|
211 |
+
// custom constructors and destructors that can be overridden for
|
212 |
+
// non-standard allocators
|
213 |
+
|
214 |
+
// Runs string's default constructor for p[0], p[1], ..., p[n-1].
|
215 |
+
virtual void RunStringCtor(string* p, size_t n) {
|
216 |
+
for (size_t i = 0; i < n; ++p, ++i) new (p) string();
|
217 |
+
}
|
218 |
+
|
219 |
+
// Runs string's default destructor for p[0], p[1], ..., p[n-1].
|
220 |
+
virtual void RunStringDtor(string* p, size_t n) {
|
221 |
+
for (size_t i = 0; i < n; ++p, ++i) p->~string();
|
222 |
+
}
|
223 |
+
|
224 |
+
virtual void RunResourceCtor(ResourceHandle* p, size_t n) {
|
225 |
+
for (size_t i = 0; i < n; ++p, ++i) new (p) ResourceHandle();
|
226 |
+
}
|
227 |
+
|
228 |
+
// Runs string's default destructor for p[0], p[1], ..., p[n-1].
|
229 |
+
virtual void RunResourceDtor(ResourceHandle* p, size_t n) {
|
230 |
+
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
|
231 |
+
}
|
232 |
+
|
233 |
+
virtual void RunVariantCtor(Variant* p, size_t n) {
|
234 |
+
for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
|
235 |
+
}
|
236 |
+
|
237 |
+
virtual void RunVariantDtor(Variant* p, size_t n) {
|
238 |
+
for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
|
239 |
+
}
|
240 |
+
|
241 |
+
// TODO(jeff): Maybe provide some interface to give info about
|
242 |
+
// current allocation state (total number of bytes available for
|
243 |
+
// allocation, number of bytes free on device, etc.)
|
244 |
+
};
|
245 |
+
|
246 |
+
// Allocator-specific constructors and destructors are used for
|
247 |
+
// strings
|
248 |
+
template <>
|
249 |
+
inline void Allocator::RunCtor(string* p, size_t n) {
|
250 |
+
RunStringCtor(p, n);
|
251 |
+
}
|
252 |
+
|
253 |
+
template <>
|
254 |
+
inline void Allocator::RunDtor(string* p, size_t n) {
|
255 |
+
RunStringDtor(p, n);
|
256 |
+
}
|
257 |
+
|
258 |
+
template <>
|
259 |
+
inline void Allocator::RunCtor(ResourceHandle* p, size_t n) {
|
260 |
+
RunResourceCtor(p, n);
|
261 |
+
}
|
262 |
+
|
263 |
+
template <>
|
264 |
+
inline void Allocator::RunDtor(ResourceHandle* p, size_t n) {
|
265 |
+
RunResourceDtor(p, n);
|
266 |
+
}
|
267 |
+
|
268 |
+
template <>
|
269 |
+
inline void Allocator::RunCtor(Variant* p, size_t n) {
|
270 |
+
RunVariantCtor(p, n);
|
271 |
+
}
|
272 |
+
|
273 |
+
template <>
|
274 |
+
inline void Allocator::RunDtor(Variant* p, size_t n) {
|
275 |
+
RunVariantDtor(p, n);
|
276 |
+
}
|
277 |
+
|
278 |
+
// An implementation of Allocator that delegates all calls to another Allocator.
|
279 |
+
//
|
280 |
+
// Useful to clients who want to override part of the functionality of another
|
281 |
+
// allocator.
|
282 |
+
class AllocatorWrapper : public Allocator {
|
283 |
+
public:
|
284 |
+
explicit AllocatorWrapper(Allocator* wrapped) : wrapped_(wrapped) {}
|
285 |
+
|
286 |
+
~AllocatorWrapper() override {}
|
287 |
+
|
288 |
+
// Returns the wrapped allocator to which all calls are delegated.
|
289 |
+
Allocator* wrapped() const { return wrapped_; }
|
290 |
+
|
291 |
+
string Name() override { return wrapped_->Name(); }
|
292 |
+
|
293 |
+
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
|
294 |
+
return wrapped_->AllocateRaw(alignment, num_bytes);
|
295 |
+
}
|
296 |
+
|
297 |
+
void* AllocateRaw(size_t alignment, size_t num_bytes,
|
298 |
+
const AllocationAttributes& allocation_attr) override {
|
299 |
+
return wrapped_->AllocateRaw(alignment, num_bytes, allocation_attr);
|
300 |
+
}
|
301 |
+
|
302 |
+
void DeallocateRaw(void* ptr) override { wrapped_->DeallocateRaw(ptr); }
|
303 |
+
|
304 |
+
bool TracksAllocationSizes() override {
|
305 |
+
return wrapped_->TracksAllocationSizes();
|
306 |
+
}
|
307 |
+
|
308 |
+
bool ShouldAllocateEmptyTensors() override {
|
309 |
+
return wrapped_->TracksAllocationSizes();
|
310 |
+
}
|
311 |
+
|
312 |
+
size_t RequestedSize(void* ptr) override {
|
313 |
+
return wrapped_->RequestedSize(ptr);
|
314 |
+
}
|
315 |
+
|
316 |
+
size_t AllocatedSize(void* ptr) override {
|
317 |
+
return wrapped_->AllocatedSize(ptr);
|
318 |
+
}
|
319 |
+
|
320 |
+
int64 AllocationId(void* ptr) override { return wrapped_->AllocationId(ptr); }
|
321 |
+
|
322 |
+
size_t AllocatedSizeSlow(void* ptr) override {
|
323 |
+
return wrapped_->AllocatedSizeSlow(ptr);
|
324 |
+
}
|
325 |
+
|
326 |
+
private:
|
327 |
+
Allocator* const wrapped_;
|
328 |
+
};
|
329 |
+
|
330 |
+
// A tensorflow Op may need access to different kinds of memory that
|
331 |
+
// are not simply a function of the device to which the Op has been
|
332 |
+
// assigned. For example, an Op executing on a GPU may still need
|
333 |
+
// to allocate CPU RAM for some purpose. Internal to the tensorflow
|
334 |
+
// runtime we may choose to allocate CPU ram from special regions
|
335 |
+
// that have been prepared for higher performance in some use
|
336 |
+
// contexts, e.g. doing DMA with particular devices. For these
|
337 |
+
// reasons, the Device interface does not expose just one memory
|
338 |
+
// Allocator, but instead provides an accessor that takes a
|
339 |
+
// specification of the desired memory attributes in order to select
|
340 |
+
// an Allocator.
|
341 |
+
//
|
342 |
+
// Example use:
|
343 |
+
// // Allocator for ordinary device memory:
|
344 |
+
// Allocator* a = allocator(AllocatorAttributes());
|
345 |
+
// ...
|
346 |
+
// // Allocator for CPU RAM, regardless of where Op is executing:
|
347 |
+
// AllocatorAttributes attr;
|
348 |
+
// attr.set_on_host(true);
|
349 |
+
// Allocator* a = allocator(attr);
|
350 |
+
struct AllocatorAttributes {
|
351 |
+
void set_on_host(bool v) { value |= (static_cast<int>(v)); }
|
352 |
+
bool on_host() const { return value & 0x1; }
|
353 |
+
void set_nic_compatible(bool v) { value |= (static_cast<int>(v) << 1); }
|
354 |
+
bool nic_compatible() const { return value & (0x1 << 1); }
|
355 |
+
void set_gpu_compatible(bool v) { value |= (static_cast<int>(v) << 2); }
|
356 |
+
bool gpu_compatible() const { return value & (0x1 << 2); }
|
357 |
+
void Merge(AllocatorAttributes other) { value |= other.value; }
|
358 |
+
// Returns true if the fields set in *this is a subset of or equal to
|
359 |
+
// those set in other.
|
360 |
+
bool IsEqualOrLessRestrictiveThan(const AllocatorAttributes& other) const {
|
361 |
+
return (value | other.value) == other.value;
|
362 |
+
}
|
363 |
+
|
364 |
+
// NOTE: The upper 8 bits of the value are reserved for
|
365 |
+
// device-specific uses. Implementors of a device can interpret these
|
366 |
+
// upper 8 bits in device-specific ways, and ops implemented for those
|
367 |
+
// devices are responsible for setting those 8 bits appropriately.
|
368 |
+
uint32 value = 0;
|
369 |
+
};
|
370 |
+
|
371 |
+
// Returns a trivial implementation of Allocator which uses the system
|
372 |
+
// default malloc. The returned allocator is a process singleton.
|
373 |
+
Allocator* cpu_allocator();
|
374 |
+
|
375 |
+
// If 'enable' is true, the process-wide cpu allocator collects
|
376 |
+
// AllocatorStats. By default, it's disabled.
|
377 |
+
void EnableCPUAllocatorStats(bool enable);
|
378 |
+
|
379 |
+
// If 'enable' is true, the process-wide cpu allocator collects full
|
380 |
+
// statistics. By default, it's disabled.
|
381 |
+
void EnableCPUAllocatorFullStats(bool enable);
|
382 |
+
|
383 |
+
// Abstract interface of an object that does the underlying suballoc/free of
|
384 |
+
// memory for a higher-level allocator.
|
385 |
+
class SubAllocator {
|
386 |
+
public:
|
387 |
+
virtual ~SubAllocator() {}
|
388 |
+
virtual void* Alloc(size_t alignment, size_t num_bytes) = 0;
|
389 |
+
virtual void Free(void* ptr, size_t num_bytes) = 0;
|
390 |
+
};
|
391 |
+
|
392 |
+
} // namespace tensorflow
|
393 |
+
|
394 |
+
#endif // TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
|
allocator_registry.cc
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include <string>
|
17 |
+
|
18 |
+
#include "tensorflow/core/framework/allocator_registry.h"
|
19 |
+
#include "tensorflow/core/platform/logging.h"
|
20 |
+
|
21 |
+
namespace tensorflow {
|
22 |
+
|
23 |
+
// static
|
24 |
+
AllocatorRegistry* AllocatorRegistry::Global() {
|
25 |
+
static AllocatorRegistry* global_allocator_registry = new AllocatorRegistry;
|
26 |
+
return global_allocator_registry;
|
27 |
+
}
|
28 |
+
|
29 |
+
Allocator* AllocatorRegistry::GetRegisteredAllocator(const string& name,
|
30 |
+
int priority) {
|
31 |
+
for (auto entry : allocators_) {
|
32 |
+
if (!name.compare(entry.name) && priority == entry.priority) {
|
33 |
+
return entry.allocator;
|
34 |
+
}
|
35 |
+
}
|
36 |
+
return nullptr;
|
37 |
+
}
|
38 |
+
|
39 |
+
void AllocatorRegistry::Register(const string& name, int priority,
|
40 |
+
Allocator* allocator) {
|
41 |
+
CHECK(!name.empty()) << "Need a valid name for Allocator";
|
42 |
+
CHECK_GE(priority, 0) << "Priority needs to be non-negative";
|
43 |
+
|
44 |
+
Allocator* existing = GetRegisteredAllocator(name, priority);
|
45 |
+
if (existing != nullptr) {
|
46 |
+
// A duplicate is if the registration name and priority match
|
47 |
+
// but the Allocator::Name()'s don't match.
|
48 |
+
CHECK_EQ(existing->Name(), allocator->Name())
|
49 |
+
<< "Allocator with name: [" << name << "], type [" << existing->Name()
|
50 |
+
<< "], priority: [" << priority
|
51 |
+
<< "] already registered. Choose a different name to register "
|
52 |
+
<< "an allocator of type " << allocator->Name();
|
53 |
+
|
54 |
+
// The allocator names match, so we can just return.
|
55 |
+
// It should be safe to delete the allocator since the caller
|
56 |
+
// gives up ownership of it.
|
57 |
+
delete allocator;
|
58 |
+
return;
|
59 |
+
}
|
60 |
+
|
61 |
+
AllocatorRegistryEntry tmp_entry;
|
62 |
+
tmp_entry.name = name;
|
63 |
+
tmp_entry.priority = priority;
|
64 |
+
tmp_entry.allocator = allocator;
|
65 |
+
|
66 |
+
allocators_.push_back(tmp_entry);
|
67 |
+
int high_pri = -1;
|
68 |
+
for (auto entry : allocators_) {
|
69 |
+
if (high_pri < entry.priority) {
|
70 |
+
m_curr_allocator_ = entry.allocator;
|
71 |
+
high_pri = entry.priority;
|
72 |
+
}
|
73 |
+
}
|
74 |
+
}
|
75 |
+
|
76 |
+
Allocator* AllocatorRegistry::GetAllocator() {
|
77 |
+
return CHECK_NOTNULL(m_curr_allocator_);
|
78 |
+
}
|
79 |
+
|
80 |
+
} // namespace tensorflow
|
allocator_registry.h
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
// Classes to maintain a static registry of memory allocators
|
17 |
+
#ifndef TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_
|
18 |
+
#define TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_
|
19 |
+
|
20 |
+
#include <string>
|
21 |
+
#include <vector>
|
22 |
+
|
23 |
+
#include "tensorflow/core/framework/allocator.h"
|
24 |
+
|
25 |
+
namespace tensorflow {
|
26 |
+
|
27 |
+
// A global AllocatorRegistry is used to hold allocators for CPU backends
|
28 |
+
class AllocatorRegistry {
|
29 |
+
public:
|
30 |
+
// Add an allocator to the registry. Caller releases ownership of
|
31 |
+
// 'allocator'.
|
32 |
+
void Register(const string& name, int priority, Allocator* allocator);
|
33 |
+
|
34 |
+
// Return allocator with highest priority
|
35 |
+
// If multiple allocators have the same high priority, return one of them
|
36 |
+
Allocator* GetAllocator();
|
37 |
+
|
38 |
+
// Returns the global registry of allocators.
|
39 |
+
static AllocatorRegistry* Global();
|
40 |
+
|
41 |
+
private:
|
42 |
+
typedef struct {
|
43 |
+
string name;
|
44 |
+
int priority;
|
45 |
+
Allocator* allocator; // not owned
|
46 |
+
} AllocatorRegistryEntry;
|
47 |
+
|
48 |
+
// Returns the Allocator registered for 'name' and 'priority',
|
49 |
+
// or 'nullptr' if not found.
|
50 |
+
Allocator* GetRegisteredAllocator(const string& name, int priority);
|
51 |
+
|
52 |
+
std::vector<AllocatorRegistryEntry> allocators_;
|
53 |
+
Allocator* m_curr_allocator_; // not owned
|
54 |
+
};
|
55 |
+
|
56 |
+
namespace allocator_registration {
|
57 |
+
|
58 |
+
class AllocatorRegistration {
|
59 |
+
public:
|
60 |
+
AllocatorRegistration(const string& name, int priority,
|
61 |
+
Allocator* allocator) {
|
62 |
+
AllocatorRegistry::Global()->Register(name, priority, allocator);
|
63 |
+
}
|
64 |
+
};
|
65 |
+
|
66 |
+
} // namespace allocator_registration
|
67 |
+
|
68 |
+
#define REGISTER_MEM_ALLOCATOR(name, priority, allocator) \
|
69 |
+
REGISTER_MEM_ALLOCATOR_UNIQ_HELPER(__COUNTER__, name, priority, allocator)
|
70 |
+
|
71 |
+
#define REGISTER_MEM_ALLOCATOR_UNIQ_HELPER(ctr, name, priority, allocator) \
|
72 |
+
REGISTER_MEM_ALLOCATOR_UNIQ(ctr, name, priority, allocator)
|
73 |
+
|
74 |
+
#define REGISTER_MEM_ALLOCATOR_UNIQ(ctr, name, priority, allocator) \
|
75 |
+
static allocator_registration::AllocatorRegistration \
|
76 |
+
register_allocator_##ctr(name, priority, new allocator)
|
77 |
+
|
78 |
+
} // namespace tensorflow
|
79 |
+
|
80 |
+
#endif // TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_
|
allocator_test.cc
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/allocator.h"
|
17 |
+
|
18 |
+
#include <algorithm>
|
19 |
+
#include <vector>
|
20 |
+
|
21 |
+
#include "tensorflow/core/platform/logging.h"
|
22 |
+
#include "tensorflow/core/platform/test.h"
|
23 |
+
#include "tensorflow/core/platform/test_benchmark.h"
|
24 |
+
|
25 |
+
namespace tensorflow {
|
26 |
+
|
27 |
+
static void CheckStats(Allocator* a, int64 num_allocs, int64 bytes_in_use,
|
28 |
+
int64 max_bytes_in_use, int64 max_alloc_size) {
|
29 |
+
AllocatorStats stats;
|
30 |
+
a->GetStats(&stats);
|
31 |
+
LOG(INFO) << "Alloc stats: \n" << stats.DebugString();
|
32 |
+
#if defined(PLATFORM_GOOGLE) && defined(NDEBUG)
|
33 |
+
// NOTE: allocator stats expectation depends on the system malloc,
|
34 |
+
// and can vary as that changes.
|
35 |
+
static const int64 kSlop = 5 * 1024;
|
36 |
+
EXPECT_GT(stats.bytes_in_use, bytes_in_use - kSlop);
|
37 |
+
EXPECT_LT(stats.bytes_in_use, bytes_in_use + kSlop);
|
38 |
+
EXPECT_GT(stats.max_bytes_in_use, max_bytes_in_use - kSlop);
|
39 |
+
EXPECT_LT(stats.max_bytes_in_use, max_bytes_in_use + kSlop);
|
40 |
+
EXPECT_EQ(stats.num_allocs, num_allocs);
|
41 |
+
EXPECT_EQ(stats.max_alloc_size, max_alloc_size);
|
42 |
+
#endif
|
43 |
+
}
|
44 |
+
|
45 |
+
TEST(AllocatorAttributesTest, AllCombos) {
|
46 |
+
for (bool on_host : {false, true}) {
|
47 |
+
for (bool nic_compatible : {false, true}) {
|
48 |
+
for (bool gpu_compatible : {false, true}) {
|
49 |
+
AllocatorAttributes aa;
|
50 |
+
aa.set_on_host(on_host);
|
51 |
+
aa.set_nic_compatible(nic_compatible);
|
52 |
+
aa.set_gpu_compatible(gpu_compatible);
|
53 |
+
EXPECT_EQ(on_host, aa.on_host());
|
54 |
+
EXPECT_EQ(nic_compatible, aa.nic_compatible());
|
55 |
+
EXPECT_EQ(gpu_compatible, aa.gpu_compatible());
|
56 |
+
}
|
57 |
+
}
|
58 |
+
}
|
59 |
+
}
|
60 |
+
|
61 |
+
TEST(AllocatorAttributesTest, IsEqualOrLessRestrictiveThan) {
|
62 |
+
AllocatorAttributes a, b;
|
63 |
+
EXPECT_TRUE(a.IsEqualOrLessRestrictiveThan(b));
|
64 |
+
EXPECT_TRUE(a.IsEqualOrLessRestrictiveThan(a));
|
65 |
+
EXPECT_TRUE(b.IsEqualOrLessRestrictiveThan(b));
|
66 |
+
|
67 |
+
b.set_gpu_compatible(true);
|
68 |
+
// The set of flags in b is not a subset of those in a.
|
69 |
+
EXPECT_TRUE(a.IsEqualOrLessRestrictiveThan(b));
|
70 |
+
EXPECT_FALSE(b.IsEqualOrLessRestrictiveThan(a));
|
71 |
+
EXPECT_TRUE(a.IsEqualOrLessRestrictiveThan(a));
|
72 |
+
EXPECT_TRUE(b.IsEqualOrLessRestrictiveThan(b));
|
73 |
+
|
74 |
+
a.set_nic_compatible(true);
|
75 |
+
// Neither a nor b is a subset of the other.
|
76 |
+
EXPECT_FALSE(a.IsEqualOrLessRestrictiveThan(b));
|
77 |
+
EXPECT_FALSE(b.IsEqualOrLessRestrictiveThan(a));
|
78 |
+
|
79 |
+
a.set_gpu_compatible(true);
|
80 |
+
// The set of flags in b is a proper subset of those in a.
|
81 |
+
EXPECT_TRUE(b.IsEqualOrLessRestrictiveThan(a));
|
82 |
+
EXPECT_FALSE(a.IsEqualOrLessRestrictiveThan(b));
|
83 |
+
}
|
84 |
+
|
85 |
+
TEST(CPUAllocatorTest, Simple) {
|
86 |
+
EnableCPUAllocatorStats(true);
|
87 |
+
Allocator* a = cpu_allocator();
|
88 |
+
std::vector<void*> ptrs;
|
89 |
+
for (int s = 1; s < 1024; s++) {
|
90 |
+
void* raw = a->AllocateRaw(1, s);
|
91 |
+
ptrs.push_back(raw);
|
92 |
+
}
|
93 |
+
std::sort(ptrs.begin(), ptrs.end());
|
94 |
+
CheckStats(a, 1023, 552640, 552640, 1024);
|
95 |
+
for (size_t i = 0; i < ptrs.size(); i++) {
|
96 |
+
if (i > 0) {
|
97 |
+
CHECK_NE(ptrs[i], ptrs[i - 1]); // No dups
|
98 |
+
}
|
99 |
+
a->DeallocateRaw(ptrs[i]);
|
100 |
+
}
|
101 |
+
CheckStats(a, 1023, 0, 552640, 1024);
|
102 |
+
float* t1 = a->Allocate<float>(1024);
|
103 |
+
double* t2 = a->Allocate<double>(1048576);
|
104 |
+
CheckStats(a, 1025, 1048576 * sizeof(double) + 1024 * sizeof(float),
|
105 |
+
1048576 * sizeof(double) + 1024 * sizeof(float),
|
106 |
+
1048576 * sizeof(double));
|
107 |
+
|
108 |
+
a->Deallocate(t1, 1024);
|
109 |
+
a->Deallocate(t2, 1048576);
|
110 |
+
|
111 |
+
CheckStats(a, 1025, 0, 1048576 * sizeof(double) + 1024 * sizeof(float),
|
112 |
+
1048576 * sizeof(double));
|
113 |
+
EnableCPUAllocatorStats(false);
|
114 |
+
}
|
115 |
+
|
116 |
+
// Define a struct that we will use to observe behavior in the unit tests
|
117 |
+
struct TestStruct {
|
118 |
+
int x; // not used just want to make sure sizeof(TestStruct) > 1
|
119 |
+
};
|
120 |
+
|
121 |
+
TEST(CPUAllocatorTest, CheckStructSize) { CHECK_GT(sizeof(TestStruct), 1); }
|
122 |
+
|
123 |
+
TEST(CPUAllocatorTest, AllocateOverflowMaxSizeT) {
|
124 |
+
Allocator* a = cpu_allocator();
|
125 |
+
|
126 |
+
// The maximum size_t value will definitely overflow.
|
127 |
+
size_t count_to_allocate = std::numeric_limits<size_t>::max();
|
128 |
+
TestStruct* const test_pointer = a->Allocate<TestStruct>(count_to_allocate);
|
129 |
+
|
130 |
+
CHECK_EQ(test_pointer, reinterpret_cast<TestStruct*>(NULL));
|
131 |
+
}
|
132 |
+
|
133 |
+
TEST(CPUAllocatorTest, AllocateOverflowSmallest) {
|
134 |
+
Allocator* a = cpu_allocator();
|
135 |
+
|
136 |
+
// count_to_allocate is the smallest count that will cause overflow.
|
137 |
+
const size_t count_to_allocate =
|
138 |
+
(std::numeric_limits<size_t>::max() / sizeof(TestStruct)) + 1;
|
139 |
+
TestStruct* const test_pointer = a->Allocate<TestStruct>(count_to_allocate);
|
140 |
+
|
141 |
+
CHECK_EQ(test_pointer, reinterpret_cast<TestStruct*>(NULL));
|
142 |
+
}
|
143 |
+
|
144 |
+
TEST(CPUAllocatorTest, Sizes) {
|
145 |
+
Allocator* a = cpu_allocator();
|
146 |
+
|
147 |
+
EXPECT_EQ(false, a->TracksAllocationSizes());
|
148 |
+
}
|
149 |
+
|
150 |
+
namespace {
|
151 |
+
|
152 |
+
AllocatorAttributes DeviceAllocatorAttribute() {
|
153 |
+
AllocatorAttributes attr;
|
154 |
+
attr.value |= (0x1 << 24);
|
155 |
+
return attr;
|
156 |
+
}
|
157 |
+
|
158 |
+
bool HasDeviceAllocatorAttribute(const AllocatorAttributes& attr) {
|
159 |
+
return attr.value & (0x1 << 24);
|
160 |
+
}
|
161 |
+
|
162 |
+
} // namespace
|
163 |
+
|
164 |
+
TEST(CustomAllocatorAttributes, TestSetterAndGetter) {
|
165 |
+
AllocatorAttributes attr = DeviceAllocatorAttribute();
|
166 |
+
EXPECT_TRUE(HasDeviceAllocatorAttribute(attr));
|
167 |
+
EXPECT_FALSE(HasDeviceAllocatorAttribute(AllocatorAttributes()));
|
168 |
+
}
|
169 |
+
|
170 |
+
static void BM_Allocation(int iters, int arg) {
|
171 |
+
Allocator* a = cpu_allocator();
|
172 |
+
// Exercise a few different allocation sizes
|
173 |
+
std::vector<int> sizes = {256, 4096, 16384, 524288, 512, 1048576};
|
174 |
+
int size_index = 0;
|
175 |
+
|
176 |
+
if (arg) EnableCPUAllocatorStats(true);
|
177 |
+
while (--iters > 0) {
|
178 |
+
int bytes = sizes[size_index++ % sizes.size()];
|
179 |
+
void* p = a->AllocateRaw(1, bytes);
|
180 |
+
a->DeallocateRaw(p);
|
181 |
+
}
|
182 |
+
if (arg) EnableCPUAllocatorStats(false);
|
183 |
+
}
|
184 |
+
BENCHMARK(BM_Allocation)->Arg(0)->Arg(1);
|
185 |
+
|
186 |
+
} // namespace tensorflow
|
api_def.proto
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Defines the text format for including per-op API definition and
|
2 |
+
// overrides for client language op code generators.
|
3 |
+
|
4 |
+
syntax = "proto3";
|
5 |
+
|
6 |
+
package tensorflow;
|
7 |
+
option cc_enable_arenas = true;
|
8 |
+
option java_outer_classname = "ApiDefProtos";
|
9 |
+
option java_multiple_files = true;
|
10 |
+
option java_package = "org.tensorflow.framework";
|
11 |
+
import "tensorflow/core/framework/attr_value.proto";
|
12 |
+
|
13 |
+
// Used to specify and override the default API & behavior in the
|
14 |
+
// generated code for client languages, from what you would get from
|
15 |
+
// the OpDef alone. There will be a set of ApiDefs that are common
|
16 |
+
// to all client languages, and another set per client language.
|
17 |
+
// The per-client-language ApiDefs will inherit values from the
|
18 |
+
// common ApiDefs which it can either replace or modify.
|
19 |
+
//
|
20 |
+
// We separate the API definition from the OpDef so we can evolve the
|
21 |
+
// API while remaining backwards compatible when interpretting old
|
22 |
+
// graphs. Overrides go in an "api_def.pbtxt" file with a text-format
|
23 |
+
// ApiDefs message.
|
24 |
+
//
|
25 |
+
// WARNING: Be *very* careful changing the API for any existing op --
|
26 |
+
// you can change the semantics of existing code. These changes may
|
27 |
+
// need to wait until a major release of TensorFlow to avoid breaking
|
28 |
+
// our compatibility promises.
|
29 |
+
message ApiDef {
|
30 |
+
// Name of the op (in the OpDef) to specify the API for.
|
31 |
+
string graph_op_name = 1;
|
32 |
+
|
33 |
+
enum Visibility {
|
34 |
+
// Normally this is "VISIBLE" unless you are inheriting a
|
35 |
+
// different value from another ApiDef.
|
36 |
+
DEFAULT_VISIBILITY = 0;
|
37 |
+
// Publicly visible in the API.
|
38 |
+
VISIBLE = 1;
|
39 |
+
// Do not include this op in the generated API. If visibility is
|
40 |
+
// set to 'SKIP', other fields are ignored for this op.
|
41 |
+
SKIP = 2;
|
42 |
+
// Hide this op by putting it into an internal namespace (or whatever
|
43 |
+
// is appropriate in the target language).
|
44 |
+
HIDDEN = 3;
|
45 |
+
}
|
46 |
+
Visibility visibility = 2;
|
47 |
+
|
48 |
+
// If you specify any endpoint, this will replace all of the
|
49 |
+
// inherited endpoints. The first endpoint should be the
|
50 |
+
// "canonical" endpoint, and should not be deprecated (unless all
|
51 |
+
// endpoints are deprecated).
|
52 |
+
message Endpoint {
|
53 |
+
// Name should be either like "CamelCaseName" or
|
54 |
+
// "Package.CamelCaseName". Client-language-specific ApiDefs may
|
55 |
+
// use a snake_case convention instead of CamelCase.
|
56 |
+
string name = 1;
|
57 |
+
|
58 |
+
// First GraphDef version at which the op is disallowed.
|
59 |
+
int32 deprecation_version = 2;
|
60 |
+
}
|
61 |
+
repeated Endpoint endpoint = 3;
|
62 |
+
|
63 |
+
message Arg {
|
64 |
+
string name = 1;
|
65 |
+
|
66 |
+
// Change the name used to access this arg in the API from what
|
67 |
+
// is used in the GraphDef. Note that these names in `backticks`
|
68 |
+
// will also be replaced in the summary & description fields.
|
69 |
+
string rename_to = 2;
|
70 |
+
|
71 |
+
// Note: this will replace any inherited arg doc. There is no
|
72 |
+
// current way of modifying arg descriptions (other than replacing
|
73 |
+
// them entirely) as can be done with op descriptions.
|
74 |
+
string description = 3;
|
75 |
+
}
|
76 |
+
repeated Arg in_arg = 4;
|
77 |
+
repeated Arg out_arg = 5;
|
78 |
+
// List of original in_arg names to specify new argument order.
|
79 |
+
// Length of arg_order should be either empty to keep current order
|
80 |
+
// or match size of in_arg.
|
81 |
+
repeated string arg_order = 11;
|
82 |
+
|
83 |
+
// Description of the graph-construction-time configuration of this
|
84 |
+
// Op. That is to say, this describes the attr fields that will
|
85 |
+
// be specified in the NodeDef.
|
86 |
+
message Attr {
|
87 |
+
string name = 1;
|
88 |
+
|
89 |
+
// Change the name used to access this attr in the API from what
|
90 |
+
// is used in the GraphDef. Note that these names in `backticks`
|
91 |
+
// will also be replaced in the summary & description fields.
|
92 |
+
string rename_to = 2;
|
93 |
+
|
94 |
+
// Specify a new default value to use for this attr. This default
|
95 |
+
// will be used when creating new graphs, as opposed to the
|
96 |
+
// default in the OpDef, which will be used when interpreting old
|
97 |
+
// GraphDefs.
|
98 |
+
AttrValue default_value = 3;
|
99 |
+
|
100 |
+
// Note: this will replace any inherited attr doc, there is no current
|
101 |
+
// way of modifying attr descriptions as can be done with op descriptions.
|
102 |
+
string description = 4;
|
103 |
+
}
|
104 |
+
repeated Attr attr = 6;
|
105 |
+
|
106 |
+
// One-line human-readable description of what the Op does.
|
107 |
+
string summary = 7;
|
108 |
+
|
109 |
+
// Additional, longer human-readable description of what the Op does.
|
110 |
+
string description = 8;
|
111 |
+
|
112 |
+
// Modify an existing/inherited description by adding text to the beginning
|
113 |
+
// or end.
|
114 |
+
string description_prefix = 9;
|
115 |
+
string description_suffix = 10;
|
116 |
+
}
|
117 |
+
|
118 |
+
message ApiDefs {
|
119 |
+
repeated ApiDef op = 1;
|
120 |
+
}
|
attr_value.proto
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto3";
|
2 |
+
|
3 |
+
package tensorflow;
|
4 |
+
option cc_enable_arenas = true;
|
5 |
+
option java_outer_classname = "AttrValueProtos";
|
6 |
+
option java_multiple_files = true;
|
7 |
+
option java_package = "org.tensorflow.framework";
|
8 |
+
|
9 |
+
import "tensorflow/core/framework/tensor.proto";
|
10 |
+
import "tensorflow/core/framework/tensor_shape.proto";
|
11 |
+
import "tensorflow/core/framework/types.proto";
|
12 |
+
|
13 |
+
// Protocol buffer representing the value for an attr used to configure an Op.
|
14 |
+
// Comment indicates the corresponding attr type. Only the field matching the
|
15 |
+
// attr type may be filled.
|
16 |
+
message AttrValue {
|
17 |
+
// LINT.IfChange
|
18 |
+
message ListValue {
|
19 |
+
repeated bytes s = 2; // "list(string)"
|
20 |
+
repeated int64 i = 3 [packed = true]; // "list(int)"
|
21 |
+
repeated float f = 4 [packed = true]; // "list(float)"
|
22 |
+
repeated bool b = 5 [packed = true]; // "list(bool)"
|
23 |
+
repeated DataType type = 6 [packed = true]; // "list(type)"
|
24 |
+
repeated TensorShapeProto shape = 7; // "list(shape)"
|
25 |
+
repeated TensorProto tensor = 8; // "list(tensor)"
|
26 |
+
repeated NameAttrList func = 9; // "list(attr)"
|
27 |
+
}
|
28 |
+
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc)
|
29 |
+
|
30 |
+
oneof value {
|
31 |
+
bytes s = 2; // "string"
|
32 |
+
int64 i = 3; // "int"
|
33 |
+
float f = 4; // "float"
|
34 |
+
bool b = 5; // "bool"
|
35 |
+
DataType type = 6; // "type"
|
36 |
+
TensorShapeProto shape = 7; // "shape"
|
37 |
+
TensorProto tensor = 8; // "tensor"
|
38 |
+
ListValue list = 1; // any "list(...)"
|
39 |
+
|
40 |
+
// "func" represents a function. func.name is a function's name or
|
41 |
+
// a primitive op's name. func.attr.first is the name of an attr
|
42 |
+
// defined for that function. func.attr.second is the value for
|
43 |
+
// that attr in the instantiation.
|
44 |
+
NameAttrList func = 10;
|
45 |
+
|
46 |
+
// This is a placeholder only used in nodes defined inside a
|
47 |
+
// function. It indicates the attr value will be supplied when
|
48 |
+
// the function is instantiated. For example, let us suppose a
|
49 |
+
// node "N" in function "FN". "N" has an attr "A" with value
|
50 |
+
// placeholder = "foo". When FN is instantiated with attr "foo"
|
51 |
+
// set to "bar", the instantiated node N's attr A will have been
|
52 |
+
// given the value "bar".
|
53 |
+
string placeholder = 9;
|
54 |
+
}
|
55 |
+
}
|
56 |
+
|
57 |
+
// A list of attr names and their values. The whole list is attached
|
58 |
+
// with a string name. E.g., MatMul[T=float].
|
59 |
+
message NameAttrList {
|
60 |
+
string name = 1;
|
61 |
+
map<string, AttrValue> attr = 2;
|
62 |
+
}
|
attr_value_util.cc
ADDED
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/attr_value_util.h"
|
17 |
+
|
18 |
+
#include <string>
|
19 |
+
#include <vector>
|
20 |
+
|
21 |
+
#include "tensorflow/core/framework/attr_value.pb_text.h"
|
22 |
+
#include "tensorflow/core/framework/tensor.pb_text.h"
|
23 |
+
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
24 |
+
#include "tensorflow/core/framework/types.h"
|
25 |
+
#include "tensorflow/core/framework/types.pb_text.h"
|
26 |
+
#include "tensorflow/core/lib/core/errors.h"
|
27 |
+
#include "tensorflow/core/lib/core/stringpiece.h"
|
28 |
+
#include "tensorflow/core/lib/hash/hash.h"
|
29 |
+
#include "tensorflow/core/lib/strings/str_util.h"
|
30 |
+
#include "tensorflow/core/platform/protobuf.h"
|
31 |
+
|
32 |
+
namespace tensorflow {
|
33 |
+
namespace {
|
34 |
+
|
35 |
+
string SummarizeString(const string& str) {
|
36 |
+
return strings::StrCat("\"", str_util::CEscape(str), "\"");
|
37 |
+
}
|
38 |
+
|
39 |
+
string SummarizeTensor(const TensorProto& tensor_proto) {
|
40 |
+
Tensor t;
|
41 |
+
if (!t.FromProto(tensor_proto)) {
|
42 |
+
return strings::StrCat(
|
43 |
+
"<Invalid TensorProto: ", ProtoShortDebugString(tensor_proto), ">");
|
44 |
+
}
|
45 |
+
return t.DebugString();
|
46 |
+
}
|
47 |
+
|
48 |
+
string SummarizeFunc(const NameAttrList& func) {
|
49 |
+
std::vector<string> entries;
|
50 |
+
for (auto p : func.attr()) {
|
51 |
+
entries.push_back(
|
52 |
+
strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
|
53 |
+
}
|
54 |
+
std::sort(entries.begin(), entries.end());
|
55 |
+
return strings::StrCat(func.name(), "[", str_util::Join(entries, ", "), "]");
|
56 |
+
}
|
57 |
+
|
58 |
+
} // namespace
|
59 |
+
|
60 |
+
string SummarizeAttrValue(const AttrValue& attr_value) {
|
61 |
+
switch (attr_value.value_case()) {
|
62 |
+
case AttrValue::kS:
|
63 |
+
return SummarizeString(attr_value.s());
|
64 |
+
case AttrValue::kI:
|
65 |
+
return strings::StrCat(attr_value.i());
|
66 |
+
case AttrValue::kF:
|
67 |
+
return strings::StrCat(attr_value.f());
|
68 |
+
case AttrValue::kB:
|
69 |
+
return attr_value.b() ? "true" : "false";
|
70 |
+
case AttrValue::kType:
|
71 |
+
return EnumName_DataType(attr_value.type());
|
72 |
+
case AttrValue::kShape:
|
73 |
+
return PartialTensorShape::DebugString(attr_value.shape());
|
74 |
+
case AttrValue::kTensor:
|
75 |
+
return SummarizeTensor(attr_value.tensor());
|
76 |
+
case AttrValue::kList: {
|
77 |
+
string ret = "[";
|
78 |
+
if (attr_value.list().s_size() > 0) {
|
79 |
+
for (int i = 0; i < attr_value.list().s_size(); ++i) {
|
80 |
+
if (i > 0) strings::StrAppend(&ret, ", ");
|
81 |
+
strings::StrAppend(&ret, SummarizeString(attr_value.list().s(i)));
|
82 |
+
}
|
83 |
+
} else if (attr_value.list().i_size() > 0) {
|
84 |
+
for (int i = 0; i < attr_value.list().i_size(); ++i) {
|
85 |
+
if (i > 0) strings::StrAppend(&ret, ", ");
|
86 |
+
strings::StrAppend(&ret, attr_value.list().i(i));
|
87 |
+
}
|
88 |
+
} else if (attr_value.list().f_size() > 0) {
|
89 |
+
for (int i = 0; i < attr_value.list().f_size(); ++i) {
|
90 |
+
if (i > 0) strings::StrAppend(&ret, ", ");
|
91 |
+
strings::StrAppend(&ret, attr_value.list().f(i));
|
92 |
+
}
|
93 |
+
} else if (attr_value.list().b_size() > 0) {
|
94 |
+
for (int i = 0; i < attr_value.list().b_size(); ++i) {
|
95 |
+
if (i > 0) strings::StrAppend(&ret, ", ");
|
96 |
+
strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false");
|
97 |
+
}
|
98 |
+
} else if (attr_value.list().type_size() > 0) {
|
99 |
+
for (int i = 0; i < attr_value.list().type_size(); ++i) {
|
100 |
+
if (i > 0) strings::StrAppend(&ret, ", ");
|
101 |
+
strings::StrAppend(&ret,
|
102 |
+
EnumName_DataType(attr_value.list().type(i)));
|
103 |
+
}
|
104 |
+
} else if (attr_value.list().shape_size() > 0) {
|
105 |
+
for (int i = 0; i < attr_value.list().shape_size(); ++i) {
|
106 |
+
if (i > 0) strings::StrAppend(&ret, ", ");
|
107 |
+
strings::StrAppend(
|
108 |
+
&ret, TensorShape::DebugString(attr_value.list().shape(i)));
|
109 |
+
}
|
110 |
+
} else if (attr_value.list().tensor_size() > 0) {
|
111 |
+
for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
|
112 |
+
if (i > 0) strings::StrAppend(&ret, ", ");
|
113 |
+
strings::StrAppend(&ret,
|
114 |
+
SummarizeTensor(attr_value.list().tensor(i)));
|
115 |
+
}
|
116 |
+
} else if (attr_value.list().func_size() > 0) {
|
117 |
+
for (int i = 0; i < attr_value.list().func_size(); ++i) {
|
118 |
+
if (i > 0) strings::StrAppend(&ret, ", ");
|
119 |
+
strings::StrAppend(&ret, SummarizeFunc(attr_value.list().func(i)));
|
120 |
+
}
|
121 |
+
}
|
122 |
+
|
123 |
+
strings::StrAppend(&ret, "]");
|
124 |
+
return ret;
|
125 |
+
}
|
126 |
+
case AttrValue::kFunc: {
|
127 |
+
return SummarizeFunc(attr_value.func());
|
128 |
+
}
|
129 |
+
case AttrValue::kPlaceholder:
|
130 |
+
return strings::StrCat("$", attr_value.placeholder());
|
131 |
+
case AttrValue::VALUE_NOT_SET:
|
132 |
+
return "<Unknown AttrValue type>";
|
133 |
+
}
|
134 |
+
return "<Unknown AttrValue type>"; // Prevent missing return warning
|
135 |
+
}
|
136 |
+
|
137 |
+
Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
|
138 |
+
int num_set = 0;
|
139 |
+
|
140 |
+
#define VALIDATE_FIELD(name, type_string, oneof_case) \
|
141 |
+
do { \
|
142 |
+
if (attr_value.has_list()) { \
|
143 |
+
if (attr_value.list().name##_size() > 0) { \
|
144 |
+
if (type != "list(" type_string ")") { \
|
145 |
+
return errors::InvalidArgument( \
|
146 |
+
"AttrValue had value with type 'list(" type_string ")' when '", \
|
147 |
+
type, "' expected"); \
|
148 |
+
} \
|
149 |
+
++num_set; \
|
150 |
+
} \
|
151 |
+
} else if (attr_value.value_case() == AttrValue::oneof_case) { \
|
152 |
+
if (type != type_string) { \
|
153 |
+
return errors::InvalidArgument( \
|
154 |
+
"AttrValue had value with type '" type_string "' when '", type, \
|
155 |
+
"' expected"); \
|
156 |
+
} \
|
157 |
+
++num_set; \
|
158 |
+
} \
|
159 |
+
} while (false)
|
160 |
+
|
161 |
+
VALIDATE_FIELD(s, "string", kS);
|
162 |
+
VALIDATE_FIELD(i, "int", kI);
|
163 |
+
VALIDATE_FIELD(f, "float", kF);
|
164 |
+
VALIDATE_FIELD(b, "bool", kB);
|
165 |
+
VALIDATE_FIELD(type, "type", kType);
|
166 |
+
VALIDATE_FIELD(shape, "shape", kShape);
|
167 |
+
VALIDATE_FIELD(tensor, "tensor", kTensor);
|
168 |
+
VALIDATE_FIELD(func, "func", kFunc);
|
169 |
+
|
170 |
+
#undef VALIDATE_FIELD
|
171 |
+
|
172 |
+
if (attr_value.value_case() == AttrValue::kPlaceholder) {
|
173 |
+
return errors::InvalidArgument(
|
174 |
+
"AttrValue had value with unexpected type 'placeholder'");
|
175 |
+
}
|
176 |
+
|
177 |
+
// If the attr type is 'list', we expect attr_value.has_list() to be
|
178 |
+
// true. However, proto3's attr_value.has_list() can be false when
|
179 |
+
// set to an empty list for GraphDef versions <= 4. So we simply
|
180 |
+
// check if has_list is false and some other field in attr_value is
|
181 |
+
// set to flag the error. This test can be made more strict once
|
182 |
+
// support for GraphDef versions <= 4 is dropped.
|
183 |
+
if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) {
|
184 |
+
if (num_set) {
|
185 |
+
return errors::InvalidArgument(
|
186 |
+
"AttrValue missing value with expected type '", type, "'");
|
187 |
+
} else {
|
188 |
+
// Indicate that we have a list, but an empty one.
|
189 |
+
++num_set;
|
190 |
+
}
|
191 |
+
}
|
192 |
+
|
193 |
+
// Okay to have an empty list, but not to be missing a non-list value.
|
194 |
+
if (num_set == 0 && !StringPiece(type).starts_with("list(")) {
|
195 |
+
return errors::InvalidArgument(
|
196 |
+
"AttrValue missing value with expected type '", type, "'");
|
197 |
+
}
|
198 |
+
|
199 |
+
// Ref types and DT_INVALID are illegal, and DataTypes must
|
200 |
+
// be a valid enum type.
|
201 |
+
if (type == "type") {
|
202 |
+
if (!DataType_IsValid(attr_value.type())) {
|
203 |
+
return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
|
204 |
+
attr_value.type());
|
205 |
+
}
|
206 |
+
if (IsRefType(attr_value.type())) {
|
207 |
+
return errors::InvalidArgument(
|
208 |
+
"AttrValue must not have reference type value of ",
|
209 |
+
DataTypeString(attr_value.type()));
|
210 |
+
}
|
211 |
+
if (attr_value.type() == DT_INVALID) {
|
212 |
+
return errors::InvalidArgument("AttrValue has invalid DataType");
|
213 |
+
}
|
214 |
+
} else if (type == "list(type)") {
|
215 |
+
for (auto as_int : attr_value.list().type()) {
|
216 |
+
const DataType dtype = static_cast<DataType>(as_int);
|
217 |
+
if (!DataType_IsValid(dtype)) {
|
218 |
+
return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
|
219 |
+
as_int);
|
220 |
+
}
|
221 |
+
if (IsRefType(dtype)) {
|
222 |
+
return errors::InvalidArgument(
|
223 |
+
"AttrValue must not have reference type value of ",
|
224 |
+
DataTypeString(dtype));
|
225 |
+
}
|
226 |
+
if (dtype == DT_INVALID) {
|
227 |
+
return errors::InvalidArgument("AttrValue contains invalid DataType");
|
228 |
+
}
|
229 |
+
}
|
230 |
+
}
|
231 |
+
|
232 |
+
return Status::OK();
|
233 |
+
}
|
234 |
+
|
235 |
+
bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
|
236 |
+
// Parse type.
|
237 |
+
string field_name;
|
238 |
+
bool is_list = type.Consume("list(");
|
239 |
+
if (type.Consume("string")) {
|
240 |
+
field_name = "s";
|
241 |
+
} else if (type.Consume("int")) {
|
242 |
+
field_name = "i";
|
243 |
+
} else if (type.Consume("float")) {
|
244 |
+
field_name = "f";
|
245 |
+
} else if (type.Consume("bool")) {
|
246 |
+
field_name = "b";
|
247 |
+
} else if (type.Consume("type")) {
|
248 |
+
field_name = "type";
|
249 |
+
} else if (type.Consume("shape")) {
|
250 |
+
field_name = "shape";
|
251 |
+
} else if (type.Consume("tensor")) {
|
252 |
+
field_name = "tensor";
|
253 |
+
} else if (type.Consume("func")) {
|
254 |
+
field_name = "func";
|
255 |
+
} else if (type.Consume("placeholder")) {
|
256 |
+
field_name = "placeholder";
|
257 |
+
} else {
|
258 |
+
return false;
|
259 |
+
}
|
260 |
+
if (is_list && !type.Consume(")")) {
|
261 |
+
return false;
|
262 |
+
}
|
263 |
+
|
264 |
+
// Construct a valid text proto message to parse.
|
265 |
+
string to_parse;
|
266 |
+
if (is_list) {
|
267 |
+
// TextFormat parser considers "i: 7" to be the same as "i: [7]",
|
268 |
+
// but we only want to allow list values with [].
|
269 |
+
StringPiece cleaned = text;
|
270 |
+
str_util::RemoveLeadingWhitespace(&cleaned);
|
271 |
+
str_util::RemoveTrailingWhitespace(&cleaned);
|
272 |
+
if (cleaned.size() < 2 || cleaned[0] != '[' ||
|
273 |
+
cleaned[cleaned.size() - 1] != ']') {
|
274 |
+
return false;
|
275 |
+
}
|
276 |
+
cleaned.remove_prefix(1);
|
277 |
+
str_util::RemoveLeadingWhitespace(&cleaned);
|
278 |
+
if (cleaned.size() == 1) {
|
279 |
+
// User wrote "[]", so return empty list without invoking the TextFormat
|
280 |
+
// parse which returns an error for "i: []".
|
281 |
+
out->Clear();
|
282 |
+
out->mutable_list();
|
283 |
+
return true;
|
284 |
+
}
|
285 |
+
to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
|
286 |
+
} else {
|
287 |
+
to_parse = strings::StrCat(field_name, ": ", text);
|
288 |
+
}
|
289 |
+
|
290 |
+
return ProtoParseFromString(to_parse, out);
|
291 |
+
}
|
292 |
+
|
293 |
+
void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
|
294 |
+
|
295 |
+
#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
|
296 |
+
void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
|
297 |
+
|
298 |
+
#define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \
|
299 |
+
void SetAttrValue(ARG_TYPE value, AttrValue* out) { \
|
300 |
+
out->mutable_list()->Clear(); /* create list() even if value empty */ \
|
301 |
+
for (const auto& v : value) { \
|
302 |
+
out->mutable_list()->add_##FIELD(v); \
|
303 |
+
} \
|
304 |
+
}
|
305 |
+
|
306 |
+
#define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
|
307 |
+
DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
|
308 |
+
DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
|
309 |
+
|
310 |
+
DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
|
311 |
+
DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
|
312 |
+
DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
|
313 |
+
DEFINE_SET_ATTR_VALUE_BOTH(int64, i)
|
314 |
+
DEFINE_SET_ATTR_VALUE_BOTH(int32, i)
|
315 |
+
DEFINE_SET_ATTR_VALUE_BOTH(float, f)
|
316 |
+
DEFINE_SET_ATTR_VALUE_BOTH(double, f)
|
317 |
+
DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
|
318 |
+
DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
|
319 |
+
DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
|
320 |
+
DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
|
321 |
+
|
322 |
+
void SetAttrValue(StringPiece value, AttrValue* out) {
|
323 |
+
out->set_s(value.data(), value.size());
|
324 |
+
}
|
325 |
+
|
326 |
+
void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) {
|
327 |
+
out->mutable_list()->Clear(); // Create list() even if value empty.
|
328 |
+
for (const auto& v : value) {
|
329 |
+
out->mutable_list()->add_s(v.data(), v.size());
|
330 |
+
}
|
331 |
+
}
|
332 |
+
|
333 |
+
void SetAttrValue(const TensorShape& value, AttrValue* out) {
|
334 |
+
value.AsProto(out->mutable_shape());
|
335 |
+
}
|
336 |
+
|
337 |
+
void SetAttrValue(const TensorShapeProto& value, AttrValue* out) {
|
338 |
+
*out->mutable_shape() = value;
|
339 |
+
}
|
340 |
+
|
341 |
+
void SetAttrValue(const PartialTensorShape& value, AttrValue* out) {
|
342 |
+
value.AsProto(out->mutable_shape());
|
343 |
+
}
|
344 |
+
|
345 |
+
void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
|
346 |
+
out->mutable_list()->Clear(); // Create list() even if value empty.
|
347 |
+
for (const auto& v : value) {
|
348 |
+
v.AsProto(out->mutable_list()->add_shape());
|
349 |
+
}
|
350 |
+
}
|
351 |
+
|
352 |
+
void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) {
|
353 |
+
out->mutable_list()->Clear(); // Create list() even if value empty.
|
354 |
+
for (const auto& v : value) {
|
355 |
+
*out->mutable_list()->add_shape() = v;
|
356 |
+
}
|
357 |
+
}
|
358 |
+
|
359 |
+
void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,
|
360 |
+
AttrValue* out) {
|
361 |
+
out->mutable_list()->Clear(); // Create list() even if value empty.
|
362 |
+
for (const auto& v : value) {
|
363 |
+
v.AsProto(out->mutable_list()->add_shape());
|
364 |
+
}
|
365 |
+
}
|
366 |
+
|
367 |
+
void SetAttrValue(const Tensor& value, AttrValue* out) {
|
368 |
+
if (value.NumElements() > 1) {
|
369 |
+
value.AsProtoTensorContent(out->mutable_tensor());
|
370 |
+
} else {
|
371 |
+
value.AsProtoField(out->mutable_tensor());
|
372 |
+
}
|
373 |
+
}
|
374 |
+
|
375 |
+
void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
|
376 |
+
out->mutable_list()->Clear(); // Create list() even if value empty.
|
377 |
+
for (const auto& v : value) {
|
378 |
+
if (v.NumElements() > 1) {
|
379 |
+
v.AsProtoTensorContent(out->mutable_list()->add_tensor());
|
380 |
+
} else {
|
381 |
+
v.AsProtoField(out->mutable_list()->add_tensor());
|
382 |
+
}
|
383 |
+
}
|
384 |
+
}
|
385 |
+
|
386 |
+
void SetAttrValue(const TensorProto& value, AttrValue* out) {
|
387 |
+
*out->mutable_tensor() = value;
|
388 |
+
}
|
389 |
+
|
390 |
+
void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
|
391 |
+
out->mutable_list()->Clear(); // Create list() even if value empty.
|
392 |
+
for (const auto& v : value) {
|
393 |
+
*out->mutable_list()->add_tensor() = v;
|
394 |
+
}
|
395 |
+
}
|
396 |
+
|
397 |
+
void SetAttrValue(const NameAttrList& value, AttrValue* out) {
|
398 |
+
*out->mutable_func() = value;
|
399 |
+
}
|
400 |
+
|
401 |
+
void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
|
402 |
+
out->mutable_list()->Clear(); // Create list() even if value empty.
|
403 |
+
for (const auto& v : value) {
|
404 |
+
*out->mutable_list()->add_func() = v;
|
405 |
+
}
|
406 |
+
}
|
407 |
+
|
408 |
+
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
|
409 |
+
// There are multiple equivalent representations of attr values containing
|
410 |
+
// TensorProtos. Compare them by constructing Tensors and serializing them
|
411 |
+
// back. Comparing Tensor objects is pretty tricky.
|
412 |
+
if (a.has_tensor() != b.has_tensor()) {
|
413 |
+
return false;
|
414 |
+
} else if (a.has_tensor() && b.has_tensor()) {
|
415 |
+
Tensor at(a.tensor().dtype());
|
416 |
+
bool success = at.FromProto(a.tensor());
|
417 |
+
DCHECK(success);
|
418 |
+
|
419 |
+
Tensor bt(b.tensor().dtype());
|
420 |
+
success = bt.FromProto(b.tensor());
|
421 |
+
DCHECK(success);
|
422 |
+
|
423 |
+
TensorProto ap;
|
424 |
+
at.AsProtoTensorContent(&ap);
|
425 |
+
|
426 |
+
TensorProto bp;
|
427 |
+
bt.AsProtoTensorContent(&bp);
|
428 |
+
|
429 |
+
string a_str, b_str;
|
430 |
+
SerializeToStringDeterministic(ap, &a_str);
|
431 |
+
SerializeToStringDeterministic(bp, &b_str);
|
432 |
+
return a_str == b_str;
|
433 |
+
}
|
434 |
+
|
435 |
+
// `func` field contains a nested AttrValue. Compare such AttrValues
|
436 |
+
// recursively.
|
437 |
+
if (a.has_func() != b.has_func()) {
|
438 |
+
return false;
|
439 |
+
} else if (a.has_func() && b.has_func()) {
|
440 |
+
const NameAttrList& af = a.func();
|
441 |
+
const NameAttrList& bf = b.func();
|
442 |
+
if (af.name() != bf.name()) return false;
|
443 |
+
std::unordered_map<string, AttrValue> am(af.attr().begin(),
|
444 |
+
af.attr().end());
|
445 |
+
for (const auto& bm_pair : bf.attr()) {
|
446 |
+
const auto& iter = am.find(bm_pair.first);
|
447 |
+
if (iter == am.end()) return false;
|
448 |
+
if (!AreAttrValuesEqual(iter->second, bm_pair.second)) return false;
|
449 |
+
am.erase(iter);
|
450 |
+
}
|
451 |
+
if (!am.empty()) return false;
|
452 |
+
return true;
|
453 |
+
}
|
454 |
+
|
455 |
+
// All other fields in AttrValue have deterministic representations.
|
456 |
+
// It is safe to compare their serialized strings.
|
457 |
+
string a_str, b_str;
|
458 |
+
SerializeToStringDeterministic(a, &a_str);
|
459 |
+
SerializeToStringDeterministic(b, &b_str);
|
460 |
+
return a_str == b_str;
|
461 |
+
}
|
462 |
+
|
463 |
+
uint64 AttrValueHash(const AttrValue& a) {
|
464 |
+
if (a.has_tensor()) {
|
465 |
+
// Deal with multiple representations by parsing TensorProto to
|
466 |
+
// Tensor and serializing it back. This is slow, but current use case
|
467 |
+
// don't need high efficiency.
|
468 |
+
Tensor tensor(a.tensor().dtype());
|
469 |
+
bool success = tensor.FromProto(a.tensor());
|
470 |
+
DCHECK(success);
|
471 |
+
TensorProto p;
|
472 |
+
tensor.AsProtoTensorContent(&p);
|
473 |
+
string s;
|
474 |
+
SerializeToStringDeterministic(p, &s);
|
475 |
+
return Hash64(s);
|
476 |
+
}
|
477 |
+
if (a.has_func()) {
|
478 |
+
const NameAttrList& func = a.func();
|
479 |
+
uint64 h = Hash64(func.name());
|
480 |
+
std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
|
481 |
+
for (const auto& pair : map) {
|
482 |
+
h = Hash64(pair.first.data(), pair.first.size(), h);
|
483 |
+
h = Hash64Combine(AttrValueHash(pair.second), h);
|
484 |
+
}
|
485 |
+
return h;
|
486 |
+
}
|
487 |
+
|
488 |
+
// If `a` is not a tensor or func, get a hash of serialized string.
|
489 |
+
string s;
|
490 |
+
SerializeToStringDeterministic(a, &s);
|
491 |
+
return Hash64(s);
|
492 |
+
}
|
493 |
+
|
494 |
+
bool HasPlaceHolder(const AttrValue& val) {
|
495 |
+
switch (val.value_case()) {
|
496 |
+
case AttrValue::kList: {
|
497 |
+
for (const NameAttrList& func : val.list().func()) {
|
498 |
+
for (const auto& p : func.attr()) {
|
499 |
+
if (HasPlaceHolder(p.second)) {
|
500 |
+
return true;
|
501 |
+
}
|
502 |
+
}
|
503 |
+
}
|
504 |
+
break;
|
505 |
+
}
|
506 |
+
case AttrValue::kFunc:
|
507 |
+
for (const auto& p : val.func().attr()) {
|
508 |
+
if (HasPlaceHolder(p.second)) {
|
509 |
+
return true;
|
510 |
+
}
|
511 |
+
}
|
512 |
+
break;
|
513 |
+
case AttrValue::kPlaceholder:
|
514 |
+
return true;
|
515 |
+
default:
|
516 |
+
break;
|
517 |
+
}
|
518 |
+
return false;
|
519 |
+
}
|
520 |
+
|
521 |
+
bool SubstitutePlaceholders(const SubstituteFunc& substitute,
|
522 |
+
AttrValue* value) {
|
523 |
+
switch (value->value_case()) {
|
524 |
+
case AttrValue::kList: {
|
525 |
+
for (NameAttrList& func : *value->mutable_list()->mutable_func()) {
|
526 |
+
for (auto& p : *func.mutable_attr()) {
|
527 |
+
if (!SubstitutePlaceholders(substitute, &p.second)) {
|
528 |
+
return false;
|
529 |
+
}
|
530 |
+
}
|
531 |
+
}
|
532 |
+
break;
|
533 |
+
}
|
534 |
+
case AttrValue::kFunc:
|
535 |
+
for (auto& p : *(value->mutable_func()->mutable_attr())) {
|
536 |
+
if (!SubstitutePlaceholders(substitute, &p.second)) {
|
537 |
+
return false;
|
538 |
+
}
|
539 |
+
}
|
540 |
+
break;
|
541 |
+
case AttrValue::kPlaceholder:
|
542 |
+
return substitute(value->placeholder(), value);
|
543 |
+
case AttrValue::VALUE_NOT_SET:
|
544 |
+
return false;
|
545 |
+
default:
|
546 |
+
break;
|
547 |
+
}
|
548 |
+
return true;
|
549 |
+
}
|
550 |
+
|
551 |
+
} // namespace tensorflow
|
attr_value_util.h
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
|
18 |
+
|
19 |
+
#include <functional>
|
20 |
+
#include <string>
|
21 |
+
#include <vector>
|
22 |
+
|
23 |
+
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
24 |
+
#include "tensorflow/core/framework/tensor.h"
|
25 |
+
#include "tensorflow/core/framework/tensor_shape.h"
|
26 |
+
#include "tensorflow/core/framework/types.h"
|
27 |
+
#include "tensorflow/core/lib/core/status.h"
|
28 |
+
#include "tensorflow/core/lib/core/stringpiece.h"
|
29 |
+
#include "tensorflow/core/lib/gtl/array_slice.h"
|
30 |
+
|
31 |
+
namespace tensorflow {
|
32 |
+
|
33 |
+
// Forward declare protos so their symbols can be removed from .so exports
|
34 |
+
class AttrValue;
|
35 |
+
class NameAttrList;
|
36 |
+
|
37 |
+
// A human-readable rendering of attr_value, that is more concise than a
|
38 |
+
// text-format proto.
|
39 |
+
string SummarizeAttrValue(const AttrValue& attr_value);
|
40 |
+
|
41 |
+
// Generates an error if attr_value doesn't have the indicated attr type.
|
42 |
+
Status AttrValueHasType(const AttrValue& attr_value, StringPiece type);
|
43 |
+
|
44 |
+
// Converts a text proto value from "text" into the field of *out
|
45 |
+
// indicated by "type" (e.g. from the type field of an AttrDef).
|
46 |
+
// Examples:
|
47 |
+
// * If type:"int" and text:"-14", then *out is set to "i: -14"
|
48 |
+
// * If type:"list(string)" and text:"['foo', 'bar']",
|
49 |
+
// then *out is set to "list { s: ['foo', 'bar'] }"
|
50 |
+
// Returns true on success.
|
51 |
+
bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out);
|
52 |
+
|
53 |
+
// Sets *out based on the type of value.
|
54 |
+
void SetAttrValue(const string& value, AttrValue* out);
|
55 |
+
void SetAttrValue(const char* value, AttrValue* out);
|
56 |
+
void SetAttrValue(StringPiece value, AttrValue* out);
|
57 |
+
void SetAttrValue(int64 value, AttrValue* out);
|
58 |
+
void SetAttrValue(int32 value, AttrValue* out);
|
59 |
+
void SetAttrValue(float value, AttrValue* out);
|
60 |
+
void SetAttrValue(double value, AttrValue* out);
|
61 |
+
void SetAttrValue(bool value, AttrValue* out);
|
62 |
+
void SetAttrValue(DataType value, AttrValue* out);
|
63 |
+
void SetAttrValue(const TensorShape& value, AttrValue* out);
|
64 |
+
void SetAttrValue(const TensorShapeProto& value, AttrValue* out);
|
65 |
+
void SetAttrValue(const PartialTensorShape& value, AttrValue* out);
|
66 |
+
void SetAttrValue(const Tensor& value, AttrValue* out);
|
67 |
+
void SetAttrValue(const TensorProto& value, AttrValue* out);
|
68 |
+
void SetAttrValue(const NameAttrList& value, AttrValue* out);
|
69 |
+
|
70 |
+
void SetAttrValue(gtl::ArraySlice<string> value, AttrValue* out);
|
71 |
+
void SetAttrValue(gtl::ArraySlice<const char*> value, AttrValue* out);
|
72 |
+
void SetAttrValue(gtl::ArraySlice<StringPiece> value, AttrValue* out);
|
73 |
+
void SetAttrValue(gtl::ArraySlice<int64> value, AttrValue* out);
|
74 |
+
void SetAttrValue(gtl::ArraySlice<int32> value, AttrValue* out);
|
75 |
+
void SetAttrValue(gtl::ArraySlice<float> value, AttrValue* out);
|
76 |
+
void SetAttrValue(gtl::ArraySlice<double> value, AttrValue* out);
|
77 |
+
void SetAttrValue(gtl::ArraySlice<bool> value, AttrValue* out);
|
78 |
+
void SetAttrValue(const std::vector<bool>& value, AttrValue* out);
|
79 |
+
void SetAttrValue(std::initializer_list<bool> value, AttrValue* out);
|
80 |
+
void SetAttrValue(DataTypeSlice value, AttrValue* out);
|
81 |
+
void SetAttrValue(gtl::ArraySlice<TensorShape> value, AttrValue* out);
|
82 |
+
void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out);
|
83 |
+
void SetAttrValue(gtl::ArraySlice<PartialTensorShape> value, AttrValue* out);
|
84 |
+
void SetAttrValue(gtl::ArraySlice<Tensor> value, AttrValue* out);
|
85 |
+
void SetAttrValue(gtl::ArraySlice<TensorProto> value, AttrValue* out);
|
86 |
+
void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out);
|
87 |
+
|
88 |
+
void SetAttrValue(const AttrValue& value, AttrValue* out);
|
89 |
+
|
90 |
+
// Returns true if a and b have the same value.
|
91 |
+
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b);
|
92 |
+
|
93 |
+
// Returns a hash of `a` that is consistent with AreAttrValuesEqual. In other
|
94 |
+
// words, if two AttrValues compare equal according to AreAttrValuesEqual,
|
95 |
+
// they will have the same hash value.
|
96 |
+
// Similarly to protobuf deterministic serialization, hash value is
|
97 |
+
// guaranteed to be stable only for a given binary. In particular, one should
|
98 |
+
// probably not persist the returned value.
|
99 |
+
uint64 AttrValueHash(const AttrValue& a);
|
100 |
+
|
101 |
+
// Returns true if "val" has a placeholder.
|
102 |
+
bool HasPlaceHolder(const AttrValue& val);
|
103 |
+
|
104 |
+
// SubstitutePlaceholders recursively replaces placeholders in 'value'
|
105 |
+
// with an attr value by calling SubstituteFunc. Returns true iff all
|
106 |
+
// placeholders in "value" are replaced with a value.
|
107 |
+
//
|
108 |
+
// SubstituteFunc is given a placeholder string. If the placeholder is
|
109 |
+
// unknown, SubstituteFunc returns false. Otherwise, overwrites the
|
110 |
+
// attr value and returns true.
|
111 |
+
using SubstituteFunc = std::function<bool(const string&, AttrValue*)>;
|
112 |
+
bool SubstitutePlaceholders(const SubstituteFunc& substitute, AttrValue* value);
|
113 |
+
|
114 |
+
} // namespace tensorflow
|
115 |
+
|
116 |
+
#endif // TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
|
attr_value_util_test.cc
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/attr_value_util.h"
|
17 |
+
|
18 |
+
#include <vector>
|
19 |
+
#include "tensorflow/core/framework/attr_value.pb.h"
|
20 |
+
#include "tensorflow/core/lib/core/status_test_util.h"
|
21 |
+
#include "tensorflow/core/platform/protobuf.h"
|
22 |
+
#include "tensorflow/core/platform/test.h"
|
23 |
+
|
24 |
+
namespace tensorflow {
|
25 |
+
|
26 |
+
// A few helpers to construct AttrValue protos.
|
27 |
+
template <typename T>
|
28 |
+
AttrValue V(T value) {
|
29 |
+
AttrValue ret;
|
30 |
+
SetAttrValue(value, &ret);
|
31 |
+
return ret;
|
32 |
+
}
|
33 |
+
|
34 |
+
AttrValue P(const string& p) {
|
35 |
+
AttrValue ret;
|
36 |
+
ret.set_placeholder(p);
|
37 |
+
return ret;
|
38 |
+
}
|
39 |
+
|
40 |
+
AttrValue F(const string& name,
|
41 |
+
std::vector<std::pair<string, AttrValue>> pairs) {
|
42 |
+
AttrValue ret;
|
43 |
+
ret.mutable_func()->set_name(name);
|
44 |
+
ret.mutable_func()->mutable_attr()->insert(pairs.begin(), pairs.end());
|
45 |
+
return ret;
|
46 |
+
}
|
47 |
+
|
48 |
+
AttrValue Fs(
|
49 |
+
std::vector<std::pair<string, std::vector<std::pair<string, AttrValue>>>>
|
50 |
+
funcs) {
|
51 |
+
AttrValue ret;
|
52 |
+
for (const auto& func : funcs) {
|
53 |
+
NameAttrList* entry = ret.mutable_list()->add_func();
|
54 |
+
entry->set_name(func.first);
|
55 |
+
entry->mutable_attr()->insert(func.second.begin(), func.second.end());
|
56 |
+
}
|
57 |
+
return ret;
|
58 |
+
}
|
59 |
+
|
60 |
+
TEST(AttrValueUtil, HasType) {
|
61 |
+
// OK
|
62 |
+
EXPECT_TRUE(AttrValueHasType(V(123), "int").ok());
|
63 |
+
EXPECT_TRUE(AttrValueHasType(V(1.2), "float").ok());
|
64 |
+
EXPECT_TRUE(AttrValueHasType(V(DT_FLOAT), "type").ok());
|
65 |
+
EXPECT_TRUE(AttrValueHasType(F("f", {}), "func").ok());
|
66 |
+
EXPECT_TRUE(AttrValueHasType(Fs({{"f", {}}, {"g", {}}}), "list(func)").ok());
|
67 |
+
|
68 |
+
// not OK.
|
69 |
+
EXPECT_FALSE(AttrValueHasType(V(123), "func").ok());
|
70 |
+
EXPECT_FALSE(AttrValueHasType(V(1.2), "int").ok());
|
71 |
+
EXPECT_FALSE(AttrValueHasType(V(DT_FLOAT), "shape").ok());
|
72 |
+
EXPECT_FALSE(AttrValueHasType(F("f", {}), "string").ok());
|
73 |
+
EXPECT_FALSE(AttrValueHasType(P("T"), "float").ok());
|
74 |
+
EXPECT_FALSE(AttrValueHasType(V(static_cast<DataType>(1000)), "type").ok());
|
75 |
+
std::vector<DataType> list_type({static_cast<DataType>(1000)});
|
76 |
+
EXPECT_FALSE(AttrValueHasType(V(list_type), "list(type)").ok());
|
77 |
+
}
|
78 |
+
|
79 |
+
SubstituteFunc ReplaceTWith(const AttrValue& val) {
|
80 |
+
return [val](const string& placeholder, AttrValue* target) {
|
81 |
+
if (placeholder == "T") {
|
82 |
+
*target = val;
|
83 |
+
return true;
|
84 |
+
} else {
|
85 |
+
return false;
|
86 |
+
}
|
87 |
+
};
|
88 |
+
}
|
89 |
+
|
90 |
+
TEST(AttrValueUtil, Basic) {
|
91 |
+
auto v = F("MatMul", {{"dtype", P("T")},
|
92 |
+
{"transpose_a", V(false)},
|
93 |
+
{"transpose_b", V(true)},
|
94 |
+
{"use_cublas", V(true)}});
|
95 |
+
TF_EXPECT_OK(AttrValueHasType(v, "func"));
|
96 |
+
EXPECT_TRUE(HasPlaceHolder(v));
|
97 |
+
|
98 |
+
EXPECT_EQ(
|
99 |
+
SummarizeAttrValue(v),
|
100 |
+
"MatMul[dtype=$T, transpose_a=false, transpose_b=true, use_cublas=true]");
|
101 |
+
|
102 |
+
SubstitutePlaceholders(ReplaceTWith(V(DT_FLOAT)), &v);
|
103 |
+
EXPECT_TRUE(!HasPlaceHolder(v));
|
104 |
+
EXPECT_EQ(SummarizeAttrValue(v),
|
105 |
+
"MatMul[dtype=DT_FLOAT, transpose_a=false, transpose_b=true, "
|
106 |
+
"use_cublas=true]");
|
107 |
+
}
|
108 |
+
|
109 |
+
TEST(AttrValueUtil, Shaped) {
|
110 |
+
auto v =
|
111 |
+
F("OpRequiresShape", {{"shape_full", V(TensorShape({1, 0}))},
|
112 |
+
{"shape_part", V(PartialTensorShape({-1, 1, 0}))}});
|
113 |
+
TF_EXPECT_OK(AttrValueHasType(v, "func"));
|
114 |
+
EXPECT_FALSE(HasPlaceHolder(v));
|
115 |
+
|
116 |
+
EXPECT_EQ(SummarizeAttrValue(v),
|
117 |
+
"OpRequiresShape[shape_full=[1,0], shape_part=[?,1,0]]");
|
118 |
+
}
|
119 |
+
|
120 |
+
TEST(AttrValueUtil, DeepAttr) {
|
121 |
+
auto v = Fs({{"f", {{"T", P("T")}}}, {"g", {{"T", P("T")}}}});
|
122 |
+
TF_EXPECT_OK(AttrValueHasType(v, "list(func)"));
|
123 |
+
EXPECT_TRUE(HasPlaceHolder(v));
|
124 |
+
|
125 |
+
for (int i = 0; i < 3; ++i) {
|
126 |
+
v = F("f", {{"T", P("T")}, {"F", v}});
|
127 |
+
EXPECT_TRUE(HasPlaceHolder(v));
|
128 |
+
}
|
129 |
+
EXPECT_EQ(SummarizeAttrValue(v),
|
130 |
+
"f[F=f[F=f[F=[f[T=$T], g[T=$T]], T=$T], T=$T], T=$T]");
|
131 |
+
|
132 |
+
SubstitutePlaceholders(ReplaceTWith(F("x", {})), &v);
|
133 |
+
EXPECT_TRUE(!HasPlaceHolder(v));
|
134 |
+
EXPECT_EQ(SummarizeAttrValue(v),
|
135 |
+
"f[F=f[F=f[F=[f[T=x[]], g[T=x[]]], T=x[]], T=x[]], T=x[]]");
|
136 |
+
}
|
137 |
+
|
138 |
+
AttrValue FromText(const string& text) {
|
139 |
+
AttrValue attr;
|
140 |
+
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &attr));
|
141 |
+
return attr;
|
142 |
+
}
|
143 |
+
|
144 |
+
void ExpectDifferent(const AttrValue& a1, const AttrValue& a2) {
|
145 |
+
EXPECT_FALSE(AreAttrValuesEqual(a1, a2));
|
146 |
+
EXPECT_FALSE(AreAttrValuesEqual(a2, a1));
|
147 |
+
EXPECT_NE(AttrValueHash(a1), AttrValueHash(a2));
|
148 |
+
}
|
149 |
+
|
150 |
+
TEST(AttrValueEquality, StringAndFuncTensors) {
|
151 |
+
AttrValue a = FromText(R"(
|
152 |
+
tensor {
|
153 |
+
dtype: DT_STRING
|
154 |
+
tensor_shape {
|
155 |
+
dim {
|
156 |
+
size: 2
|
157 |
+
}
|
158 |
+
}
|
159 |
+
string_val: 'reader_dataset_ops_test/tmphtXHks/text_line.0.txt'
|
160 |
+
string_val: 'reader_dataset_ops_test/tmphtXHks/text_line.1.txt'
|
161 |
+
})");
|
162 |
+
EXPECT_TRUE(AreAttrValuesEqual(a, a));
|
163 |
+
EXPECT_EQ(AttrValueHash(a), AttrValueHash(a));
|
164 |
+
|
165 |
+
AttrValue b = a;
|
166 |
+
(*b.mutable_tensor()->mutable_string_val(0))[3] = '1';
|
167 |
+
ExpectDifferent(a, b);
|
168 |
+
|
169 |
+
AttrValue c1;
|
170 |
+
c1.mutable_func()->set_name("func_name");
|
171 |
+
(*c1.mutable_func()->mutable_attr())["attr1"] = a;
|
172 |
+
(*c1.mutable_func()->mutable_attr())["attr2"] = b;
|
173 |
+
EXPECT_TRUE(AreAttrValuesEqual(c1, c1));
|
174 |
+
EXPECT_EQ(AttrValueHash(c1), AttrValueHash(c1));
|
175 |
+
|
176 |
+
ExpectDifferent(c1, a);
|
177 |
+
|
178 |
+
AttrValue c2 = c1;
|
179 |
+
c2.mutable_func()->set_name("func_name2");
|
180 |
+
ExpectDifferent(c1, c2);
|
181 |
+
|
182 |
+
c2 = c1;
|
183 |
+
(*c2.mutable_func()->mutable_attr())["attr3"] = b;
|
184 |
+
ExpectDifferent(c1, c2);
|
185 |
+
|
186 |
+
c2 = c1;
|
187 |
+
(*c2.mutable_func()->mutable_attr())["attr2"] = a;
|
188 |
+
ExpectDifferent(c1, c2);
|
189 |
+
|
190 |
+
c2 = c1;
|
191 |
+
c2.mutable_func()->mutable_attr()->erase("attr2");
|
192 |
+
ExpectDifferent(c1, c2);
|
193 |
+
}
|
194 |
+
|
195 |
+
} // namespace tensorflow
|
bfloat16.cc
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/bfloat16.h"
|
17 |
+
|
18 |
+
namespace tensorflow {
|
19 |
+
|
20 |
+
void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
|
21 |
+
const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
|
22 |
+
uint16_t* q = reinterpret_cast<uint16_t*>(dst);
|
23 |
+
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
24 |
+
for (; size != 0; p += 2, q++, size--) {
|
25 |
+
*q = p[0];
|
26 |
+
}
|
27 |
+
#else
|
28 |
+
for (; size != 0; p += 2, q++, size--) {
|
29 |
+
*q = p[1];
|
30 |
+
}
|
31 |
+
#endif
|
32 |
+
}
|
33 |
+
|
34 |
+
void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {
|
35 |
+
const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
|
36 |
+
uint16_t* q = reinterpret_cast<uint16_t*>(dst);
|
37 |
+
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
38 |
+
for (; size != 0; p++, q += 2, size--) {
|
39 |
+
q[0] = *p;
|
40 |
+
q[1] = 0;
|
41 |
+
}
|
42 |
+
#else
|
43 |
+
for (; size != 0; p++, q += 2, size--) {
|
44 |
+
q[0] = 0;
|
45 |
+
q[1] = *p;
|
46 |
+
}
|
47 |
+
#endif
|
48 |
+
}
|
49 |
+
|
50 |
+
} // end namespace tensorflow
|
bfloat16.h
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_BFLOAT16_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_BFLOAT16_H_
|
18 |
+
|
19 |
+
#include "tensorflow/core/framework/numeric_types.h"
|
20 |
+
#include "tensorflow/core/platform/types.h"
|
21 |
+
|
22 |
+
#if defined(PLATFORM_WINDOWS)
|
23 |
+
#include "tensorflow/core/platform/windows/cpu_info.h"
|
24 |
+
#endif
|
25 |
+
|
26 |
+
// Compact 16-bit encoding of floating point numbers. This representation uses
|
27 |
+
// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. It
|
28 |
+
// is assumed that floats are in IEEE 754 format so the representation is just
|
29 |
+
// bits 16-31 of a single precision float.
|
30 |
+
//
|
31 |
+
// NOTE: The IEEE floating point standard defines a float16 format that
|
32 |
+
// is different than this format (it has fewer bits of exponent and more
|
33 |
+
// bits of mantissa). We don't use that format here because conversion
|
34 |
+
// to/from 32-bit floats is more complex for that format, and the
|
35 |
+
// conversion for this format is very simple.
|
36 |
+
//
|
37 |
+
// Because of the existing IEEE float16 type, we do not name our representation
|
38 |
+
// "float16" but just use "uint16".
|
39 |
+
//
|
40 |
+
// <-----our 16bits float------->
|
41 |
+
// s e e e e e e e e f f f f f f f f f f f f f f f f f f f f f f f
|
42 |
+
// <------------------------------float-------------------------->
|
43 |
+
// 3 3 2 2 1 1 0
|
44 |
+
// 1 0 3 2 5 4 0
|
45 |
+
//
|
46 |
+
//
|
47 |
+
// This type only supports conversion back and forth with float.
|
48 |
+
//
|
49 |
+
// This file must be compilable by nvcc.
|
50 |
+
//
|
51 |
+
// The type is defined in framework/numeric_types.h.
|
52 |
+
|
53 |
+
namespace tensorflow {
|
54 |
+
|
55 |
+
// Conversion routines between an array of float and bfloat16 of
|
56 |
+
// "size".
|
57 |
+
void FloatToBFloat16(const float* src, bfloat16* dst, int64 size);
|
58 |
+
void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size);
|
59 |
+
|
60 |
+
} // namespace tensorflow
|
61 |
+
|
62 |
+
#endif // TENSORFLOW_FRAMEWORK_BFLOAT16_H_
|
bfloat16_test.cc
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/bfloat16.h"
|
17 |
+
|
18 |
+
#include "tensorflow/core/framework/numeric_types.h"
|
19 |
+
#include "tensorflow/core/lib/core/casts.h"
|
20 |
+
#include "tensorflow/core/platform/test.h"
|
21 |
+
#include "tensorflow/core/platform/test_benchmark.h"
|
22 |
+
|
23 |
+
namespace tensorflow {
|
24 |
+
namespace {
|
25 |
+
|
26 |
+
TEST(Bfloat16Test, Simple) {
|
27 |
+
bfloat16 a(12);
|
28 |
+
// Floating point representation of 12: 0x41400000
|
29 |
+
EXPECT_EQ(0x4140, a.value);
|
30 |
+
}
|
31 |
+
|
32 |
+
float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
|
33 |
+
uint32_t low_mantissa) {
|
34 |
+
return bit_cast<float>((sign << 31) + (exponent << 23) +
|
35 |
+
(high_mantissa << 16) + low_mantissa);
|
36 |
+
}
|
37 |
+
|
38 |
+
struct Bfloat16TestParam {
|
39 |
+
float input;
|
40 |
+
float expected;
|
41 |
+
};
|
42 |
+
|
43 |
+
class Bfloat16Test : public ::testing::Test,
|
44 |
+
public ::testing::WithParamInterface<Bfloat16TestParam> {};
|
45 |
+
|
46 |
+
TEST_P(Bfloat16Test, TruncateTest) {
|
47 |
+
bfloat16 a(GetParam().input);
|
48 |
+
if (std::isnan(GetParam().input)) {
|
49 |
+
EXPECT_TRUE(std::isnan(float(a)) || std::isinf(float(a)));
|
50 |
+
return;
|
51 |
+
}
|
52 |
+
EXPECT_EQ(GetParam().expected, float(a));
|
53 |
+
}
|
54 |
+
|
55 |
+
INSTANTIATE_TEST_CASE_P(
|
56 |
+
Bfloat16Test_Instantiation, Bfloat16Test,
|
57 |
+
::testing::Values(
|
58 |
+
Bfloat16TestParam{
|
59 |
+
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1111010111000011),
|
60 |
+
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
61 |
+
Bfloat16TestParam{
|
62 |
+
BinaryToFloat(1, 0b10000000, 0b1001000, 0b1111010111000011),
|
63 |
+
BinaryToFloat(1, 0b10000000, 0b1001000, 0b0000000000000000)},
|
64 |
+
Bfloat16TestParam{
|
65 |
+
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
|
66 |
+
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
67 |
+
Bfloat16TestParam{
|
68 |
+
BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000001),
|
69 |
+
BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000000)},
|
70 |
+
Bfloat16TestParam{
|
71 |
+
BinaryToFloat(0, 0b11111111, 0b1111111, 0b1111111111111111),
|
72 |
+
BinaryToFloat(0, 0b11111111, 0b1111111, 0b0000000000000000)},
|
73 |
+
Bfloat16TestParam{
|
74 |
+
BinaryToFloat(1, 0b10000000, 0b1001000, 0b1100000000000000),
|
75 |
+
BinaryToFloat(1, 0b10000000, 0b1001000, 0b0000000000000000)},
|
76 |
+
Bfloat16TestParam{
|
77 |
+
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
|
78 |
+
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
79 |
+
Bfloat16TestParam{
|
80 |
+
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0100000000000000),
|
81 |
+
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
82 |
+
Bfloat16TestParam{
|
83 |
+
BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
|
84 |
+
BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
|
85 |
+
Bfloat16TestParam{
|
86 |
+
BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000),
|
87 |
+
BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)},
|
88 |
+
Bfloat16TestParam{
|
89 |
+
BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000),
|
90 |
+
BinaryToFloat(0, 0b00000000, 0b1111111, 0b0000000000000000)}));
|
91 |
+
|
92 |
+
TEST(Bfloat16Test, Conversion) {
|
93 |
+
float a[100];
|
94 |
+
for (int i = 0; i < 100; ++i) {
|
95 |
+
a[i] = i + 1.25;
|
96 |
+
}
|
97 |
+
bfloat16 b[100];
|
98 |
+
float c[100];
|
99 |
+
FloatToBFloat16(a, b, 100);
|
100 |
+
BFloat16ToFloat(b, c, 100);
|
101 |
+
for (int i = 0; i < 100; ++i) {
|
102 |
+
// The relative error should be less than 1/(2^7) since bfloat16
|
103 |
+
// has 7 bits mantissa.
|
104 |
+
EXPECT_LE(fabs(c[i] - a[i]) / a[i], 1.0 / 128);
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
TEST(Bfloat16Test, Epsilon) {
|
109 |
+
EXPECT_LT(1.0f, static_cast<float>(bfloat16::epsilon() + bfloat16(1.0f)));
|
110 |
+
EXPECT_EQ(1.0f, static_cast<float>((bfloat16::epsilon() / bfloat16(2.0f)) +
|
111 |
+
bfloat16(1.0f)));
|
112 |
+
}
|
113 |
+
|
114 |
+
TEST(Bfloat16Test, Negate) {
|
115 |
+
EXPECT_EQ(-3.0f, static_cast<float>(-bfloat16(3.0f)));
|
116 |
+
EXPECT_EQ(4.5f, static_cast<float>(-bfloat16(-4.5f)));
|
117 |
+
}
|
118 |
+
|
119 |
+
static void BM_FloatToBFloat16(int iters) {
|
120 |
+
testing::StopTiming();
|
121 |
+
static const int N = 32 << 20;
|
122 |
+
const int64 tot = static_cast<int64>(iters) * N;
|
123 |
+
testing::ItemsProcessed(tot);
|
124 |
+
testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
|
125 |
+
|
126 |
+
float* inp = new float[N];
|
127 |
+
bfloat16* out = new bfloat16[N];
|
128 |
+
|
129 |
+
testing::StartTiming();
|
130 |
+
while (iters--) {
|
131 |
+
FloatToBFloat16(inp, out, N);
|
132 |
+
}
|
133 |
+
delete[] inp;
|
134 |
+
delete[] out;
|
135 |
+
}
|
136 |
+
BENCHMARK(BM_FloatToBFloat16);
|
137 |
+
|
138 |
+
static void BM_BFloat16ToFloat(int iters) {
|
139 |
+
testing::StopTiming();
|
140 |
+
static const int N = 32 << 20;
|
141 |
+
const int64 tot = static_cast<int64>(iters) * N;
|
142 |
+
testing::ItemsProcessed(tot);
|
143 |
+
testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
|
144 |
+
|
145 |
+
bfloat16* inp = new bfloat16[N];
|
146 |
+
float* out = new float[N];
|
147 |
+
|
148 |
+
testing::StartTiming();
|
149 |
+
while (iters--) {
|
150 |
+
BFloat16ToFloat(inp, out, N);
|
151 |
+
}
|
152 |
+
delete[] inp;
|
153 |
+
delete[] out;
|
154 |
+
}
|
155 |
+
BENCHMARK(BM_BFloat16ToFloat);
|
156 |
+
|
157 |
+
} // namespace
|
158 |
+
} // namespace tensorflow
|
cancellation.cc
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/cancellation.h"
|
17 |
+
|
18 |
+
#include "tensorflow/core/lib/core/errors.h"
|
19 |
+
#include "tensorflow/core/platform/logging.h"
|
20 |
+
|
21 |
+
namespace tensorflow {
|
22 |
+
|
23 |
+
const CancellationToken CancellationManager::kInvalidToken = -1;
|
24 |
+
|
25 |
+
CancellationManager::CancellationManager()
|
26 |
+
: is_cancelling_(false),
|
27 |
+
is_cancelled_(false),
|
28 |
+
next_cancellation_token_(0) {}
|
29 |
+
|
30 |
+
void CancellationManager::StartCancel() {
|
31 |
+
gtl::FlatMap<CancellationToken, CancelCallback> callbacks_to_run;
|
32 |
+
{
|
33 |
+
mutex_lock l(mu_);
|
34 |
+
if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
|
35 |
+
return;
|
36 |
+
}
|
37 |
+
is_cancelling_ = true;
|
38 |
+
std::swap(callbacks_, callbacks_to_run);
|
39 |
+
}
|
40 |
+
// We call these callbacks without holding mu_, so that concurrent
|
41 |
+
// calls to DeregisterCallback, which can happen asynchronously, do
|
42 |
+
// not block. The callbacks remain valid because any concurrent call
|
43 |
+
// to DeregisterCallback will block until the
|
44 |
+
// cancelled_notification_ is notified.
|
45 |
+
for (auto key_and_value : callbacks_to_run) {
|
46 |
+
key_and_value.second();
|
47 |
+
}
|
48 |
+
{
|
49 |
+
mutex_lock l(mu_);
|
50 |
+
is_cancelling_ = false;
|
51 |
+
is_cancelled_.store(true, std::memory_order_release);
|
52 |
+
}
|
53 |
+
cancelled_notification_.Notify();
|
54 |
+
}
|
55 |
+
|
56 |
+
CancellationToken CancellationManager::get_cancellation_token() {
|
57 |
+
mutex_lock l(mu_);
|
58 |
+
return next_cancellation_token_++;
|
59 |
+
}
|
60 |
+
|
61 |
+
bool CancellationManager::RegisterCallback(CancellationToken token,
|
62 |
+
CancelCallback callback) {
|
63 |
+
mutex_lock l(mu_);
|
64 |
+
CHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token";
|
65 |
+
bool should_register = !is_cancelled_ && !is_cancelling_;
|
66 |
+
if (should_register) {
|
67 |
+
std::swap(callbacks_[token], callback);
|
68 |
+
}
|
69 |
+
return should_register;
|
70 |
+
}
|
71 |
+
|
72 |
+
bool CancellationManager::DeregisterCallback(CancellationToken token) {
|
73 |
+
mu_.lock();
|
74 |
+
if (is_cancelled_) {
|
75 |
+
mu_.unlock();
|
76 |
+
return false;
|
77 |
+
} else if (is_cancelling_) {
|
78 |
+
mu_.unlock();
|
79 |
+
// Wait for all of the cancellation callbacks to be called. This
|
80 |
+
// wait ensures that the caller of DeregisterCallback does not
|
81 |
+
// return immediately and free objects that may be used in the
|
82 |
+
// execution of any currently pending callbacks in StartCancel.
|
83 |
+
cancelled_notification_.WaitForNotification();
|
84 |
+
return false;
|
85 |
+
} else {
|
86 |
+
callbacks_.erase(token);
|
87 |
+
mu_.unlock();
|
88 |
+
return true;
|
89 |
+
}
|
90 |
+
}
|
91 |
+
|
92 |
+
CancellationManager::~CancellationManager() { StartCancel(); }
|
93 |
+
|
94 |
+
} // end namespace tensorflow
|
cancellation.h
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_
|
18 |
+
|
19 |
+
#include <atomic>
|
20 |
+
#include <functional>
|
21 |
+
|
22 |
+
#include "tensorflow/core/lib/core/notification.h"
|
23 |
+
#include "tensorflow/core/lib/core/status.h"
|
24 |
+
#include "tensorflow/core/lib/gtl/flatmap.h"
|
25 |
+
#include "tensorflow/core/lib/hash/hash.h"
|
26 |
+
#include "tensorflow/core/platform/mutex.h"
|
27 |
+
#include "tensorflow/core/platform/thread_annotations.h"
|
28 |
+
#include "tensorflow/core/platform/types.h"
|
29 |
+
|
30 |
+
namespace tensorflow {
|
31 |
+
|
32 |
+
// A token that can be used to register and deregister a
|
33 |
+
// CancelCallback with a CancellationManager.
|
34 |
+
//
|
35 |
+
// CancellationToken values must be created by a call to
|
36 |
+
// CancellationManager::get_cancellation_token.
|
37 |
+
typedef int64 CancellationToken;
|
38 |
+
|
39 |
+
// A callback that is invoked when a step is canceled.
|
40 |
+
//
|
41 |
+
// NOTE(mrry): See caveats about CancelCallback implementations in the
|
42 |
+
// comment for CancellationManager::RegisterCallback.
|
43 |
+
typedef std::function<void()> CancelCallback;
|
44 |
+
|
45 |
+
class CancellationManager {
|
46 |
+
public:
|
47 |
+
// A value that won't be returned by get_cancellation_token().
|
48 |
+
static const CancellationToken kInvalidToken;
|
49 |
+
|
50 |
+
CancellationManager();
|
51 |
+
~CancellationManager();
|
52 |
+
|
53 |
+
// Run all callbacks associated with this manager.
|
54 |
+
void StartCancel();
|
55 |
+
|
56 |
+
// Returns true iff StartCancel() has been called.
|
57 |
+
bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
|
58 |
+
|
59 |
+
// Returns a token that must be used in calls to RegisterCallback
|
60 |
+
// and DeregisterCallback.
|
61 |
+
CancellationToken get_cancellation_token();
|
62 |
+
|
63 |
+
// Attempts to register the given callback to be invoked when this
|
64 |
+
// manager is cancelled. Returns true if the callback was
|
65 |
+
// registered; returns false if this manager was already cancelled,
|
66 |
+
// and the callback was not registered.
|
67 |
+
//
|
68 |
+
// If this method returns false, it is the caller's responsibility
|
69 |
+
// to perform any cancellation cleanup.
|
70 |
+
//
|
71 |
+
// This method is tricky to use correctly. The following usage pattern
|
72 |
+
// is recommended:
|
73 |
+
//
|
74 |
+
// class ObjectWithCancellableOperation {
|
75 |
+
// mutex mu_;
|
76 |
+
// void CancellableOperation(CancellationManager* cm,
|
77 |
+
// std::function<void(Status)> callback) {
|
78 |
+
// bool already_cancelled;
|
79 |
+
// CancellationToken token = cm->get_cancellation_token();
|
80 |
+
// {
|
81 |
+
// mutex_lock(mu_);
|
82 |
+
// already_cancelled = !cm->RegisterCallback(
|
83 |
+
// [this, token]() { Cancel(token); });
|
84 |
+
// if (!already_cancelled) {
|
85 |
+
// // Issue asynchronous operation. Associate the pending operation
|
86 |
+
// // with `token` in some object state, or provide another way for
|
87 |
+
// // the Cancel method to look up the operation for cancellation.
|
88 |
+
// // Ensure that `cm->DeregisterCallback(token)` is called without
|
89 |
+
// // holding `mu_`, before `callback` is invoked.
|
90 |
+
// // ...
|
91 |
+
// }
|
92 |
+
// }
|
93 |
+
// if (already_cancelled) {
|
94 |
+
// callback(errors::Cancelled("Operation was cancelled"));
|
95 |
+
// }
|
96 |
+
// }
|
97 |
+
//
|
98 |
+
// void Cancel(CancellationToken token) {
|
99 |
+
// mutex_lock(mu_);
|
100 |
+
// // Take action to cancel the operation with the given cancellation
|
101 |
+
// // token.
|
102 |
+
// }
|
103 |
+
//
|
104 |
+
// NOTE(mrry): The caller should take care that (i) the calling code
|
105 |
+
// is robust to `callback` being invoked asynchronously (e.g. from
|
106 |
+
// another thread), (ii) `callback` is deregistered by a call to
|
107 |
+
// this->DeregisterCallback(token) when the operation completes
|
108 |
+
// successfully, and (iii) `callback` does not invoke any method
|
109 |
+
// on this cancellation manager. Furthermore, it is important that
|
110 |
+
// the eventual caller of the complementary DeregisterCallback does not
|
111 |
+
// hold any mutexes that are required by `callback`.
|
112 |
+
bool RegisterCallback(CancellationToken token, CancelCallback callback);
|
113 |
+
|
114 |
+
// Deregister the callback that, when registered, was associated
|
115 |
+
// with the given cancellation token. Returns true iff the callback
|
116 |
+
// was deregistered and will not be invoked; otherwise returns false
|
117 |
+
// after the callback has been invoked, blocking if necessary.
|
118 |
+
//
|
119 |
+
// NOTE(mrry): This method may block if cancellation is in progress.
|
120 |
+
// The caller of this method must not hold any mutexes that are required
|
121 |
+
// to invoke any cancellation callback that has been registered with this
|
122 |
+
// cancellation manager.
|
123 |
+
bool DeregisterCallback(CancellationToken token);
|
124 |
+
|
125 |
+
private:
|
126 |
+
bool is_cancelling_;
|
127 |
+
std::atomic_bool is_cancelled_;
|
128 |
+
|
129 |
+
mutex mu_;
|
130 |
+
Notification cancelled_notification_;
|
131 |
+
CancellationToken next_cancellation_token_ GUARDED_BY(mu_);
|
132 |
+
gtl::FlatMap<CancellationToken, CancelCallback> callbacks_ GUARDED_BY(mu_);
|
133 |
+
};
|
134 |
+
|
135 |
+
} // namespace tensorflow
|
136 |
+
|
137 |
+
#endif // TENSORFLOW_FRAMEWORK_CANCELLATION_H_
|
cancellation_test.cc
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/cancellation.h"
|
17 |
+
|
18 |
+
#include <vector>
|
19 |
+
#include "tensorflow/core/lib/core/notification.h"
|
20 |
+
#include "tensorflow/core/lib/core/threadpool.h"
|
21 |
+
#include "tensorflow/core/platform/test.h"
|
22 |
+
|
23 |
+
namespace tensorflow {
|
24 |
+
|
25 |
+
TEST(Cancellation, SimpleNoCancel) {
|
26 |
+
bool is_cancelled = false;
|
27 |
+
CancellationManager* manager = new CancellationManager();
|
28 |
+
auto token = manager->get_cancellation_token();
|
29 |
+
bool registered = manager->RegisterCallback(
|
30 |
+
token, [&is_cancelled]() { is_cancelled = true; });
|
31 |
+
EXPECT_TRUE(registered);
|
32 |
+
bool deregistered = manager->DeregisterCallback(token);
|
33 |
+
EXPECT_TRUE(deregistered);
|
34 |
+
delete manager;
|
35 |
+
EXPECT_FALSE(is_cancelled);
|
36 |
+
}
|
37 |
+
|
38 |
+
TEST(Cancellation, SimpleCancel) {
|
39 |
+
bool is_cancelled = false;
|
40 |
+
CancellationManager* manager = new CancellationManager();
|
41 |
+
auto token = manager->get_cancellation_token();
|
42 |
+
bool registered = manager->RegisterCallback(
|
43 |
+
token, [&is_cancelled]() { is_cancelled = true; });
|
44 |
+
EXPECT_TRUE(registered);
|
45 |
+
manager->StartCancel();
|
46 |
+
EXPECT_TRUE(is_cancelled);
|
47 |
+
delete manager;
|
48 |
+
}
|
49 |
+
|
50 |
+
TEST(Cancellation, CancelBeforeRegister) {
|
51 |
+
CancellationManager* manager = new CancellationManager();
|
52 |
+
auto token = manager->get_cancellation_token();
|
53 |
+
manager->StartCancel();
|
54 |
+
bool registered = manager->RegisterCallback(token, nullptr);
|
55 |
+
EXPECT_FALSE(registered);
|
56 |
+
delete manager;
|
57 |
+
}
|
58 |
+
|
59 |
+
TEST(Cancellation, DeregisterAfterCancel) {
|
60 |
+
bool is_cancelled = false;
|
61 |
+
CancellationManager* manager = new CancellationManager();
|
62 |
+
auto token = manager->get_cancellation_token();
|
63 |
+
bool registered = manager->RegisterCallback(
|
64 |
+
token, [&is_cancelled]() { is_cancelled = true; });
|
65 |
+
EXPECT_TRUE(registered);
|
66 |
+
manager->StartCancel();
|
67 |
+
EXPECT_TRUE(is_cancelled);
|
68 |
+
bool deregistered = manager->DeregisterCallback(token);
|
69 |
+
EXPECT_FALSE(deregistered);
|
70 |
+
delete manager;
|
71 |
+
}
|
72 |
+
|
73 |
+
TEST(Cancellation, CancelMultiple) {
|
74 |
+
bool is_cancelled_1 = false, is_cancelled_2 = false, is_cancelled_3 = false;
|
75 |
+
CancellationManager* manager = new CancellationManager();
|
76 |
+
auto token_1 = manager->get_cancellation_token();
|
77 |
+
bool registered_1 = manager->RegisterCallback(
|
78 |
+
token_1, [&is_cancelled_1]() { is_cancelled_1 = true; });
|
79 |
+
EXPECT_TRUE(registered_1);
|
80 |
+
auto token_2 = manager->get_cancellation_token();
|
81 |
+
bool registered_2 = manager->RegisterCallback(
|
82 |
+
token_2, [&is_cancelled_2]() { is_cancelled_2 = true; });
|
83 |
+
EXPECT_TRUE(registered_2);
|
84 |
+
EXPECT_FALSE(is_cancelled_1);
|
85 |
+
EXPECT_FALSE(is_cancelled_2);
|
86 |
+
manager->StartCancel();
|
87 |
+
EXPECT_TRUE(is_cancelled_1);
|
88 |
+
EXPECT_TRUE(is_cancelled_2);
|
89 |
+
EXPECT_FALSE(is_cancelled_3);
|
90 |
+
auto token_3 = manager->get_cancellation_token();
|
91 |
+
bool registered_3 = manager->RegisterCallback(
|
92 |
+
token_3, [&is_cancelled_3]() { is_cancelled_3 = true; });
|
93 |
+
EXPECT_FALSE(registered_3);
|
94 |
+
EXPECT_FALSE(is_cancelled_3);
|
95 |
+
delete manager;
|
96 |
+
}
|
97 |
+
|
98 |
+
TEST(Cancellation, IsCancelled) {
|
99 |
+
CancellationManager* cm = new CancellationManager();
|
100 |
+
thread::ThreadPool w(Env::Default(), "test", 4);
|
101 |
+
std::vector<Notification> done(8);
|
102 |
+
for (size_t i = 0; i < done.size(); ++i) {
|
103 |
+
Notification* n = &done[i];
|
104 |
+
w.Schedule([n, cm]() {
|
105 |
+
while (!cm->IsCancelled()) {
|
106 |
+
}
|
107 |
+
n->Notify();
|
108 |
+
});
|
109 |
+
}
|
110 |
+
Env::Default()->SleepForMicroseconds(1000000 /* 1 second */);
|
111 |
+
cm->StartCancel();
|
112 |
+
for (size_t i = 0; i < done.size(); ++i) {
|
113 |
+
done[i].WaitForNotification();
|
114 |
+
}
|
115 |
+
delete cm;
|
116 |
+
}
|
117 |
+
|
118 |
+
} // namespace tensorflow
|
common_shape_fns.cc
ADDED
@@ -0,0 +1,1399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
#include "tensorflow/core/framework/common_shape_fns.h"
|
16 |
+
#include "tensorflow/core/framework/attr_value.pb.h"
|
17 |
+
|
18 |
+
namespace tensorflow {
|
19 |
+
|
20 |
+
Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
|
21 |
+
int64 dilation_rate, int64 stride,
|
22 |
+
Padding padding_type, int64* output_size,
|
23 |
+
int64* padding_before,
|
24 |
+
int64* padding_after) {
|
25 |
+
if (stride <= 0) {
|
26 |
+
return errors::InvalidArgument("Stride must be > 0, but got ", stride);
|
27 |
+
}
|
28 |
+
if (dilation_rate < 1) {
|
29 |
+
return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
|
30 |
+
dilation_rate);
|
31 |
+
}
|
32 |
+
|
33 |
+
// See also the parallel implementation in GetWindowedOutputSizeFromDimsV2.
|
34 |
+
int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1;
|
35 |
+
switch (padding_type) {
|
36 |
+
case Padding::VALID:
|
37 |
+
*output_size = (input_size - effective_filter_size + stride) / stride;
|
38 |
+
*padding_before = *padding_after = 0;
|
39 |
+
break;
|
40 |
+
case Padding::SAME:
|
41 |
+
*output_size = (input_size + stride - 1) / stride;
|
42 |
+
const int64 padding_needed =
|
43 |
+
std::max(0LL, (*output_size - 1) * stride + effective_filter_size -
|
44 |
+
input_size);
|
45 |
+
// For odd values of total padding, add more padding at the 'right'
|
46 |
+
// side of the given dimension.
|
47 |
+
*padding_before = padding_needed / 2;
|
48 |
+
*padding_after = padding_needed - *padding_before;
|
49 |
+
break;
|
50 |
+
}
|
51 |
+
if (*output_size < 0) {
|
52 |
+
return errors::InvalidArgument("computed output size would be negative");
|
53 |
+
}
|
54 |
+
return Status::OK();
|
55 |
+
}
|
56 |
+
|
57 |
+
Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
|
58 |
+
int64 stride, Padding padding_type,
|
59 |
+
int64* output_size, int64* padding_before,
|
60 |
+
int64* padding_after) {
|
61 |
+
return GetWindowedOutputSizeVerboseV2(input_size, filter_size,
|
62 |
+
/*dilation_rate=*/1, stride,
|
63 |
+
padding_type, output_size,
|
64 |
+
padding_before, padding_after);
|
65 |
+
}
|
66 |
+
|
67 |
+
Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
|
68 |
+
Padding padding_type, int64* output_size,
|
69 |
+
int64* padding_size) {
|
70 |
+
int64 padding_after_unused;
|
71 |
+
return GetWindowedOutputSizeVerbose(input_size, filter_size, stride,
|
72 |
+
padding_type, output_size, padding_size,
|
73 |
+
&padding_after_unused);
|
74 |
+
}
|
75 |
+
|
76 |
+
Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
|
77 |
+
int64 dilation_rate, int64 stride,
|
78 |
+
Padding padding_type, int64* output_size,
|
79 |
+
int64* padding_size) {
|
80 |
+
int64 padding_after_unused;
|
81 |
+
return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate,
|
82 |
+
stride, padding_type, output_size,
|
83 |
+
padding_size, &padding_after_unused);
|
84 |
+
}
|
85 |
+
|
86 |
+
Status Get3dOutputSize(const std::array<int64, 3>& input,
|
87 |
+
const std::array<int64, 3>& window,
|
88 |
+
const std::array<int64, 3>& strides,
|
89 |
+
Padding padding_type, std::array<int64, 3>* output_ptr,
|
90 |
+
std::array<int64, 3>* padding_ptr) {
|
91 |
+
for (size_t i = 0; i < input.size(); ++i) {
|
92 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i],
|
93 |
+
padding_type, &(*output_ptr)[i],
|
94 |
+
&(*padding_ptr)[i]));
|
95 |
+
}
|
96 |
+
return Status::OK();
|
97 |
+
}
|
98 |
+
|
99 |
+
Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
|
100 |
+
const std::array<int64, 3>& window,
|
101 |
+
const std::array<int64, 3>& dilations,
|
102 |
+
const std::array<int64, 3>& strides,
|
103 |
+
Padding padding_type, std::array<int64, 3>* output_ptr,
|
104 |
+
std::array<int64, 3>* padding_ptr) {
|
105 |
+
for (size_t i = 0; i < input.size(); ++i) {
|
106 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
|
107 |
+
input[i], window[i], dilations[i], strides[i], padding_type,
|
108 |
+
&(*output_ptr)[i], &(*padding_ptr)[i]));
|
109 |
+
}
|
110 |
+
return Status::OK();
|
111 |
+
}
|
112 |
+
|
113 |
+
namespace shape_inference {
|
114 |
+
|
115 |
+
// The V2 version computes windowed output size with arbitrary dilation_rate,
|
116 |
+
// while the original version only handles the cases where dilation_rates equal
|
117 |
+
// to 1.
|
118 |
+
Status GetWindowedOutputSizeFromDimsV2(
|
119 |
+
shape_inference::InferenceContext* c,
|
120 |
+
shape_inference::DimensionHandle input_size,
|
121 |
+
shape_inference::DimensionOrConstant filter_size, int64 dilation_rate,
|
122 |
+
int64 stride, Padding padding_type,
|
123 |
+
shape_inference::DimensionHandle* output_size) {
|
124 |
+
if (stride <= 0) {
|
125 |
+
return errors::InvalidArgument("Stride must be > 0, but got ", stride);
|
126 |
+
}
|
127 |
+
|
128 |
+
if (dilation_rate < 1) {
|
129 |
+
return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
|
130 |
+
dilation_rate);
|
131 |
+
}
|
132 |
+
|
133 |
+
// See also the parallel implementation in GetWindowedOutputSizeVerbose.
|
134 |
+
switch (padding_type) {
|
135 |
+
case Padding::VALID:
|
136 |
+
if (dilation_rate > 1) {
|
137 |
+
DimensionHandle window_size;
|
138 |
+
TF_RETURN_IF_ERROR(
|
139 |
+
c->Subtract(c->MakeDim(filter_size), 1, &window_size));
|
140 |
+
TF_RETURN_IF_ERROR(
|
141 |
+
c->Multiply(window_size, dilation_rate, &window_size));
|
142 |
+
TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
|
143 |
+
TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
|
144 |
+
} else {
|
145 |
+
TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
|
146 |
+
}
|
147 |
+
TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
|
148 |
+
TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
|
149 |
+
/*evenly_divisible=*/false, output_size));
|
150 |
+
break;
|
151 |
+
case Padding::SAME:
|
152 |
+
TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
|
153 |
+
TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
|
154 |
+
/*evenly_divisible=*/false, output_size));
|
155 |
+
break;
|
156 |
+
}
|
157 |
+
return Status::OK();
|
158 |
+
}
|
159 |
+
|
160 |
+
Status GetWindowedOutputSizeFromDims(
|
161 |
+
shape_inference::InferenceContext* c,
|
162 |
+
shape_inference::DimensionHandle input_size,
|
163 |
+
shape_inference::DimensionOrConstant filter_size, int64 stride,
|
164 |
+
Padding padding_type, shape_inference::DimensionHandle* output_size) {
|
165 |
+
return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
|
166 |
+
/*dilation_rate=*/1, stride,
|
167 |
+
padding_type, output_size);
|
168 |
+
}
|
169 |
+
|
170 |
+
Status UnchangedShape(shape_inference::InferenceContext* c) {
|
171 |
+
c->set_output(0, c->input(0));
|
172 |
+
return Status::OK();
|
173 |
+
}
|
174 |
+
|
175 |
+
Status MatMulShape(shape_inference::InferenceContext* c) {
|
176 |
+
ShapeHandle a;
|
177 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
|
178 |
+
|
179 |
+
ShapeHandle b;
|
180 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
|
181 |
+
|
182 |
+
bool transpose_a, transpose_b;
|
183 |
+
TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
|
184 |
+
TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
|
185 |
+
DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
|
186 |
+
DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
|
187 |
+
|
188 |
+
// Validate that the inner shapes are compatible.
|
189 |
+
DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
|
190 |
+
DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
|
191 |
+
DimensionHandle merged;
|
192 |
+
TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
|
193 |
+
|
194 |
+
c->set_output(0, c->Matrix(output_rows, output_cols));
|
195 |
+
return Status::OK();
|
196 |
+
}
|
197 |
+
|
198 |
+
Status BiasAddShape(shape_inference::InferenceContext* c) {
|
199 |
+
ShapeHandle input_shape;
|
200 |
+
|
201 |
+
// Fetch the data_format attribute, which may not exist.
|
202 |
+
string data_format;
|
203 |
+
Status s = c->GetAttr("data_format", &data_format);
|
204 |
+
|
205 |
+
if (s.ok() && data_format == "NCHW") {
|
206 |
+
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
|
207 |
+
} else {
|
208 |
+
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
|
209 |
+
}
|
210 |
+
|
211 |
+
ShapeHandle bias_shape;
|
212 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
|
213 |
+
DimensionHandle bias_dim = c->Dim(bias_shape, 0);
|
214 |
+
|
215 |
+
// If rank unknown, return unknown shape.
|
216 |
+
if (!c->RankKnown(input_shape)) {
|
217 |
+
c->set_output(0, c->UnknownShape());
|
218 |
+
return Status::OK();
|
219 |
+
}
|
220 |
+
|
221 |
+
// Output has the same shape as the input, and matches the length of
|
222 |
+
// the bias in its bias dimension.
|
223 |
+
ShapeHandle output_shape;
|
224 |
+
if (s.ok() && data_format == "NCHW") {
|
225 |
+
// Merge the length of bias_shape into the third to last dimension
|
226 |
+
ShapeHandle first;
|
227 |
+
TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -3, &first));
|
228 |
+
|
229 |
+
ShapeHandle last;
|
230 |
+
TF_RETURN_IF_ERROR(c->Subshape(input_shape, -2, &last));
|
231 |
+
|
232 |
+
DimensionHandle input_bias_dim = c->Dim(input_shape, -3);
|
233 |
+
DimensionHandle merged_bias_dim;
|
234 |
+
TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
|
235 |
+
ShapeHandle merged_bias = c->Vector(merged_bias_dim);
|
236 |
+
|
237 |
+
ShapeHandle temp;
|
238 |
+
TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
|
239 |
+
TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
|
240 |
+
} else {
|
241 |
+
ShapeHandle all_but_bias;
|
242 |
+
TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
|
243 |
+
|
244 |
+
DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
|
245 |
+
DimensionHandle merged_bias_dim;
|
246 |
+
TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
|
247 |
+
|
248 |
+
ShapeHandle merged_bias = c->Vector(merged_bias_dim);
|
249 |
+
TF_RETURN_IF_ERROR(
|
250 |
+
c->Concatenate(all_but_bias, merged_bias, &output_shape));
|
251 |
+
}
|
252 |
+
|
253 |
+
c->set_output(0, output_shape);
|
254 |
+
return Status::OK();
|
255 |
+
}
|
256 |
+
|
257 |
+
Status BiasAddGradShape(shape_inference::InferenceContext* c) {
|
258 |
+
ShapeHandle input_shape;
|
259 |
+
// Fetch the data_format attribute, which may not exist.
|
260 |
+
string data_format;
|
261 |
+
Status s = c->GetAttr("data_format", &data_format);
|
262 |
+
|
263 |
+
if (s.ok() && data_format == "NCHW") {
|
264 |
+
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
|
265 |
+
c->set_output(0, c->Vector(c->Dim(input_shape, -3)));
|
266 |
+
} else {
|
267 |
+
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
|
268 |
+
c->set_output(0, c->Vector(c->Dim(input_shape, -1)));
|
269 |
+
}
|
270 |
+
|
271 |
+
return Status::OK();
|
272 |
+
}
|
273 |
+
|
274 |
+
Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
|
275 |
+
const ShapeHandle shape_handle,
|
276 |
+
const string& tensor_name,
|
277 |
+
shape_inference::InferenceContext* c) {
|
278 |
+
if (tensor_format == FORMAT_NCHW_VECT_C) {
|
279 |
+
// Check that the vect dim has size 4.
|
280 |
+
const int num_dims = c->Rank(shape_handle);
|
281 |
+
DimensionHandle vect_dim = c->Dim(
|
282 |
+
shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format));
|
283 |
+
DimensionHandle unused_vect_dim;
|
284 |
+
TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim));
|
285 |
+
}
|
286 |
+
|
287 |
+
return Status::OK();
|
288 |
+
}
|
289 |
+
|
290 |
+
Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
|
291 |
+
const std::vector<DimensionOrConstant>& spatial,
|
292 |
+
DimensionOrConstant C, ShapeHandle* out,
|
293 |
+
shape_inference::InferenceContext* context) {
|
294 |
+
const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
|
295 |
+
std::vector<DimensionHandle> dims_actual(num_dims);
|
296 |
+
dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N);
|
297 |
+
int outer_c_index = GetTensorFeatureDimIndex(num_dims, format);
|
298 |
+
dims_actual[outer_c_index] = context->MakeDim(C);
|
299 |
+
if (format == FORMAT_NCHW_VECT_C) {
|
300 |
+
dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
|
301 |
+
context->MakeDim(4);
|
302 |
+
}
|
303 |
+
for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) {
|
304 |
+
dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
|
305 |
+
context->MakeDim(spatial[spatial_dim]);
|
306 |
+
}
|
307 |
+
*out = context->MakeShape(dims_actual);
|
308 |
+
return Status::OK();
|
309 |
+
}
|
310 |
+
|
311 |
+
Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
|
312 |
+
DimensionHandle* batch_dim,
|
313 |
+
gtl::MutableArraySlice<DimensionHandle> spatial_dims,
|
314 |
+
DimensionHandle* filter_dim,
|
315 |
+
InferenceContext* context) {
|
316 |
+
const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
|
317 |
+
// Batch.
|
318 |
+
*batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
|
319 |
+
// Spatial.
|
320 |
+
for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
|
321 |
+
++spatial_dim_index) {
|
322 |
+
spatial_dims[spatial_dim_index] = context->Dim(
|
323 |
+
shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
|
324 |
+
}
|
325 |
+
// Channel.
|
326 |
+
*filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
|
327 |
+
if (format == FORMAT_NCHW_VECT_C) {
|
328 |
+
TF_RETURN_IF_ERROR(context->Multiply(
|
329 |
+
*filter_dim,
|
330 |
+
context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
|
331 |
+
filter_dim));
|
332 |
+
}
|
333 |
+
return Status::OK();
|
334 |
+
}
|
335 |
+
|
336 |
+
Status ShapeFromDimensions(DimensionHandle batch_dim,
|
337 |
+
gtl::ArraySlice<DimensionHandle> spatial_dims,
|
338 |
+
DimensionHandle filter_dim, TensorFormat format,
|
339 |
+
InferenceContext* context, ShapeHandle* shape) {
|
340 |
+
const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
|
341 |
+
std::vector<DimensionHandle> out_dims(rank);
|
342 |
+
|
343 |
+
// Batch.
|
344 |
+
out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
|
345 |
+
// Spatial.
|
346 |
+
for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
|
347 |
+
++spatial_dim_index) {
|
348 |
+
out_dims[tensorflow::GetTensorSpatialDimIndex(
|
349 |
+
rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
|
350 |
+
}
|
351 |
+
// Channel.
|
352 |
+
if (format == tensorflow::FORMAT_NCHW_VECT_C) {
|
353 |
+
// When format is NCHW_VECT_C, factor the feature map count
|
354 |
+
// into the outer feature count and the inner feature count (=4).
|
355 |
+
TF_RETURN_IF_ERROR(context->Divide(
|
356 |
+
filter_dim, 4, /*evenly_divisible=*/true,
|
357 |
+
&out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
|
358 |
+
out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4);
|
359 |
+
} else {
|
360 |
+
out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
|
361 |
+
}
|
362 |
+
|
363 |
+
*shape = context->MakeShape(out_dims);
|
364 |
+
return tensorflow::Status::OK();
|
365 |
+
}
|
366 |
+
|
367 |
+
Status Conv2DShape(shape_inference::InferenceContext* c) {
|
368 |
+
string data_format_str, filter_format_str;
|
369 |
+
if (!c->GetAttr("data_format", &data_format_str).ok()) {
|
370 |
+
data_format_str = "NHWC";
|
371 |
+
}
|
372 |
+
if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
|
373 |
+
filter_format_str = "HWIO";
|
374 |
+
}
|
375 |
+
|
376 |
+
TensorFormat data_format;
|
377 |
+
if (!FormatFromString(data_format_str, &data_format)) {
|
378 |
+
return errors::InvalidArgument("Invalid data format string: ",
|
379 |
+
data_format_str);
|
380 |
+
}
|
381 |
+
FilterTensorFormat filter_format;
|
382 |
+
if (!FilterFormatFromString(filter_format_str, &filter_format)) {
|
383 |
+
return errors::InvalidArgument("Invalid filter format string: ",
|
384 |
+
filter_format_str);
|
385 |
+
}
|
386 |
+
|
387 |
+
constexpr int num_spatial_dims = 2;
|
388 |
+
const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
|
389 |
+
ShapeHandle conv_input_shape;
|
390 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
|
391 |
+
TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
|
392 |
+
data_format, conv_input_shape, "conv_input", c));
|
393 |
+
|
394 |
+
// The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
|
395 |
+
ShapeHandle filter_shape;
|
396 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
|
397 |
+
TF_RETURN_IF_ERROR(
|
398 |
+
CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
|
399 |
+
|
400 |
+
std::vector<int32> dilations;
|
401 |
+
TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
|
402 |
+
|
403 |
+
if (dilations.size() != 4) {
|
404 |
+
return errors::InvalidArgument(
|
405 |
+
"Conv2D requires the dilation attribute to contain 4 values, but got: ",
|
406 |
+
dilations.size());
|
407 |
+
}
|
408 |
+
|
409 |
+
std::vector<int32> strides;
|
410 |
+
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
|
411 |
+
|
412 |
+
// strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
|
413 |
+
if (strides.size() != 4) {
|
414 |
+
return errors::InvalidArgument("Conv2D on data format ", data_format_str,
|
415 |
+
" requires the stride attribute to contain"
|
416 |
+
" 4 values, but got: ",
|
417 |
+
strides.size());
|
418 |
+
}
|
419 |
+
|
420 |
+
const int32 stride_rows = GetTensorDim(strides, data_format, 'H');
|
421 |
+
const int32 stride_cols = GetTensorDim(strides, data_format, 'W');
|
422 |
+
const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H');
|
423 |
+
const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W');
|
424 |
+
|
425 |
+
DimensionHandle batch_size_dim;
|
426 |
+
DimensionHandle input_depth_dim;
|
427 |
+
gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
|
428 |
+
TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format,
|
429 |
+
&batch_size_dim, &input_spatial_dims,
|
430 |
+
&input_depth_dim, c));
|
431 |
+
|
432 |
+
DimensionHandle output_depth_dim = c->Dim(
|
433 |
+
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
|
434 |
+
DimensionHandle filter_rows_dim = c->Dim(
|
435 |
+
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
|
436 |
+
DimensionHandle filter_cols_dim = c->Dim(
|
437 |
+
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
|
438 |
+
DimensionHandle filter_input_depth_dim;
|
439 |
+
if (filter_format == FORMAT_OIHW_VECT_I) {
|
440 |
+
TF_RETURN_IF_ERROR(c->Multiply(
|
441 |
+
c->Dim(filter_shape,
|
442 |
+
GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
|
443 |
+
c->Dim(filter_shape,
|
444 |
+
GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
|
445 |
+
&filter_input_depth_dim));
|
446 |
+
} else {
|
447 |
+
filter_input_depth_dim = c->Dim(
|
448 |
+
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
|
449 |
+
}
|
450 |
+
|
451 |
+
// Check that the input tensor and the filter tensor agree on the input
|
452 |
+
// channel count.
|
453 |
+
DimensionHandle unused;
|
454 |
+
TF_RETURN_IF_ERROR(
|
455 |
+
c->Merge(input_depth_dim, filter_input_depth_dim, &unused));
|
456 |
+
|
457 |
+
Padding padding;
|
458 |
+
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
|
459 |
+
|
460 |
+
DimensionHandle output_rows, output_cols;
|
461 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
|
462 |
+
c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
|
463 |
+
padding, &output_rows));
|
464 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
|
465 |
+
c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
|
466 |
+
padding, &output_cols));
|
467 |
+
|
468 |
+
ShapeHandle output_shape;
|
469 |
+
TF_RETURN_IF_ERROR(
|
470 |
+
ShapeFromDimensions(batch_size_dim, {output_rows, output_cols},
|
471 |
+
output_depth_dim, data_format, c, &output_shape));
|
472 |
+
c->set_output(0, output_shape);
|
473 |
+
return Status::OK();
|
474 |
+
}
|
475 |
+
|
476 |
+
// TODO(mjanusz): Unify all conv/pooling shape functions.
|
477 |
+
Status Conv3DShape(shape_inference::InferenceContext* c) {
|
478 |
+
ShapeHandle input_shape;
|
479 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
|
480 |
+
ShapeHandle filter_shape;
|
481 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
|
482 |
+
|
483 |
+
string data_format;
|
484 |
+
Status s = c->GetAttr("data_format", &data_format);
|
485 |
+
|
486 |
+
std::vector<int32> strides;
|
487 |
+
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
|
488 |
+
if (strides.size() != 5) {
|
489 |
+
return errors::InvalidArgument(
|
490 |
+
"Conv3D requires the stride attribute to contain 5 values, but got: ",
|
491 |
+
strides.size());
|
492 |
+
}
|
493 |
+
|
494 |
+
int32 stride_planes, stride_rows, stride_cols;
|
495 |
+
if (s.ok() && data_format == "NCDHW") {
|
496 |
+
// Convert input_shape to NDHWC.
|
497 |
+
auto dim = [&](char dimension) {
|
498 |
+
return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
|
499 |
+
};
|
500 |
+
input_shape =
|
501 |
+
c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
|
502 |
+
stride_planes = strides[2];
|
503 |
+
stride_cols = strides[3];
|
504 |
+
stride_rows = strides[4];
|
505 |
+
} else {
|
506 |
+
stride_planes = strides[1];
|
507 |
+
stride_rows = strides[2];
|
508 |
+
stride_cols = strides[3];
|
509 |
+
}
|
510 |
+
|
511 |
+
DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
|
512 |
+
DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
|
513 |
+
DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
|
514 |
+
DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
|
515 |
+
|
516 |
+
DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
|
517 |
+
DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
|
518 |
+
DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
|
519 |
+
DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
|
520 |
+
|
521 |
+
DimensionHandle unused;
|
522 |
+
TF_RETURN_IF_ERROR(
|
523 |
+
c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused));
|
524 |
+
|
525 |
+
Padding padding;
|
526 |
+
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
|
527 |
+
DimensionHandle output_planes, output_rows, output_cols;
|
528 |
+
|
529 |
+
TF_RETURN_IF_ERROR(
|
530 |
+
GetWindowedOutputSizeFromDims(c, in_planes_dim, filter_planes_dim,
|
531 |
+
stride_planes, padding, &output_planes));
|
532 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
533 |
+
c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
|
534 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
535 |
+
c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
|
536 |
+
|
537 |
+
ShapeHandle output_shape;
|
538 |
+
if (data_format == "NCDHW") {
|
539 |
+
output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
|
540 |
+
output_planes, output_rows, output_cols});
|
541 |
+
} else {
|
542 |
+
output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
|
543 |
+
output_cols, output_depth_dim});
|
544 |
+
}
|
545 |
+
c->set_output(0, output_shape);
|
546 |
+
return Status::OK();
|
547 |
+
}
|
548 |
+
|
549 |
+
Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
|
550 |
+
ShapeHandle input_shape;
|
551 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
|
552 |
+
ShapeHandle filter_shape;
|
553 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
|
554 |
+
|
555 |
+
std::vector<int32> strides;
|
556 |
+
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
|
557 |
+
|
558 |
+
if (strides.size() != 4) {
|
559 |
+
return errors::InvalidArgument(
|
560 |
+
"DepthwiseConv2D requires the stride attribute to contain 4 values, "
|
561 |
+
"but got: ",
|
562 |
+
strides.size());
|
563 |
+
}
|
564 |
+
|
565 |
+
string data_format;
|
566 |
+
Status s = c->GetAttr("data_format", &data_format);
|
567 |
+
int32 stride_rows;
|
568 |
+
int32 stride_cols;
|
569 |
+
if (s.ok() && data_format == "NCHW") {
|
570 |
+
// Canonicalize input shape to NHWC so the shape inference code below can
|
571 |
+
// process it.
|
572 |
+
input_shape =
|
573 |
+
c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
|
574 |
+
c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
|
575 |
+
stride_rows = strides[2];
|
576 |
+
stride_cols = strides[3];
|
577 |
+
} else {
|
578 |
+
stride_rows = strides[1];
|
579 |
+
stride_cols = strides[2];
|
580 |
+
}
|
581 |
+
|
582 |
+
DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
|
583 |
+
DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
|
584 |
+
DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
|
585 |
+
|
586 |
+
DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
|
587 |
+
DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
|
588 |
+
DimensionHandle input_depth = c->Dim(filter_shape, 2);
|
589 |
+
DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
|
590 |
+
|
591 |
+
// Check that the input depths are compatible.
|
592 |
+
TF_RETURN_IF_ERROR(
|
593 |
+
c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
|
594 |
+
|
595 |
+
DimensionHandle output_depth;
|
596 |
+
TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
|
597 |
+
|
598 |
+
Padding padding;
|
599 |
+
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
|
600 |
+
|
601 |
+
// TODO(mrry,shlens): Raise an error if the stride would cause
|
602 |
+
// information in the input to be ignored. This will require a change
|
603 |
+
// in the kernel implementation.
|
604 |
+
DimensionHandle output_rows, output_cols;
|
605 |
+
|
606 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
607 |
+
c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
|
608 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
609 |
+
c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
|
610 |
+
|
611 |
+
ShapeHandle output_shape;
|
612 |
+
if (data_format == "NCHW") {
|
613 |
+
output_shape =
|
614 |
+
c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
|
615 |
+
} else {
|
616 |
+
output_shape =
|
617 |
+
c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
|
618 |
+
}
|
619 |
+
c->set_output(0, output_shape);
|
620 |
+
return Status::OK();
|
621 |
+
}
|
622 |
+
|
623 |
+
Status AvgPoolShape(shape_inference::InferenceContext* c) {
|
624 |
+
string data_format_str;
|
625 |
+
TensorFormat data_format;
|
626 |
+
Status s = c->GetAttr("data_format", &data_format_str);
|
627 |
+
if (s.ok()) {
|
628 |
+
FormatFromString(data_format_str, &data_format);
|
629 |
+
} else {
|
630 |
+
data_format = FORMAT_NHWC;
|
631 |
+
}
|
632 |
+
|
633 |
+
const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
|
634 |
+
ShapeHandle input_shape;
|
635 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
|
636 |
+
|
637 |
+
TF_RETURN_IF_ERROR(
|
638 |
+
CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
|
639 |
+
|
640 |
+
std::vector<int32> strides;
|
641 |
+
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
|
642 |
+
if (strides.size() != 4) {
|
643 |
+
return errors::InvalidArgument(
|
644 |
+
"AvgPool requires the stride attribute to contain 4 values, but got: ",
|
645 |
+
strides.size());
|
646 |
+
}
|
647 |
+
|
648 |
+
std::vector<int32> kernel_sizes;
|
649 |
+
TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
|
650 |
+
if (kernel_sizes.size() != 4) {
|
651 |
+
return errors::InvalidArgument(
|
652 |
+
"AvgPool requires the ksize attribute to contain 4 values, but got: ",
|
653 |
+
kernel_sizes.size());
|
654 |
+
}
|
655 |
+
|
656 |
+
int32 stride_rows = GetTensorDim(strides, data_format, 'H');
|
657 |
+
int32 stride_cols = GetTensorDim(strides, data_format, 'W');
|
658 |
+
int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
|
659 |
+
int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
|
660 |
+
|
661 |
+
constexpr int num_spatial_dims = 2;
|
662 |
+
DimensionHandle batch_size_dim = c->Dim(
|
663 |
+
input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
|
664 |
+
DimensionHandle in_rows_dim = c->Dim(
|
665 |
+
input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
|
666 |
+
DimensionHandle in_cols_dim = c->Dim(
|
667 |
+
input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
|
668 |
+
DimensionHandle depth_dim = c->Dim(
|
669 |
+
input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
|
670 |
+
|
671 |
+
Padding padding;
|
672 |
+
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
|
673 |
+
|
674 |
+
// TODO(mrry,shlens): Raise an error if the stride would cause
|
675 |
+
// information in the input to be ignored. This will require a change
|
676 |
+
// in the kernel implementation.
|
677 |
+
|
678 |
+
DimensionHandle output_rows, output_cols;
|
679 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
680 |
+
c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
|
681 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
682 |
+
c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
|
683 |
+
|
684 |
+
ShapeHandle output_shape;
|
685 |
+
TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
|
686 |
+
{output_rows, output_cols}, depth_dim,
|
687 |
+
&output_shape, c));
|
688 |
+
c->set_output(0, output_shape);
|
689 |
+
return Status::OK();
|
690 |
+
}
|
691 |
+
|
692 |
+
Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
693 |
+
ShapeHandle x;
|
694 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
|
695 |
+
|
696 |
+
bool is_training;
|
697 |
+
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
698 |
+
int number_inputs = (is_training) ? 3 : 5;
|
699 |
+
string data_format;
|
700 |
+
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
|
701 |
+
DimensionHandle channel_dim =
|
702 |
+
(data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1);
|
703 |
+
|
704 |
+
// covers scale, offset, and if is_training is false, mean, variance
|
705 |
+
for (int i = 1; i < number_inputs; ++i) {
|
706 |
+
ShapeHandle vec;
|
707 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
|
708 |
+
TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
|
709 |
+
}
|
710 |
+
|
711 |
+
ShapeHandle y;
|
712 |
+
if (data_format == "NHWC") {
|
713 |
+
TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y));
|
714 |
+
} else {
|
715 |
+
TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y));
|
716 |
+
}
|
717 |
+
c->set_output(0, y);
|
718 |
+
ShapeHandle vector_shape = c->Vector(channel_dim);
|
719 |
+
c->set_output(1, vector_shape);
|
720 |
+
c->set_output(2, vector_shape);
|
721 |
+
c->set_output(3, vector_shape);
|
722 |
+
c->set_output(4, vector_shape);
|
723 |
+
return Status::OK();
|
724 |
+
}
|
725 |
+
|
726 |
+
Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
727 |
+
ShapeHandle y_backprop;
|
728 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
|
729 |
+
ShapeHandle x;
|
730 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
|
731 |
+
|
732 |
+
bool is_training;
|
733 |
+
string data_format;
|
734 |
+
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
735 |
+
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
|
736 |
+
DimensionHandle channel_dim =
|
737 |
+
(data_format == "NHWC") ? c->Dim(y_backprop, 3) : c->Dim(y_backprop, 1);
|
738 |
+
if (data_format == "NHWC") {
|
739 |
+
TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim));
|
740 |
+
} else {
|
741 |
+
TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim));
|
742 |
+
}
|
743 |
+
|
744 |
+
// covers scale, mean (reserve_space_1), variance (reserve_space_2)
|
745 |
+
for (int i = 2; i < 5; ++i) {
|
746 |
+
ShapeHandle vec;
|
747 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
|
748 |
+
TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
|
749 |
+
}
|
750 |
+
|
751 |
+
ShapeHandle x_backprop;
|
752 |
+
if (data_format == "NHWC") {
|
753 |
+
TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop));
|
754 |
+
} else {
|
755 |
+
TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop));
|
756 |
+
}
|
757 |
+
c->set_output(0, x_backprop);
|
758 |
+
c->set_output(1, c->Vector(channel_dim));
|
759 |
+
c->set_output(2, c->Vector(channel_dim));
|
760 |
+
// Set the correct shapes for reserve_spaces
|
761 |
+
// so that gradients can be performed when
|
762 |
+
// the op is in a symbolic condition.
|
763 |
+
if (is_training) {
|
764 |
+
c->set_output(3, c->Vector(0));
|
765 |
+
c->set_output(4, c->Vector(0));
|
766 |
+
} else {
|
767 |
+
c->set_output(3, c->Vector(channel_dim));
|
768 |
+
c->set_output(4, c->Vector(channel_dim));
|
769 |
+
}
|
770 |
+
return Status::OK();
|
771 |
+
}
|
772 |
+
|
773 |
+
Status MaxPoolShape(shape_inference::InferenceContext* c) {
|
774 |
+
string data_format_str;
|
775 |
+
TensorFormat data_format;
|
776 |
+
Status s = c->GetAttr("data_format", &data_format_str);
|
777 |
+
if (s.ok()) {
|
778 |
+
FormatFromString(data_format_str, &data_format);
|
779 |
+
} else {
|
780 |
+
data_format = FORMAT_NHWC;
|
781 |
+
}
|
782 |
+
|
783 |
+
const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
|
784 |
+
ShapeHandle input_shape;
|
785 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
|
786 |
+
|
787 |
+
TF_RETURN_IF_ERROR(
|
788 |
+
CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
|
789 |
+
|
790 |
+
std::vector<int32> strides;
|
791 |
+
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
|
792 |
+
if (strides.size() != 4) {
|
793 |
+
return errors::InvalidArgument(
|
794 |
+
"MaxPool requires the stride attribute to contain 4 values, but got: ",
|
795 |
+
strides.size());
|
796 |
+
}
|
797 |
+
|
798 |
+
std::vector<int32> kernel_sizes;
|
799 |
+
TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
|
800 |
+
if (kernel_sizes.size() != 4) {
|
801 |
+
return errors::InvalidArgument(
|
802 |
+
"MaxPool requires the ksize attribute to contain 4 values, but got: ",
|
803 |
+
kernel_sizes.size());
|
804 |
+
}
|
805 |
+
|
806 |
+
int32 stride_depth = GetTensorDim(strides, data_format, 'C');
|
807 |
+
int32 stride_rows = GetTensorDim(strides, data_format, 'H');
|
808 |
+
int32 stride_cols = GetTensorDim(strides, data_format, 'W');
|
809 |
+
int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
|
810 |
+
int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
|
811 |
+
int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
|
812 |
+
|
813 |
+
constexpr int num_spatial_dims = 2;
|
814 |
+
DimensionHandle batch_size_dim = c->Dim(
|
815 |
+
input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
|
816 |
+
DimensionHandle in_rows_dim = c->Dim(
|
817 |
+
input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
|
818 |
+
DimensionHandle in_cols_dim = c->Dim(
|
819 |
+
input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
|
820 |
+
DimensionHandle in_depth_dim = c->Dim(
|
821 |
+
input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
|
822 |
+
|
823 |
+
Padding padding;
|
824 |
+
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
|
825 |
+
|
826 |
+
ShapeHandle output_shape;
|
827 |
+
DimensionHandle output_rows, output_cols, output_depth;
|
828 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
829 |
+
c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
|
830 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
831 |
+
c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
|
832 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
833 |
+
c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
|
834 |
+
|
835 |
+
TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
|
836 |
+
{output_rows, output_cols},
|
837 |
+
output_depth, &output_shape, c));
|
838 |
+
|
839 |
+
c->set_output(0, output_shape);
|
840 |
+
return Status::OK();
|
841 |
+
}
|
842 |
+
|
843 |
+
Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
|
844 |
+
string data_format_str;
|
845 |
+
TensorFormat data_format;
|
846 |
+
Status s = c->GetAttr("data_format", &data_format_str);
|
847 |
+
if (s.ok()) {
|
848 |
+
FormatFromString(data_format_str, &data_format);
|
849 |
+
} else {
|
850 |
+
data_format = FORMAT_NHWC;
|
851 |
+
}
|
852 |
+
|
853 |
+
const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
|
854 |
+
ShapeHandle input_shape;
|
855 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
|
856 |
+
|
857 |
+
TF_RETURN_IF_ERROR(
|
858 |
+
CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
|
859 |
+
|
860 |
+
std::vector<int32> kernel_sizes;
|
861 |
+
std::vector<int32> strides;
|
862 |
+
|
863 |
+
if (c->num_inputs() + 2 == num_inputs) {
|
864 |
+
TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
|
865 |
+
|
866 |
+
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
|
867 |
+
} else {
|
868 |
+
// Verify shape of ksize and strides input.
|
869 |
+
ShapeHandle size;
|
870 |
+
DimensionHandle unused;
|
871 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
|
872 |
+
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
|
873 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
|
874 |
+
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
|
875 |
+
|
876 |
+
const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
|
877 |
+
if (kernel_sizes_tensor == nullptr) {
|
878 |
+
c->set_output(0, c->UnknownShape());
|
879 |
+
return Status::OK();
|
880 |
+
}
|
881 |
+
kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
|
882 |
+
auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
|
883 |
+
std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
|
884 |
+
kernel_sizes.begin());
|
885 |
+
|
886 |
+
const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
|
887 |
+
if (strides_tensor == nullptr) {
|
888 |
+
c->set_output(0, c->UnknownShape());
|
889 |
+
return Status::OK();
|
890 |
+
}
|
891 |
+
strides.resize(strides_tensor->shape().num_elements());
|
892 |
+
auto strides_vec = strides_tensor->flat<int32>();
|
893 |
+
std::copy_n(&strides_vec(0), strides.size(), strides.begin());
|
894 |
+
}
|
895 |
+
|
896 |
+
if (strides.size() != 4) {
|
897 |
+
return errors::InvalidArgument(
|
898 |
+
"MaxPool requires the stride attribute to contain 4 values, but "
|
899 |
+
"got: ",
|
900 |
+
strides.size());
|
901 |
+
}
|
902 |
+
if (kernel_sizes.size() != 4) {
|
903 |
+
return errors::InvalidArgument(
|
904 |
+
"MaxPool requires the ksize attribute to contain 4 values, but got: ",
|
905 |
+
kernel_sizes.size());
|
906 |
+
}
|
907 |
+
|
908 |
+
int32 stride_depth = GetTensorDim(strides, data_format, 'C');
|
909 |
+
int32 stride_rows = GetTensorDim(strides, data_format, 'H');
|
910 |
+
int32 stride_cols = GetTensorDim(strides, data_format, 'W');
|
911 |
+
int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
|
912 |
+
int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
|
913 |
+
int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
|
914 |
+
|
915 |
+
constexpr int num_spatial_dims = 2;
|
916 |
+
DimensionHandle batch_size_dim = c->Dim(
|
917 |
+
input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
|
918 |
+
DimensionHandle in_rows_dim = c->Dim(
|
919 |
+
input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
|
920 |
+
DimensionHandle in_cols_dim = c->Dim(
|
921 |
+
input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
|
922 |
+
DimensionHandle in_depth_dim = c->Dim(
|
923 |
+
input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
|
924 |
+
|
925 |
+
Padding padding;
|
926 |
+
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
|
927 |
+
|
928 |
+
ShapeHandle output_shape;
|
929 |
+
DimensionHandle output_rows, output_cols, output_depth;
|
930 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
931 |
+
c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
|
932 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
933 |
+
c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
|
934 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
935 |
+
c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
|
936 |
+
|
937 |
+
TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
|
938 |
+
{output_rows, output_cols},
|
939 |
+
output_depth, &output_shape, c));
|
940 |
+
|
941 |
+
c->set_output(0, output_shape);
|
942 |
+
return Status::OK();
|
943 |
+
}
|
944 |
+
|
945 |
+
Status Pool3DShape(shape_inference::InferenceContext* c) {
|
946 |
+
ShapeHandle input_shape;
|
947 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
|
948 |
+
|
949 |
+
string data_format;
|
950 |
+
Status s = c->GetAttr("data_format", &data_format);
|
951 |
+
|
952 |
+
std::vector<int32> strides;
|
953 |
+
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
|
954 |
+
if (strides.size() != 5) {
|
955 |
+
return errors::InvalidArgument(
|
956 |
+
"Pool3D ops require the stride attribute to contain 5 values, but "
|
957 |
+
"got: ",
|
958 |
+
strides.size());
|
959 |
+
}
|
960 |
+
|
961 |
+
std::vector<int32> kernel_sizes;
|
962 |
+
TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
|
963 |
+
if (kernel_sizes.size() != 5) {
|
964 |
+
return errors::InvalidArgument(
|
965 |
+
"Pool3D requires the ksize attribute to contain 5 values, but got: ",
|
966 |
+
kernel_sizes.size());
|
967 |
+
}
|
968 |
+
|
969 |
+
int32 stride_planes, stride_rows, stride_cols;
|
970 |
+
int32 kernel_planes, kernel_rows, kernel_cols;
|
971 |
+
|
972 |
+
if (s.ok() && data_format == "NCDHW") {
|
973 |
+
// Convert input_shape to NDHWC.
|
974 |
+
auto dim = [&](char dimension) {
|
975 |
+
return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
|
976 |
+
};
|
977 |
+
input_shape =
|
978 |
+
c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
|
979 |
+
stride_planes = strides[2];
|
980 |
+
stride_rows = strides[3];
|
981 |
+
stride_cols = strides[4];
|
982 |
+
kernel_planes = kernel_sizes[2];
|
983 |
+
kernel_rows = kernel_sizes[3];
|
984 |
+
kernel_cols = kernel_sizes[4];
|
985 |
+
} else {
|
986 |
+
stride_planes = strides[1];
|
987 |
+
stride_rows = strides[2];
|
988 |
+
stride_cols = strides[3];
|
989 |
+
kernel_planes = kernel_sizes[1];
|
990 |
+
kernel_rows = kernel_sizes[2];
|
991 |
+
kernel_cols = kernel_sizes[3];
|
992 |
+
}
|
993 |
+
|
994 |
+
DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
|
995 |
+
DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
|
996 |
+
DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
|
997 |
+
DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
|
998 |
+
DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
|
999 |
+
|
1000 |
+
Padding padding;
|
1001 |
+
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
|
1002 |
+
|
1003 |
+
// TODO(mrry,shlens): Raise an error if the stride would cause
|
1004 |
+
// information in the input to be ignored. This will require a change
|
1005 |
+
// in the kernel implementation.
|
1006 |
+
DimensionHandle output_planes, output_rows, output_cols;
|
1007 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
1008 |
+
c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes));
|
1009 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
1010 |
+
c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
|
1011 |
+
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
|
1012 |
+
c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
|
1013 |
+
|
1014 |
+
ShapeHandle output_shape;
|
1015 |
+
if (data_format == "NCDHW") {
|
1016 |
+
output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
|
1017 |
+
output_planes, output_rows, output_cols});
|
1018 |
+
} else {
|
1019 |
+
output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
|
1020 |
+
output_cols, output_depth_dim});
|
1021 |
+
}
|
1022 |
+
|
1023 |
+
c->set_output(0, output_shape);
|
1024 |
+
return Status::OK();
|
1025 |
+
}
|
1026 |
+
|
1027 |
+
Status UnknownShape(shape_inference::InferenceContext* c) {
|
1028 |
+
for (int i = 0; i < c->num_outputs(); ++i) {
|
1029 |
+
c->set_output(i, c->UnknownShape());
|
1030 |
+
}
|
1031 |
+
return Status::OK();
|
1032 |
+
}
|
1033 |
+
|
1034 |
+
template <typename T>
|
1035 |
+
Status ReductionShapeHelper(const Tensor* reduction_indices_t,
|
1036 |
+
const int32 input_rank,
|
1037 |
+
std::set<int64>& true_indices) {
|
1038 |
+
auto reduction_indices = reduction_indices_t->flat<T>();
|
1039 |
+
for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
|
1040 |
+
const T reduction_index = reduction_indices(i);
|
1041 |
+
if (reduction_index < -input_rank || reduction_index >= input_rank) {
|
1042 |
+
return errors::InvalidArgument("Invalid reduction dimension ",
|
1043 |
+
reduction_index, " for input with ",
|
1044 |
+
input_rank, " dimensions.");
|
1045 |
+
}
|
1046 |
+
|
1047 |
+
auto wrapped_index = reduction_index;
|
1048 |
+
if (wrapped_index < 0) {
|
1049 |
+
wrapped_index += input_rank;
|
1050 |
+
}
|
1051 |
+
|
1052 |
+
true_indices.insert(wrapped_index);
|
1053 |
+
}
|
1054 |
+
return Status::OK();
|
1055 |
+
}
|
1056 |
+
|
1057 |
+
Status ReductionShape(InferenceContext* c) {
|
1058 |
+
ShapeHandle input = c->input(0);
|
1059 |
+
|
1060 |
+
ShapeHandle indices;
|
1061 |
+
// Older versions of TensorFlow accidentally allowed higher rank tensors like
|
1062 |
+
// [[1,2]] or [[1],[2]] to represent axis=[1,2].
|
1063 |
+
if (c->graph_def_version() < 21) {
|
1064 |
+
indices = c->input(1);
|
1065 |
+
} else {
|
1066 |
+
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
|
1067 |
+
}
|
1068 |
+
|
1069 |
+
bool keep_dims;
|
1070 |
+
TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
|
1071 |
+
|
1072 |
+
const Tensor* reduction_indices_t = c->input_tensor(1);
|
1073 |
+
if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
|
1074 |
+
// If we do not have the reduction values at runtime, or the
|
1075 |
+
// rank of the input, we don't know the output shape.
|
1076 |
+
|
1077 |
+
if (keep_dims && c->RankKnown(input)) {
|
1078 |
+
// output rank matches input input if <keep_dims>.
|
1079 |
+
c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
|
1080 |
+
return Status::OK();
|
1081 |
+
} else {
|
1082 |
+
return shape_inference::UnknownShape(c);
|
1083 |
+
}
|
1084 |
+
}
|
1085 |
+
|
1086 |
+
const int32 input_rank = c->Rank(input);
|
1087 |
+
std::set<int64> true_indices;
|
1088 |
+
if (reduction_indices_t->dtype() == DataType::DT_INT32) {
|
1089 |
+
TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
|
1090 |
+
input_rank, true_indices));
|
1091 |
+
} else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
|
1092 |
+
TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t,
|
1093 |
+
input_rank, true_indices));
|
1094 |
+
} else {
|
1095 |
+
return errors::InvalidArgument(
|
1096 |
+
"reduction_indices can only be int32 or int64");
|
1097 |
+
}
|
1098 |
+
|
1099 |
+
std::vector<DimensionHandle> dims;
|
1100 |
+
for (int i = 0; i < input_rank; ++i) {
|
1101 |
+
if (true_indices.count(i) > 0) {
|
1102 |
+
if (keep_dims) {
|
1103 |
+
dims.emplace_back(c->MakeDim(1));
|
1104 |
+
}
|
1105 |
+
} else {
|
1106 |
+
dims.emplace_back(c->Dim(input, i));
|
1107 |
+
}
|
1108 |
+
}
|
1109 |
+
|
1110 |
+
c->set_output(0, c->MakeShape(dims));
|
1111 |
+
return Status::OK();
|
1112 |
+
}
|
1113 |
+
|
1114 |
+
Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
|
1115 |
+
int end_value_index, int dim_index) {
|
1116 |
+
ShapeHandle unused;
|
1117 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
|
1118 |
+
const Tensor* concat_dim_t = c->input_tensor(dim_index);
|
1119 |
+
if (concat_dim_t == nullptr) {
|
1120 |
+
// Return an unknown shape with same rank as inputs, or an unknown rank
|
1121 |
+
// if no input's rank is known.
|
1122 |
+
|
1123 |
+
// Find rank.
|
1124 |
+
int32 rank = InferenceContext::kUnknownRank;
|
1125 |
+
for (int i = start_value_index; i < end_value_index; ++i) {
|
1126 |
+
if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
|
1127 |
+
if (rank != InferenceContext::kUnknownRank) {
|
1128 |
+
break;
|
1129 |
+
}
|
1130 |
+
}
|
1131 |
+
if (rank == InferenceContext::kUnknownRank) {
|
1132 |
+
c->set_output(0, c->UnknownShape());
|
1133 |
+
return Status::OK();
|
1134 |
+
} else if (rank == 0) {
|
1135 |
+
return errors::InvalidArgument(
|
1136 |
+
"Can't concatenate scalars (use tf.stack instead)");
|
1137 |
+
} else {
|
1138 |
+
for (int i = start_value_index; i < end_value_index; ++i) {
|
1139 |
+
// Check that all the inputs are of the correct rank.
|
1140 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
|
1141 |
+
}
|
1142 |
+
}
|
1143 |
+
// Build result of <rank> different unknown dims.
|
1144 |
+
std::vector<DimensionHandle> dims;
|
1145 |
+
dims.reserve(rank);
|
1146 |
+
for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
|
1147 |
+
c->set_output(0, c->MakeShape(dims));
|
1148 |
+
return Status::OK();
|
1149 |
+
}
|
1150 |
+
|
1151 |
+
// Merge all the non-concat dims, and sum the concat dim to make an output
|
1152 |
+
// shape.
|
1153 |
+
const int32 concat_dim = concat_dim_t->scalar<int32>()();
|
1154 |
+
|
1155 |
+
// Minimum required number of dimensions.
|
1156 |
+
const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
|
1157 |
+
|
1158 |
+
ShapeHandle output_before;
|
1159 |
+
ShapeHandle output_after;
|
1160 |
+
|
1161 |
+
ShapeHandle input = c->input(end_value_index - 1);
|
1162 |
+
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
|
1163 |
+
TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
|
1164 |
+
DimensionHandle output_middle = c->Dim(input, concat_dim);
|
1165 |
+
if (concat_dim == -1) {
|
1166 |
+
output_after = c->Scalar(); // no dimensions.
|
1167 |
+
} else {
|
1168 |
+
TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
|
1169 |
+
}
|
1170 |
+
|
1171 |
+
for (int i = end_value_index - 2; i >= start_value_index; --i) {
|
1172 |
+
ShapeHandle before;
|
1173 |
+
ShapeHandle after;
|
1174 |
+
input = c->input(i);
|
1175 |
+
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
|
1176 |
+
TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
|
1177 |
+
DimensionHandle middle = c->Dim(input, concat_dim);
|
1178 |
+
if (concat_dim == -1) {
|
1179 |
+
after = c->Scalar();
|
1180 |
+
} else {
|
1181 |
+
TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
|
1182 |
+
}
|
1183 |
+
|
1184 |
+
TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
|
1185 |
+
TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
|
1186 |
+
TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
|
1187 |
+
}
|
1188 |
+
|
1189 |
+
ShapeHandle s;
|
1190 |
+
TF_RETURN_IF_ERROR(
|
1191 |
+
c->Concatenate(output_before, c->Vector(output_middle), &s));
|
1192 |
+
TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
|
1193 |
+
c->set_output(0, s);
|
1194 |
+
return Status::OK();
|
1195 |
+
}
|
1196 |
+
|
1197 |
+
Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
|
1198 |
+
return ConcatShapeHelper(c, 1 /* start_value_index */,
|
1199 |
+
1 + num_inputs_to_concat /* end_value_index */,
|
1200 |
+
0 /* dim_index */);
|
1201 |
+
}
|
1202 |
+
|
1203 |
+
Status ConcatV2Shape(InferenceContext* c) {
|
1204 |
+
return ConcatShapeHelper(c, 0 /* start_value_index */,
|
1205 |
+
c->num_inputs() - 1 /* end_value_index */,
|
1206 |
+
c->num_inputs() - 1 /* dim_index */);
|
1207 |
+
}
|
1208 |
+
|
1209 |
+
Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
|
1210 |
+
ShapeHandle shape_x = c->input(0);
|
1211 |
+
ShapeHandle shape_y = c->input(1);
|
1212 |
+
if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
|
1213 |
+
c->set_output(0, c->UnknownShape());
|
1214 |
+
return Status::OK();
|
1215 |
+
}
|
1216 |
+
const int32 rank_x = c->Rank(shape_x);
|
1217 |
+
const int32 rank_y = c->Rank(shape_y);
|
1218 |
+
const int32 rank_out = std::max(rank_x, rank_y);
|
1219 |
+
|
1220 |
+
// To compute the broadcast dimensions, we zip together shape_x and shape_y
|
1221 |
+
// and
|
1222 |
+
// pad with 1 to make them the same length.
|
1223 |
+
std::vector<DimensionHandle> dims;
|
1224 |
+
DimensionHandle dim_one;
|
1225 |
+
if (rank_x != rank_y) dim_one = c->MakeDim(1);
|
1226 |
+
for (int i = 0; i < rank_out; ++i) {
|
1227 |
+
const auto dim_x = i < (rank_out - rank_x)
|
1228 |
+
? dim_one
|
1229 |
+
: c->Dim(shape_x, i - (rank_out - rank_x));
|
1230 |
+
const bool dim_y_is_one = (i < (rank_out - rank_y));
|
1231 |
+
const auto dim_y =
|
1232 |
+
dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y));
|
1233 |
+
if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
|
1234 |
+
// One or both dimensions is unknown.
|
1235 |
+
//
|
1236 |
+
// - If either dimension is greater than 1, we assume that the program is
|
1237 |
+
// correct, and the other dimension will be broadcast to match it.
|
1238 |
+
// TODO(cwhipkey): For shape inference, if we eliminate the shape checks
|
1239 |
+
// in C++ op code, we must still assert that the unknown dim is either 1
|
1240 |
+
// or the same as the known dim.
|
1241 |
+
// - If either dimension is 1, the other dimension is the output.
|
1242 |
+
if (c->Value(dim_x) > 1) {
|
1243 |
+
dims.push_back(dim_x);
|
1244 |
+
} else if (c->Value(dim_y) > 1) {
|
1245 |
+
dims.push_back(dim_y);
|
1246 |
+
} else if (c->Value(dim_x) == 1) {
|
1247 |
+
dims.push_back(dim_y);
|
1248 |
+
} else if (c->Value(dim_y) == 1) {
|
1249 |
+
dims.push_back(dim_x);
|
1250 |
+
} else if (dim_y.SameHandle(dim_x)) {
|
1251 |
+
dims.push_back(dim_x);
|
1252 |
+
} else {
|
1253 |
+
dims.push_back(c->UnknownDim());
|
1254 |
+
}
|
1255 |
+
} else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
|
1256 |
+
if (c->Value(dim_x) == 1 && !dim_y_is_one) {
|
1257 |
+
// We will broadcast dim_x to dim_y.
|
1258 |
+
dims.push_back(dim_y);
|
1259 |
+
} else {
|
1260 |
+
DCHECK_EQ(c->Value(dim_y), 1);
|
1261 |
+
// We will broadcast dim_y to dim_x.
|
1262 |
+
dims.push_back(dim_x);
|
1263 |
+
}
|
1264 |
+
} else {
|
1265 |
+
DimensionHandle dim;
|
1266 |
+
TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
|
1267 |
+
dims.push_back(dim);
|
1268 |
+
}
|
1269 |
+
}
|
1270 |
+
|
1271 |
+
c->set_output(0, c->MakeShape(dims));
|
1272 |
+
return Status::OK();
|
1273 |
+
}
|
1274 |
+
|
1275 |
+
Status RandomShape(shape_inference::InferenceContext* c) {
|
1276 |
+
shape_inference::ShapeHandle out;
|
1277 |
+
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
|
1278 |
+
c->set_output(0, out);
|
1279 |
+
return Status::OK();
|
1280 |
+
}
|
1281 |
+
|
1282 |
+
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
|
1283 |
+
ShapeHandle values_shape, ShapeHandle shape_shape) {
|
1284 |
+
// Validate ranks.
|
1285 |
+
ShapeHandle unused_shape;
|
1286 |
+
TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
|
1287 |
+
TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
|
1288 |
+
TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
|
1289 |
+
|
1290 |
+
// Number of elements in indices and values must match.
|
1291 |
+
DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
|
1292 |
+
if (c->ValueKnown(num_index_elements_dim)) {
|
1293 |
+
DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
|
1294 |
+
if (c->ValueKnown(num_values_elements_dim)) {
|
1295 |
+
int64 num_index_elements = c->Value(num_index_elements_dim);
|
1296 |
+
int64 num_values_elements = c->Value(num_values_elements_dim);
|
1297 |
+
if (num_index_elements != num_values_elements) {
|
1298 |
+
return errors::InvalidArgument("Number of elements in index (",
|
1299 |
+
num_index_elements, ") and values (",
|
1300 |
+
num_values_elements, ") do not match.");
|
1301 |
+
}
|
1302 |
+
}
|
1303 |
+
}
|
1304 |
+
|
1305 |
+
// Rank embedded in indices must match shape.
|
1306 |
+
DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
|
1307 |
+
if (c->ValueKnown(index_rank_dim)) {
|
1308 |
+
DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
|
1309 |
+
if (c->ValueKnown(shape_rank_dim)) {
|
1310 |
+
int64 index_rank = c->Value(index_rank_dim);
|
1311 |
+
int32 shape_rank = c->Value(shape_rank_dim);
|
1312 |
+
if (index_rank != shape_rank) {
|
1313 |
+
return errors::InvalidArgument("Index rank (", index_rank,
|
1314 |
+
") and shape rank (", shape_rank,
|
1315 |
+
") do not match.");
|
1316 |
+
}
|
1317 |
+
}
|
1318 |
+
}
|
1319 |
+
|
1320 |
+
return Status::OK();
|
1321 |
+
}
|
1322 |
+
|
1323 |
+
Status ScatterNdUpdateShape(InferenceContext* c) {
|
1324 |
+
ShapeHandle input_shape = c->input(0);
|
1325 |
+
if (c->input_handle_shapes_and_types(0) != nullptr) {
|
1326 |
+
input_shape = (*c->input_handle_shapes_and_types(0))[0].shape;
|
1327 |
+
}
|
1328 |
+
ShapeHandle indices_shape;
|
1329 |
+
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
|
1330 |
+
ShapeHandle updates_shape;
|
1331 |
+
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
|
1332 |
+
|
1333 |
+
if (c->Value(c->NumElements(input_shape)) == 0 &&
|
1334 |
+
(c->Value(c->NumElements(indices_shape)) > 0 ||
|
1335 |
+
c->Value(c->NumElements(updates_shape)) > 0)) {
|
1336 |
+
return errors::InvalidArgument(
|
1337 |
+
"Indices and updates specified for empty output shape");
|
1338 |
+
}
|
1339 |
+
|
1340 |
+
if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
|
1341 |
+
const int64 num_outer_dims = c->Rank(indices_shape) - 1;
|
1342 |
+
const DimensionHandle index_size = c->Dim(indices_shape, -1);
|
1343 |
+
|
1344 |
+
// We can only do more validation if the last dimension of indices
|
1345 |
+
// is a known value.
|
1346 |
+
if (c->ValueKnown(index_size)) {
|
1347 |
+
const int64 ix = c->Value(index_size);
|
1348 |
+
ShapeHandle unused;
|
1349 |
+
ShapeHandle prefix_indices;
|
1350 |
+
TF_RETURN_IF_ERROR(
|
1351 |
+
c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices));
|
1352 |
+
ShapeHandle prefix_updates;
|
1353 |
+
TF_RETURN_IF_ERROR(
|
1354 |
+
c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates));
|
1355 |
+
|
1356 |
+
Status s = c->Merge(prefix_indices, prefix_updates, &unused);
|
1357 |
+
if (!s.ok()) {
|
1358 |
+
return errors::InvalidArgument(
|
1359 |
+
"The outer ", num_outer_dims, " dimensions of indices.shape=",
|
1360 |
+
c->DebugString(indices_shape), " must match the outer ",
|
1361 |
+
num_outer_dims, " dimensions of updates.shape=",
|
1362 |
+
c->DebugString(updates_shape), ": ", s.error_message());
|
1363 |
+
}
|
1364 |
+
|
1365 |
+
ShapeHandle input_suffix;
|
1366 |
+
TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix));
|
1367 |
+
ShapeHandle suffix_updates;
|
1368 |
+
TF_RETURN_IF_ERROR(
|
1369 |
+
c->Subshape(updates_shape, num_outer_dims, &suffix_updates));
|
1370 |
+
s = c->Merge(input_suffix, suffix_updates, &unused);
|
1371 |
+
if (!s.ok()) {
|
1372 |
+
return errors::InvalidArgument(
|
1373 |
+
"The inner ", c->Rank(input_shape) - ix,
|
1374 |
+
" dimensions of input.shape=", c->DebugString(input_shape),
|
1375 |
+
" must match the inner ", c->Rank(updates_shape) - num_outer_dims,
|
1376 |
+
" dimensions of updates.shape=", c->DebugString(updates_shape),
|
1377 |
+
": ", s.error_message());
|
1378 |
+
}
|
1379 |
+
}
|
1380 |
+
}
|
1381 |
+
|
1382 |
+
if (c->input_handle_shapes_and_types(0) == nullptr) {
|
1383 |
+
c->set_output(0, input_shape);
|
1384 |
+
}
|
1385 |
+
return Status::OK();
|
1386 |
+
}
|
1387 |
+
|
1388 |
+
Status ExplicitShape(InferenceContext* c) {
|
1389 |
+
PartialTensorShape shape;
|
1390 |
+
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
|
1391 |
+
ShapeHandle output_shape;
|
1392 |
+
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape));
|
1393 |
+
c->set_output(0, output_shape);
|
1394 |
+
return Status::OK();
|
1395 |
+
}
|
1396 |
+
|
1397 |
+
} // namespace shape_inference
|
1398 |
+
|
1399 |
+
} // namespace tensorflow
|
common_shape_fns.h
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
#ifndef THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
|
16 |
+
#define THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
|
17 |
+
|
18 |
+
#include <array>
|
19 |
+
|
20 |
+
#include "tensorflow/core/framework/shape_inference.h"
|
21 |
+
#include "tensorflow/core/util/padding.h"
|
22 |
+
#include "tensorflow/core/util/tensor_format.h"
|
23 |
+
|
24 |
+
namespace tensorflow {
|
25 |
+
|
26 |
+
// GetWindowedOutputSize(): Given an input tensor, kernel, stride and padding
|
27 |
+
// type, the function computes the output and padding dimensions.
|
28 |
+
//
|
29 |
+
// For example, ignoring batches or multiple features, a 1D convolution
|
30 |
+
// takes as input a 1D tensor of shape (H), and convolves it with a filter of
|
31 |
+
// shape (K).
|
32 |
+
//
|
33 |
+
// It also takes in a few additional parameters:
|
34 |
+
//
|
35 |
+
// Stride (S): the stride with which we apply the filters. This is the offset
|
36 |
+
// between locations where we apply the filters. A larger stride
|
37 |
+
// means that the output will be spatially smaller.
|
38 |
+
//
|
39 |
+
// Padding (P): the padding we apply to the input tensor along each
|
40 |
+
// dimension. This is usually used to make sure that the spatial dimensions
|
41 |
+
// do not shrink when we progress with convolutions. Two types of padding are
|
42 |
+
// often used:
|
43 |
+
// SAME: the pad value is computed so that the output will have size H/S.
|
44 |
+
// VALID: no padding is carried out.
|
45 |
+
// The padded area is zero-filled.
|
46 |
+
//
|
47 |
+
// The output dimensions for convolution and many other operations, when given
|
48 |
+
// all the parameters above, are as follows:
|
49 |
+
// - When Padding = SAME: the output size is (H'), where
|
50 |
+
// H' = ceil(float(H) / float(S))
|
51 |
+
// where ceil is the ceiling function. The number of padded cells
|
52 |
+
// is computed as:
|
53 |
+
// Pc = ((H' - 1) * S + K - H) / 2
|
54 |
+
// When the stride is 1, the expression simplifies to
|
55 |
+
// H' = H, Pc = (K-1)/2.
|
56 |
+
// This is where SAME comes from - the output has the same size as the input
|
57 |
+
// has.
|
58 |
+
//
|
59 |
+
// - When Padding = VALID: the output size is computed as
|
60 |
+
// H' = ceil(float(H - K + 1) / float(S))
|
61 |
+
// and the number of padded cells is always zero.
|
62 |
+
// When the stride is 1, the expression simplifies to
|
63 |
+
// H' = H-K+1.
|
64 |
+
//
|
65 |
+
// For convolution, mathematically, the output value at location (r')
|
66 |
+
// is the inner product of two vectors: the chunk of input at
|
67 |
+
// ((r'*S-Pr) : (r'*S-Pr+K)),
|
68 |
+
// and the filter.
|
69 |
+
//
|
70 |
+
// For 2D and 3D convolutions, the spatial dimensions are orthogonal, so the
|
71 |
+
// size and padding of each spatial dimension can be computed by calling
|
72 |
+
// GetWindowedOutputSize separately for each dimension.
|
73 |
+
//
|
74 |
+
Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
|
75 |
+
Padding padding_type, int64* output_size,
|
76 |
+
int64* padding_size);
|
77 |
+
|
78 |
+
// The V2 version computes the same outputs with arbitrary dilation_rate.
|
79 |
+
// The output dimensions are computed as follows:
|
80 |
+
// - When adding dilation_rate (D), we compute an effective filter size (K'):
|
81 |
+
// K' = (K - 1) * D + 1
|
82 |
+
// - When Padding = SAME: the output size is (H'), where
|
83 |
+
// H' = ceil(float(H) / float(S))
|
84 |
+
// where ceil is the ceiling function. The number of padded cells
|
85 |
+
// is computed as:
|
86 |
+
// Pc = ((H' - 1) * S + K' - H) / 2
|
87 |
+
// When the stride is 1, the expression simplifies to
|
88 |
+
// H' = H, Pc = (K'-1)/2.
|
89 |
+
// This is where SAME comes from - the output has the same size as the input
|
90 |
+
// has.
|
91 |
+
//
|
92 |
+
// - When Padding = VALID: the output size is computed as
|
93 |
+
// H' = ceil(float(H - K' + 1) / float(S))
|
94 |
+
// and the number of padded cells is always zero.
|
95 |
+
// When the stride is 1, the expression simplifies to
|
96 |
+
// H' = H-K'+1.
|
97 |
+
//
|
98 |
+
// TODO(b/67112639): Merge V2 versions and the original versions eventually.
|
99 |
+
Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
|
100 |
+
int64 dilation_rate, int64 stride,
|
101 |
+
Padding padding_type, int64* output_size,
|
102 |
+
int64* padding_size);
|
103 |
+
|
104 |
+
// Returns the same output dimensions as in GetWindowedOutputSize, but returns
|
105 |
+
// verbose padding dimensions (before/after). Any excess padding
|
106 |
+
// (caused by an odd padding size value) is added to the 'padding_after'
|
107 |
+
// dimension.
|
108 |
+
Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
|
109 |
+
int64 stride, Padding padding_type,
|
110 |
+
int64* output_size, int64* padding_before,
|
111 |
+
int64* padding_after);
|
112 |
+
|
113 |
+
// The V2 version computes the same outputs with arbitrary dilation_rate. For
|
114 |
+
// detailed equations, refer to the comments for GetWindowedOutputSizeV2().
|
115 |
+
Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
|
116 |
+
int64 dilation_rate, int64 stride,
|
117 |
+
Padding padding_type, int64* output_size,
|
118 |
+
int64* padding_before,
|
119 |
+
int64* padding_after);
|
120 |
+
|
121 |
+
// Given an input tensor, kernel, stride and padding type, populates the 3D size
|
122 |
+
// of the output tensor and padding to be applied to the input tensor at the
|
123 |
+
// lower end of every dimension. Use for 3D convolutions, where the input data
|
124 |
+
// is padded with zeros, as well as for 3D avg/max pooling, where the input data
|
125 |
+
// is padded with invalid values that are not considered for pooling.
|
126 |
+
Status Get3dOutputSize(const std::array<int64, 3>& input,
|
127 |
+
const std::array<int64, 3>& window,
|
128 |
+
const std::array<int64, 3>& strides,
|
129 |
+
Padding padding_type, std::array<int64, 3>* output_ptr,
|
130 |
+
std::array<int64, 3>* padding_ptr);
|
131 |
+
|
132 |
+
// The V2 version computes the same outputs with arbitrary dilation_rate. For
|
133 |
+
// detailed equations, refer to the comments for GetWindowedOutputSizeV2().
|
134 |
+
Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
|
135 |
+
const std::array<int64, 3>& window,
|
136 |
+
const std::array<int64, 3>& dilations,
|
137 |
+
const std::array<int64, 3>& strides,
|
138 |
+
Padding padding_type, std::array<int64, 3>* output_ptr,
|
139 |
+
std::array<int64, 3>* padding_ptr);
|
140 |
+
|
141 |
+
namespace shape_inference {
|
142 |
+
|
143 |
+
// Like GetWindowedOutputSize, but deals with DimensionHandles.
|
144 |
+
Status GetWindowedOutputSizeFromDims(InferenceContext* c,
|
145 |
+
DimensionHandle input_size,
|
146 |
+
DimensionOrConstant filter_size,
|
147 |
+
int64 stride, Padding padding_type,
|
148 |
+
DimensionHandle* output_size);
|
149 |
+
|
150 |
+
// The V2 version computes the same outputs with arbitrary dilation_rate. For
|
151 |
+
// detailed equations, refer to the comments for GetWindowedOutputSizeV2().
|
152 |
+
Status GetWindowedOutputSizeFromDimsV2(InferenceContext* c,
|
153 |
+
DimensionHandle input_size,
|
154 |
+
DimensionOrConstant filter_size,
|
155 |
+
int64 dilation_rate, int64 stride,
|
156 |
+
Padding padding_type,
|
157 |
+
DimensionHandle* output_size);
|
158 |
+
|
159 |
+
// Transfers shape of input(0) to output(0).
|
160 |
+
Status UnchangedShape(shape_inference::InferenceContext* c);
|
161 |
+
|
162 |
+
// Transfers shape of input(0) to output(0), after asserting its rank is <rank>.
|
163 |
+
inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c,
|
164 |
+
int32 rank) {
|
165 |
+
ShapeHandle out;
|
166 |
+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out));
|
167 |
+
c->set_output(0, out);
|
168 |
+
return Status::OK();
|
169 |
+
}
|
170 |
+
|
171 |
+
// Transfers shape of input(0) to output(0), after asserting its rank >= <rank>.
|
172 |
+
inline Status UnchangedShapeWithRankAtLeast(
|
173 |
+
shape_inference::InferenceContext* c, int32 rank) {
|
174 |
+
ShapeHandle out;
|
175 |
+
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out));
|
176 |
+
c->set_output(0, out);
|
177 |
+
return Status::OK();
|
178 |
+
}
|
179 |
+
|
180 |
+
// Transfers shape of input(0) to output(0), after asserting its rank <= <rank>.
|
181 |
+
inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c,
|
182 |
+
int32 rank) {
|
183 |
+
ShapeHandle out;
|
184 |
+
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out));
|
185 |
+
c->set_output(0, out);
|
186 |
+
return Status::OK();
|
187 |
+
}
|
188 |
+
|
189 |
+
// Shape function for use with ops no outputs.
|
190 |
+
inline Status NoOutputs(shape_inference::InferenceContext* c) {
|
191 |
+
return Status::OK();
|
192 |
+
}
|
193 |
+
|
194 |
+
// Shape function for ops that output a single scalar value.
|
195 |
+
inline Status ScalarShape(shape_inference::InferenceContext* c) {
|
196 |
+
c->set_output(0, c->Scalar());
|
197 |
+
return Status::OK();
|
198 |
+
}
|
199 |
+
|
200 |
+
// Shape function for binary ops where both inputs and the output match.
|
201 |
+
inline Status MergeBothInputsShapeFn(InferenceContext* c) {
|
202 |
+
ShapeHandle out;
|
203 |
+
TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out));
|
204 |
+
c->set_output(0, out);
|
205 |
+
return Status::OK();
|
206 |
+
}
|
207 |
+
|
208 |
+
// Returns a new shape with the specified dims arranged in the specified
|
209 |
+
// format. The returned value is owned by this context.
|
210 |
+
// Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth.
|
211 |
+
Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
|
212 |
+
const std::vector<DimensionOrConstant>& spatial,
|
213 |
+
DimensionOrConstant C, ShapeHandle* out,
|
214 |
+
shape_inference::InferenceContext* context);
|
215 |
+
|
216 |
+
// Shape function for MatMul-like operations.
|
217 |
+
Status MatMulShape(shape_inference::InferenceContext* c);
|
218 |
+
|
219 |
+
// Shape function for BiasAdd-like operations.
|
220 |
+
Status BiasAddShape(shape_inference::InferenceContext* c);
|
221 |
+
|
222 |
+
// Shape function for BiasAddGrad-like operations.
|
223 |
+
Status BiasAddGradShape(shape_inference::InferenceContext* c);
|
224 |
+
|
225 |
+
// Shape function for Conv2D-like operations.
|
226 |
+
Status Conv2DShape(shape_inference::InferenceContext* c);
|
227 |
+
|
228 |
+
// Shape function for Conv3D-like operations.
|
229 |
+
Status Conv3DShape(shape_inference::InferenceContext* c);
|
230 |
+
|
231 |
+
// Shape function for DepthwiseConv2D-like operations.
|
232 |
+
Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
|
233 |
+
|
234 |
+
// Shape function for AvgPool-like operations.
|
235 |
+
Status AvgPoolShape(shape_inference::InferenceContext* c);
|
236 |
+
|
237 |
+
// Shape function for FusedBatchNorm and FusedBatchNormV2 operations.
|
238 |
+
Status FusedBatchNormShape(shape_inference::InferenceContext* c);
|
239 |
+
|
240 |
+
// Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations.
|
241 |
+
Status FusedBatchNormGradShape(shape_inference::InferenceContext* c);
|
242 |
+
|
243 |
+
// Shape function for MaxPool-like operations.
|
244 |
+
Status MaxPoolShape(shape_inference::InferenceContext* c);
|
245 |
+
|
246 |
+
// Shape function for MaxPoolV2-like operations.
|
247 |
+
Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
|
248 |
+
|
249 |
+
// Shape function for 3D Pooling operations.
|
250 |
+
Status Pool3DShape(shape_inference::InferenceContext* c);
|
251 |
+
|
252 |
+
// Shape function for use with ops whose output shapes are unknown.
|
253 |
+
Status UnknownShape(shape_inference::InferenceContext* c);
|
254 |
+
|
255 |
+
// Shape function for reduction operations.
|
256 |
+
Status ReductionShape(shape_inference::InferenceContext* c);
|
257 |
+
|
258 |
+
// Shape function for concat operations.
|
259 |
+
// <num_inputs_to_concat> is the number of inputs to concatenate and are taken
|
260 |
+
// from inputs
|
261 |
+
// [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input.
|
262 |
+
Status ConcatShape(shape_inference::InferenceContext* c,
|
263 |
+
int num_inputs_to_concat);
|
264 |
+
|
265 |
+
// Shape function for concat operations.
|
266 |
+
Status ConcatV2Shape(shape_inference::InferenceContext* c);
|
267 |
+
|
268 |
+
// Shape function for binary operators that broadcast their inputs.
|
269 |
+
// Tested by ops/math_ops_test.cc.
|
270 |
+
Status BroadcastBinaryOpShapeFn(InferenceContext* c);
|
271 |
+
|
272 |
+
// Shape function for random operations.
|
273 |
+
Status RandomShape(shape_inference::InferenceContext* c);
|
274 |
+
|
275 |
+
// Validates the 3 component tensors of a sparse tensor have the proper
|
276 |
+
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
|
277 |
+
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
|
278 |
+
ShapeHandle values_shape, ShapeHandle shape_shape);
|
279 |
+
|
280 |
+
// Shape function for ScatterNd update/add/sub/... operations.
|
281 |
+
Status ScatterNdUpdateShape(InferenceContext* c);
|
282 |
+
|
283 |
+
// Shape function for ops with an explicit "shape" attribute.
|
284 |
+
Status ExplicitShape(InferenceContext* c);
|
285 |
+
|
286 |
+
} // namespace shape_inference
|
287 |
+
|
288 |
+
} // namespace tensorflow
|
289 |
+
|
290 |
+
#endif // THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
|
common_shape_fns_test.cc
ADDED
@@ -0,0 +1,1131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
#include "tensorflow/core/framework/common_shape_fns.h"
|
16 |
+
|
17 |
+
#include "tensorflow/core/framework/fake_input.h"
|
18 |
+
#include "tensorflow/core/framework/node_def_builder.h"
|
19 |
+
#include "tensorflow/core/framework/op_def_builder.h"
|
20 |
+
#include "tensorflow/core/framework/shape_inference_testutil.h"
|
21 |
+
#include "tensorflow/core/framework/tensor_testutil.h"
|
22 |
+
#include "tensorflow/core/lib/core/status_test_util.h"
|
23 |
+
#include "tensorflow/core/platform/test.h"
|
24 |
+
|
25 |
+
namespace tensorflow {
|
26 |
+
namespace shape_inference {
|
27 |
+
|
28 |
+
namespace {
|
29 |
+
|
30 |
+
PartialTensorShape S(std::initializer_list<int64> dims) {
|
31 |
+
return PartialTensorShape(dims);
|
32 |
+
}
|
33 |
+
|
34 |
+
PartialTensorShape Unknown() { return PartialTensorShape(); }
|
35 |
+
|
36 |
+
OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
37 |
+
OpRegistrationData op_reg_data;
|
38 |
+
OpDefBuilder b("dummy");
|
39 |
+
for (int i = 0; i < num_inputs; ++i) {
|
40 |
+
b.Input(strings::StrCat("i", i, ": float"));
|
41 |
+
}
|
42 |
+
for (int i = 0; i < num_outputs; ++i) {
|
43 |
+
b.Output(strings::StrCat("o", i, ": float"));
|
44 |
+
}
|
45 |
+
CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
|
46 |
+
return op_reg_data.op_def;
|
47 |
+
}
|
48 |
+
|
49 |
+
} // namespace
|
50 |
+
|
51 |
+
TEST(CommonShapeFnsTest, NoOutputShapeTest) {
|
52 |
+
OpRegistrationData op_reg_data;
|
53 |
+
TF_CHECK_OK(OpDefBuilder("Assert")
|
54 |
+
.Input("condition: bool")
|
55 |
+
.Input("data: float")
|
56 |
+
.Finalize(&op_reg_data));
|
57 |
+
OpDef op_def = op_reg_data.op_def;
|
58 |
+
|
59 |
+
NodeDef def;
|
60 |
+
TF_CHECK_OK(NodeDefBuilder("test", "Assert")
|
61 |
+
.Input("condition", 0, DT_BOOL)
|
62 |
+
.Input({{"data", 0, DT_FLOAT}})
|
63 |
+
.Finalize(&def));
|
64 |
+
|
65 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({}), S({10})}, {},
|
66 |
+
{}, {});
|
67 |
+
TF_EXPECT_OK(NoOutputs(&c));
|
68 |
+
EXPECT_EQ(0, c.num_outputs());
|
69 |
+
}
|
70 |
+
|
71 |
+
TEST(CommonShapeFnsTest, ScalarShapeTest) {
|
72 |
+
OpRegistrationData op_reg_data;
|
73 |
+
TF_CHECK_OK(OpDefBuilder("L2Loss")
|
74 |
+
.Input("t: float")
|
75 |
+
.Output("t: float")
|
76 |
+
.Finalize(&op_reg_data));
|
77 |
+
OpDef op_def = op_reg_data.op_def;
|
78 |
+
|
79 |
+
NodeDef def;
|
80 |
+
TF_CHECK_OK(
|
81 |
+
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
|
82 |
+
|
83 |
+
{
|
84 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {});
|
85 |
+
TF_EXPECT_OK(ScalarShape(&c));
|
86 |
+
ShapeHandle output = c.output(0);
|
87 |
+
EXPECT_EQ(0, c.Rank(output));
|
88 |
+
}
|
89 |
+
|
90 |
+
{
|
91 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
92 |
+
{S({1, 23, 4, 4, 2})}, {}, {}, {});
|
93 |
+
TF_EXPECT_OK(ScalarShape(&c));
|
94 |
+
ShapeHandle output = c.output(0);
|
95 |
+
EXPECT_EQ(0, c.Rank(output));
|
96 |
+
}
|
97 |
+
}
|
98 |
+
|
99 |
+
TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
100 |
+
OpRegistrationData op_reg_data;
|
101 |
+
TF_CHECK_OK(OpDefBuilder("MatMul")
|
102 |
+
.Input("a: float")
|
103 |
+
.Input("b: float")
|
104 |
+
.Output("c: float")
|
105 |
+
.Attr("transpose_a:bool=false")
|
106 |
+
.Attr("transpose_b:bool=false")
|
107 |
+
.Finalize(&op_reg_data));
|
108 |
+
OpDef op_def = op_reg_data.op_def;
|
109 |
+
|
110 |
+
NodeDef def;
|
111 |
+
TF_CHECK_OK(NodeDefBuilder("test", "MatMul")
|
112 |
+
.Input("a", 0, DT_FLOAT)
|
113 |
+
.Input("b", 0, DT_FLOAT)
|
114 |
+
.Attr("transpose_a", false)
|
115 |
+
.Attr("transpose_b", false)
|
116 |
+
.Finalize(&def));
|
117 |
+
|
118 |
+
{
|
119 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
120 |
+
{S({2, 3}), S({3, 4})}, {}, {}, {});
|
121 |
+
TF_EXPECT_OK(MatMulShape(&c));
|
122 |
+
ShapeHandle output = c.output(0);
|
123 |
+
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
124 |
+
EXPECT_EQ(4, c.Value(c.Dim(output, 1)));
|
125 |
+
}
|
126 |
+
|
127 |
+
{
|
128 |
+
// Unknown inner dimension for one
|
129 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
130 |
+
{S({2, -1}), S({3, 4})}, {}, {}, {});
|
131 |
+
TF_EXPECT_OK(MatMulShape(&c));
|
132 |
+
ShapeHandle output = c.output(0);
|
133 |
+
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
134 |
+
EXPECT_EQ(4, c.Value(c.Dim(output, 1)));
|
135 |
+
}
|
136 |
+
|
137 |
+
{
|
138 |
+
// Invalid rank.
|
139 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2}), S({3, 4})},
|
140 |
+
{}, {}, {});
|
141 |
+
auto s = MatMulShape(&c);
|
142 |
+
EXPECT_FALSE(s.ok());
|
143 |
+
EXPECT_TRUE(
|
144 |
+
StringPiece(s.ToString())
|
145 |
+
.contains("Invalid argument: Shape must be rank 2 but is rank 1"));
|
146 |
+
}
|
147 |
+
|
148 |
+
{
|
149 |
+
// Unknown outer dimension
|
150 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
151 |
+
{S({2, 3}), S({3, -1})}, {}, {}, {});
|
152 |
+
TF_EXPECT_OK(MatMulShape(&c));
|
153 |
+
ShapeHandle output = c.output(0);
|
154 |
+
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
155 |
+
EXPECT_FALSE(c.ValueKnown(c.Dim(output, 1)));
|
156 |
+
}
|
157 |
+
|
158 |
+
{
|
159 |
+
// Inner shapes not compatible
|
160 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
161 |
+
{S({2, 5}), S({3, 4})}, {}, {}, {});
|
162 |
+
auto s = MatMulShape(&c);
|
163 |
+
EXPECT_FALSE(s.ok());
|
164 |
+
EXPECT_TRUE(
|
165 |
+
StringPiece(s.ToString())
|
166 |
+
.contains(
|
167 |
+
"Invalid argument: Dimensions must be equal, but are 5 and 3"));
|
168 |
+
}
|
169 |
+
|
170 |
+
{
|
171 |
+
// Inner shapes not compatible
|
172 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
173 |
+
{S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {});
|
174 |
+
auto s = MatMulShape(&c);
|
175 |
+
EXPECT_FALSE(s.ok());
|
176 |
+
EXPECT_TRUE(
|
177 |
+
StringPiece(s.ToString())
|
178 |
+
.contains("Invalid argument: Shape must be rank 2 but is rank 3"));
|
179 |
+
}
|
180 |
+
|
181 |
+
{
|
182 |
+
// transpose_a
|
183 |
+
TF_CHECK_OK(NodeDefBuilder("test", "MatMul")
|
184 |
+
.Input("a", 0, DT_FLOAT)
|
185 |
+
.Input("b", 0, DT_FLOAT)
|
186 |
+
.Attr("transpose_a", true)
|
187 |
+
.Attr("transpose_b", false)
|
188 |
+
.Attr("type", DT_FLOAT)
|
189 |
+
.Finalize(&def));
|
190 |
+
|
191 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
192 |
+
{S({3, 2}), S({3, 4})}, {}, {}, {});
|
193 |
+
auto s = MatMulShape(&c);
|
194 |
+
ShapeHandle output = c.output(0);
|
195 |
+
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
196 |
+
EXPECT_EQ(4, c.Value(c.Dim(output, 1)));
|
197 |
+
}
|
198 |
+
|
199 |
+
{
|
200 |
+
// transpose_b
|
201 |
+
TF_CHECK_OK(NodeDefBuilder("test", "MatMul")
|
202 |
+
.Input("a", 0, DT_FLOAT)
|
203 |
+
.Input("b", 0, DT_FLOAT)
|
204 |
+
.Attr("transpose_a", false)
|
205 |
+
.Attr("transpose_b", true)
|
206 |
+
.Attr("type", DT_FLOAT)
|
207 |
+
.Finalize(&def));
|
208 |
+
|
209 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
210 |
+
{S({2, 3}), S({4, 3})}, {}, {}, {});
|
211 |
+
auto s = MatMulShape(&c);
|
212 |
+
ShapeHandle output = c.output(0);
|
213 |
+
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
214 |
+
EXPECT_EQ(4, c.Value(c.Dim(output, 1)));
|
215 |
+
}
|
216 |
+
}
|
217 |
+
|
218 |
+
TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
219 |
+
OpRegistrationData op_reg_data;
|
220 |
+
TF_CHECK_OK(OpDefBuilder("BiasAdd")
|
221 |
+
.Input("a: float")
|
222 |
+
.Input("b: float")
|
223 |
+
.Output("c: float")
|
224 |
+
.Finalize(&op_reg_data));
|
225 |
+
|
226 |
+
OpDef op_def = op_reg_data.op_def;
|
227 |
+
NodeDef def;
|
228 |
+
TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
|
229 |
+
.Input("a", 0, DT_FLOAT)
|
230 |
+
.Input("b", 0, DT_FLOAT)
|
231 |
+
.Finalize(&def));
|
232 |
+
|
233 |
+
{
|
234 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
235 |
+
{S({2, 10}), S({10})}, {}, {}, {});
|
236 |
+
TF_EXPECT_OK(BiasAddShape(&c));
|
237 |
+
ShapeHandle output = c.output(0);
|
238 |
+
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
239 |
+
EXPECT_EQ(10, c.Value(c.Dim(output, 1)));
|
240 |
+
}
|
241 |
+
|
242 |
+
{
|
243 |
+
// Unknown ranks.
|
244 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
245 |
+
{Unknown(), Unknown()}, {}, {}, {});
|
246 |
+
TF_EXPECT_OK(BiasAddShape(&c));
|
247 |
+
ShapeHandle output = c.output(0);
|
248 |
+
EXPECT_FALSE(c.RankKnown(output));
|
249 |
+
}
|
250 |
+
|
251 |
+
{
|
252 |
+
// Rank > 2
|
253 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
254 |
+
{S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {});
|
255 |
+
TF_EXPECT_OK(BiasAddShape(&c));
|
256 |
+
ShapeHandle output = c.output(0);
|
257 |
+
EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
|
258 |
+
}
|
259 |
+
|
260 |
+
{
|
261 |
+
// NCHW format
|
262 |
+
TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
|
263 |
+
.Input("a", 0, DT_FLOAT)
|
264 |
+
.Input("b", 0, DT_FLOAT)
|
265 |
+
.Attr("data_format", "NCHW")
|
266 |
+
.Finalize(&def));
|
267 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
268 |
+
{S({2, 3, 4, 5}), S({3})}, {}, {}, {});
|
269 |
+
TF_EXPECT_OK(BiasAddShape(&c));
|
270 |
+
ShapeHandle output = c.output(0);
|
271 |
+
EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
|
272 |
+
}
|
273 |
+
|
274 |
+
{
|
275 |
+
// NCHW format with high input rank
|
276 |
+
TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
|
277 |
+
.Input("a", 0, DT_FLOAT)
|
278 |
+
.Input("b", 0, DT_FLOAT)
|
279 |
+
.Attr("data_format", "NCHW")
|
280 |
+
.Finalize(&def));
|
281 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
282 |
+
{S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {});
|
283 |
+
TF_EXPECT_OK(BiasAddShape(&c));
|
284 |
+
ShapeHandle output = c.output(0);
|
285 |
+
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
|
286 |
+
}
|
287 |
+
|
288 |
+
{
|
289 |
+
// NCHW format with input rank 3
|
290 |
+
TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
|
291 |
+
.Input("a", 0, DT_FLOAT)
|
292 |
+
.Input("b", 0, DT_FLOAT)
|
293 |
+
.Attr("data_format", "NCHW")
|
294 |
+
.Finalize(&def));
|
295 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
296 |
+
{S({10, 11, 12}), S({10})}, {}, {}, {});
|
297 |
+
TF_EXPECT_OK(BiasAddShape(&c));
|
298 |
+
ShapeHandle output = c.output(0);
|
299 |
+
EXPECT_EQ("[10,11,12]", c.DebugString(output));
|
300 |
+
}
|
301 |
+
|
302 |
+
{
|
303 |
+
// Input rank not high enough
|
304 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3}), S({3})}, {},
|
305 |
+
{}, {});
|
306 |
+
EXPECT_FALSE(BiasAddShape(&c).ok());
|
307 |
+
}
|
308 |
+
|
309 |
+
{
|
310 |
+
// NCHW rank not high enough
|
311 |
+
TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
|
312 |
+
.Input("a", 0, DT_FLOAT)
|
313 |
+
.Input("b", 0, DT_FLOAT)
|
314 |
+
.Attr("data_format", "NCHW")
|
315 |
+
.Finalize(&def));
|
316 |
+
// NCHW format
|
317 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3}), S({3})},
|
318 |
+
{}, {}, {});
|
319 |
+
EXPECT_FALSE(BiasAddShape(&c).ok());
|
320 |
+
}
|
321 |
+
}
|
322 |
+
|
323 |
+
TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
324 |
+
OpRegistrationData op_reg_data;
|
325 |
+
TF_CHECK_OK(OpDefBuilder("BiasAddGrad")
|
326 |
+
.Input("a: float")
|
327 |
+
.Output("b: float")
|
328 |
+
.Finalize(&op_reg_data));
|
329 |
+
|
330 |
+
OpDef op_def = op_reg_data.op_def;
|
331 |
+
NodeDef def;
|
332 |
+
TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
|
333 |
+
.Input("a", 0, DT_FLOAT)
|
334 |
+
.Finalize(&def));
|
335 |
+
|
336 |
+
{
|
337 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 10})}, {}, {},
|
338 |
+
{});
|
339 |
+
TF_EXPECT_OK(BiasAddGradShape(&c));
|
340 |
+
ShapeHandle output = c.output(0);
|
341 |
+
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
342 |
+
}
|
343 |
+
|
344 |
+
{
|
345 |
+
// Rank > 2
|
346 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({5, 7, 2, 10})},
|
347 |
+
{}, {}, {});
|
348 |
+
TF_EXPECT_OK(BiasAddGradShape(&c));
|
349 |
+
ShapeHandle output = c.output(0);
|
350 |
+
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
351 |
+
}
|
352 |
+
|
353 |
+
{
|
354 |
+
// NCHW format
|
355 |
+
TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
|
356 |
+
.Input("a", 0, DT_FLOAT)
|
357 |
+
.Attr("data_format", "NCHW")
|
358 |
+
.Finalize(&def));
|
359 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3, 4, 5})},
|
360 |
+
{}, {}, {});
|
361 |
+
TF_EXPECT_OK(BiasAddGradShape(&c));
|
362 |
+
ShapeHandle output = c.output(0);
|
363 |
+
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
|
364 |
+
}
|
365 |
+
|
366 |
+
{
|
367 |
+
// NCHW format with high input rank
|
368 |
+
TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
|
369 |
+
.Input("a", 0, DT_FLOAT)
|
370 |
+
.Attr("data_format", "NCHW")
|
371 |
+
.Finalize(&def));
|
372 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
373 |
+
{S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {});
|
374 |
+
TF_EXPECT_OK(BiasAddGradShape(&c));
|
375 |
+
ShapeHandle output = c.output(0);
|
376 |
+
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
|
377 |
+
}
|
378 |
+
|
379 |
+
{
|
380 |
+
// NCHW format with input rank 3
|
381 |
+
TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
|
382 |
+
.Input("a", 0, DT_FLOAT)
|
383 |
+
.Attr("data_format", "NCHW")
|
384 |
+
.Finalize(&def));
|
385 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({10, 11, 12})},
|
386 |
+
{}, {}, {});
|
387 |
+
TF_EXPECT_OK(BiasAddGradShape(&c));
|
388 |
+
ShapeHandle output = c.output(0);
|
389 |
+
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
390 |
+
}
|
391 |
+
|
392 |
+
{
|
393 |
+
// Input rank not high enough
|
394 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {},
|
395 |
+
{});
|
396 |
+
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
397 |
+
}
|
398 |
+
|
399 |
+
{
|
400 |
+
// NCHW rank not high enough
|
401 |
+
TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
|
402 |
+
.Input("a", 0, DT_FLOAT)
|
403 |
+
.Attr("data_format", "NCHW")
|
404 |
+
.Finalize(&def));
|
405 |
+
// NCHW format
|
406 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3})}, {}, {},
|
407 |
+
{});
|
408 |
+
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
409 |
+
}
|
410 |
+
}
|
411 |
+
|
412 |
+
TEST(CommonShapeFnsTest, Conv2DShapeTest) {
|
413 |
+
ShapeInferenceTestOp op("Conv2D");
|
414 |
+
auto set_op = [&op](const std::vector<int32>& strides, const string& padding,
|
415 |
+
const string& data_format, const string& filter_format) {
|
416 |
+
TF_CHECK_OK(NodeDefBuilder("test", "Conv2D")
|
417 |
+
.Input("input", 0, DT_FLOAT)
|
418 |
+
.Input("filter", 0, DT_FLOAT)
|
419 |
+
.Attr("strides", strides)
|
420 |
+
.Attr("padding", padding)
|
421 |
+
.Attr("data_format", data_format)
|
422 |
+
.Attr("filter_format", filter_format)
|
423 |
+
.Finalize(&op.node_def));
|
424 |
+
};
|
425 |
+
|
426 |
+
// Invalid rank for input
|
427 |
+
INFER_ERROR("must be rank 4", op, "[4,4];[2,1,1,1]");
|
428 |
+
// Invalid rank for filter
|
429 |
+
INFER_ERROR("must be rank 4", op, "[1,4,4,1];[2,1,1]");
|
430 |
+
|
431 |
+
// Invalid value for strides
|
432 |
+
set_op({{1, 1, 0, 1}}, "VALID", "NHWC", "HWIO");
|
433 |
+
INFER_ERROR("must be > 0", op, "[1,2,2,1];[1,1,1,1]");
|
434 |
+
|
435 |
+
// 1x1 filter
|
436 |
+
set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO");
|
437 |
+
INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
|
438 |
+
|
439 |
+
// 2x2 filter
|
440 |
+
set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO");
|
441 |
+
INFER_OK(op, "[1,2,2,1];[2,2,1,1]", "[d0_0,1,1,d1_3]");
|
442 |
+
|
443 |
+
// 3x3 input, 1x1 filter, 2x2 stride
|
444 |
+
set_op({{1, 2, 2, 1}}, "VALID", "NHWC", "HWIO");
|
445 |
+
INFER_OK(op, "[1,3,3,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
|
446 |
+
|
447 |
+
// 3x3 input, 1x1 filter, 2x1 stride
|
448 |
+
set_op({{1, 2, 1, 1}}, "VALID", "NHWC", "HWIO");
|
449 |
+
INFER_OK(op, "[1,3,3,1];[1,1,1,1]", "[d0_0,2,3,d1_3]");
|
450 |
+
|
451 |
+
// 4x4 input, 2x1 filter, 1x2 stride
|
452 |
+
set_op({{1, 1, 2, 1}}, "VALID", "NHWC", "HWIO");
|
453 |
+
INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]");
|
454 |
+
|
455 |
+
// Unknown dims in the critical fields lead to partial inference.
|
456 |
+
INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]");
|
457 |
+
INFER_OK(op, "[1,?,4,1];[2,1,1,1]", "[d0_0,?,2,d1_3]");
|
458 |
+
INFER_OK(op, "[1,4,?,1];[2,1,1,1]", "[d0_0,3,?,d1_3]");
|
459 |
+
INFER_OK(op, "[1,4,4,?];[2,1,1,1]", "[d0_0,3,2,d1_3]");
|
460 |
+
INFER_OK(op, "[1,4,4,1];[?,1,1,1]", "[d0_0,?,2,d1_3]");
|
461 |
+
INFER_OK(op, "[1,4,4,1];[2,?,1,1]", "[d0_0,3,?,d1_3]");
|
462 |
+
|
463 |
+
// input depths must match.
|
464 |
+
INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op,
|
465 |
+
"[1,2,2,10];[1,1,10000,20]");
|
466 |
+
|
467 |
+
// Tests for NCHW
|
468 |
+
// 1x1 filter
|
469 |
+
set_op({{1, 1, 1, 1}}, "VALID", "NCHW", "HWIO");
|
470 |
+
INFER_OK(op, "[1,1,2,2];[1,1,1,1]", "[d0_0,d1_3,2,2]");
|
471 |
+
|
472 |
+
// 2x2 filter
|
473 |
+
set_op({{1, 1, 1, 1}}, "VALID", "NCHW", "HWIO");
|
474 |
+
INFER_OK(op, "[1,1,2,2];[2,2,1,1]", "[d0_0,d1_3,1,1]");
|
475 |
+
|
476 |
+
// 3x3 input, 1x1 filter, 2x2 stride
|
477 |
+
set_op({{1, 1, 2, 2}}, "VALID", "NCHW", "HWIO");
|
478 |
+
INFER_OK(op, "[1,1,3,3];[1,1,1,1]", "[d0_0,d1_3,2,2]");
|
479 |
+
|
480 |
+
// 3x3 input, 1x1 filter, 2x1 stride
|
481 |
+
set_op({{1, 1, 2, 1}}, "VALID", "NCHW", "HWIO");
|
482 |
+
INFER_OK(op, "[1,1,3,3];[1,1,1,1]", "[d0_0,d1_3,2,3]");
|
483 |
+
|
484 |
+
// 4x4 input, 2x1 filter, 1x2 stride
|
485 |
+
set_op({{1, 1, 1, 2}}, "VALID", "NCHW", "HWIO");
|
486 |
+
INFER_OK(op, "[1,1,4,4];[2,1,1,1]", "[d0_0,d1_3,3,2]");
|
487 |
+
|
488 |
+
// Tests for NCHW_VECT_C
|
489 |
+
// 1x1 filter
|
490 |
+
set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
|
491 |
+
INFER_OK(op, "[1,1,2,2,4];[4,1,1,1,4]", "[d0_0,1,2,2,4]");
|
492 |
+
|
493 |
+
// 2x2 filter
|
494 |
+
set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
|
495 |
+
INFER_OK(op, "[1,1,2,2,4];[4,1,2,2,4]", "[d0_0,1,1,1,4]");
|
496 |
+
|
497 |
+
// 3x3 input, 1x1 filter, 2x2 stride
|
498 |
+
set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
|
499 |
+
INFER_OK(op, "[1,1,3,3,4];[8,1,1,1,4]", "[d0_0,2,2,2,4]");
|
500 |
+
|
501 |
+
// 3x3 input, 1x1 filter, 2x1 stride
|
502 |
+
set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
|
503 |
+
INFER_OK(op, "[1,1,3,3,4];[4,1,1,1,4]", "[d0_0,1,2,3,4]");
|
504 |
+
|
505 |
+
// 4x4 input, 2x1 filter, 1x2 stride
|
506 |
+
set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
|
507 |
+
INFER_OK(op, "[1,1,4,4,4];[4,1,2,1,4]", "[d0_0,1,3,2,4]");
|
508 |
+
|
509 |
+
// Some tests for "SAME" padding
|
510 |
+
|
511 |
+
// 4x4 input, 1x1 filter, 1x1 stride
|
512 |
+
set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO");
|
513 |
+
INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
|
514 |
+
|
515 |
+
// 3x3 input, 2x2 filter, 1x1 stride
|
516 |
+
set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO");
|
517 |
+
INFER_OK(op, "[1,3,3,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
|
518 |
+
|
519 |
+
// 4x4 input, 2x2 filter, 2x2 stride
|
520 |
+
set_op({{1, 2, 2, 1}}, "SAME", "NHWC", "HWIO");
|
521 |
+
INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,2,2,d1_3]");
|
522 |
+
|
523 |
+
// 4x4 input, 2x2 filter, 1x1 stride
|
524 |
+
set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO");
|
525 |
+
INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
|
526 |
+
|
527 |
+
// With stride 1x1 and SAME, unknown dims don't matter - filter dims except
|
528 |
+
// for output channels are ignored for output, so all inputs are carried
|
529 |
+
// through to output.
|
530 |
+
set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO");
|
531 |
+
INFER_OK(op, "[1,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
532 |
+
INFER_OK(op, "[1,?,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
533 |
+
INFER_OK(op, "[1,4,?,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
534 |
+
INFER_OK(op, "[1,4,4,?];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
535 |
+
INFER_OK(op, "[?,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
536 |
+
|
537 |
+
// With stride != 1, the input HW dims are divided to produce output dims.
|
538 |
+
set_op({{1, 2, 2, 1}}, "SAME", "NHWC", "HWIO");
|
539 |
+
INFER_OK(op, "[?,4,4,1];[?,?,?,?]", "[d0_0,2,2,d1_3]");
|
540 |
+
INFER_OK(op, "[1,?,4,1];[?,?,?,?]", "[d0_0,?,2,d1_3]");
|
541 |
+
INFER_OK(op, "[1,4,?,1];[?,?,?,?]", "[d0_0,2,?,d1_3]");
|
542 |
+
INFER_OK(op, "[1,4,4,?];[?,?,?,?]", "[d0_0,2,2,d1_3]");
|
543 |
+
}
|
544 |
+
|
545 |
+
TEST(CommonShapeFnsTest, Conv2DDilatedShapeTest) {
|
546 |
+
ShapeInferenceTestOp op("Conv2D");
|
547 |
+
auto set_op = [&op](const std::vector<int32>& dilations,
|
548 |
+
const std::vector<int32>& strides, const string& padding,
|
549 |
+
const string& data_format) {
|
550 |
+
TF_CHECK_OK(NodeDefBuilder("test", "Conv2D")
|
551 |
+
.Input("input", 0, DT_FLOAT)
|
552 |
+
.Input("filter", 0, DT_FLOAT)
|
553 |
+
.Attr("dilations", dilations)
|
554 |
+
.Attr("strides", strides)
|
555 |
+
.Attr("padding", padding)
|
556 |
+
.Attr("data_format", data_format)
|
557 |
+
.Finalize(&op.node_def));
|
558 |
+
};
|
559 |
+
|
560 |
+
// Invalid rank for dilation
|
561 |
+
set_op({{1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
|
562 |
+
INFER_ERROR("contain 4 values", op, "[1,2,2,1];[1,1,1,1]");
|
563 |
+
|
564 |
+
// Invalid value for dilation
|
565 |
+
set_op({{1, 0, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
|
566 |
+
INFER_ERROR("must be >= 1", op, "[1,2,2,1];[1,1,1,1]");
|
567 |
+
|
568 |
+
// Tests for NHWC
|
569 |
+
// 1x1 filter, 2x1 dilations, 1x1 strides
|
570 |
+
set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
|
571 |
+
INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
|
572 |
+
|
573 |
+
// 1x1 filter, 2x1 dilations, 2x1 strides
|
574 |
+
set_op({{1, 2, 1, 1}}, {{1, 2, 1, 1}}, "VALID", "NHWC");
|
575 |
+
INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,2,4,d1_3]");
|
576 |
+
|
577 |
+
// 1x1 filter, 2x1 dilations, 2x2 strides
|
578 |
+
set_op({{1, 2, 1, 1}}, {{1, 2, 2, 1}}, "VALID", "NHWC");
|
579 |
+
INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
|
580 |
+
|
581 |
+
// 3x3 filter, 2x1 dilations, 1x1 strides
|
582 |
+
set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
|
583 |
+
INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,1,3,d1_3]");
|
584 |
+
|
585 |
+
// 3x3 filter, 2x1 dilations, 2x1 strides
|
586 |
+
set_op({{1, 2, 1, 1}}, {{1, 2, 1, 1}}, "VALID", "NHWC");
|
587 |
+
INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,1,3,d1_3]");
|
588 |
+
|
589 |
+
// 3x3 filter, 1x2 dilations, 2x2 strides
|
590 |
+
set_op({{1, 1, 2, 1}}, {{1, 2, 2, 1}}, "VALID", "NHWC");
|
591 |
+
INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,2,1,d1_3]");
|
592 |
+
|
593 |
+
// Tests for NCHW
|
594 |
+
// 1x1 filter, 2x1 dilations, 1x1 strides
|
595 |
+
set_op({{1, 1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NCHW");
|
596 |
+
INFER_OK(op, "[1,1,2,2];[1,1,1,1]", "[d0_0,d1_3,2,2]");
|
597 |
+
|
598 |
+
// 1x1 filter, 2x1 dilations, 2x1 strides
|
599 |
+
set_op({{1, 1, 2, 1}}, {{1, 1, 2, 1}}, "VALID", "NCHW");
|
600 |
+
INFER_OK(op, "[1,1,4,4];[1,1,1,1]", "[d0_0,d1_3,2,4]");
|
601 |
+
|
602 |
+
// 1x1 filter, 2x1 dilations, 2x2 strides
|
603 |
+
set_op({{1, 1, 2, 1}}, {{1, 1, 2, 2}}, "VALID", "NCHW");
|
604 |
+
INFER_OK(op, "[1,1,4,4];[1,1,1,1]", "[d0_0,d1_3,2,2]");
|
605 |
+
|
606 |
+
// 3x3 filter, 2x1 dilations, 1x1 strides
|
607 |
+
set_op({{1, 1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NCHW");
|
608 |
+
INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,1,3]");
|
609 |
+
|
610 |
+
// 3x3 filter, 2x1 dilations, 2x1 strides
|
611 |
+
set_op({{1, 1, 2, 1}}, {{1, 1, 2, 1}}, "VALID", "NCHW");
|
612 |
+
INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,1,3]");
|
613 |
+
|
614 |
+
// 3x3 filter, 1x2 dilations, 2x2 strides
|
615 |
+
set_op({{1, 1, 1, 2}}, {{1, 1, 2, 2}}, "VALID", "NCHW");
|
616 |
+
INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,2,1]");
|
617 |
+
|
618 |
+
// Some tests for "SAME" padding
|
619 |
+
|
620 |
+
// 4x4 input, 1x1 filter, 2x1 dilations, 1x1 stride
|
621 |
+
set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC");
|
622 |
+
INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
|
623 |
+
|
624 |
+
// 3x3 input, 2x2 filter, 2x2 dilations, 1x1 stride
|
625 |
+
set_op({{1, 2, 2, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC");
|
626 |
+
INFER_OK(op, "[1,3,3,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
|
627 |
+
|
628 |
+
// 4x4 input, 2x2 filter, 1x2 dilations, 2x2 stride
|
629 |
+
set_op({{1, 1, 2, 1}}, {{1, 2, 2, 1}}, "SAME", "NHWC");
|
630 |
+
INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,2,2,d1_3]");
|
631 |
+
|
632 |
+
// 4x4 input, 2x2 filter, 2x2 dilations, 1x1 stride
|
633 |
+
set_op({{1, 2, 2, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC");
|
634 |
+
INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
|
635 |
+
}
|
636 |
+
|
637 |
+
TEST(CommonShapeFnsTest, Conv3DShapeTest) {
|
638 |
+
ShapeInferenceTestOp op("Conv3D");
|
639 |
+
auto set_op = [&op](const std::vector<int32>& strides,
|
640 |
+
const string& padding) {
|
641 |
+
TF_CHECK_OK(NodeDefBuilder("test", "Conv3D")
|
642 |
+
.Input("input", 0, DT_FLOAT)
|
643 |
+
.Input("filter", 0, DT_FLOAT)
|
644 |
+
.Attr("strides", strides)
|
645 |
+
.Attr("padding", padding)
|
646 |
+
.Finalize(&op.node_def));
|
647 |
+
};
|
648 |
+
|
649 |
+
// 1x1x1 filter
|
650 |
+
set_op({{1, 1, 1, 1, 1}}, "VALID");
|
651 |
+
INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]");
|
652 |
+
|
653 |
+
// Invalid rank for input
|
654 |
+
INFER_ERROR("must be rank 5", op, "[4,4];[2,1,1,1]");
|
655 |
+
// Invalid rank for filter
|
656 |
+
INFER_ERROR("must be rank 5", op, "[1,4,4,1];[2,1,1]");
|
657 |
+
|
658 |
+
// unknown dims in the critical fields give partial inference.
|
659 |
+
INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]");
|
660 |
+
INFER_OK(op, "[1,?,2,2,1];[1,1,1,1,1]", "[d0_0,?,2,2,d1_4]");
|
661 |
+
INFER_OK(op, "[1,2,?,2,1];[1,1,1,1,1]", "[d0_0,2,?,2,d1_4]");
|
662 |
+
INFER_OK(op, "[1,2,2,?,1];[1,1,1,1,1]", "[d0_0,2,2,?,d1_4]");
|
663 |
+
INFER_OK(op, "[1,2,2,2,1];[?,1,1,1,1]", "[d0_0,?,2,2,d1_4]");
|
664 |
+
INFER_OK(op, "[1,2,2,2,1];[1,?,1,1,1]", "[d0_0,2,?,2,d1_4]");
|
665 |
+
INFER_OK(op, "[1,2,2,2,1];[1,1,?,1,1]", "[d0_0,2,2,?,d1_4]");
|
666 |
+
INFER_OK(op, "[1,2,2,2,1];[1,1,1,?,1]", "[d0_0,2,2,2,d1_4]");
|
667 |
+
INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,?]", "[d0_0,2,2,2,d1_4]");
|
668 |
+
|
669 |
+
// input depths must match.
|
670 |
+
INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op,
|
671 |
+
"[1,2,2,2,10];[1,1,1,10000,20]");
|
672 |
+
|
673 |
+
// 2x2x2 filter
|
674 |
+
set_op({{1, 1, 1, 1, 1}}, "VALID");
|
675 |
+
INFER_OK(op, "[1,2,2,2,1];[2,2,2,1,1]", "[d0_0,1,1,1,d1_4]");
|
676 |
+
|
677 |
+
// 3x3 input, 1x1 filter, 2x2 stride
|
678 |
+
set_op({{1, 2, 2, 2, 1}}, "VALID");
|
679 |
+
INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]");
|
680 |
+
|
681 |
+
// 3x3 input, 1x1 filter, 2x1x1 stride
|
682 |
+
set_op({{1, 2, 1, 1, 1}}, "VALID");
|
683 |
+
INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,3,3,d1_4]");
|
684 |
+
|
685 |
+
// 4x4 input, 2x2 filter, 1x1 stride
|
686 |
+
set_op({{1, 1, 1, 1, 1}}, "SAME");
|
687 |
+
INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
688 |
+
|
689 |
+
// with SAME, filter doesn't matter except for last dim.
|
690 |
+
set_op({{1, 1, 1, 1, 1}}, "SAME");
|
691 |
+
INFER_OK(op, "[?,4,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
692 |
+
INFER_OK(op, "[1,?,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
693 |
+
INFER_OK(op, "[1,4,?,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
694 |
+
INFER_OK(op, "[1,4,4,?,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
695 |
+
INFER_OK(op, "[1,4,4,4,?];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
696 |
+
INFER_OK(op, "[1,4,4,4,1];[?,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
697 |
+
INFER_OK(op, "[1,4,4,4,1];[2,?,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
698 |
+
INFER_OK(op, "[1,4,4,4,1];[2,2,?,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
699 |
+
INFER_OK(op, "[1,4,4,4,1];[2,2,2,?,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
700 |
+
INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,?]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
701 |
+
|
702 |
+
// with SAME, and stride != 1, division happens to produce output.
|
703 |
+
set_op({{1, 2, 3, 4, 1}}, "SAME");
|
704 |
+
INFER_OK(op, "[1,4,9,4,1];[2,2,2,1,1]", "[d0_0,2,3,1,d1_4]");
|
705 |
+
INFER_OK(op, "[?,4,9,4,1];[2,2,2,1,1]", "[d0_0,2,3,1,d1_4]");
|
706 |
+
INFER_OK(op, "[1,?,9,4,1];[2,2,2,1,1]", "[d0_0,?,3,1,d1_4]");
|
707 |
+
INFER_OK(op, "[1,4,?,4,1];[2,2,2,1,1]", "[d0_0,2,?,1,d1_4]");
|
708 |
+
INFER_OK(op, "[1,4,9,?,1];[2,2,2,1,1]", "[d0_0,2,3,?,d1_4]");
|
709 |
+
INFER_OK(op, "[1,4,9,4,?];[2,2,2,1,1]", "[d0_0,2,3,1,d1_4]");
|
710 |
+
INFER_OK(op, "[1,4,9,4,1];[?,2,2,1,1]", "[d0_0,2,3,1,d1_4]");
|
711 |
+
INFER_OK(op, "[1,4,9,4,1];[2,?,2,1,1]", "[d0_0,2,3,1,d1_4]");
|
712 |
+
INFER_OK(op, "[1,4,9,4,1];[2,2,?,1,1]", "[d0_0,2,3,1,d1_4]");
|
713 |
+
INFER_OK(op, "[1,4,9,4,1];[2,2,2,?,1]", "[d0_0,2,3,1,d1_4]");
|
714 |
+
INFER_OK(op, "[1,4,9,4,1];[2,2,2,1,?]", "[d0_0,2,3,1,d1_4]");
|
715 |
+
}
|
716 |
+
|
717 |
+
TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) {
|
718 |
+
ShapeInferenceTestOp op("DepthwiseConv2dNative");
|
719 |
+
std::vector<int32> strides = {{1, 1, 1, 1}};
|
720 |
+
TF_CHECK_OK(NodeDefBuilder("test", "DepthwiseConv2dNative")
|
721 |
+
.Input("input", 0, DT_FLOAT)
|
722 |
+
.Input("filter", 0, DT_FLOAT)
|
723 |
+
.Attr("strides", strides)
|
724 |
+
.Attr("padding", "VALID")
|
725 |
+
.Attr("data_format", "NHWC")
|
726 |
+
.Finalize(&op.node_def));
|
727 |
+
|
728 |
+
// Most of DepthwiseConv2D is implicitly tested by Conv2D, so
|
729 |
+
// we test only the very-specific differences here.
|
730 |
+
|
731 |
+
// 1x1 filter, depth multiplication
|
732 |
+
INFER_OK(op, "[1,2,2,3];[1,1,3,4]", "[d0_0,2,2,12]");
|
733 |
+
|
734 |
+
// Input depths not compatible
|
735 |
+
INFER_ERROR("Dimensions must be equal, but are 3 and 12", op,
|
736 |
+
"[1,2,2,3];[1,1,12,4]");
|
737 |
+
|
738 |
+
// No unknown dims in the critical fields.
|
739 |
+
INFER_OK(op, "[1,2,2,3];[1,1,3,4]", "[d0_0,2,2,12]");
|
740 |
+
INFER_OK(op, "[1,?,2,3];[1,1,3,4]", "[d0_0,?,2,12]");
|
741 |
+
INFER_OK(op, "[1,2,?,3];[1,1,3,4]", "[d0_0,2,?,12]");
|
742 |
+
INFER_OK(op, "[1,2,2,3];[?,1,3,4]", "[d0_0,?,2,12]");
|
743 |
+
INFER_OK(op, "[1,2,2,3];[1,?,3,4]", "[d0_0,2,?,12]");
|
744 |
+
INFER_OK(op, "[1,2,2,3];[1,1,?,4]", "[d0_0,2,2,12]");
|
745 |
+
INFER_OK(op, "[1,2,2,?];[1,1,?,4]", "[d0_0,2,2,?]");
|
746 |
+
INFER_OK(op, "[1,2,2,3];[1,1,3,?]", "[d0_0,2,2,?]");
|
747 |
+
|
748 |
+
// Test for NCHW format.
|
749 |
+
TF_CHECK_OK(NodeDefBuilder("test", "DepthwiseConv2dNative")
|
750 |
+
.Input("input", 0, DT_FLOAT)
|
751 |
+
.Input("filter", 0, DT_FLOAT)
|
752 |
+
.Attr("strides", strides)
|
753 |
+
.Attr("padding", "VALID")
|
754 |
+
.Attr("data_format", "NCHW")
|
755 |
+
.Finalize(&op.node_def));
|
756 |
+
|
757 |
+
// 1x1 filter, depth multiplication
|
758 |
+
INFER_OK(op, "[1,3,2,2];[1,1,3,4]", "[d0_0,12,2,2]");
|
759 |
+
}
|
760 |
+
|
761 |
+
TEST(CommonShapeFnsTest, AvgPool2DShapeTest) {
|
762 |
+
ShapeInferenceTestOp op("AvgPool");
|
763 |
+
auto set_op = [&op](const std::vector<int32>& strides,
|
764 |
+
const std::vector<int32>& ksizes, const string& padding,
|
765 |
+
const string& data_format) {
|
766 |
+
TF_CHECK_OK(NodeDefBuilder("test", "AvgPool")
|
767 |
+
.Input("input", 0, DT_FLOAT)
|
768 |
+
.Attr("strides", strides)
|
769 |
+
.Attr("ksize", ksizes)
|
770 |
+
.Attr("padding", padding)
|
771 |
+
.Attr("data_format", data_format)
|
772 |
+
.Finalize(&op.node_def));
|
773 |
+
};
|
774 |
+
|
775 |
+
// Most of the functionality is tested by conv-like shapes,
|
776 |
+
// so we check the very-specific avgpooling features here.
|
777 |
+
|
778 |
+
// 1x1 filter, 1x1 stride
|
779 |
+
set_op({1, 1, 1, 1}, {1, 1, 1, 1}, "VALID", "NHWC");
|
780 |
+
INFER_OK(op, "[1,2,2,1]", "[d0_0,2,2,d0_3]");
|
781 |
+
|
782 |
+
// 4x4 input, 2x1 ksize, 1x2 stride
|
783 |
+
set_op({1, 1, 2, 1}, {1, 2, 1, 1}, "VALID", "NHWC");
|
784 |
+
INFER_OK(op, "[1,4,4,1]", "[d0_0,3,2,d0_3]");
|
785 |
+
|
786 |
+
// 4x4 input, 2x1 ksize, 1x2 stride
|
787 |
+
// unknown dims in the critical fields lead to partial inference.
|
788 |
+
// Assumes NHWC format.
|
789 |
+
INFER_OK(op, "[1,?,4,1]", "[d0_0,?,2,d0_3]");
|
790 |
+
INFER_OK(op, "[1,4,?,1]", "[d0_0,3,?,d0_3]");
|
791 |
+
|
792 |
+
// 4x4 input, 2x1 ksize, 1x2 stride, NCHW format
|
793 |
+
set_op({{1, 1, 1, 2}}, {1, 1, 2, 1}, "VALID", "NCHW");
|
794 |
+
INFER_OK(op, "[1,1,4,4]", "[d0_0,d0_1,3,2]");
|
795 |
+
|
796 |
+
// 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C test
|
797 |
+
set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "VALID", "NCHW_VECT_C");
|
798 |
+
INFER_OK(op, "[2,3,5,7,4]", "[d0_0,d0_1,4,6,4]");
|
799 |
+
INFER_OK(op, "[5,7,?,?,4]", "[d0_0,d0_1,?,?,4]");
|
800 |
+
INFER_OK(op, "[?,?,?,?,4]", "[d0_0,d0_1,?,?,4]");
|
801 |
+
INFER_ERROR("Dimension must be 4 but is 3", op, "[2,5,7,11,3]");
|
802 |
+
|
803 |
+
// Invalid rank for input
|
804 |
+
INFER_ERROR("Shape must be rank", op, "[4,4]");
|
805 |
+
}
|
806 |
+
|
807 |
+
TEST(CommonShapeFnsTest, MaxPool2DShapeTest) {
|
808 |
+
ShapeInferenceTestOp op("MaxPool");
|
809 |
+
auto set_op = [&op](const std::vector<int32>& strides,
|
810 |
+
const std::vector<int32>& ksizes, const string& padding,
|
811 |
+
const string& data_format) {
|
812 |
+
TF_CHECK_OK(NodeDefBuilder("test", "MaxPool")
|
813 |
+
.Input("input", 0, DT_FLOAT)
|
814 |
+
.Attr("strides", strides)
|
815 |
+
.Attr("ksize", ksizes)
|
816 |
+
.Attr("padding", padding)
|
817 |
+
.Attr("data_format", data_format)
|
818 |
+
.Finalize(&op.node_def));
|
819 |
+
};
|
820 |
+
|
821 |
+
// Most of the functionality is tested by conv-like shapes,
|
822 |
+
// so we check the very-specific maxpooling features here,
|
823 |
+
// namely depthwise kernel and striding.
|
824 |
+
|
825 |
+
// all 1 strides, depth 2 filter
|
826 |
+
set_op({1, 1, 1, 1}, {1, 1, 1, 2}, "VALID", "NHWC");
|
827 |
+
INFER_OK(op, "[1,2,2,2]", "[d0_0,2,2,1]");
|
828 |
+
|
829 |
+
// depth 3 stride, 1x1x1 filter, NCHW
|
830 |
+
set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW");
|
831 |
+
INFER_OK(op, "[1,7,5,5]", "[d0_0,3,5,5]");
|
832 |
+
|
833 |
+
// 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C tests
|
834 |
+
set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "SAME", "NCHW_VECT_C");
|
835 |
+
INFER_OK(op, "[2,3,5,7,4]", "[d0_0,d0_1,d0_2,d0_3,4]");
|
836 |
+
INFER_OK(op, "[5,7,?,?,4]", "[d0_0,d0_1,d0_2,d0_3,4]");
|
837 |
+
INFER_OK(op, "[?,?,?,?,4]", "[d0_0,d0_1,d0_2,d0_3,4]");
|
838 |
+
INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8]");
|
839 |
+
}
|
840 |
+
|
841 |
+
TEST(CommonShapeFnsTest, MaxPoolV22DShapeTest) {
|
842 |
+
ShapeInferenceTestOp op("MaxPoolV2");
|
843 |
+
Tensor ksizes_tensor, strides_tensor;
|
844 |
+
auto set_op = [&op, &ksizes_tensor, &strides_tensor](
|
845 |
+
const std::vector<int32>& strides,
|
846 |
+
const std::vector<int32>& ksizes, const string& padding,
|
847 |
+
const string& data_format) {
|
848 |
+
TF_CHECK_OK(NodeDefBuilder("test", "MaxPoolV2")
|
849 |
+
.Input("input", 0, DT_FLOAT)
|
850 |
+
.Input("ksize", 1, DT_INT32)
|
851 |
+
.Input("strides", 2, DT_INT32)
|
852 |
+
.Attr("padding", padding)
|
853 |
+
.Attr("data_format", data_format)
|
854 |
+
.Finalize(&op.node_def));
|
855 |
+
ksizes_tensor = test::AsTensor<int32>(ksizes);
|
856 |
+
op.input_tensors.resize(3);
|
857 |
+
op.input_tensors[0] = nullptr;
|
858 |
+
op.input_tensors[1] = &ksizes_tensor;
|
859 |
+
strides_tensor = test::AsTensor<int32>(strides);
|
860 |
+
op.input_tensors[2] = &strides_tensor;
|
861 |
+
};
|
862 |
+
|
863 |
+
// Most of the functionality is tested by conv-like shapes,
|
864 |
+
// so we check the very-specific maxpooling features here,
|
865 |
+
// namely depthwise kernel and striding.
|
866 |
+
|
867 |
+
// all 1 strides, depth 2 filter
|
868 |
+
set_op({1, 1, 1, 1}, {1, 1, 1, 2}, "VALID", "NHWC");
|
869 |
+
INFER_OK(op, "[1,2,2,2];[4];[4]", "[d0_0,2,2,1]");
|
870 |
+
|
871 |
+
// depth 3 stride, 1x1x1 filter, NCHW
|
872 |
+
set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW");
|
873 |
+
INFER_OK(op, "[1,7,5,5];[4];[4]", "[d0_0,3,5,5]");
|
874 |
+
|
875 |
+
// 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C tests
|
876 |
+
set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "SAME", "NCHW_VECT_C");
|
877 |
+
INFER_OK(op, "[2,3,5,7,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]");
|
878 |
+
INFER_OK(op, "[5,7,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]");
|
879 |
+
INFER_OK(op, "[?,?,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]");
|
880 |
+
INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8];[4];[4]");
|
881 |
+
}
|
882 |
+
|
883 |
+
TEST(CommonShapeFnsTest, Pool3DShapeTest) {
|
884 |
+
ShapeInferenceTestOp op("MaxPool3D");
|
885 |
+
auto set_op = [&op](const std::vector<int32>& strides,
|
886 |
+
const std::vector<int32>& ksizes, const string& padding) {
|
887 |
+
TF_CHECK_OK(NodeDefBuilder("test", "MaxPool3D")
|
888 |
+
.Input("input", 0, DT_FLOAT)
|
889 |
+
.Attr("strides", strides)
|
890 |
+
.Attr("ksize", ksizes)
|
891 |
+
.Attr("padding", padding)
|
892 |
+
.Finalize(&op.node_def));
|
893 |
+
};
|
894 |
+
|
895 |
+
// Most of the functionality is tested by conv-like shapes,
|
896 |
+
// so we check that we handle the extra dimension properly.
|
897 |
+
|
898 |
+
// 2x3x4 stride, 1x1x1 filter.
|
899 |
+
set_op({1, 2, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID");
|
900 |
+
INFER_OK(op, "[1,24,24,24,1]", "[d0_0,12,8,6,d0_4]");
|
901 |
+
|
902 |
+
// Test partially known dimensions
|
903 |
+
set_op({1, 1, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID");
|
904 |
+
INFER_OK(op, "[1,?,24,24,1]", "[d0_0,?,8,6,d0_4]");
|
905 |
+
}
|
906 |
+
|
907 |
+
TEST(CommonShapeFnsTest, UnknownShapeTest) {
|
908 |
+
{
|
909 |
+
// Single output
|
910 |
+
ShapeInferenceTestOp op("QueueDequeue");
|
911 |
+
TF_CHECK_OK(NodeDefBuilder("test", "QueueDequeue")
|
912 |
+
.Input("handle", 0, DT_STRING_REF)
|
913 |
+
.Attr("component_types", {DT_FLOAT})
|
914 |
+
.Finalize(&op.node_def));
|
915 |
+
INFER_OK(op, "[1]", "?");
|
916 |
+
}
|
917 |
+
|
918 |
+
{
|
919 |
+
// Multiple outputs
|
920 |
+
ShapeInferenceTestOp op("QueueDequeue");
|
921 |
+
TF_CHECK_OK(NodeDefBuilder("test", "QueueDequeue")
|
922 |
+
.Input("handle", 0, DT_STRING_REF)
|
923 |
+
.Attr("component_types", {DT_FLOAT, DT_FLOAT, DT_STRING})
|
924 |
+
.Finalize(&op.node_def));
|
925 |
+
INFER_OK(op, "[1]", "?;?;?");
|
926 |
+
}
|
927 |
+
}
|
928 |
+
|
929 |
+
TEST(CommonShapeFnsTest, Reduce_ShapeFn) {
|
930 |
+
ShapeInferenceTestOp op("Sum");
|
931 |
+
op.input_tensors.resize(2);
|
932 |
+
|
933 |
+
TF_ASSERT_OK(NodeDefBuilder("test", "Sum")
|
934 |
+
.Input("input", 0, DT_FLOAT)
|
935 |
+
.Input("reduction_indices", 1, DT_INT32)
|
936 |
+
.Attr("keep_dims", false)
|
937 |
+
.Finalize(&op.node_def));
|
938 |
+
|
939 |
+
// Reduction indices not available, so output is unknown.
|
940 |
+
INFER_OK(op, "[2,4,5];[2]", "?");
|
941 |
+
INFER_OK(op, "?;[2]", "?");
|
942 |
+
|
943 |
+
Tensor indices = test::AsTensor<int32>({1, 2});
|
944 |
+
op.input_tensors[1] = &indices;
|
945 |
+
|
946 |
+
// Reduction indices available
|
947 |
+
INFER_OK(op, "[2,4,5];[2]", "[d0_0]");
|
948 |
+
|
949 |
+
// Wrapped indices
|
950 |
+
indices = test::AsTensor<int32>({-1, -2});
|
951 |
+
op.input_tensors[1] = &indices;
|
952 |
+
INFER_OK(op, "[2,4,5];[2]", "[d0_0]");
|
953 |
+
|
954 |
+
// Scalar
|
955 |
+
indices = test::AsScalar<int32>(0);
|
956 |
+
op.input_tensors[1] = &indices;
|
957 |
+
INFER_OK(op, "[2,4,5];[]", "[d0_1,d0_2]");
|
958 |
+
|
959 |
+
indices = test::AsScalar<int32>(-4);
|
960 |
+
op.input_tensors[1] = &indices;
|
961 |
+
INFER_ERROR("Invalid reduction dimension", op, "[2,4,5];[]");
|
962 |
+
|
963 |
+
// Empty reduction indices
|
964 |
+
indices = test::AsTensor<int32>({});
|
965 |
+
op.input_tensors[1] = &indices;
|
966 |
+
INFER_OK(op, "[2,4,5];[0]", "[d0_0,d0_1,d0_2]");
|
967 |
+
|
968 |
+
// Keep dims = true
|
969 |
+
TF_ASSERT_OK(NodeDefBuilder("test", "Sum")
|
970 |
+
.Input("input", 0, DT_FLOAT)
|
971 |
+
.Input("reduction_indices", 1, DT_INT32)
|
972 |
+
.Attr("keep_dims", true)
|
973 |
+
.Finalize(&op.node_def));
|
974 |
+
indices = test::AsTensor<int32>({-1, -2});
|
975 |
+
op.input_tensors[1] = &indices;
|
976 |
+
INFER_OK(op, "[2,4,5];[2]", "[d0_0, 1, 1]");
|
977 |
+
|
978 |
+
// input rank is known, but reduction indices are not (with keep_dim=true).
|
979 |
+
// The output rank matches input rank (because of keep_dims=true).
|
980 |
+
op.input_tensors[1] = nullptr;
|
981 |
+
INFER_OK(op, "[?,?,?];?", "[?,?,?]");
|
982 |
+
INFER_OK(op, "[?,?,?];[2]", "[?,?,?]");
|
983 |
+
|
984 |
+
// Reduction indices with too many dimensions.
|
985 |
+
INFER_ERROR("must be at most rank 1 but is rank 2", op, "[?,?,?];[?,?]");
|
986 |
+
// With older graph-def version, this is allowed.
|
987 |
+
op.graph_def_version = 20;
|
988 |
+
INFER_OK(op, "[?,?,?];[?,?]", "[?,?,?]");
|
989 |
+
// And when the tensor is specified, it's still allowed.
|
990 |
+
op.input_tensors[1] = &indices;
|
991 |
+
indices = test::AsTensor<int32>({-1, -2}, TensorShape({2, 1}));
|
992 |
+
INFER_OK(op, "[2,4,5];[2,1]", "[d0_0, 1, 1]");
|
993 |
+
indices = test::AsTensor<int32>({-1, -2}, TensorShape({1, 2}));
|
994 |
+
INFER_OK(op, "[2,4,5];[1,2]", "[d0_0, 1, 1]");
|
995 |
+
}
|
996 |
+
|
997 |
+
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
|
998 |
+
NodeDef def;
|
999 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
1000 |
+
{Unknown(), Unknown(), Unknown()}, {}, {}, {});
|
1001 |
+
EXPECT_EQ(3, c.num_inputs());
|
1002 |
+
EXPECT_EQ(1, c.num_outputs());
|
1003 |
+
|
1004 |
+
auto indices = c.input(0);
|
1005 |
+
auto values = c.input(1);
|
1006 |
+
auto shape = c.input(2);
|
1007 |
+
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
1008 |
+
}
|
1009 |
+
|
1010 |
+
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
|
1011 |
+
NodeDef def;
|
1012 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
1013 |
+
{S({-1, -1}), S({-1}), S({-1})}, {}, {}, {});
|
1014 |
+
EXPECT_EQ(3, c.num_inputs());
|
1015 |
+
EXPECT_EQ(1, c.num_outputs());
|
1016 |
+
|
1017 |
+
auto indices = c.input(0);
|
1018 |
+
auto values = c.input(1);
|
1019 |
+
auto shape = c.input(2);
|
1020 |
+
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
1021 |
+
}
|
1022 |
+
|
1023 |
+
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
|
1024 |
+
NodeDef def;
|
1025 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
1026 |
+
{S({-1}), S({-1}), S({-1})}, {}, {}, {});
|
1027 |
+
EXPECT_EQ(3, c.num_inputs());
|
1028 |
+
EXPECT_EQ(1, c.num_outputs());
|
1029 |
+
|
1030 |
+
auto indices = c.input(0);
|
1031 |
+
auto values = c.input(1);
|
1032 |
+
auto shape = c.input(2);
|
1033 |
+
EXPECT_EQ(error::INVALID_ARGUMENT,
|
1034 |
+
ValidateSparseTensor(&c, indices, values, shape).code());
|
1035 |
+
}
|
1036 |
+
|
1037 |
+
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
|
1038 |
+
NodeDef def;
|
1039 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
1040 |
+
{S({5, 3}), S({4}), S({3})}, {}, {}, {});
|
1041 |
+
EXPECT_EQ(3, c.num_inputs());
|
1042 |
+
EXPECT_EQ(1, c.num_outputs());
|
1043 |
+
|
1044 |
+
auto indices = c.input(0);
|
1045 |
+
auto values = c.input(1);
|
1046 |
+
auto shape = c.input(2);
|
1047 |
+
EXPECT_EQ(error::INVALID_ARGUMENT,
|
1048 |
+
ValidateSparseTensor(&c, indices, values, shape).code());
|
1049 |
+
}
|
1050 |
+
|
1051 |
+
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
|
1052 |
+
NodeDef def;
|
1053 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
1054 |
+
{S({5, 3}), S({5}), S({4})}, {}, {}, {});
|
1055 |
+
EXPECT_EQ(3, c.num_inputs());
|
1056 |
+
EXPECT_EQ(1, c.num_outputs());
|
1057 |
+
|
1058 |
+
auto indices = c.input(0);
|
1059 |
+
auto values = c.input(1);
|
1060 |
+
auto shape = c.input(2);
|
1061 |
+
EXPECT_EQ(error::INVALID_ARGUMENT,
|
1062 |
+
ValidateSparseTensor(&c, indices, values, shape).code());
|
1063 |
+
}
|
1064 |
+
|
1065 |
+
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
1066 |
+
NodeDef def;
|
1067 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
1068 |
+
{S({-1, 3}), S({5}), S({3})}, {}, {}, {});
|
1069 |
+
EXPECT_EQ(3, c.num_inputs());
|
1070 |
+
EXPECT_EQ(1, c.num_outputs());
|
1071 |
+
|
1072 |
+
auto indices = c.input(0);
|
1073 |
+
auto values = c.input(1);
|
1074 |
+
auto shape = c.input(2);
|
1075 |
+
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
1076 |
+
}
|
1077 |
+
|
1078 |
+
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
|
1079 |
+
NodeDef def;
|
1080 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
1081 |
+
{S({5, 3}), S({-1}), S({3})}, {}, {}, {});
|
1082 |
+
EXPECT_EQ(3, c.num_inputs());
|
1083 |
+
EXPECT_EQ(1, c.num_outputs());
|
1084 |
+
|
1085 |
+
auto indices = c.input(0);
|
1086 |
+
auto values = c.input(1);
|
1087 |
+
auto shape = c.input(2);
|
1088 |
+
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
1089 |
+
}
|
1090 |
+
|
1091 |
+
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
|
1092 |
+
NodeDef def;
|
1093 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
1094 |
+
{S({5, -1}), S({5}), S({3})}, {}, {}, {});
|
1095 |
+
EXPECT_EQ(3, c.num_inputs());
|
1096 |
+
EXPECT_EQ(1, c.num_outputs());
|
1097 |
+
|
1098 |
+
auto indices = c.input(0);
|
1099 |
+
auto values = c.input(1);
|
1100 |
+
auto shape = c.input(2);
|
1101 |
+
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
1102 |
+
}
|
1103 |
+
|
1104 |
+
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
|
1105 |
+
NodeDef def;
|
1106 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
1107 |
+
{S({5, 3}), S({5}), S({-1})}, {}, {}, {});
|
1108 |
+
EXPECT_EQ(3, c.num_inputs());
|
1109 |
+
EXPECT_EQ(1, c.num_outputs());
|
1110 |
+
|
1111 |
+
auto indices = c.input(0);
|
1112 |
+
auto values = c.input(1);
|
1113 |
+
auto shape = c.input(2);
|
1114 |
+
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
1115 |
+
}
|
1116 |
+
|
1117 |
+
TEST(CommonShapeFnsTest, ValidateSparseTensor) {
|
1118 |
+
NodeDef def;
|
1119 |
+
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
1120 |
+
{S({5, 3}), S({5}), S({3})}, {}, {}, {});
|
1121 |
+
EXPECT_EQ(3, c.num_inputs());
|
1122 |
+
EXPECT_EQ(1, c.num_outputs());
|
1123 |
+
|
1124 |
+
auto indices = c.input(0);
|
1125 |
+
auto values = c.input(1);
|
1126 |
+
auto shape = c.input(2);
|
1127 |
+
TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
|
1128 |
+
}
|
1129 |
+
|
1130 |
+
} // namespace shape_inference
|
1131 |
+
} // namespace tensorflow
|
control_flow.h
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
|
18 |
+
|
19 |
+
#include "tensorflow/core/lib/hash/hash.h"
|
20 |
+
#include "tensorflow/core/platform/logging.h"
|
21 |
+
#include "tensorflow/core/platform/types.h"
|
22 |
+
|
23 |
+
namespace tensorflow {
|
24 |
+
|
25 |
+
const uint64 kIllegalFrameId = ~0uLL;
|
26 |
+
const int64 kIllegalIterId = -1;
|
27 |
+
|
28 |
+
// For the purpose of control flow, every tensor produced by TensorFlow is
|
29 |
+
// conceptually tagged by a 'FrameAndIter'. FrameAndIter consists of a
|
30 |
+
// 'frame_id' and an 'iter_id'. The tensor value it represents is produced
|
31 |
+
// in the frame with frame_id at the iteration of iter_id.
|
32 |
+
struct FrameAndIter {
|
33 |
+
uint64 frame_id = kIllegalFrameId;
|
34 |
+
int64 iter_id = kIllegalIterId;
|
35 |
+
|
36 |
+
FrameAndIter() {}
|
37 |
+
|
38 |
+
FrameAndIter(uint64 frame, int64 iter) {
|
39 |
+
frame_id = frame;
|
40 |
+
iter_id = iter;
|
41 |
+
}
|
42 |
+
|
43 |
+
bool operator==(const FrameAndIter& other) const {
|
44 |
+
return (frame_id == other.frame_id && iter_id == other.iter_id);
|
45 |
+
}
|
46 |
+
};
|
47 |
+
|
48 |
+
struct FrameAndIterHash {
|
49 |
+
size_t operator()(const FrameAndIter& key) const {
|
50 |
+
// Make sure there are no padding bytes that we don't want
|
51 |
+
CHECK_EQ(sizeof(uint64) + sizeof(int64), sizeof(FrameAndIter));
|
52 |
+
return Hash64(reinterpret_cast<const char*>(&key), sizeof(FrameAndIter));
|
53 |
+
}
|
54 |
+
};
|
55 |
+
|
56 |
+
} // namespace tensorflow
|
57 |
+
|
58 |
+
#endif // TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
|
cost_graph.proto
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto3";
|
2 |
+
|
3 |
+
package tensorflow;
|
4 |
+
option cc_enable_arenas = true;
|
5 |
+
option java_outer_classname = "CostGraphProtos";
|
6 |
+
option java_multiple_files = true;
|
7 |
+
option java_package = "org.tensorflow.framework";
|
8 |
+
|
9 |
+
import "tensorflow/core/framework/tensor_shape.proto";
|
10 |
+
import "tensorflow/core/framework/types.proto";
|
11 |
+
|
12 |
+
message CostGraphDef {
|
13 |
+
message Node {
|
14 |
+
// The name of the node. Names are globally unique.
|
15 |
+
string name = 1;
|
16 |
+
|
17 |
+
// The device of the node. Can be empty if the node is mapped to the
|
18 |
+
// default partition or partitioning hasn't been run yet.
|
19 |
+
string device = 2;
|
20 |
+
|
21 |
+
// The id of the node. Node ids are only unique inside a partition.
|
22 |
+
int32 id = 3;
|
23 |
+
|
24 |
+
// Inputs of this node. They must be executed before this node can be
|
25 |
+
// executed. An input is a particular output of another node, specified
|
26 |
+
// by the node id and the output index.
|
27 |
+
message InputInfo {
|
28 |
+
int32 preceding_node = 1;
|
29 |
+
int32 preceding_port = 2;
|
30 |
+
}
|
31 |
+
repeated InputInfo input_info = 4;
|
32 |
+
|
33 |
+
// Outputs of this node.
|
34 |
+
message OutputInfo {
|
35 |
+
int64 size = 1;
|
36 |
+
// If >= 0, the output is an alias of an input. Note that an alias input
|
37 |
+
// may itself be an alias. The algorithm will therefore need to follow
|
38 |
+
// those pointers.
|
39 |
+
int64 alias_input_port = 2;
|
40 |
+
TensorShapeProto shape = 3;
|
41 |
+
DataType dtype = 4;
|
42 |
+
}
|
43 |
+
repeated OutputInfo output_info = 5;
|
44 |
+
|
45 |
+
// Temporary memory used by this node.
|
46 |
+
int64 temporary_memory_size = 6;
|
47 |
+
|
48 |
+
int64 host_temp_memory_size = 10;
|
49 |
+
int64 device_temp_memory_size = 11;
|
50 |
+
int64 host_persistent_memory_size = 12;
|
51 |
+
int64 device_persistent_memory_size = 16;
|
52 |
+
|
53 |
+
// Estimate of the computational cost of this node, in microseconds.
|
54 |
+
int64 compute_cost = 9;
|
55 |
+
|
56 |
+
// Analytical estimate of the computational cost of this node, in
|
57 |
+
// microseconds.
|
58 |
+
int64 compute_time = 14;
|
59 |
+
|
60 |
+
// Analytical estimate of the memory access cost of this node, in
|
61 |
+
// microseconds.
|
62 |
+
int64 memory_time = 15;
|
63 |
+
|
64 |
+
// If true, the output is permanent: it can't be discarded, because this
|
65 |
+
// node is part of the "final output". Nodes may depend on final nodes.
|
66 |
+
bool is_final = 7;
|
67 |
+
|
68 |
+
// Ids of the control inputs for this node.
|
69 |
+
repeated int32 control_input = 8;
|
70 |
+
}
|
71 |
+
repeated Node node = 1;
|
72 |
+
}
|
device_attributes.proto
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto3";
|
2 |
+
|
3 |
+
package tensorflow;
|
4 |
+
option cc_enable_arenas = true;
|
5 |
+
option java_outer_classname = "DeviceAttributesProtos";
|
6 |
+
option java_multiple_files = true;
|
7 |
+
option java_package = "org.tensorflow.framework";
|
8 |
+
|
9 |
+
message DeviceLocality {
|
10 |
+
// Optional bus locality of device. Default value of 0 means
|
11 |
+
// no specific locality. Specific localities are indexed from 1.
|
12 |
+
int32 bus_id = 1;
|
13 |
+
};
|
14 |
+
|
15 |
+
message DeviceAttributes {
|
16 |
+
// Fully specified name of the device within a cluster.
|
17 |
+
string name = 1;
|
18 |
+
|
19 |
+
// String representation of device_type.
|
20 |
+
string device_type = 2;
|
21 |
+
|
22 |
+
// Memory capacity of device in bytes.
|
23 |
+
int64 memory_limit = 4;
|
24 |
+
|
25 |
+
// Platform-specific data about device that may be useful
|
26 |
+
// for supporting efficient data transfers.
|
27 |
+
DeviceLocality locality = 5;
|
28 |
+
|
29 |
+
// A device is assigned a global unique number each time it is
|
30 |
+
// initialized. "incarnation" should never be 0.
|
31 |
+
fixed64 incarnation = 6;
|
32 |
+
|
33 |
+
// String representation of the physical device that this device maps to.
|
34 |
+
string physical_device_desc = 7;
|
35 |
+
}
|
device_base.cc
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/device_base.h"
|
17 |
+
|
18 |
+
namespace tensorflow {
|
19 |
+
|
20 |
+
DeviceBase::~DeviceBase() {}
|
21 |
+
|
22 |
+
const DeviceAttributes& DeviceBase::attributes() const {
|
23 |
+
LOG(FATAL) << "Device does not implement attributes()";
|
24 |
+
}
|
25 |
+
|
26 |
+
const string& DeviceBase::name() const {
|
27 |
+
LOG(FATAL) << "Device does not implement name()";
|
28 |
+
}
|
29 |
+
|
30 |
+
} // namespace tensorflow
|
device_base.h
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
|
18 |
+
|
19 |
+
#include <memory>
|
20 |
+
#include <string>
|
21 |
+
#include <unordered_map>
|
22 |
+
|
23 |
+
#include "tensorflow/core/framework/tensor.h"
|
24 |
+
#include "tensorflow/core/lib/core/errors.h"
|
25 |
+
#include "tensorflow/core/lib/core/refcount.h"
|
26 |
+
#include "tensorflow/core/lib/core/status.h"
|
27 |
+
#include "tensorflow/core/lib/core/stringpiece.h"
|
28 |
+
#include "tensorflow/core/platform/logging.h"
|
29 |
+
|
30 |
+
namespace Eigen {
|
31 |
+
struct ThreadPoolDevice;
|
32 |
+
#ifdef TENSORFLOW_USE_SYCL
|
33 |
+
struct SyclDevice;
|
34 |
+
#endif
|
35 |
+
} // end namespace Eigen
|
36 |
+
|
37 |
+
namespace perftools {
|
38 |
+
namespace gputools {
|
39 |
+
class Stream;
|
40 |
+
} // namespace gputools
|
41 |
+
} // namespace perftools
|
42 |
+
|
43 |
+
namespace tensorflow {
|
44 |
+
|
45 |
+
class Device;
|
46 |
+
class DeviceAttributes;
|
47 |
+
class Env;
|
48 |
+
class EventMgr;
|
49 |
+
class OpKernelContext;
|
50 |
+
class ResourceMgr;
|
51 |
+
class TensorProto;
|
52 |
+
|
53 |
+
namespace thread {
|
54 |
+
class ThreadPool;
|
55 |
+
}
|
56 |
+
|
57 |
+
// A wrapper for an Eigen Gpu Device that includes per-op state. The
|
58 |
+
// class is defined even for non-GPU devices since the
|
59 |
+
// OpKernelContext::Params structure wants to fill it in.
|
60 |
+
class PerOpGpuDevice {
|
61 |
+
public:
|
62 |
+
virtual ~PerOpGpuDevice() {}
|
63 |
+
virtual const Eigen::GpuDevice& device() const = 0;
|
64 |
+
};
|
65 |
+
|
66 |
+
// A class that devices can subclass to pass around
|
67 |
+
// Device-specific context to OpKernels.
|
68 |
+
class DeviceContext : public core::RefCounted {
|
69 |
+
public:
|
70 |
+
~DeviceContext() override {}
|
71 |
+
virtual perftools::gputools::Stream* stream() const { return nullptr; }
|
72 |
+
virtual void MaintainLifetimeOnStream(
|
73 |
+
const Tensor* t, perftools::gputools::Stream* stream) const {}
|
74 |
+
|
75 |
+
// "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into
|
76 |
+
// "device_tensor" which is on a GPU device "device". "device_tensor"
|
77 |
+
// must be allocated to be of the same size as "cpu_tensor".
|
78 |
+
virtual void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
79 |
+
Tensor* device_tensor,
|
80 |
+
StatusCallback done) const {
|
81 |
+
done(errors::Internal("Unrecognized device type in CPU-to-device Copy"));
|
82 |
+
}
|
83 |
+
|
84 |
+
// "device_tensor" is a tensor on a non-CPU device. Copies
|
85 |
+
// device_tensor into "cpu_tensor". "cpu_tensor" must be allocated
|
86 |
+
// to be of the same size as "device_tensor".
|
87 |
+
virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
88 |
+
StringPiece tensor_name, Device* device,
|
89 |
+
Tensor* cpu_tensor, StatusCallback done) {
|
90 |
+
done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
|
91 |
+
}
|
92 |
+
};
|
93 |
+
|
94 |
+
// map[i] is the DeviceContext* for the node with id i, if i < map.size().
|
95 |
+
typedef std::vector<DeviceContext*> DeviceContextMap;
|
96 |
+
|
97 |
+
class DeviceBase {
|
98 |
+
public:
|
99 |
+
explicit DeviceBase(Env* env) : env_(env) {}
|
100 |
+
virtual ~DeviceBase();
|
101 |
+
|
102 |
+
Env* env() const { return env_; }
|
103 |
+
|
104 |
+
// Override this to return true for devices that require an Op's
|
105 |
+
// compute method to save references to the temporary tensors it
|
106 |
+
// allocates until the Op execution completes
|
107 |
+
virtual bool RequiresRecordingAccessedTensors() const { return false; }
|
108 |
+
|
109 |
+
struct CpuWorkerThreads {
|
110 |
+
int num_threads = 0;
|
111 |
+
thread::ThreadPool* workers = nullptr;
|
112 |
+
};
|
113 |
+
|
114 |
+
// Does not take ownership.
|
115 |
+
void set_tensorflow_cpu_worker_threads(CpuWorkerThreads* t) {
|
116 |
+
cpu_worker_threads_ = t;
|
117 |
+
}
|
118 |
+
|
119 |
+
virtual const CpuWorkerThreads* tensorflow_cpu_worker_threads() const {
|
120 |
+
CHECK(cpu_worker_threads_ != nullptr);
|
121 |
+
return cpu_worker_threads_;
|
122 |
+
}
|
123 |
+
|
124 |
+
// "stream" is used in special circumstances (such as the
|
125 |
+
// constructors of Ops) where there is no available OpKernelContext.
|
126 |
+
// "default_context" is used by OpKernelContext whenever a device does not
|
127 |
+
// supply a DeviceContext for an op in FillContextMap (e.g. when only
|
128 |
+
// using a single stream.)
|
129 |
+
// "event_mgr" is used to delay deallocation of temporary GPU buffers.
|
130 |
+
// TODO(pbar) Work out how to move this out of DeviceBase.
|
131 |
+
struct GpuDeviceInfo {
|
132 |
+
// Make sure all the defaults are NULL, so we can spot missing assignments.
|
133 |
+
perftools::gputools::Stream* stream = nullptr;
|
134 |
+
DeviceContext* default_context = nullptr;
|
135 |
+
EventMgr* event_mgr = nullptr;
|
136 |
+
int gpu_id = -1;
|
137 |
+
};
|
138 |
+
|
139 |
+
// Does not take ownership.
|
140 |
+
void set_tensorflow_gpu_device_info(GpuDeviceInfo* g) {
|
141 |
+
gpu_device_info_ = g;
|
142 |
+
}
|
143 |
+
|
144 |
+
virtual const GpuDeviceInfo* tensorflow_gpu_device_info() const {
|
145 |
+
return gpu_device_info_;
|
146 |
+
}
|
147 |
+
|
148 |
+
// The preferred thread pool for this device. If it is nullptr, the system
|
149 |
+
// automatically assigns a thread pool for execution.
|
150 |
+
virtual thread::ThreadPool* tensorflow_device_thread_pool() {
|
151 |
+
return device_thread_pool_;
|
152 |
+
}
|
153 |
+
|
154 |
+
// Does not take ownership.
|
155 |
+
void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) {
|
156 |
+
eigen_cpu_device_ = d;
|
157 |
+
}
|
158 |
+
|
159 |
+
#ifdef TENSORFLOW_USE_SYCL
|
160 |
+
void set_eigen_sycl_device(Eigen::SyclDevice* d) { eigen_sycl_device_ = d; }
|
161 |
+
#endif
|
162 |
+
|
163 |
+
// Return the Allocator implementation to use based on the allocator
|
164 |
+
// attributes requested. See allocator.h for more details.
|
165 |
+
virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
|
166 |
+
LOG(FATAL) << "GetAllocator() is not implemented.";
|
167 |
+
return nullptr;
|
168 |
+
}
|
169 |
+
|
170 |
+
// Return the Allocator implementation to use based on the allocator
|
171 |
+
// attributes requested and the supplied resource manager. By
|
172 |
+
// default this ignores the resource manager and calls the base
|
173 |
+
// implementation but devices can override if they want to consult
|
174 |
+
// the resource manager when choosing the allocator.
|
175 |
+
virtual Allocator* GetStepAllocator(AllocatorAttributes attr,
|
176 |
+
ResourceMgr* /*step_resource_manager*/) {
|
177 |
+
return GetAllocator(attr);
|
178 |
+
}
|
179 |
+
|
180 |
+
virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() {
|
181 |
+
CHECK(eigen_cpu_device_ != nullptr);
|
182 |
+
return eigen_cpu_device_;
|
183 |
+
}
|
184 |
+
|
185 |
+
#ifdef TENSORFLOW_USE_SYCL
|
186 |
+
virtual const Eigen::SyclDevice* eigen_sycl_device() const {
|
187 |
+
CHECK(eigen_sycl_device_ != nullptr);
|
188 |
+
return eigen_sycl_device_;
|
189 |
+
}
|
190 |
+
#endif
|
191 |
+
|
192 |
+
// Caller owns the return value. The OpKernelContext calls this even
|
193 |
+
// for devices that do not implement an eigen_gpu_device. Overridden
|
194 |
+
// by GPU devices to return a derived type.
|
195 |
+
virtual PerOpGpuDevice* MakeGpuDevice() { return nullptr; }
|
196 |
+
|
197 |
+
virtual DeviceBase* UnderlyingDevice() { return this; }
|
198 |
+
virtual const DeviceBase* UnderlyingDevice() const { return this; }
|
199 |
+
|
200 |
+
// This is overridden by GPU devices to reinitialize the derived
|
201 |
+
// type returned by MakeGpuDevice.
|
202 |
+
virtual void ReinitializeGpuDevice(OpKernelContext* /*context*/,
|
203 |
+
PerOpGpuDevice* /*device*/,
|
204 |
+
DeviceContext* /*dc*/,
|
205 |
+
Allocator* /*allocator*/) {}
|
206 |
+
|
207 |
+
// Unimplemented by default
|
208 |
+
virtual const DeviceAttributes& attributes() const;
|
209 |
+
virtual const string& name() const;
|
210 |
+
|
211 |
+
// Materializes the given TensorProto into 'tensor' stored in Device
|
212 |
+
// memory. Most devices will want to override this.
|
213 |
+
//
|
214 |
+
// TODO(vrv): We should be able to put this function into
|
215 |
+
// OpKernelContext and handle the copies from device memory via send
|
216 |
+
// and receive nodes, instead of requiring that each device handle
|
217 |
+
// the copies here as well as in copy ops.
|
218 |
+
virtual Status MakeTensorFromProto(const TensorProto& tensor_proto,
|
219 |
+
const AllocatorAttributes alloc_attrs,
|
220 |
+
Tensor* tensor) {
|
221 |
+
return errors::Internal("Device does not implement MakeTensorFromProto()");
|
222 |
+
}
|
223 |
+
|
224 |
+
protected:
|
225 |
+
// Does not take ownership.
|
226 |
+
void set_tensorflow_device_thread_pool(thread::ThreadPool* thread_pool) {
|
227 |
+
device_thread_pool_ = thread_pool;
|
228 |
+
}
|
229 |
+
|
230 |
+
private:
|
231 |
+
Env* const env_;
|
232 |
+
CpuWorkerThreads* cpu_worker_threads_ = nullptr;
|
233 |
+
GpuDeviceInfo* gpu_device_info_ = nullptr;
|
234 |
+
thread::ThreadPool* device_thread_pool_ = nullptr;
|
235 |
+
Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr;
|
236 |
+
#ifdef TENSORFLOW_USE_SYCL
|
237 |
+
Eigen::SyclDevice* eigen_sycl_device_ = nullptr;
|
238 |
+
#endif
|
239 |
+
};
|
240 |
+
|
241 |
+
} // namespace tensorflow
|
242 |
+
|
243 |
+
#endif // TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
|
fake_input.cc
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/fake_input.h"
|
17 |
+
|
18 |
+
#include <vector>
|
19 |
+
#include "tensorflow/core/framework/attr_value.pb.h"
|
20 |
+
#include "tensorflow/core/framework/node_def_util.h"
|
21 |
+
#include "tensorflow/core/framework/op_def.pb.h"
|
22 |
+
#include "tensorflow/core/framework/op_def_util.h"
|
23 |
+
#include "tensorflow/core/lib/core/errors.h"
|
24 |
+
#include "tensorflow/core/lib/core/status.h"
|
25 |
+
|
26 |
+
namespace tensorflow {
|
27 |
+
namespace {
|
28 |
+
|
29 |
+
class FakeInputImpl {
|
30 |
+
public:
|
31 |
+
FakeInputImpl(const OpDef* op_def, int in_index, const NodeDef* node_def,
|
32 |
+
NodeDefBuilder* builder);
|
33 |
+
void SetN(int n);
|
34 |
+
void SetDataType(DataType dt);
|
35 |
+
void SetTypeList(DataTypeSlice dts);
|
36 |
+
Status AddInputToBuilder();
|
37 |
+
|
38 |
+
private:
|
39 |
+
static string FakeNodeName(int in_index);
|
40 |
+
Status GetN(int* n) const;
|
41 |
+
Status GetDataType(DataType* dt) const;
|
42 |
+
void NSources(int n, DataType dt) const;
|
43 |
+
void SourceList(DataTypeSlice dts) const;
|
44 |
+
|
45 |
+
const OpDef* const op_def_;
|
46 |
+
const OpDef::ArgDef* const arg_;
|
47 |
+
const string in_node_;
|
48 |
+
const NodeDef* const node_def_;
|
49 |
+
NodeDefBuilder* const builder_;
|
50 |
+
|
51 |
+
bool n_specified_;
|
52 |
+
int n_;
|
53 |
+
bool dt_specified_;
|
54 |
+
DataType dt_;
|
55 |
+
bool dts_specified_;
|
56 |
+
DataTypeSlice dts_;
|
57 |
+
};
|
58 |
+
|
59 |
+
FakeInputImpl::FakeInputImpl(const OpDef* op_def, int in_index,
|
60 |
+
const NodeDef* node_def, NodeDefBuilder* builder)
|
61 |
+
: op_def_(op_def),
|
62 |
+
arg_(&op_def->input_arg(in_index)),
|
63 |
+
in_node_(FakeNodeName(in_index)),
|
64 |
+
node_def_(node_def),
|
65 |
+
builder_(builder),
|
66 |
+
n_specified_(false),
|
67 |
+
dt_specified_(false),
|
68 |
+
dts_specified_(false) {}
|
69 |
+
|
70 |
+
void FakeInputImpl::SetN(int n) {
|
71 |
+
n_specified_ = true;
|
72 |
+
n_ = n;
|
73 |
+
}
|
74 |
+
|
75 |
+
void FakeInputImpl::SetDataType(DataType dt) {
|
76 |
+
dt_specified_ = true;
|
77 |
+
dt_ = dt;
|
78 |
+
}
|
79 |
+
|
80 |
+
void FakeInputImpl::SetTypeList(DataTypeSlice dts) {
|
81 |
+
dts_specified_ = true;
|
82 |
+
dts_ = dts;
|
83 |
+
}
|
84 |
+
|
85 |
+
Status FakeInputImpl::AddInputToBuilder() {
|
86 |
+
if (dts_specified_) {
|
87 |
+
SourceList(dts_);
|
88 |
+
|
89 |
+
} else if (n_specified_ || !arg_->number_attr().empty()) {
|
90 |
+
int n;
|
91 |
+
TF_RETURN_IF_ERROR(GetN(&n));
|
92 |
+
|
93 |
+
DataType dt;
|
94 |
+
if (n > 0) {
|
95 |
+
TF_RETURN_IF_ERROR(GetDataType(&dt));
|
96 |
+
} else {
|
97 |
+
dt = DT_FLOAT;
|
98 |
+
}
|
99 |
+
|
100 |
+
NSources(n, dt);
|
101 |
+
} else {
|
102 |
+
if (!dt_specified_ && !arg_->type_list_attr().empty()) {
|
103 |
+
DataTypeVector dts;
|
104 |
+
Status status = GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts);
|
105 |
+
if (!status.ok()) {
|
106 |
+
return errors::InvalidArgument(
|
107 |
+
"Could not infer list of types for input '", arg_->name(), "': ",
|
108 |
+
status.error_message());
|
109 |
+
}
|
110 |
+
SourceList(dts);
|
111 |
+
return Status::OK();
|
112 |
+
}
|
113 |
+
|
114 |
+
DataType dt;
|
115 |
+
TF_RETURN_IF_ERROR(GetDataType(&dt));
|
116 |
+
builder_->Input(in_node_, 0, dt);
|
117 |
+
}
|
118 |
+
return Status::OK();
|
119 |
+
}
|
120 |
+
|
121 |
+
// static
|
122 |
+
string FakeInputImpl::FakeNodeName(int in_index) {
|
123 |
+
char c = 'a' + (in_index % 26);
|
124 |
+
return string(&c, 1);
|
125 |
+
}
|
126 |
+
|
127 |
+
Status FakeInputImpl::GetN(int* n) const {
|
128 |
+
if (n_specified_) {
|
129 |
+
*n = n_;
|
130 |
+
} else {
|
131 |
+
Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n);
|
132 |
+
if (!status.ok()) {
|
133 |
+
return errors::InvalidArgument("Could not infer length of input '",
|
134 |
+
arg_->name(), "': ",
|
135 |
+
status.error_message());
|
136 |
+
}
|
137 |
+
}
|
138 |
+
return Status::OK();
|
139 |
+
}
|
140 |
+
|
141 |
+
Status FakeInputImpl::GetDataType(DataType* dt) const {
|
142 |
+
if (dt_specified_) {
|
143 |
+
*dt = dt_;
|
144 |
+
return Status::OK(); // Ignore is_ref field of arg_.
|
145 |
+
} else if (arg_->type() != DT_INVALID) {
|
146 |
+
*dt = arg_->type();
|
147 |
+
} else if (!arg_->type_attr().empty()) {
|
148 |
+
Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt);
|
149 |
+
if (!status.ok()) {
|
150 |
+
// Check if the type attr has a default
|
151 |
+
const OpDef::AttrDef* attr = FindAttr(arg_->type_attr(), *op_def_);
|
152 |
+
if (attr && attr->has_default_value()) {
|
153 |
+
*dt = attr->default_value().type();
|
154 |
+
} else {
|
155 |
+
return errors::InvalidArgument("Could not infer type for input '",
|
156 |
+
arg_->name(), "': ",
|
157 |
+
status.error_message());
|
158 |
+
}
|
159 |
+
}
|
160 |
+
} else {
|
161 |
+
return errors::InvalidArgument("No type or type_attr field in arg '",
|
162 |
+
arg_->name(), "'");
|
163 |
+
}
|
164 |
+
if (arg_->is_ref()) {
|
165 |
+
*dt = MakeRefType(*dt);
|
166 |
+
}
|
167 |
+
return Status::OK();
|
168 |
+
}
|
169 |
+
|
170 |
+
void FakeInputImpl::NSources(int n, DataType dt) const {
|
171 |
+
std::vector<NodeDefBuilder::NodeOut> srcs;
|
172 |
+
srcs.reserve(n);
|
173 |
+
for (int i = 0; i < n; ++i) {
|
174 |
+
srcs.emplace_back(in_node_, i, dt);
|
175 |
+
}
|
176 |
+
builder_->Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
|
177 |
+
}
|
178 |
+
|
179 |
+
void FakeInputImpl::SourceList(DataTypeSlice dts) const {
|
180 |
+
std::vector<NodeDefBuilder::NodeOut> srcs;
|
181 |
+
srcs.reserve(dts.size());
|
182 |
+
for (size_t i = 0; i < dts.size(); ++i) {
|
183 |
+
srcs.emplace_back(in_node_, i, dts[i]);
|
184 |
+
}
|
185 |
+
builder_->Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
|
186 |
+
}
|
187 |
+
|
188 |
+
} // namespace
|
189 |
+
|
190 |
+
// Public interface ------------------------------------------------------------
|
191 |
+
|
192 |
+
FakeInputFunctor FakeInput() {
|
193 |
+
return [](const OpDef& op_def, int in_index, const NodeDef& node_def,
|
194 |
+
NodeDefBuilder* builder) {
|
195 |
+
FakeInputImpl impl(&op_def, in_index, &node_def, builder);
|
196 |
+
return impl.AddInputToBuilder();
|
197 |
+
};
|
198 |
+
}
|
199 |
+
|
200 |
+
FakeInputFunctor FakeInput(DataType dt) {
|
201 |
+
return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
|
202 |
+
NodeDefBuilder* builder) {
|
203 |
+
FakeInputImpl impl(&op_def, in_index, &node_def, builder);
|
204 |
+
impl.SetDataType(dt);
|
205 |
+
return impl.AddInputToBuilder();
|
206 |
+
};
|
207 |
+
}
|
208 |
+
|
209 |
+
FakeInputFunctor FakeInput(int n) {
|
210 |
+
return [n](const OpDef& op_def, int in_index, const NodeDef& node_def,
|
211 |
+
NodeDefBuilder* builder) {
|
212 |
+
FakeInputImpl impl(&op_def, in_index, &node_def, builder);
|
213 |
+
impl.SetN(n);
|
214 |
+
return impl.AddInputToBuilder();
|
215 |
+
};
|
216 |
+
}
|
217 |
+
|
218 |
+
FakeInputFunctor FakeInput(int n, DataType dt) {
|
219 |
+
return [n, dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
|
220 |
+
NodeDefBuilder* builder) {
|
221 |
+
FakeInputImpl impl(&op_def, in_index, &node_def, builder);
|
222 |
+
impl.SetN(n);
|
223 |
+
impl.SetDataType(dt);
|
224 |
+
return impl.AddInputToBuilder();
|
225 |
+
};
|
226 |
+
}
|
227 |
+
|
228 |
+
FakeInputFunctor FakeInput(DataTypeSlice dts) {
|
229 |
+
// Make a copy to ensure the data will still be around when the lambda is
|
230 |
+
// called.
|
231 |
+
DataTypeVector dtv(dts.begin(), dts.end());
|
232 |
+
return [dtv](const OpDef& op_def, int in_index, const NodeDef& node_def,
|
233 |
+
NodeDefBuilder* builder) {
|
234 |
+
FakeInputImpl impl(&op_def, in_index, &node_def, builder);
|
235 |
+
impl.SetTypeList(dtv);
|
236 |
+
return impl.AddInputToBuilder();
|
237 |
+
};
|
238 |
+
}
|
239 |
+
|
240 |
+
} // namespace tensorflow
|
fake_input.h
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
|
18 |
+
|
19 |
+
#include "tensorflow/core/framework/node_def_builder.h"
|
20 |
+
#include "tensorflow/core/framework/types.h"
|
21 |
+
|
22 |
+
namespace tensorflow {
|
23 |
+
|
24 |
+
// These functions return values that may be passed to
|
25 |
+
// NodeDefBuilder::Input() to add an input for a test. Use them when
|
26 |
+
// you don't care about the node names/output indices providing the
|
27 |
+
// input. They also allow you to omit the input types and/or
|
28 |
+
// list length when they may be inferred.
|
29 |
+
FakeInputFunctor FakeInput(); // Infer everything
|
30 |
+
FakeInputFunctor FakeInput(DataType dt);
|
31 |
+
FakeInputFunctor FakeInput(int n); // List of length n
|
32 |
+
FakeInputFunctor FakeInput(int n, DataType dt);
|
33 |
+
FakeInputFunctor FakeInput(DataTypeSlice dts);
|
34 |
+
inline FakeInputFunctor FakeInput(std::initializer_list<DataType> dts) {
|
35 |
+
return FakeInput(DataTypeSlice(dts));
|
36 |
+
}
|
37 |
+
|
38 |
+
} // namespace tensorflow
|
39 |
+
|
40 |
+
#endif // TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
|
function.cc
ADDED
@@ -0,0 +1,1322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/function.h"
|
17 |
+
|
18 |
+
#include <map>
|
19 |
+
#include <unordered_map>
|
20 |
+
#include <utility>
|
21 |
+
#include <vector>
|
22 |
+
|
23 |
+
#include "tensorflow/core/framework/common_shape_fns.h"
|
24 |
+
#include "tensorflow/core/framework/function.pb_text.h"
|
25 |
+
#include "tensorflow/core/framework/graph.pb.h"
|
26 |
+
#include "tensorflow/core/framework/node_def.pb.h"
|
27 |
+
#include "tensorflow/core/framework/node_def_util.h"
|
28 |
+
#include "tensorflow/core/framework/op.h"
|
29 |
+
#include "tensorflow/core/graph/graph.h"
|
30 |
+
#include "tensorflow/core/lib/core/errors.h"
|
31 |
+
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
32 |
+
#include "tensorflow/core/lib/gtl/map_util.h"
|
33 |
+
#include "tensorflow/core/util/equal_graph_def.h"
|
34 |
+
|
35 |
+
namespace tensorflow {
|
36 |
+
|
37 |
+
// Extracts the actual type from "attr_values" based on its definition
|
38 |
+
// "arg_def".
|
39 |
+
//
|
40 |
+
// If "arg_def" is a N*T type, *is_type_list is set to false, and
|
41 |
+
// *dtypes is set to be a vector of size N and each element is T.
|
42 |
+
//
|
43 |
+
// If "arg_def" is a list(type), *is_type_list is set to true, and
|
44 |
+
// *dtypes is set to be a vector of types specified in attrs for
|
45 |
+
// arg_def.
|
46 |
+
//
|
47 |
+
// Otherwise (arg_def is a simple type T), *is_type_list is set to
|
48 |
+
// false, and *dtypes is set to a single element vector, whose only
|
49 |
+
// element is T.
|
50 |
+
Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
|
51 |
+
bool* is_type_list, DataTypeVector* dtypes) {
|
52 |
+
dtypes->clear();
|
53 |
+
if (!arg_def.type_list_attr().empty()) {
|
54 |
+
const AttrValue* v = attrs.Find(arg_def.type_list_attr());
|
55 |
+
if (v == nullptr) {
|
56 |
+
return errors::NotFound("type attr not found: ",
|
57 |
+
arg_def.type_list_attr());
|
58 |
+
}
|
59 |
+
*is_type_list = true;
|
60 |
+
for (int i = 0; i < v->list().type_size(); ++i) {
|
61 |
+
dtypes->push_back(v->list().type(i));
|
62 |
+
}
|
63 |
+
return Status::OK();
|
64 |
+
}
|
65 |
+
|
66 |
+
*is_type_list = false;
|
67 |
+
int num = 1;
|
68 |
+
if (!arg_def.number_attr().empty()) {
|
69 |
+
const AttrValue* v = attrs.Find(arg_def.number_attr());
|
70 |
+
if (v == nullptr) {
|
71 |
+
return errors::NotFound("type attr not found: ", arg_def.type_attr());
|
72 |
+
}
|
73 |
+
num = v->i();
|
74 |
+
}
|
75 |
+
|
76 |
+
DataType dtype;
|
77 |
+
if (arg_def.type() != DT_INVALID) {
|
78 |
+
dtype = arg_def.type();
|
79 |
+
} else if (arg_def.type_attr().empty()) {
|
80 |
+
dtype = DT_INVALID;
|
81 |
+
} else {
|
82 |
+
const AttrValue* v = attrs.Find(arg_def.type_attr());
|
83 |
+
if (v == nullptr) {
|
84 |
+
return errors::NotFound("type attr not found: ", arg_def.type_attr());
|
85 |
+
}
|
86 |
+
dtype = v->type();
|
87 |
+
}
|
88 |
+
dtypes->resize(num, dtype);
|
89 |
+
return Status::OK();
|
90 |
+
}
|
91 |
+
|
92 |
+
namespace {
|
93 |
+
|
94 |
+
template <typename T>
|
95 |
+
void AddAttr(const string& name, const T& val, NodeDef* ndef) {
|
96 |
+
SetAttrValue(val, &((*ndef->mutable_attr())[name]));
|
97 |
+
}
|
98 |
+
|
99 |
+
Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
|
100 |
+
// attr_values should specify all attrs defined in fdef.
|
101 |
+
for (const auto& a : sig.attr()) {
|
102 |
+
const AttrValue* v = attr_values.Find(a.name());
|
103 |
+
if (!v) {
|
104 |
+
return errors::NotFound("Attr ", a.name(), " is not found from ",
|
105 |
+
SummarizeOpDef(sig));
|
106 |
+
}
|
107 |
+
Status status = AttrValueHasType(*v, a.type());
|
108 |
+
if (!status.ok()) {
|
109 |
+
errors::AppendToMessage(&status, "for attr '", a.name(), "'");
|
110 |
+
return status;
|
111 |
+
}
|
112 |
+
}
|
113 |
+
|
114 |
+
// TODO(josh11b): Enable this code once it works with function gradients.
|
115 |
+
// Right now the C++ function gradient code assumes it can pass
|
116 |
+
// all the attrs of the function to the gradient, and any attrs that
|
117 |
+
// the gradient doesn't care about will be ignored.
|
118 |
+
#if 0
|
119 |
+
if (attr_values.size() != sig.attr_size()) {
|
120 |
+
for (const auto& a : attr_values) {
|
121 |
+
// TODO(josh11b): Possibly should ignore attrs that start with "_" here?
|
122 |
+
bool found = false;
|
123 |
+
for (const auto& s : sig.attr()) {
|
124 |
+
if (a.first == s.name()) {
|
125 |
+
found = true;
|
126 |
+
break;
|
127 |
+
}
|
128 |
+
}
|
129 |
+
if (!found) {
|
130 |
+
return errors::NotFound("Attr ", a.first, " is not found in ",
|
131 |
+
SummarizeOpDef(sig));
|
132 |
+
}
|
133 |
+
}
|
134 |
+
}
|
135 |
+
#endif
|
136 |
+
|
137 |
+
return Status::OK();
|
138 |
+
}
|
139 |
+
|
140 |
+
// A helper class for instantiating functions. This contains shared information
|
141 |
+
// like the resulting graph and node name index.
|
142 |
+
class FunctionInstantiationHelper {
|
143 |
+
public:
|
144 |
+
FunctionInstantiationHelper(GetFunctionSignature get_function,
|
145 |
+
InstantiationResult* result)
|
146 |
+
: get_function_(std ::move(get_function)), result_(*result) {
|
147 |
+
result_.nodes.clear();
|
148 |
+
}
|
149 |
+
|
150 |
+
// Builds index for nodes that can be used as node's input arguments.
|
151 |
+
Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
|
152 |
+
AttrSlice attr_values) {
|
153 |
+
bool is_type_list;
|
154 |
+
DataTypeVector dtypes;
|
155 |
+
TF_RETURN_IF_ERROR(
|
156 |
+
ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
|
157 |
+
CHECK_GE(dtypes.size(), size_t{1});
|
158 |
+
int arg_index = result_.nodes.size();
|
159 |
+
TF_RETURN_IF_ERROR(
|
160 |
+
AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
|
161 |
+
// Creates dtypes.size() nodes in the graph.
|
162 |
+
for (size_t i = 0; i < dtypes.size(); ++i) {
|
163 |
+
TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
|
164 |
+
{true, arg_index, 0, false, {dtypes[i]}}));
|
165 |
+
DCHECK_EQ(arg_index, result_.nodes.size());
|
166 |
+
string name = arg_def.name();
|
167 |
+
if (dtypes.size() > 1) {
|
168 |
+
strings::StrAppend(&name, "_", i);
|
169 |
+
}
|
170 |
+
NodeDef* gnode = AddNode(name);
|
171 |
+
gnode->set_op("_Arg");
|
172 |
+
AddAttr("T", dtypes[i], gnode);
|
173 |
+
AddAttr("index", arg_index, gnode);
|
174 |
+
result_.arg_types.push_back(dtypes[i]);
|
175 |
+
++arg_index;
|
176 |
+
}
|
177 |
+
return Status::OK();
|
178 |
+
}
|
179 |
+
|
180 |
+
Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
|
181 |
+
const int arg_index) {
|
182 |
+
const OpDef* node_sig = nullptr;
|
183 |
+
TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
|
184 |
+
if (node_sig->output_arg_size() == 0) {
|
185 |
+
return AddItem(node.name(), {false, arg_index, 0, false, {}});
|
186 |
+
}
|
187 |
+
const int num_retval = node_sig->output_arg_size();
|
188 |
+
int start = 0;
|
189 |
+
bool is_type_list;
|
190 |
+
DataTypeVector dtypes;
|
191 |
+
for (int i = 0; i < num_retval; ++i) {
|
192 |
+
TF_RETURN_IF_ERROR(
|
193 |
+
ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
|
194 |
+
// Note that we rely on the backwards-compatibility test enforcing
|
195 |
+
// that output_arg(*).name() doesn't change here.
|
196 |
+
const string base_name =
|
197 |
+
strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
|
198 |
+
TF_RETURN_IF_ERROR(
|
199 |
+
AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
|
200 |
+
for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
|
201 |
+
TF_RETURN_IF_ERROR(
|
202 |
+
AddItem(strings::StrCat(base_name, ":", j),
|
203 |
+
{false, arg_index, start + j, false, {dtypes[j]}}));
|
204 |
+
}
|
205 |
+
start += dtypes.size();
|
206 |
+
}
|
207 |
+
return Status::OK();
|
208 |
+
}
|
209 |
+
|
210 |
+
Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
|
211 |
+
const OpDef* fnode_sig = nullptr;
|
212 |
+
TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
|
213 |
+
NodeDef* gnode = AddNode(fnode.name());
|
214 |
+
gnode->set_op(fnode.op());
|
215 |
+
gnode->set_device(fnode.device());
|
216 |
+
int gnode_idx = nodes_.size() - 1;
|
217 |
+
|
218 |
+
// Input
|
219 |
+
const int num_args = fnode_sig->input_arg_size();
|
220 |
+
bool is_type_list; // ignored
|
221 |
+
DataTypeVector dtypes;
|
222 |
+
int fnode_arg_index = 0;
|
223 |
+
for (int i = 0; i < num_args; ++i) {
|
224 |
+
TF_RETURN_IF_ERROR(
|
225 |
+
ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
|
226 |
+
// Consume inputs (indexed by fnode_arg_index) until we have
|
227 |
+
// matched each element of dtypes (indexed by j).
|
228 |
+
for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
|
229 |
+
if (fnode_arg_index >= fnode.input_size()) {
|
230 |
+
// Should never happen if we computed dtypes correctly.
|
231 |
+
return errors::InvalidArgument(
|
232 |
+
"Attempt to access beyond input size: ", fnode_arg_index,
|
233 |
+
" >= ", fnode.input_size());
|
234 |
+
}
|
235 |
+
// Look up the next input.
|
236 |
+
const string& input_name = fnode.input(fnode_arg_index);
|
237 |
+
const auto* item = GetItemOrNull(input_name);
|
238 |
+
if (item == nullptr) {
|
239 |
+
return errors::InvalidArgument(
|
240 |
+
"input ", input_name, " is not found: ", SummarizeNodeDef(fnode));
|
241 |
+
}
|
242 |
+
if (item->dtypes.size() > dtypes.size() - j) {
|
243 |
+
return errors::InvalidArgument("Input ", input_name, " too long for ",
|
244 |
+
fnode_sig->input_arg(i).name());
|
245 |
+
}
|
246 |
+
// Match up all the elements of this input (indexed by k) with
|
247 |
+
// elements of dtypes (advancing j).
|
248 |
+
for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
|
249 |
+
if (item->dtypes[k] != dtypes[j]) {
|
250 |
+
return errors::InvalidArgument(
|
251 |
+
"input ", fnode_sig->input_arg(i).name(), "[", j,
|
252 |
+
"] expected type ", DataTypeString(dtypes[j]),
|
253 |
+
" != ", DataTypeString(item->dtypes[k]), ", the type of ",
|
254 |
+
input_name, "[", k, "]");
|
255 |
+
}
|
256 |
+
if (item->is_func_arg) {
|
257 |
+
AddInput(gnode_idx, item->nid + k, 0);
|
258 |
+
} else {
|
259 |
+
AddInput(gnode_idx, item->nid, item->idx + k);
|
260 |
+
}
|
261 |
+
}
|
262 |
+
}
|
263 |
+
}
|
264 |
+
|
265 |
+
// Control deps.
|
266 |
+
for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
|
267 |
+
const string& input = fnode.input(i);
|
268 |
+
if (input.empty() || input[0] != '^') {
|
269 |
+
return errors::InvalidArgument("Expected input[", i, "] == '", input,
|
270 |
+
"' to be a control input.");
|
271 |
+
}
|
272 |
+
int nid = -1;
|
273 |
+
const string node_name = input.substr(1);
|
274 |
+
const string node_colon = node_name + ":";
|
275 |
+
const string node_colon_bound = node_name + ";";
|
276 |
+
// index_ is a map sorted lexicographically, so the key we are looking for
|
277 |
+
// must lie in the range [node_name, node_colon_bound).
|
278 |
+
auto it = index_.lower_bound(node_name);
|
279 |
+
while (it != index_.end() && it->first <= node_colon_bound) {
|
280 |
+
if (it->first == node_name ||
|
281 |
+
tensorflow::StringPiece(it->first).starts_with(node_colon)) {
|
282 |
+
nid = it->second.nid;
|
283 |
+
break;
|
284 |
+
}
|
285 |
+
++it;
|
286 |
+
}
|
287 |
+
if (nid == -1) {
|
288 |
+
return errors::InvalidArgument("input[", i, "] == '", input,
|
289 |
+
"', is not found.");
|
290 |
+
}
|
291 |
+
AddDep(gnode_idx, nid);
|
292 |
+
}
|
293 |
+
|
294 |
+
// Attrs.
|
295 |
+
for (const auto& p : attrs) {
|
296 |
+
(*gnode->mutable_attr())[p.first] = p.second;
|
297 |
+
}
|
298 |
+
|
299 |
+
return Status::OK();
|
300 |
+
}
|
301 |
+
|
302 |
+
Status AddReturnNode(
|
303 |
+
const OpDef::ArgDef& ret_def, AttrSlice attrs,
|
304 |
+
const ::tensorflow::protobuf::Map<string, string>& ret_map,
|
305 |
+
int* ret_index) {
|
306 |
+
auto ret_iter = ret_map.find(ret_def.name());
|
307 |
+
if (ret_iter == ret_map.end()) {
|
308 |
+
return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
|
309 |
+
}
|
310 |
+
bool is_type_list;
|
311 |
+
DataTypeVector dtypes;
|
312 |
+
TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
|
313 |
+
CHECK_GE(dtypes.size(), size_t{1});
|
314 |
+
const auto* item = GetItemOrNull(ret_iter->second);
|
315 |
+
if (item == nullptr) {
|
316 |
+
return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
|
317 |
+
ret_iter->second, " is not found.");
|
318 |
+
}
|
319 |
+
if (dtypes != item->dtypes) {
|
320 |
+
return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
|
321 |
+
" : ", DataTypeVectorString(dtypes),
|
322 |
+
" vs. ",
|
323 |
+
DataTypeVectorString(item->dtypes));
|
324 |
+
}
|
325 |
+
for (size_t i = 0; i < dtypes.size(); ++i) {
|
326 |
+
string name = strings::StrCat(ret_def.name(), "_RetVal");
|
327 |
+
if (dtypes.size() > 1) {
|
328 |
+
strings::StrAppend(&name, "_", i);
|
329 |
+
}
|
330 |
+
NodeDef* gnode = AddNode(name);
|
331 |
+
gnode->set_op("_Retval");
|
332 |
+
AddInput(nodes_.size() - 1, item->nid, item->idx + i);
|
333 |
+
AddAttr("T", dtypes[i], gnode);
|
334 |
+
AddAttr("index", (*ret_index)++, gnode);
|
335 |
+
result_.ret_types.push_back(dtypes[i]);
|
336 |
+
}
|
337 |
+
return Status::OK();
|
338 |
+
}
|
339 |
+
|
340 |
+
// Adds the actual node inputs to the result graph by converting indexes to
|
341 |
+
// the node names.
|
342 |
+
void AddNodeInputs() {
|
343 |
+
for (int i = 0; i < result_.nodes.size(); i++) {
|
344 |
+
NodeInfo& node_info = nodes_[i];
|
345 |
+
for (const auto& p : node_info.data_inputs) {
|
346 |
+
result_.nodes[i].add_input(Name(p.first, p.second));
|
347 |
+
}
|
348 |
+
for (int index : node_info.control_inputs) {
|
349 |
+
result_.nodes[i].add_input(Dep(index));
|
350 |
+
}
|
351 |
+
}
|
352 |
+
}
|
353 |
+
|
354 |
+
private:
|
355 |
+
// This is used to build a small index for all names that can be used as a
|
356 |
+
// node's input arguments.
|
357 |
+
//
|
358 |
+
// If is_func_arg is true, the name is a function's argument. In
|
359 |
+
// this case, the produced graph def has node[nid:nid + dtype.size()].
|
360 |
+
//
|
361 |
+
// Otherwise, the name is a function body's node return value. In
|
362 |
+
// this case, the produced graph def has one node node[nid] and
|
363 |
+
// the node's output index [idx ... idx + num) corresponds to the
|
364 |
+
// named outputs.
|
365 |
+
//
|
366 |
+
// In all cases, "dtype" specifies the data type.
|
367 |
+
struct NameInfoItem {
|
368 |
+
bool is_func_arg;
|
369 |
+
int nid;
|
370 |
+
int idx;
|
371 |
+
bool is_type_list;
|
372 |
+
DataTypeVector dtypes;
|
373 |
+
};
|
374 |
+
|
375 |
+
// Adds an item into the input name index.
|
376 |
+
Status AddItem(const string& name, const NameInfoItem& item) {
|
377 |
+
if (!index_.insert({name, item}).second) {
|
378 |
+
return errors::InvalidArgument(
|
379 |
+
strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
|
380 |
+
" name: "),
|
381 |
+
name);
|
382 |
+
}
|
383 |
+
return Status::OK();
|
384 |
+
}
|
385 |
+
|
386 |
+
const NameInfoItem* GetItemOrNull(const string& name) const {
|
387 |
+
return gtl::FindOrNull(index_, name);
|
388 |
+
}
|
389 |
+
|
390 |
+
string Dep(int node_index) const {
|
391 |
+
return strings::StrCat("^", Name(node_index));
|
392 |
+
}
|
393 |
+
|
394 |
+
string Name(int node_index) const {
|
395 |
+
CHECK_LT(node_index, nodes_.size());
|
396 |
+
return nodes_[node_index].name;
|
397 |
+
}
|
398 |
+
|
399 |
+
string Name(int node_index, int output_index) const {
|
400 |
+
if (output_index == 0) {
|
401 |
+
return Name(node_index);
|
402 |
+
} else {
|
403 |
+
return strings::StrCat(Name(node_index), ":", output_index);
|
404 |
+
}
|
405 |
+
}
|
406 |
+
|
407 |
+
NodeDef* AddNode(const string& name) {
|
408 |
+
result_.nodes.emplace_back();
|
409 |
+
NodeDef* gnode = &result_.nodes.back();
|
410 |
+
gnode->set_name(name);
|
411 |
+
nodes_.push_back({name, {}, {}});
|
412 |
+
CHECK_EQ(result_.nodes.size(), nodes_.size());
|
413 |
+
return gnode;
|
414 |
+
}
|
415 |
+
|
416 |
+
void AddInput(int node_index, int output_node, int output_index) {
|
417 |
+
CHECK_LT(node_index, nodes_.size());
|
418 |
+
nodes_[node_index].data_inputs.push_back(
|
419 |
+
std::make_pair(output_node, output_index));
|
420 |
+
}
|
421 |
+
|
422 |
+
void AddDep(int node_index, int dep_index) {
|
423 |
+
CHECK_LT(node_index, nodes_.size());
|
424 |
+
nodes_[node_index].control_inputs.push_back(dep_index);
|
425 |
+
}
|
426 |
+
|
427 |
+
GetFunctionSignature get_function_;
|
428 |
+
InstantiationResult& result_;
|
429 |
+
// A small index for all names that can be used as a node's input arguments.
|
430 |
+
std::map<string, NameInfoItem> index_;
|
431 |
+
// This contains information about a node in the new graph including the node
|
432 |
+
// names and input nodes' indexes.
|
433 |
+
struct NodeInfo {
|
434 |
+
string name;
|
435 |
+
// Data inputs where <n, k> means arg k of node n.
|
436 |
+
std::vector<std::pair<int, int>> data_inputs;
|
437 |
+
// Control inputs (dependencies).
|
438 |
+
std::vector<int> control_inputs;
|
439 |
+
};
|
440 |
+
// nodes_[i] is the information about result_.nodes[i].
|
441 |
+
std::vector<NodeInfo> nodes_;
|
442 |
+
};
|
443 |
+
|
444 |
+
// Various helpers Print(proto) to print relevant protos to ascii.
|
445 |
+
string Print(const OpDef::ArgDef& arg) {
|
446 |
+
string out;
|
447 |
+
strings::StrAppend(&out, arg.name(), ":");
|
448 |
+
if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
|
449 |
+
if (!arg.number_attr().empty()) {
|
450 |
+
strings::StrAppend(&out, arg.number_attr(), "*");
|
451 |
+
}
|
452 |
+
if (arg.type() != DT_INVALID) {
|
453 |
+
strings::StrAppend(&out, DataTypeString(arg.type()));
|
454 |
+
} else {
|
455 |
+
strings::StrAppend(&out, arg.type_attr());
|
456 |
+
}
|
457 |
+
if (arg.is_ref()) strings::StrAppend(&out, ")");
|
458 |
+
return out;
|
459 |
+
}
|
460 |
+
|
461 |
+
// TODO(josh11b): Merge this with SummarizeAttrValue().
|
462 |
+
string Print(const AttrValue& attr_value) {
|
463 |
+
if (attr_value.value_case() == AttrValue::kType) {
|
464 |
+
return DataTypeString(attr_value.type());
|
465 |
+
} else if ((attr_value.value_case() == AttrValue::kList) &&
|
466 |
+
(attr_value.list().type_size() > 0)) {
|
467 |
+
string ret = "{";
|
468 |
+
for (int i = 0; i < attr_value.list().type_size(); ++i) {
|
469 |
+
if (i > 0) strings::StrAppend(&ret, ", ");
|
470 |
+
strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
|
471 |
+
}
|
472 |
+
strings::StrAppend(&ret, "}");
|
473 |
+
return ret;
|
474 |
+
} else if (attr_value.value_case() == AttrValue::kFunc) {
|
475 |
+
if (attr_value.func().attr_size() == 0) {
|
476 |
+
return attr_value.func().name();
|
477 |
+
}
|
478 |
+
std::vector<string> entries;
|
479 |
+
for (auto p : attr_value.func().attr()) {
|
480 |
+
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
|
481 |
+
}
|
482 |
+
std::sort(entries.begin(), entries.end());
|
483 |
+
return strings::StrCat(attr_value.func().name(), "[",
|
484 |
+
str_util::Join(entries, ", "), "]");
|
485 |
+
}
|
486 |
+
return SummarizeAttrValue(attr_value);
|
487 |
+
}
|
488 |
+
|
489 |
+
// TODO(josh11b): Merge this with SummarizeNodeDef().
|
490 |
+
string Print(const NodeDef& n) {
|
491 |
+
string out;
|
492 |
+
strings::StrAppend(&out, n.name(), " = ", n.op());
|
493 |
+
if (n.attr_size() > 0) {
|
494 |
+
std::vector<string> entries;
|
495 |
+
for (auto& a : n.attr()) {
|
496 |
+
entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
|
497 |
+
}
|
498 |
+
std::sort(entries.begin(), entries.end());
|
499 |
+
strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
|
500 |
+
}
|
501 |
+
strings::StrAppend(&out, "(");
|
502 |
+
std::vector<StringPiece> dat;
|
503 |
+
std::vector<string> dep;
|
504 |
+
for (StringPiece s : n.input()) {
|
505 |
+
if (s.Consume("^")) {
|
506 |
+
dep.push_back(s.ToString());
|
507 |
+
} else {
|
508 |
+
dat.push_back(s);
|
509 |
+
}
|
510 |
+
}
|
511 |
+
strings::StrAppend(&out, str_util::Join(dat, ", "), ")");
|
512 |
+
if (!dep.empty()) {
|
513 |
+
strings::StrAppend(&out, " @ ", str_util::Join(dep, ", "));
|
514 |
+
}
|
515 |
+
return out;
|
516 |
+
}
|
517 |
+
|
518 |
+
string Print(const FunctionDef& fdef) {
|
519 |
+
string out;
|
520 |
+
const OpDef& sig = fdef.signature();
|
521 |
+
strings::StrAppend(&out, "\n", sig.name());
|
522 |
+
if (sig.attr_size() > 0) {
|
523 |
+
strings::StrAppend(&out, "[");
|
524 |
+
for (int i = 0; i < sig.attr_size(); ++i) {
|
525 |
+
const auto& a = sig.attr(i);
|
526 |
+
if (i > 0) strings::StrAppend(&out, ", ");
|
527 |
+
if (a.type() == "type") {
|
528 |
+
strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
|
529 |
+
} else {
|
530 |
+
strings::StrAppend(&out, a.name(), ":", a.type());
|
531 |
+
}
|
532 |
+
}
|
533 |
+
strings::StrAppend(&out, "]");
|
534 |
+
}
|
535 |
+
strings::StrAppend(&out, "(");
|
536 |
+
for (int i = 0; i < sig.input_arg_size(); ++i) {
|
537 |
+
if (i > 0) strings::StrAppend(&out, ", ");
|
538 |
+
strings::StrAppend(&out, Print(sig.input_arg(i)));
|
539 |
+
}
|
540 |
+
strings::StrAppend(&out, ") -> (");
|
541 |
+
for (int i = 0; i < sig.output_arg_size(); ++i) {
|
542 |
+
if (i > 0) strings::StrAppend(&out, ", ");
|
543 |
+
strings::StrAppend(&out, Print(sig.output_arg(i)));
|
544 |
+
}
|
545 |
+
strings::StrAppend(&out, ") {\n");
|
546 |
+
for (const auto& n : fdef.node_def()) {
|
547 |
+
strings::StrAppend(&out, " ", Print(n), "\n");
|
548 |
+
}
|
549 |
+
for (const auto& r : fdef.ret()) {
|
550 |
+
strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
|
551 |
+
}
|
552 |
+
strings::StrAppend(&out, "}\n");
|
553 |
+
return out;
|
554 |
+
}
|
555 |
+
|
556 |
+
string Print(gtl::ArraySlice<const NodeDef*> nodes) {
|
557 |
+
std::vector<const NodeDef*> arg;
|
558 |
+
std::vector<const NodeDef*> ret;
|
559 |
+
std::vector<const NodeDef*> body;
|
560 |
+
for (const NodeDef* n : nodes) {
|
561 |
+
if (n->op() == "_Arg") {
|
562 |
+
arg.push_back(n);
|
563 |
+
} else if (n->op() == "_Retval") {
|
564 |
+
ret.push_back(n);
|
565 |
+
} else {
|
566 |
+
body.push_back(n);
|
567 |
+
}
|
568 |
+
}
|
569 |
+
auto comp = [](const NodeDef* x, const NodeDef* y) {
|
570 |
+
int xi;
|
571 |
+
TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
|
572 |
+
int yi;
|
573 |
+
TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
|
574 |
+
return xi < yi;
|
575 |
+
};
|
576 |
+
std::sort(arg.begin(), arg.end(), comp);
|
577 |
+
std::sort(ret.begin(), ret.end(), comp);
|
578 |
+
string out;
|
579 |
+
strings::StrAppend(&out, "\n(");
|
580 |
+
auto get_type = [](const NodeDef& n) {
|
581 |
+
DataType dt;
|
582 |
+
if (!GetNodeAttr(n, "T", &dt).ok()) {
|
583 |
+
dt = DT_INVALID;
|
584 |
+
}
|
585 |
+
return DataTypeString(dt);
|
586 |
+
};
|
587 |
+
for (size_t i = 0; i < arg.size(); ++i) {
|
588 |
+
const NodeDef* n = arg[i];
|
589 |
+
if (i > 0) strings::StrAppend(&out, ", ");
|
590 |
+
CHECK_GE(n->attr_size(), 2);
|
591 |
+
strings::StrAppend(&out, n->name(), ":", get_type(*n));
|
592 |
+
}
|
593 |
+
strings::StrAppend(&out, ") -> (");
|
594 |
+
for (size_t i = 0; i < ret.size(); ++i) {
|
595 |
+
const NodeDef* n = ret[i];
|
596 |
+
if (i > 0) strings::StrAppend(&out, ", ");
|
597 |
+
CHECK_LE(2, n->attr_size());
|
598 |
+
CHECK_EQ(1, n->input_size());
|
599 |
+
strings::StrAppend(&out, n->input(0), ":", get_type(*n));
|
600 |
+
}
|
601 |
+
strings::StrAppend(&out, ") {\n");
|
602 |
+
for (size_t i = 0; i < body.size(); ++i) {
|
603 |
+
strings::StrAppend(&out, " ", Print(*body[i]), "\n");
|
604 |
+
}
|
605 |
+
strings::StrAppend(&out, "}\n");
|
606 |
+
return out;
|
607 |
+
}
|
608 |
+
|
609 |
+
Status AddDefaultAttrs(const string& op,
|
610 |
+
const GetFunctionSignature& get_function,
|
611 |
+
AttrValueMap* attrs) {
|
612 |
+
const OpDef* op_def = nullptr;
|
613 |
+
TF_RETURN_IF_ERROR(get_function(op, &op_def));
|
614 |
+
AttrSlice attr_slice(attrs);
|
615 |
+
for (const auto& attr_def : op_def->attr()) {
|
616 |
+
if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
|
617 |
+
if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
|
618 |
+
return errors::Internal("Somehow duplicated: ", attr_def.name());
|
619 |
+
}
|
620 |
+
}
|
621 |
+
}
|
622 |
+
return Status::OK();
|
623 |
+
}
|
624 |
+
|
625 |
+
} // end namespace
|
626 |
+
|
627 |
+
Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
628 |
+
GetFunctionSignature get_function,
|
629 |
+
InstantiationResult* result) {
|
630 |
+
VLOG(3) << "Instantiation Function: " << Print(fdef);
|
631 |
+
|
632 |
+
const OpDef& sig = fdef.signature();
|
633 |
+
TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
|
634 |
+
|
635 |
+
FunctionInstantiationHelper helper(get_function, result);
|
636 |
+
Status s;
|
637 |
+
for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
|
638 |
+
s = helper.BuildInputArgIndex(arg_def, attr_values);
|
639 |
+
if (!s.ok()) {
|
640 |
+
errors::AppendToMessage(&s, "In ", Print(arg_def));
|
641 |
+
return s;
|
642 |
+
}
|
643 |
+
}
|
644 |
+
|
645 |
+
auto substitute = [attr_values](StringPiece name, AttrValue* val) {
|
646 |
+
if (const AttrValue* v = attr_values.Find(name)) {
|
647 |
+
*val = *v;
|
648 |
+
return true;
|
649 |
+
}
|
650 |
+
return false;
|
651 |
+
};
|
652 |
+
|
653 |
+
// Makes a copy of all attrs in fdef and substitutes placeholders.
|
654 |
+
// After this step, every attr is bound to a concrete value.
|
655 |
+
std::vector<AttrValueMap> node_attrs;
|
656 |
+
node_attrs.resize(fdef.node_def_size());
|
657 |
+
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
658 |
+
for (auto attr : fdef.node_def(i).attr()) {
|
659 |
+
if (!SubstitutePlaceholders(substitute, &attr.second)) {
|
660 |
+
return errors::InvalidArgument("Failed to bind all placeholders in ",
|
661 |
+
SummarizeAttrValue(attr.second));
|
662 |
+
}
|
663 |
+
if (!node_attrs[i].insert(attr).second) {
|
664 |
+
return errors::Internal("Somehow duplicated: ", attr.first);
|
665 |
+
}
|
666 |
+
}
|
667 |
+
TF_RETURN_IF_ERROR(
|
668 |
+
AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
|
669 |
+
}
|
670 |
+
|
671 |
+
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
672 |
+
s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
|
673 |
+
result->nodes.size() + i);
|
674 |
+
if (!s.ok()) {
|
675 |
+
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
|
676 |
+
return s;
|
677 |
+
}
|
678 |
+
}
|
679 |
+
// Emits one node for each fdef.node_def.
|
680 |
+
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
681 |
+
s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
|
682 |
+
if (!s.ok()) {
|
683 |
+
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
|
684 |
+
return s;
|
685 |
+
}
|
686 |
+
}
|
687 |
+
|
688 |
+
// Emits nodes for the function's return values.
|
689 |
+
int ret_index = 0;
|
690 |
+
for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
|
691 |
+
s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), &ret_index);
|
692 |
+
if (!s.ok()) {
|
693 |
+
errors::AppendToMessage(&s, "In function output ", Print(ret_def));
|
694 |
+
return s;
|
695 |
+
}
|
696 |
+
}
|
697 |
+
|
698 |
+
// Adds the actual node inputs using the input indexes.
|
699 |
+
helper.AddNodeInputs();
|
700 |
+
|
701 |
+
return Status::OK();
|
702 |
+
}
|
703 |
+
|
704 |
+
string DebugString(const FunctionDef& func_def) { return Print(func_def); }
|
705 |
+
|
706 |
+
string DebugString(const GraphDef& instantiated_func_def) {
|
707 |
+
std::vector<const NodeDef*> ptrs;
|
708 |
+
for (const NodeDef& n : instantiated_func_def.node()) {
|
709 |
+
ptrs.push_back(&n);
|
710 |
+
}
|
711 |
+
return Print(ptrs);
|
712 |
+
}
|
713 |
+
|
714 |
+
string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
|
715 |
+
std::vector<const NodeDef*> ptrs;
|
716 |
+
for (const NodeDef& n : instantiated_func_nodes) {
|
717 |
+
ptrs.push_back(&n);
|
718 |
+
}
|
719 |
+
return Print(ptrs);
|
720 |
+
}
|
721 |
+
|
722 |
+
string DebugStringWhole(const GraphDef& gdef) {
|
723 |
+
string ret;
|
724 |
+
for (const auto& fdef : gdef.library().function()) {
|
725 |
+
strings::StrAppend(&ret, Print(fdef));
|
726 |
+
}
|
727 |
+
strings::StrAppend(&ret, "\n");
|
728 |
+
for (const auto& ndef : gdef.node()) {
|
729 |
+
strings::StrAppend(&ret, Print(ndef), "\n");
|
730 |
+
}
|
731 |
+
return ret;
|
732 |
+
}
|
733 |
+
|
734 |
+
namespace {
|
735 |
+
|
736 |
+
// Returns the name -> attr mapping of fdef's attrs that have a value set. In
|
737 |
+
// Python, it's possible to access unset attrs, which returns a default value
|
738 |
+
// and adds an unset attr to the map.
|
739 |
+
std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
|
740 |
+
std::map<string, AttrValue> set_attrs;
|
741 |
+
for (auto pair : fdef.attr()) {
|
742 |
+
if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
|
743 |
+
set_attrs[pair.first] = pair.second;
|
744 |
+
}
|
745 |
+
}
|
746 |
+
return set_attrs;
|
747 |
+
}
|
748 |
+
|
749 |
+
} // end namespace
|
750 |
+
|
751 |
+
bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
|
752 |
+
if (!OpDefEqual(f1.signature(), f2.signature())) return false;
|
753 |
+
|
754 |
+
std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
|
755 |
+
std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
|
756 |
+
if (f1_attrs.size() != f2_attrs.size()) return false;
|
757 |
+
for (auto iter1 : f1_attrs) {
|
758 |
+
auto iter2 = f2_attrs.find(iter1.first);
|
759 |
+
if (iter2 == f2_attrs.end()) return false;
|
760 |
+
if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
|
761 |
+
}
|
762 |
+
|
763 |
+
if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
|
764 |
+
return false;
|
765 |
+
}
|
766 |
+
|
767 |
+
std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
|
768 |
+
std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
|
769 |
+
if (ret1 != ret2) return false;
|
770 |
+
|
771 |
+
return true;
|
772 |
+
}
|
773 |
+
|
774 |
+
uint64 FunctionDefHash(const FunctionDef& fdef) {
|
775 |
+
// signature
|
776 |
+
uint64 h = OpDefHash(fdef.signature());
|
777 |
+
|
778 |
+
// attrs
|
779 |
+
std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
|
780 |
+
for (const auto& p : attrs) {
|
781 |
+
h = Hash64(p.first.data(), p.first.size(), h);
|
782 |
+
h = Hash64Combine(AttrValueHash(p.second), h);
|
783 |
+
}
|
784 |
+
|
785 |
+
// node defs
|
786 |
+
h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
|
787 |
+
|
788 |
+
// output names
|
789 |
+
std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
|
790 |
+
for (const auto& p : ret) {
|
791 |
+
h = Hash64(p.first.data(), p.first.size(), h);
|
792 |
+
h = Hash64(p.second.data(), p.second.size(), h);
|
793 |
+
}
|
794 |
+
|
795 |
+
return h;
|
796 |
+
}
|
797 |
+
|
798 |
+
string Canonicalize(const string& funcname, AttrSlice attrs) {
|
799 |
+
std::vector<string> entries;
|
800 |
+
entries.reserve(attrs.size());
|
801 |
+
for (auto p : attrs) {
|
802 |
+
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
|
803 |
+
}
|
804 |
+
std::sort(entries.begin(), entries.end());
|
805 |
+
return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
|
806 |
+
}
|
807 |
+
|
808 |
+
FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
|
809 |
+
DataTypeSlice ret_types)
|
810 |
+
: arg_types_(arg_types.begin(), arg_types.end()),
|
811 |
+
ret_types_(ret_types.begin(), ret_types.end()) {
|
812 |
+
args_.resize(arg_types_.size());
|
813 |
+
rets_.resize(ret_types_.size());
|
814 |
+
}
|
815 |
+
|
816 |
+
FunctionCallFrame::~FunctionCallFrame() {}
|
817 |
+
|
818 |
+
Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
|
819 |
+
// Input type checks.
|
820 |
+
if (args.size() != arg_types_.size()) {
|
821 |
+
return errors::InvalidArgument("Expects ", arg_types_.size(),
|
822 |
+
" arguments, but ", args.size(),
|
823 |
+
" is provided");
|
824 |
+
}
|
825 |
+
for (size_t i = 0; i < args.size(); ++i) {
|
826 |
+
if (arg_types_[i] != args[i].dtype()) {
|
827 |
+
return errors::InvalidArgument(
|
828 |
+
"Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
|
829 |
+
DataTypeString(args[i].dtype()), " is provided");
|
830 |
+
}
|
831 |
+
args_[i] = args[i];
|
832 |
+
}
|
833 |
+
return Status::OK();
|
834 |
+
}
|
835 |
+
|
836 |
+
Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
|
837 |
+
rets->clear();
|
838 |
+
rets->reserve(rets_.size());
|
839 |
+
for (size_t i = 0; i < rets_.size(); ++i) {
|
840 |
+
const auto& item = rets_[i];
|
841 |
+
if (item.has_val) {
|
842 |
+
rets->push_back(item.val);
|
843 |
+
} else {
|
844 |
+
return errors::Internal("Retval[", i, "] does not have value");
|
845 |
+
}
|
846 |
+
}
|
847 |
+
return Status::OK();
|
848 |
+
}
|
849 |
+
|
850 |
+
Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets) {
|
851 |
+
rets->clear();
|
852 |
+
rets->reserve(rets_.size());
|
853 |
+
for (size_t i = 0; i < rets_.size(); ++i) {
|
854 |
+
if (rets_[i].has_val) {
|
855 |
+
rets->emplace_back(std::move(rets_[i].val));
|
856 |
+
} else {
|
857 |
+
return errors::Internal("Retval[", i, "] does not have value");
|
858 |
+
}
|
859 |
+
}
|
860 |
+
return Status::OK();
|
861 |
+
}
|
862 |
+
|
863 |
+
Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
|
864 |
+
if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
|
865 |
+
return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
|
866 |
+
args_.size(), ")");
|
867 |
+
}
|
868 |
+
*val = args_[index];
|
869 |
+
return Status::OK();
|
870 |
+
}
|
871 |
+
|
872 |
+
Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
|
873 |
+
if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
|
874 |
+
return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
|
875 |
+
rets_.size(), ")");
|
876 |
+
}
|
877 |
+
if (val.dtype() != ret_types_[index]) {
|
878 |
+
return errors::InvalidArgument(
|
879 |
+
"Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
|
880 |
+
", but ", DataTypeString(val.dtype()), " is provided.");
|
881 |
+
}
|
882 |
+
Retval* item = &rets_[index];
|
883 |
+
if (!item->has_val) {
|
884 |
+
item->has_val = true;
|
885 |
+
item->val = val;
|
886 |
+
} else {
|
887 |
+
return errors::Internal("Retval[", index, "] has already been set.");
|
888 |
+
}
|
889 |
+
return Status::OK();
|
890 |
+
}
|
891 |
+
|
892 |
+
FunctionLibraryDefinition::FunctionDefAndOpRegistration::
|
893 |
+
FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
|
894 |
+
: fdef(fdef_in),
|
895 |
+
// Exact shape inference for functions is handled by ShapeRefiner.
|
896 |
+
// Here we pass a dummy shape inference function for legacy code paths.
|
897 |
+
op_registration_data(fdef.signature(), shape_inference::UnknownShape,
|
898 |
+
true /* is_function */) {}
|
899 |
+
|
900 |
+
FunctionLibraryDefinition::FunctionLibraryDefinition(
|
901 |
+
const FunctionLibraryDefinition& other)
|
902 |
+
: default_registry_(other.default_registry_), func_grad_(other.func_grad_) {
|
903 |
+
for (const auto& it : other.function_defs_) {
|
904 |
+
TF_CHECK_OK(AddFunctionDef(it.second->fdef));
|
905 |
+
}
|
906 |
+
}
|
907 |
+
|
908 |
+
FunctionLibraryDefinition::FunctionLibraryDefinition(
|
909 |
+
const OpRegistryInterface* default_registry,
|
910 |
+
const FunctionDefLibrary& def_lib)
|
911 |
+
: default_registry_(default_registry),
|
912 |
+
function_defs_(def_lib.function_size()) {
|
913 |
+
for (const auto& fdef : def_lib.function()) {
|
914 |
+
// The latter function definition wins.
|
915 |
+
auto& ptr = function_defs_[fdef.signature().name()];
|
916 |
+
ptr.reset(new FunctionDefAndOpRegistration(fdef));
|
917 |
+
}
|
918 |
+
for (const auto& grad : def_lib.gradient()) {
|
919 |
+
func_grad_[grad.function_name()] = grad.gradient_func();
|
920 |
+
}
|
921 |
+
}
|
922 |
+
|
923 |
+
FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
|
924 |
+
|
925 |
+
const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
|
926 |
+
auto iter = function_defs_.find(name);
|
927 |
+
if (iter == function_defs_.end()) {
|
928 |
+
return nullptr;
|
929 |
+
} else {
|
930 |
+
return &iter->second->fdef;
|
931 |
+
}
|
932 |
+
}
|
933 |
+
|
934 |
+
Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
|
935 |
+
bool added;
|
936 |
+
return AddFunctionDefHelper(fdef, &added);
|
937 |
+
}
|
938 |
+
|
939 |
+
Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
|
940 |
+
bool* added) {
|
941 |
+
*added = false;
|
942 |
+
std::unique_ptr<FunctionDefAndOpRegistration>* entry =
|
943 |
+
&function_defs_[fdef.signature().name()];
|
944 |
+
if (*entry != nullptr) {
|
945 |
+
if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
|
946 |
+
return errors::InvalidArgument(
|
947 |
+
"Cannot add function '", fdef.signature().name(),
|
948 |
+
"' because a different function with the same name already "
|
949 |
+
"exists.");
|
950 |
+
}
|
951 |
+
// Ignore duplicate FunctionDefs
|
952 |
+
return Status::OK();
|
953 |
+
}
|
954 |
+
const OpDef* op_def;
|
955 |
+
if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
|
956 |
+
return errors::InvalidArgument(
|
957 |
+
"Cannot add function '", fdef.signature().name(),
|
958 |
+
"' because an op with the same name already exists.");
|
959 |
+
}
|
960 |
+
entry->reset(new FunctionDefAndOpRegistration(fdef));
|
961 |
+
*added = true;
|
962 |
+
return Status::OK();
|
963 |
+
}
|
964 |
+
|
965 |
+
Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
|
966 |
+
bool added;
|
967 |
+
return AddGradientDefHelper(grad, &added);
|
968 |
+
}
|
969 |
+
|
970 |
+
Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
|
971 |
+
bool* added) {
|
972 |
+
*added = false;
|
973 |
+
string* entry = &func_grad_[grad.function_name()];
|
974 |
+
if (!entry->empty()) {
|
975 |
+
if (*entry != grad.gradient_func()) {
|
976 |
+
return errors::InvalidArgument(
|
977 |
+
"Cannot assign gradient function '", grad.gradient_func(), "' to '",
|
978 |
+
grad.function_name(), "' because it already has gradient function ",
|
979 |
+
"'", *entry, "'");
|
980 |
+
}
|
981 |
+
// Ignore duplicate GradientDefs
|
982 |
+
return Status::OK();
|
983 |
+
}
|
984 |
+
*entry = grad.gradient_func();
|
985 |
+
*added = true;
|
986 |
+
return Status::OK();
|
987 |
+
}
|
988 |
+
|
989 |
+
Status FunctionLibraryDefinition::AddLibrary(
|
990 |
+
const FunctionLibraryDefinition& other) {
|
991 |
+
// Remember the funcs and grads that we added successfully so that
|
992 |
+
// we can roll them back on error.
|
993 |
+
std::vector<string> funcs;
|
994 |
+
std::vector<string> funcs_with_grads;
|
995 |
+
Status s;
|
996 |
+
bool added;
|
997 |
+
for (auto iter : other.function_defs_) {
|
998 |
+
s = AddFunctionDefHelper(iter.second->fdef, &added);
|
999 |
+
if (!s.ok()) {
|
1000 |
+
Remove(funcs, funcs_with_grads);
|
1001 |
+
return s;
|
1002 |
+
}
|
1003 |
+
if (added) {
|
1004 |
+
funcs.push_back(iter.second->fdef.signature().name());
|
1005 |
+
}
|
1006 |
+
}
|
1007 |
+
for (auto iter : other.func_grad_) {
|
1008 |
+
GradientDef grad;
|
1009 |
+
grad.set_function_name(iter.first);
|
1010 |
+
grad.set_gradient_func(iter.second);
|
1011 |
+
s = AddGradientDefHelper(grad, &added);
|
1012 |
+
if (!s.ok()) {
|
1013 |
+
Remove(funcs, funcs_with_grads);
|
1014 |
+
return s;
|
1015 |
+
}
|
1016 |
+
if (added) {
|
1017 |
+
funcs_with_grads.push_back(grad.function_name());
|
1018 |
+
}
|
1019 |
+
}
|
1020 |
+
return Status::OK();
|
1021 |
+
}
|
1022 |
+
|
1023 |
+
Status FunctionLibraryDefinition::AddLibrary(
|
1024 |
+
const FunctionDefLibrary& lib_def) {
|
1025 |
+
// Remember the funcs and grads that we added successfully so that
|
1026 |
+
// we can roll them back on error.
|
1027 |
+
std::vector<string> funcs;
|
1028 |
+
std::vector<string> funcs_with_grads;
|
1029 |
+
Status s;
|
1030 |
+
bool added;
|
1031 |
+
for (const FunctionDef& fdef : lib_def.function()) {
|
1032 |
+
s = AddFunctionDefHelper(fdef, &added);
|
1033 |
+
if (!s.ok()) {
|
1034 |
+
Remove(funcs, funcs_with_grads);
|
1035 |
+
return s;
|
1036 |
+
}
|
1037 |
+
if (added) {
|
1038 |
+
funcs.push_back(fdef.signature().name());
|
1039 |
+
}
|
1040 |
+
}
|
1041 |
+
for (const GradientDef& grad : lib_def.gradient()) {
|
1042 |
+
s = AddGradientDefHelper(grad, &added);
|
1043 |
+
if (!s.ok()) {
|
1044 |
+
Remove(funcs, funcs_with_grads);
|
1045 |
+
return s;
|
1046 |
+
}
|
1047 |
+
if (added) {
|
1048 |
+
funcs_with_grads.push_back(grad.function_name());
|
1049 |
+
}
|
1050 |
+
}
|
1051 |
+
return Status::OK();
|
1052 |
+
}
|
1053 |
+
|
1054 |
+
void FunctionLibraryDefinition::RemoveFunction(const string& func) {
|
1055 |
+
const auto& i = function_defs_.find(func);
|
1056 |
+
DCHECK(i != function_defs_.end());
|
1057 |
+
function_defs_.erase(i);
|
1058 |
+
}
|
1059 |
+
|
1060 |
+
void FunctionLibraryDefinition::RemoveGradient(const string& func) {
|
1061 |
+
const auto& i = func_grad_.find(func);
|
1062 |
+
DCHECK(i != func_grad_.end());
|
1063 |
+
func_grad_.erase(i);
|
1064 |
+
}
|
1065 |
+
|
1066 |
+
void FunctionLibraryDefinition::Remove(
|
1067 |
+
const std::vector<string>& funcs,
|
1068 |
+
const std::vector<string>& funcs_with_grads) {
|
1069 |
+
for (const string& f : funcs) {
|
1070 |
+
RemoveFunction(f);
|
1071 |
+
}
|
1072 |
+
for (const string& f : funcs_with_grads) {
|
1073 |
+
RemoveGradient(f);
|
1074 |
+
}
|
1075 |
+
}
|
1076 |
+
|
1077 |
+
string FunctionLibraryDefinition::FindGradient(const string& func) const {
|
1078 |
+
return gtl::FindWithDefault(func_grad_, func, "");
|
1079 |
+
}
|
1080 |
+
|
1081 |
+
Status FunctionLibraryDefinition::LookUp(
|
1082 |
+
const string& op, const OpRegistrationData** op_reg_data) const {
|
1083 |
+
auto iter = function_defs_.find(op);
|
1084 |
+
if (iter != function_defs_.end()) {
|
1085 |
+
*op_reg_data = &iter->second->op_registration_data;
|
1086 |
+
return Status::OK();
|
1087 |
+
}
|
1088 |
+
return default_registry_->LookUp(op, op_reg_data);
|
1089 |
+
}
|
1090 |
+
|
1091 |
+
const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
|
1092 |
+
const NodeDef& ndef) const {
|
1093 |
+
if (ndef.op() != kGradientOp) {
|
1094 |
+
// If 'ndef' calls a function and the function's def has the attr,
|
1095 |
+
// returns it.
|
1096 |
+
return Find(ndef.op());
|
1097 |
+
}
|
1098 |
+
|
1099 |
+
// If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
|
1100 |
+
// Foo's attributes.
|
1101 |
+
const NameAttrList* forward_func_attrs;
|
1102 |
+
if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
|
1103 |
+
return nullptr;
|
1104 |
+
}
|
1105 |
+
const string& func_name = forward_func_attrs->name();
|
1106 |
+
const string& grad_name = FindGradient(func_name);
|
1107 |
+
// If 'func' has a user-defined gradient function, uses the grad
|
1108 |
+
// function's attrs to see if noinline is specified. Otherwise,
|
1109 |
+
// uses func's attrs.
|
1110 |
+
if (!grad_name.empty()) {
|
1111 |
+
return Find(grad_name);
|
1112 |
+
}
|
1113 |
+
return Find(func_name);
|
1114 |
+
}
|
1115 |
+
|
1116 |
+
FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
|
1117 |
+
FunctionDefLibrary lib;
|
1118 |
+
for (const auto& f : function_defs_) {
|
1119 |
+
*lib.add_function() = f.second->fdef;
|
1120 |
+
}
|
1121 |
+
for (const auto& g : func_grad_) {
|
1122 |
+
GradientDef* gd = lib.add_gradient();
|
1123 |
+
gd->set_function_name(g.first);
|
1124 |
+
gd->set_gradient_func(g.second);
|
1125 |
+
}
|
1126 |
+
return lib;
|
1127 |
+
}
|
1128 |
+
|
1129 |
+
template <typename T>
|
1130 |
+
Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
|
1131 |
+
const string& attr, T* value) const {
|
1132 |
+
const FunctionDef* fdef = GetAttrImpl(ndef);
|
1133 |
+
if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
|
1134 |
+
return Status::OK();
|
1135 |
+
}
|
1136 |
+
return errors::InvalidArgument("Attr ", attr, " is not defined.");
|
1137 |
+
}
|
1138 |
+
|
1139 |
+
template <typename T>
|
1140 |
+
Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
|
1141 |
+
T* value) const {
|
1142 |
+
return GetAttr(node.def(), attr, value);
|
1143 |
+
}
|
1144 |
+
|
1145 |
+
#define GET_ATTR(T) \
|
1146 |
+
template Status FunctionLibraryDefinition::GetAttr(const Node&, \
|
1147 |
+
const string&, T*) const; \
|
1148 |
+
template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \
|
1149 |
+
const string&, T*) const;
|
1150 |
+
GET_ATTR(string)
|
1151 |
+
GET_ATTR(bool)
|
1152 |
+
#undef GET_ATTR
|
1153 |
+
|
1154 |
+
void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
|
1155 |
+
if (val.size() >= 2 && val[0] == '$') {
|
1156 |
+
proto.set_placeholder(val.data() + 1, val.size() - 1);
|
1157 |
+
} else {
|
1158 |
+
SetAttrValue(val, &proto);
|
1159 |
+
}
|
1160 |
+
}
|
1161 |
+
|
1162 |
+
FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
|
1163 |
+
const string& name,
|
1164 |
+
gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
|
1165 |
+
AttrValueWrapper ret;
|
1166 |
+
ret.proto.mutable_func()->set_name(name);
|
1167 |
+
for (const auto& a : attrs) {
|
1168 |
+
ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
|
1169 |
+
}
|
1170 |
+
return ret;
|
1171 |
+
}
|
1172 |
+
|
1173 |
+
NodeDef FunctionDefHelper::Node::ToNodeDef() const {
|
1174 |
+
NodeDef n;
|
1175 |
+
n.set_op(this->op);
|
1176 |
+
n.set_name(this->ret[0]);
|
1177 |
+
for (const auto& a : this->attr) {
|
1178 |
+
n.mutable_attr()->insert({a.first, a.second.proto});
|
1179 |
+
}
|
1180 |
+
for (const string& a : this->arg) {
|
1181 |
+
n.add_input(a);
|
1182 |
+
}
|
1183 |
+
for (const string& d : this->dep) {
|
1184 |
+
n.add_input(strings::StrCat("^", d));
|
1185 |
+
}
|
1186 |
+
return n;
|
1187 |
+
}
|
1188 |
+
|
1189 |
+
/* static */
|
1190 |
+
FunctionDef FunctionDefHelper::Create(
|
1191 |
+
const string& function_name, gtl::ArraySlice<string> in_def,
|
1192 |
+
gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
|
1193 |
+
gtl::ArraySlice<Node> node_def,
|
1194 |
+
gtl::ArraySlice<std::pair<string, string>> ret_def) {
|
1195 |
+
FunctionDef fdef;
|
1196 |
+
|
1197 |
+
// Signature
|
1198 |
+
OpDefBuilder b(function_name);
|
1199 |
+
for (const auto& i : in_def) b.Input(i);
|
1200 |
+
for (const auto& o : out_def) b.Output(o);
|
1201 |
+
for (const auto& a : attr_def) b.Attr(a);
|
1202 |
+
|
1203 |
+
OpRegistrationData op_reg_data;
|
1204 |
+
TF_CHECK_OK(b.Finalize(&op_reg_data));
|
1205 |
+
fdef.mutable_signature()->Swap(&op_reg_data.op_def);
|
1206 |
+
|
1207 |
+
// Function body
|
1208 |
+
for (const auto& n : node_def) {
|
1209 |
+
*(fdef.add_node_def()) = n.ToNodeDef();
|
1210 |
+
}
|
1211 |
+
|
1212 |
+
// Returns
|
1213 |
+
for (const auto& r : ret_def) {
|
1214 |
+
fdef.mutable_ret()->insert({r.first, r.second});
|
1215 |
+
}
|
1216 |
+
return fdef;
|
1217 |
+
}
|
1218 |
+
|
1219 |
+
/* static */
|
1220 |
+
FunctionDef FunctionDefHelper::Define(const string& name,
|
1221 |
+
gtl::ArraySlice<string> arg_def,
|
1222 |
+
gtl::ArraySlice<string> ret_def,
|
1223 |
+
gtl::ArraySlice<string> attr_def,
|
1224 |
+
gtl::ArraySlice<Node> node_def) {
|
1225 |
+
FunctionDef fdef;
|
1226 |
+
OpDefBuilder b(name);
|
1227 |
+
for (const auto& a : arg_def) b.Input(a);
|
1228 |
+
for (const auto& r : ret_def) b.Output(r);
|
1229 |
+
for (const auto& a : attr_def) b.Attr(a);
|
1230 |
+
|
1231 |
+
OpRegistrationData op_reg_data;
|
1232 |
+
TF_CHECK_OK(b.Finalize(&op_reg_data));
|
1233 |
+
fdef.mutable_signature()->Swap(&op_reg_data.op_def);
|
1234 |
+
|
1235 |
+
// Mapping from legacy output names to NodeDef outputs.
|
1236 |
+
std::unordered_map<string, string> ret_index;
|
1237 |
+
for (const auto& a : fdef.signature().input_arg()) {
|
1238 |
+
ret_index[a.name()] = a.name();
|
1239 |
+
}
|
1240 |
+
|
1241 |
+
// For looking up OpDefs
|
1242 |
+
auto* op_def_registry = OpRegistry::Global();
|
1243 |
+
|
1244 |
+
// Function body
|
1245 |
+
for (const auto& src : node_def) {
|
1246 |
+
NodeDef* n = fdef.add_node_def();
|
1247 |
+
n->set_op(src.op);
|
1248 |
+
n->set_name(src.ret[0]);
|
1249 |
+
for (const auto& a : src.attr) {
|
1250 |
+
n->mutable_attr()->insert({a.first, a.second.proto});
|
1251 |
+
}
|
1252 |
+
for (const string& a : src.arg) {
|
1253 |
+
const auto iter = ret_index.find(a);
|
1254 |
+
CHECK(iter != ret_index.end()) << "Node input '" << a << "' in '"
|
1255 |
+
<< src.ret[0] << "' of " << name;
|
1256 |
+
n->add_input(iter->second);
|
1257 |
+
}
|
1258 |
+
for (const string& d : src.dep) {
|
1259 |
+
n->add_input(strings::StrCat("^", d));
|
1260 |
+
}
|
1261 |
+
|
1262 |
+
// Add the outputs of this node to ret_index.
|
1263 |
+
const OpDef* op_def = nullptr;
|
1264 |
+
TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
|
1265 |
+
CHECK(op_def != nullptr) << n->op();
|
1266 |
+
NameRangeMap output_names;
|
1267 |
+
TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
|
1268 |
+
for (const auto& o : output_names) {
|
1269 |
+
CHECK_LE(o.second.second, src.ret.size())
|
1270 |
+
<< "Missing ret for output '" << o.first << "' in '" << src.ret[0]
|
1271 |
+
<< "' of " << name;
|
1272 |
+
for (int i = o.second.first; i < o.second.second; ++i) {
|
1273 |
+
ret_index[src.ret[i]] =
|
1274 |
+
strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
|
1275 |
+
}
|
1276 |
+
}
|
1277 |
+
}
|
1278 |
+
|
1279 |
+
// Returns
|
1280 |
+
for (const auto& r : fdef.signature().output_arg()) {
|
1281 |
+
const auto iter = ret_index.find(r.name());
|
1282 |
+
CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
|
1283 |
+
fdef.mutable_ret()->insert({r.name(), iter->second});
|
1284 |
+
}
|
1285 |
+
return fdef;
|
1286 |
+
}
|
1287 |
+
|
1288 |
+
FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
|
1289 |
+
gtl::ArraySlice<string> ret_def,
|
1290 |
+
gtl::ArraySlice<string> attr_def,
|
1291 |
+
gtl::ArraySlice<Node> node_def) {
|
1292 |
+
return Define("_", arg_def, ret_def, attr_def, node_def);
|
1293 |
+
}
|
1294 |
+
|
1295 |
+
namespace gradient {
|
1296 |
+
|
1297 |
+
typedef std::unordered_map<string, Creator> OpGradFactory;
|
1298 |
+
|
1299 |
+
OpGradFactory* GetOpGradFactory() {
|
1300 |
+
static OpGradFactory* factory = new OpGradFactory;
|
1301 |
+
return factory;
|
1302 |
+
}
|
1303 |
+
|
1304 |
+
bool RegisterOp(const string& op, Creator func) {
|
1305 |
+
CHECK(GetOpGradFactory()->insert({op, func}).second)
|
1306 |
+
<< "Duplicated gradient for " << op;
|
1307 |
+
return true;
|
1308 |
+
}
|
1309 |
+
|
1310 |
+
Status GetOpGradientCreator(const string& op, Creator* creator) {
|
1311 |
+
auto fac = GetOpGradFactory();
|
1312 |
+
auto iter = fac->find(op);
|
1313 |
+
if (iter == fac->end()) {
|
1314 |
+
return errors::NotFound("No gradient defined for op: ", op);
|
1315 |
+
}
|
1316 |
+
*creator = iter->second;
|
1317 |
+
return Status::OK();
|
1318 |
+
}
|
1319 |
+
|
1320 |
+
} // end namespace gradient
|
1321 |
+
|
1322 |
+
} // end namespace tensorflow
|
function.h
ADDED
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_FUNCTION_H_
|
18 |
+
|
19 |
+
#include <vector>
|
20 |
+
#include "tensorflow/core/framework/attr_value.pb.h"
|
21 |
+
#include "tensorflow/core/framework/attr_value_util.h"
|
22 |
+
#include "tensorflow/core/framework/function.pb.h"
|
23 |
+
#include "tensorflow/core/framework/node_def_util.h"
|
24 |
+
#include "tensorflow/core/framework/op.h"
|
25 |
+
#include "tensorflow/core/framework/selective_registration.h"
|
26 |
+
#include "tensorflow/core/framework/types.h"
|
27 |
+
#include "tensorflow/core/lib/gtl/flatmap.h"
|
28 |
+
#include "tensorflow/core/lib/hash/hash.h"
|
29 |
+
#include "tensorflow/core/platform/env.h"
|
30 |
+
#include "tensorflow/core/platform/macros.h"
|
31 |
+
#include "tensorflow/core/platform/protobuf.h"
|
32 |
+
|
33 |
+
namespace tensorflow {
|
34 |
+
|
35 |
+
class CancellationManager;
|
36 |
+
class GraphDef;
|
37 |
+
class OpKernel;
|
38 |
+
class ResourceMgr;
|
39 |
+
class Rendezvous;
|
40 |
+
class ScopedStepContainer;
|
41 |
+
class StepStatsCollector;
|
42 |
+
class Node;
|
43 |
+
|
44 |
+
// FunctionDefHelper::Create is a convenient helper to construct a
|
45 |
+
// FunctionDef proto.
|
46 |
+
// E.g.,
|
47 |
+
// FunctionDef my_func = FunctionDefHelper::Create(
|
48 |
+
// "my_func_name",
|
49 |
+
// {"x:T", "y:T" /* one string per argument */},
|
50 |
+
// {"z:T" /* one string per return value */},
|
51 |
+
// {"T: {float, double}" /* one string per attribute */},
|
52 |
+
// {
|
53 |
+
// {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
|
54 |
+
// /* one entry per function node */
|
55 |
+
// },
|
56 |
+
// /* Mapping between function returns and function node outputs. */
|
57 |
+
// {{"z", "o:z"}});
|
58 |
+
//
|
59 |
+
// For the old Function::Node approach, use FunctionDefHelper::Define()
|
60 |
+
// E.g.,
|
61 |
+
// FunctionDef my_func = FunctionDefHelper::Define(
|
62 |
+
// "my_func_name",
|
63 |
+
// {"x:T", "y:T" /* one string per argument */},
|
64 |
+
// {"z:T" /* one string per return value */},
|
65 |
+
// {"T: {float, double}" /* one string per attribute */},
|
66 |
+
// {
|
67 |
+
// {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
|
68 |
+
// /* one entry per function node */
|
69 |
+
// });
|
70 |
+
class FunctionDefHelper {
|
71 |
+
public:
|
72 |
+
// AttrValueWrapper has copy constructors for the type T so that
|
73 |
+
// it's easy to construct a simple AttrValue proto.
|
74 |
+
//
|
75 |
+
// If T is a string type (const char*, string, or StringPiece), and
|
76 |
+
// it starts with "$", we construct a AttrValue of "placeholder".
|
77 |
+
//
|
78 |
+
// E.g.,
|
79 |
+
// std::<string, AttrValueWrapper> x = {"T", "$T"}
|
80 |
+
// is a named attr value placeholder.
|
81 |
+
struct AttrValueWrapper {
|
82 |
+
AttrValue proto;
|
83 |
+
|
84 |
+
AttrValueWrapper() {}
|
85 |
+
|
86 |
+
template <typename T>
|
87 |
+
AttrValueWrapper(T val) { // NOLINT(runtime/explicit)
|
88 |
+
SetAttrValue(val, &proto);
|
89 |
+
}
|
90 |
+
|
91 |
+
private:
|
92 |
+
void InitFromString(StringPiece val);
|
93 |
+
};
|
94 |
+
|
95 |
+
// Constructs an AttrValue.func given the "name" and "attrs".
|
96 |
+
static AttrValueWrapper FunctionRef(
|
97 |
+
const string& name,
|
98 |
+
gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
|
99 |
+
static AttrValueWrapper FunctionRef(const string& name) {
|
100 |
+
return FunctionRef(name, {});
|
101 |
+
}
|
102 |
+
|
103 |
+
// Node is used to construct FunctionDef.Node using initialization
|
104 |
+
// lists. E.g.,
|
105 |
+
// Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y
|
106 |
+
struct Node {
|
107 |
+
// When constructing a NodeDef, the first entry in ret is used as
|
108 |
+
// the node name, the remaining values are ignored.
|
109 |
+
std::vector<string> ret;
|
110 |
+
string op;
|
111 |
+
std::vector<string> arg;
|
112 |
+
std::vector<std::pair<string, AttrValueWrapper>> attr;
|
113 |
+
std::vector<string> dep;
|
114 |
+
|
115 |
+
NodeDef ToNodeDef() const;
|
116 |
+
};
|
117 |
+
|
118 |
+
// The Create() function uses the new NodeDef field. `ret_def`
|
119 |
+
// holds a mapping from the function output names from `out_def` to
|
120 |
+
// the node outputs from `node_def`.
|
121 |
+
static FunctionDef Create(const string& function_name,
|
122 |
+
gtl::ArraySlice<string> in_def,
|
123 |
+
gtl::ArraySlice<string> out_def,
|
124 |
+
gtl::ArraySlice<string> attr_def,
|
125 |
+
gtl::ArraySlice<Node> node_def,
|
126 |
+
gtl::ArraySlice<std::pair<string, string>> ret_def);
|
127 |
+
|
128 |
+
// The two Define() functions use the old FunctionDef::Node field.
|
129 |
+
// TODO(josh11b): Get rid of these and transition to the one above.
|
130 |
+
static FunctionDef Define(const string& function_name,
|
131 |
+
gtl::ArraySlice<string> arg_def,
|
132 |
+
gtl::ArraySlice<string> ret_def,
|
133 |
+
gtl::ArraySlice<string> attr_def,
|
134 |
+
gtl::ArraySlice<Node> node_def);
|
135 |
+
|
136 |
+
// Defines an anonymous function. I.e., its name is not relevant.
|
137 |
+
static FunctionDef Define(gtl::ArraySlice<string> arg_def,
|
138 |
+
gtl::ArraySlice<string> ret_def,
|
139 |
+
gtl::ArraySlice<string> attr_def,
|
140 |
+
gtl::ArraySlice<Node> node_def);
|
141 |
+
|
142 |
+
// Helpers to construct a constant scalar.
|
143 |
+
template <typename T>
|
144 |
+
static Node Const(const string& name, const T& val) {
|
145 |
+
Node n = {{name}, "Const"};
|
146 |
+
const DataType dtype = DataTypeToEnum<T>::value;
|
147 |
+
n.attr.push_back({"dtype", dtype});
|
148 |
+
Tensor t(dtype, TensorShape({}));
|
149 |
+
t.scalar<T>()() = val;
|
150 |
+
n.attr.push_back({"value", t});
|
151 |
+
return n;
|
152 |
+
}
|
153 |
+
|
154 |
+
template <typename T>
|
155 |
+
static Node Const(const string& name, gtl::ArraySlice<T> vals) {
|
156 |
+
Node n = {{name}, "Const"};
|
157 |
+
const DataType dtype = DataTypeToEnum<T>::value;
|
158 |
+
n.attr.push_back({"dtype", dtype});
|
159 |
+
int64 num = vals.size();
|
160 |
+
Tensor t(dtype, TensorShape({num}));
|
161 |
+
for (size_t i = 0; i < vals.size(); ++i) {
|
162 |
+
t.flat<T>()(i) = vals[i];
|
163 |
+
}
|
164 |
+
n.attr.push_back({"value", t});
|
165 |
+
return n;
|
166 |
+
}
|
167 |
+
};
|
168 |
+
|
169 |
+
template <>
|
170 |
+
inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
|
171 |
+
InitFromString(val);
|
172 |
+
}
|
173 |
+
|
174 |
+
template <>
|
175 |
+
inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
|
176 |
+
const string& val) {
|
177 |
+
InitFromString(val);
|
178 |
+
}
|
179 |
+
|
180 |
+
template <>
|
181 |
+
inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
|
182 |
+
InitFromString(val);
|
183 |
+
}
|
184 |
+
|
185 |
+
// Instantiate a function.
|
186 |
+
//
|
187 |
+
// "fdef" encodes a TF function with some attrs in fdef.signature.attr
|
188 |
+
// containing placeholders. InstantiateFunction binds these
|
189 |
+
// placeholders and produces an instantiated function encoded in
|
190 |
+
// "result.gdef". The value to substitute a placeholder is given by
|
191 |
+
// "attr_values", which is a map from a placeholder name to an attr
|
192 |
+
// value.
|
193 |
+
//
|
194 |
+
// InstantiateFunction calls "get_function" to find signatures of other
|
195 |
+
// functions and primitive ops.
|
196 |
+
|
197 |
+
// GetFunctionSignature(func name, opdef) returns OK if the func name is found
|
198 |
+
// and opdef is filled with a pointer to the corresponding signature
|
199 |
+
// (a OpDef proto). Otherwise, returns an error.
|
200 |
+
typedef std::function<Status(const string&, const OpDef**)>
|
201 |
+
GetFunctionSignature;
|
202 |
+
|
203 |
+
struct InstantiationResult {
|
204 |
+
DataTypeVector arg_types;
|
205 |
+
DataTypeVector ret_types;
|
206 |
+
std::vector<NodeDef> nodes;
|
207 |
+
};
|
208 |
+
Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
209 |
+
GetFunctionSignature get_function,
|
210 |
+
InstantiationResult* result);
|
211 |
+
|
212 |
+
// Returns a debug string for a function definition.
|
213 |
+
//
|
214 |
+
// The returned text is multiple-line. It is intended to be
|
215 |
+
// human-readable rather than being friendly to parsers. It is _NOT_
|
216 |
+
// intended to be the canonical string representation of "func_def".
|
217 |
+
// Particularly, it may not include all information presented in
|
218 |
+
// "func_def" (e.g., comments, description of the function arguments,
|
219 |
+
// etc.)
|
220 |
+
string DebugString(const FunctionDef& func_def);
|
221 |
+
string DebugString(const GraphDef& instantiated_func_def);
|
222 |
+
string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
|
223 |
+
|
224 |
+
// Returns a debug string for a top level graph (the main program and
|
225 |
+
// its supporting functions defined in its library).
|
226 |
+
string DebugStringWhole(const GraphDef& gdef);
|
227 |
+
|
228 |
+
// Returns true if f1 == f2. Compares all fields, including descriptions. Order
|
229 |
+
// of NodeDefs doesn't matter.
|
230 |
+
bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
|
231 |
+
|
232 |
+
// Return a hash of `fdef` that is consistent with FunctionDefsEqual method.
|
233 |
+
// In other words, if two fdefs compare equal, their hash values will be the
|
234 |
+
// same.
|
235 |
+
uint64 FunctionDefHash(const FunctionDef& fdef);
|
236 |
+
|
237 |
+
// Returns a canonicalized string for the instantiation of the
|
238 |
+
// function of the given "name" and attributes "attrs".
|
239 |
+
//
|
240 |
+
// The returned string is guaranteed to be stable within one address
|
241 |
+
// space. But it may be change as the implementation
|
242 |
+
// evolves. Therefore, it should not be persisted or compared across
|
243 |
+
// address spaces.
|
244 |
+
string Canonicalize(const string& funcname, AttrSlice attrs);
|
245 |
+
|
246 |
+
class CallFrameInterface {
|
247 |
+
public:
|
248 |
+
virtual ~CallFrameInterface() {}
|
249 |
+
|
250 |
+
virtual size_t num_args() const = 0;
|
251 |
+
virtual size_t num_retvals() const = 0;
|
252 |
+
|
253 |
+
virtual Status GetArg(int index, Tensor* val) const = 0;
|
254 |
+
virtual Status SetRetval(int index, const Tensor& val) = 0;
|
255 |
+
};
|
256 |
+
|
257 |
+
// Represents a function call frame. I.e., the data structure used to
|
258 |
+
// pass arguments to a function and retrieve its results.
|
259 |
+
//
|
260 |
+
// Runtime must arrange accesses to one FunctionCallFrame s.t.
|
261 |
+
// 1. SetArgs() happens before any GetArg();
|
262 |
+
// 2. GetRetvals happens after all SetRetval();
|
263 |
+
class FunctionCallFrame : public CallFrameInterface {
|
264 |
+
public:
|
265 |
+
FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
|
266 |
+
~FunctionCallFrame();
|
267 |
+
|
268 |
+
// Caller methods.
|
269 |
+
Status SetArgs(gtl::ArraySlice<Tensor> args);
|
270 |
+
Status GetRetvals(std::vector<Tensor>* rets) const;
|
271 |
+
Status ConsumeRetvals(std::vector<Tensor>* rets);
|
272 |
+
|
273 |
+
size_t num_args() const override { return arg_types_.size(); }
|
274 |
+
size_t num_retvals() const override { return ret_types_.size(); }
|
275 |
+
|
276 |
+
// Callee methods.
|
277 |
+
Status GetArg(int index, Tensor* val) const override;
|
278 |
+
Status SetRetval(int index, const Tensor& val) override;
|
279 |
+
|
280 |
+
private:
|
281 |
+
DataTypeVector arg_types_;
|
282 |
+
DataTypeVector ret_types_;
|
283 |
+
gtl::InlinedVector<Tensor, 4> args_;
|
284 |
+
struct Retval {
|
285 |
+
bool has_val = false;
|
286 |
+
Tensor val;
|
287 |
+
};
|
288 |
+
gtl::InlinedVector<Retval, 4> rets_;
|
289 |
+
|
290 |
+
TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
|
291 |
+
};
|
292 |
+
|
293 |
+
// Helper to maintain a map between function names in a given
|
294 |
+
// FunctionDefLibrary and function definitions.
|
295 |
+
class FunctionLibraryDefinition : public OpRegistryInterface {
|
296 |
+
public:
|
297 |
+
explicit FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def);
|
298 |
+
FunctionLibraryDefinition(const OpRegistryInterface* default_registry,
|
299 |
+
const FunctionDefLibrary& lib_def);
|
300 |
+
~FunctionLibraryDefinition() override;
|
301 |
+
|
302 |
+
FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) =
|
303 |
+
delete;
|
304 |
+
|
305 |
+
// Returns nullptr if "func" is not defined in "lib_def". Otherwise,
|
306 |
+
// returns its definition proto.
|
307 |
+
const FunctionDef* Find(const string& func) const;
|
308 |
+
|
309 |
+
// Adds function definition 'fdef' to this function library.
|
310 |
+
// Returns status 'ok' on success, or error otherwise. This is a no-op if
|
311 |
+
// 'fdef' already exists in this function library.
|
312 |
+
// If 'fdef' is successfully added to the library, it will be accessible
|
313 |
+
// from 'LookUp' and included in the proto returned by 'ToProto'.
|
314 |
+
// This operation is atomic.
|
315 |
+
Status AddFunctionDef(const FunctionDef& fdef);
|
316 |
+
|
317 |
+
// Adds gradient definition 'grad' to this function library.
|
318 |
+
// This is a no-op if 'grad' already exists in this function library.
|
319 |
+
// If 'grad' is successfully added, it will be accessible via 'FindGradient'
|
320 |
+
// and included in the proto returned by 'ToProto'.
|
321 |
+
// This operation is atomic.
|
322 |
+
Status AddGradientDef(const GradientDef& grad);
|
323 |
+
|
324 |
+
// Adds the functions and gradients in 'other' to this function library.
|
325 |
+
// Duplicate functions and gradients are ignored.
|
326 |
+
// This operation is atomic.
|
327 |
+
Status AddLibrary(const FunctionLibraryDefinition& other);
|
328 |
+
|
329 |
+
// Adds the functions and gradients in 'lib_def' to this function library.
|
330 |
+
// Duplicate functions and gradients are ignored.
|
331 |
+
// This operation is atomic.
|
332 |
+
Status AddLibrary(const FunctionDefLibrary& lib_def);
|
333 |
+
|
334 |
+
// If the gradient function for 'func' is specified explicitly in
|
335 |
+
// the library, returns the gradient function name. Otherwise,
|
336 |
+
// returns an empty string.
|
337 |
+
string FindGradient(const string& func) const;
|
338 |
+
|
339 |
+
// OpRegistryInterface method. Useful for constructing a Graph.
|
340 |
+
//
|
341 |
+
// If "op" is defined in the library, returns its signature.
|
342 |
+
// Otherwise, assume "op" is a primitive op and returns its op
|
343 |
+
// signature and shape inference function.
|
344 |
+
Status LookUp(const string& op_type_name,
|
345 |
+
const OpRegistrationData** op_reg_data) const override;
|
346 |
+
|
347 |
+
static constexpr const char* const kGradientOp = "SymbolicGradient";
|
348 |
+
static constexpr const char* const kFuncAttr = "f";
|
349 |
+
|
350 |
+
// Given a node def 'ndef', inspects attributes of the callee
|
351 |
+
// function to derive the attribute 'value' for 'attr'. Returns OK
|
352 |
+
// iff the attribute is given by the function's definition.
|
353 |
+
// TODO(irving): Remove; keep only the const Node& version.
|
354 |
+
template <typename T>
|
355 |
+
Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const;
|
356 |
+
|
357 |
+
// Given a node, inspects attributes of the callee function to derive the
|
358 |
+
// attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
|
359 |
+
// function's definition.
|
360 |
+
template <typename T>
|
361 |
+
Status GetAttr(const Node& node, const string& attr, T* value) const;
|
362 |
+
|
363 |
+
// Returns a proto representation of the state of this function library.
|
364 |
+
FunctionDefLibrary ToProto() const;
|
365 |
+
|
366 |
+
const OpRegistryInterface* default_registry() const {
|
367 |
+
return default_registry_;
|
368 |
+
}
|
369 |
+
|
370 |
+
private:
|
371 |
+
// Shape inference for functions is handled separately by ShapeRefiner.
|
372 |
+
|
373 |
+
struct FunctionDefAndOpRegistration {
|
374 |
+
FunctionDefAndOpRegistration(const FunctionDef& fdef_in);
|
375 |
+
|
376 |
+
FunctionDef fdef;
|
377 |
+
OpRegistrationData op_registration_data;
|
378 |
+
};
|
379 |
+
|
380 |
+
// Same as AddFunctionDef/AddGradientDef except these methods set
|
381 |
+
// `added` to true if the `fdef`/`grad` were actually added to this.
|
382 |
+
Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added);
|
383 |
+
Status AddGradientDefHelper(const GradientDef& grad, bool* added);
|
384 |
+
|
385 |
+
const OpRegistryInterface* const default_registry_;
|
386 |
+
gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>>
|
387 |
+
function_defs_;
|
388 |
+
gtl::FlatMap<string, string> func_grad_;
|
389 |
+
|
390 |
+
// Helper function for GetAttr. Returns the FunctionDef* to get the
|
391 |
+
// attr from.
|
392 |
+
const FunctionDef* GetAttrImpl(const NodeDef& ndef) const;
|
393 |
+
|
394 |
+
// Remove function `func` from the library. `func` must be in the library.
|
395 |
+
void RemoveFunction(const string& func);
|
396 |
+
|
397 |
+
// Remove gradient of function `func` from the library. `func` must have
|
398 |
+
// a gradient.
|
399 |
+
void RemoveGradient(const string& func);
|
400 |
+
|
401 |
+
// Remove all functions in `funcs` and all gradients of
|
402 |
+
// functions in `funcs_with_grads` from this library.
|
403 |
+
void Remove(const std::vector<string>& funcs,
|
404 |
+
const std::vector<string>& funcs_with_grads);
|
405 |
+
};
|
406 |
+
|
407 |
+
// Forward declare. Defined in common_runtime/function.h
|
408 |
+
struct FunctionBody;
|
409 |
+
|
410 |
+
// Forward declare. Defined in common_runtime/device.h
|
411 |
+
class Device;
|
412 |
+
|
413 |
+
class FunctionLibraryRuntime {
|
414 |
+
public:
|
415 |
+
virtual ~FunctionLibraryRuntime() {}
|
416 |
+
|
417 |
+
// Instantiate a function with the given "attrs".
|
418 |
+
//
|
419 |
+
// Returns OK and fills in "handle" if the instantiation succeeds.
|
420 |
+
// Otherwise returns an error and "handle" is undefined.
|
421 |
+
typedef uint64 Handle;
|
422 |
+
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
|
423 |
+
Handle* handle) = 0;
|
424 |
+
|
425 |
+
// Releases state associated with the handle.
|
426 |
+
virtual Status ReleaseHandle(Handle handle) = 0;
|
427 |
+
|
428 |
+
// Returns the function body for the instantiated function given its
|
429 |
+
// handle 'h'. Returns nullptr if "h" is not found.
|
430 |
+
//
|
431 |
+
// *this keeps the ownership of the returned object, which remains alive
|
432 |
+
// as long as *this.
|
433 |
+
virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
|
434 |
+
|
435 |
+
// Asynchronously invokes the instantiated function identified by
|
436 |
+
// "handle".
|
437 |
+
//
|
438 |
+
// If function execution succeeds, "done" is called with OK and
|
439 |
+
// "*rets" is filled with the function's return values. Otheriwse,
|
440 |
+
// "done" is called with an error status.
|
441 |
+
//
|
442 |
+
// Does not take ownership of "rets".
|
443 |
+
// In the cross-process scenario, runner isn't used for making the Async
|
444 |
+
// RPC calls.
|
445 |
+
struct Options {
|
446 |
+
// The id of the step that is calling this function.
|
447 |
+
int64 step_id = 0;
|
448 |
+
Rendezvous* rendezvous = nullptr;
|
449 |
+
CancellationManager* cancellation_manager = nullptr;
|
450 |
+
ScopedStepContainer* step_container = nullptr;
|
451 |
+
StepStatsCollector* stats_collector = nullptr;
|
452 |
+
|
453 |
+
std::function<void(std::function<void()>)>* runner = nullptr;
|
454 |
+
|
455 |
+
// Parameters for remote function execution.
|
456 |
+
bool remote_execution = false;
|
457 |
+
string source_device = ""; // Fully specified device name.
|
458 |
+
|
459 |
+
// Allocator attributes specifying where the args are / rets should be put.
|
460 |
+
// These should either be {} or match the length of args / retvals. If {},
|
461 |
+
// the default allocator attributes will be assumed for all args / retvals.
|
462 |
+
std::vector<AllocatorAttributes> args_alloc_attrs;
|
463 |
+
std::vector<AllocatorAttributes> rets_alloc_attrs;
|
464 |
+
|
465 |
+
// If true, we create a new IntraProcessRendezvous, else use the existing
|
466 |
+
// one.
|
467 |
+
bool create_rendezvous = false;
|
468 |
+
};
|
469 |
+
typedef std::function<void(const Status&)> DoneCallback;
|
470 |
+
virtual void Run(const Options& opts, Handle handle,
|
471 |
+
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
472 |
+
DoneCallback done) = 0;
|
473 |
+
virtual void Run(const Options& opts, Handle handle,
|
474 |
+
CallFrameInterface* call_frame, DoneCallback done) = 0;
|
475 |
+
|
476 |
+
// Creates a "kernel" for the given node def "ndef".
|
477 |
+
//
|
478 |
+
// If succeeds, returns OK and the caller takes the ownership of the
|
479 |
+
// returned "*kernel". Otherwise, returns an error.
|
480 |
+
virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0;
|
481 |
+
|
482 |
+
// Returns true iff 'function' is stateful.
|
483 |
+
virtual bool IsStateful(const string& function_name) = 0;
|
484 |
+
|
485 |
+
// Returns the device on which the function executes.
|
486 |
+
virtual Device* device() = 0;
|
487 |
+
|
488 |
+
// Returns the function library definition that backs this runtime.
|
489 |
+
virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
|
490 |
+
const = 0;
|
491 |
+
|
492 |
+
// Returns the environment on which the function executes.
|
493 |
+
virtual Env* env() = 0;
|
494 |
+
|
495 |
+
// Returns a debug string showing the definition of the function of
|
496 |
+
// 'handle'.
|
497 |
+
virtual string DebugString(Handle handle) = 0;
|
498 |
+
|
499 |
+
// Returns the graph version number.
|
500 |
+
virtual int graph_def_version() = 0;
|
501 |
+
|
502 |
+
typedef uint64 LocalHandle;
|
503 |
+
};
|
504 |
+
|
505 |
+
const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
|
506 |
+
const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
|
507 |
+
typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&,
|
508 |
+
std::unique_ptr<OpKernel>*)>
|
509 |
+
CustomKernelCreator;
|
510 |
+
|
511 |
+
// Used to instantiate and run functions in a distributed system.
|
512 |
+
class DistributedFunctionLibraryRuntime {
|
513 |
+
public:
|
514 |
+
virtual ~DistributedFunctionLibraryRuntime() {}
|
515 |
+
|
516 |
+
// The _target attr in attrs determines where the function is instantiated.
|
517 |
+
virtual Status Instantiate(const string& function_name,
|
518 |
+
const FunctionLibraryDefinition& lib_def,
|
519 |
+
AttrSlice attrs,
|
520 |
+
FunctionLibraryRuntime::LocalHandle* handle) = 0;
|
521 |
+
|
522 |
+
// opts.runner isn't used for execution.
|
523 |
+
virtual void Run(const FunctionLibraryRuntime::Options& opts,
|
524 |
+
FunctionLibraryRuntime::LocalHandle handle,
|
525 |
+
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
526 |
+
FunctionLibraryRuntime::DoneCallback done) = 0;
|
527 |
+
};
|
528 |
+
|
529 |
+
// Extracts the actual type from "attr_values" based on its definition
|
530 |
+
// "arg_def".
|
531 |
+
//
|
532 |
+
// If "arg_def" is a N*T type, *is_type_list is set to false, and
|
533 |
+
// *dtypes is set to be a vector of size N and each element is T.
|
534 |
+
//
|
535 |
+
// If "arg_def" is a list(type), *is_type_list is set to true, and
|
536 |
+
// *dtypes is set to be a vector of types specified in attrs for
|
537 |
+
// arg_def.
|
538 |
+
//
|
539 |
+
// Otherwise (arg_def is a simple type T), *is_type_list is set to
|
540 |
+
// false, and *dtypes is set to a single element vector, whose only
|
541 |
+
// element is T.
|
542 |
+
Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
|
543 |
+
bool* is_type_list, DataTypeVector* dtypes);
|
544 |
+
|
545 |
+
// To register a gradient function for a builtin op, one should use
|
546 |
+
// REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
|
547 |
+
//
|
548 |
+
// Typically, the c++ grad factory is a plan function that can be
|
549 |
+
// converted into ::tensorflow::gradient::Creator, which is
|
550 |
+
// std::function<Status(const AttrSlice&, FunctionDef*)>.
|
551 |
+
//
|
552 |
+
// A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
|
553 |
+
// definition of a brain function which compute the gradient for the
|
554 |
+
// <op_name> when the <op_name> is instantiated with the given attrs.
|
555 |
+
//
|
556 |
+
// E.g.,
|
557 |
+
//
|
558 |
+
// Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
|
559 |
+
// bool transpose_a;
|
560 |
+
// TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
|
561 |
+
// bool transpose_b;
|
562 |
+
// TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
|
563 |
+
// DataType dtype;
|
564 |
+
// TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
|
565 |
+
// if (!transpose_a && !transpose_b) {
|
566 |
+
// *g = FunctionDefHelper::Define(
|
567 |
+
// "MatMulGrad",
|
568 |
+
// {"x:T ", "y:T", "dz:T"}, // Inputs to this function
|
569 |
+
// {"dx:T", "dy:T"}, // Outputs from this function
|
570 |
+
// {"T: {float, double}"}, // Attributes needed by this function
|
571 |
+
// {
|
572 |
+
// {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
|
573 |
+
// {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
|
574 |
+
// {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
|
575 |
+
// {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
|
576 |
+
// });
|
577 |
+
// } else {
|
578 |
+
// ... ...
|
579 |
+
// }
|
580 |
+
// return Status::OK();
|
581 |
+
// }
|
582 |
+
//
|
583 |
+
// NOTE: $T is substituted with the type variable "T" when the
|
584 |
+
// gradient function MatMul is instantiated.
|
585 |
+
//
|
586 |
+
// TODO(zhifengc): Better documentation somewhere.
|
587 |
+
|
588 |
+
// Macros to define a gradient function factory for a primitive
|
589 |
+
// operation.
|
590 |
+
#define REGISTER_OP_GRADIENT(name, fn) \
|
591 |
+
REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
|
592 |
+
|
593 |
+
#define REGISTER_OP_NO_GRADIENT(name) \
|
594 |
+
REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
|
595 |
+
|
596 |
+
#define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
|
597 |
+
REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
|
598 |
+
|
599 |
+
#define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \
|
600 |
+
static bool unused_grad_##ctr = SHOULD_REGISTER_OP_GRADIENT && \
|
601 |
+
::tensorflow::gradient::RegisterOp(name, fn)
|
602 |
+
|
603 |
+
namespace gradient {
|
604 |
+
// Register a gradient creator for the "op".
|
605 |
+
typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
|
606 |
+
bool RegisterOp(const string& op, Creator func);
|
607 |
+
|
608 |
+
// Returns OK the gradient creator for the "op" is found (may be
|
609 |
+
// nullptr if REGISTER_OP_NO_GRADIENT is used.
|
610 |
+
Status GetOpGradientCreator(const string& op, Creator* creator);
|
611 |
+
};
|
612 |
+
|
613 |
+
// Declare explicit instantiations of GetAttr
|
614 |
+
#define GET_ATTR(T) \
|
615 |
+
extern template Status FunctionLibraryDefinition::GetAttr( \
|
616 |
+
const Node&, const string&, T*) const; \
|
617 |
+
extern template Status FunctionLibraryDefinition::GetAttr( \
|
618 |
+
const NodeDef&, const string&, T*) const;
|
619 |
+
GET_ATTR(string)
|
620 |
+
GET_ATTR(bool)
|
621 |
+
#undef GET_ATTR
|
622 |
+
|
623 |
+
} // end namespace tensorflow
|
624 |
+
|
625 |
+
#endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_
|
function.proto
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto3";
|
2 |
+
|
3 |
+
package tensorflow;
|
4 |
+
option cc_enable_arenas = true;
|
5 |
+
option java_outer_classname = "FunctionProtos";
|
6 |
+
option java_multiple_files = true;
|
7 |
+
option java_package = "org.tensorflow.framework";
|
8 |
+
|
9 |
+
import "tensorflow/core/framework/attr_value.proto";
|
10 |
+
import "tensorflow/core/framework/node_def.proto";
|
11 |
+
import "tensorflow/core/framework/op_def.proto";
|
12 |
+
|
13 |
+
// A library is a set of named functions.
|
14 |
+
message FunctionDefLibrary {
|
15 |
+
repeated FunctionDef function = 1;
|
16 |
+
repeated GradientDef gradient = 2;
|
17 |
+
}
|
18 |
+
|
19 |
+
// A function can be instantiated when the runtime can bind every attr
|
20 |
+
// with a value. When a GraphDef has a call to a function, it must
|
21 |
+
// have binding for every attr defined in the signature.
|
22 |
+
//
|
23 |
+
// TODO(zhifengc):
|
24 |
+
// * device spec, etc.
|
25 |
+
message FunctionDef {
|
26 |
+
// The definition of the function's name, arguments, return values,
|
27 |
+
// attrs etc.
|
28 |
+
OpDef signature = 1;
|
29 |
+
|
30 |
+
// Attributes specific to this function definition.
|
31 |
+
map<string, AttrValue> attr = 5;
|
32 |
+
|
33 |
+
// NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21.
|
34 |
+
|
35 |
+
// In both of the following fields, there is the need to specify an
|
36 |
+
// output that is used as either the input to another node (in
|
37 |
+
// `node_def`) or as a return value of the function (in `ret`).
|
38 |
+
// Unlike the NodeDefs in GraphDef, we need to be able to specify a
|
39 |
+
// list in some cases (instead of just single outputs). Also, we
|
40 |
+
// need to be able to deal with lists of unknown length (so the
|
41 |
+
// output index may not be known at function definition time). So
|
42 |
+
// we use the following format instead:
|
43 |
+
// * "fun_in" where "fun_in" is the name of a function input arg in
|
44 |
+
// the `signature` field above. This represents that input, whether
|
45 |
+
// it is a single tensor or a list.
|
46 |
+
// * "fun_in:0" gives the first element of a function input arg (a
|
47 |
+
// non-list input is considered a list of length 1 for these
|
48 |
+
// purposes).
|
49 |
+
// * "node:out" where "node" is the name of a node in `node_def` and
|
50 |
+
// "out" is the name one of its op's output arguments (the name
|
51 |
+
// comes from the OpDef of the node's op). This represents that
|
52 |
+
// node's output, whether it is a single tensor or a list.
|
53 |
+
// Note: We enforce that an op's output arguments are never
|
54 |
+
// renamed in the backwards-compatibility test.
|
55 |
+
// * "node:out:0" gives the first element of a node output arg (a
|
56 |
+
// non-list output is considered a list of length 1 for these
|
57 |
+
// purposes).
|
58 |
+
//
|
59 |
+
// NOT CURRENTLY SUPPORTED (but may be in the future):
|
60 |
+
// * "node:out:-1" gives last element in a node output list
|
61 |
+
// * "node:out:1:" gives a list with all but the first element in a
|
62 |
+
// node output list
|
63 |
+
// * "node:out::-1" gives a list with all but the last element in a
|
64 |
+
// node output list
|
65 |
+
|
66 |
+
// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
|
67 |
+
// may have values of type `placeholder` and the `input` field uses
|
68 |
+
// the "output" format above.
|
69 |
+
|
70 |
+
// By convention, "op" in node_def is resolved by consulting with a
|
71 |
+
// user-defined library first. If not resolved, "func" is assumed to
|
72 |
+
// be a builtin op.
|
73 |
+
repeated NodeDef node_def = 3;
|
74 |
+
|
75 |
+
// A mapping from the output arg names from `signature` to the
|
76 |
+
// outputs from `node_def` that should be returned by the function.
|
77 |
+
map<string, string> ret = 4;
|
78 |
+
}
|
79 |
+
|
80 |
+
// GradientDef defines the gradient function of a function defined in
|
81 |
+
// a function library.
|
82 |
+
//
|
83 |
+
// A gradient function g (specified by gradient_func) for a function f
|
84 |
+
// (specified by function_name) must follow the following:
|
85 |
+
//
|
86 |
+
// The function 'f' must be a numerical function which takes N inputs
|
87 |
+
// and produces M outputs. Its gradient function 'g', which is a
|
88 |
+
// function taking N + M inputs and produces N outputs.
|
89 |
+
//
|
90 |
+
// I.e. if we have
|
91 |
+
// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
|
92 |
+
// then, g is
|
93 |
+
// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
|
94 |
+
// dL/dy1, dL/dy2, ..., dL/dy_M),
|
95 |
+
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
|
96 |
+
// loss function). dL/dx_i is the partial derivative of L with respect
|
97 |
+
// to x_i.
|
98 |
+
message GradientDef {
|
99 |
+
string function_name = 1; // The function name.
|
100 |
+
string gradient_func = 2; // The gradient function's name.
|
101 |
+
}
|
function_test.cc
ADDED
@@ -0,0 +1,1339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/function.h"
|
17 |
+
#include <vector>
|
18 |
+
#include "tensorflow/core/framework/function.pb.h"
|
19 |
+
#include "tensorflow/core/framework/function_testlib.h"
|
20 |
+
#include "tensorflow/core/framework/op.h"
|
21 |
+
#include "tensorflow/core/framework/tensor_testutil.h"
|
22 |
+
#include "tensorflow/core/kernels/ops_util.h"
|
23 |
+
#include "tensorflow/core/lib/core/status_test_util.h"
|
24 |
+
#include "tensorflow/core/lib/gtl/array_slice.h"
|
25 |
+
#include "tensorflow/core/lib/strings/str_util.h"
|
26 |
+
#include "tensorflow/core/lib/strings/strcat.h"
|
27 |
+
#include "tensorflow/core/platform/test.h"
|
28 |
+
#include "tensorflow/core/platform/types.h"
|
29 |
+
|
30 |
+
namespace tensorflow {
|
31 |
+
namespace {
|
32 |
+
|
33 |
+
// A helper class to make AttrSlice from initializer lists
|
34 |
+
class Attrs {
|
35 |
+
public:
|
36 |
+
Attrs(const std::initializer_list< // NOLINT(runtime/explicit)
|
37 |
+
std::pair<string, FunctionDefHelper::AttrValueWrapper>>
|
38 |
+
attrs) {
|
39 |
+
for (const auto& aval : attrs) {
|
40 |
+
map_.insert({aval.first, aval.second.proto});
|
41 |
+
}
|
42 |
+
}
|
43 |
+
|
44 |
+
operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit)
|
45 |
+
|
46 |
+
private:
|
47 |
+
AttrValueMap map_;
|
48 |
+
};
|
49 |
+
|
50 |
+
typedef FunctionDefHelper FDH;
|
51 |
+
|
52 |
+
Status GetOpSig(const string& op, const OpDef** sig) {
|
53 |
+
return OpRegistry::Global()->LookUpOpDef(op, sig);
|
54 |
+
}
|
55 |
+
|
56 |
+
REGISTER_OP("One")
|
57 |
+
.Output("y: T")
|
58 |
+
.Attr("T: {float, double, int32, int64}")
|
59 |
+
.Doc(R"doc(
|
60 |
+
Returns a tensor with a single element (1) of type T.
|
61 |
+
|
62 |
+
y: A scalar in type T.
|
63 |
+
|
64 |
+
)doc");
|
65 |
+
|
66 |
+
TEST(TFunc, SquarePlusOne) {
|
67 |
+
auto fdef = FDH::Create(
|
68 |
+
// Name
|
69 |
+
"SquarePlusOne",
|
70 |
+
// Inputs
|
71 |
+
{"x: T"},
|
72 |
+
// Outputs
|
73 |
+
{"y: T"},
|
74 |
+
// Attrs
|
75 |
+
{"T: {float, double, int32, int64}"},
|
76 |
+
// Nodes
|
77 |
+
{// a = Square<T>(x)
|
78 |
+
{{"a"}, "Square", {"x"}, {{"T", "$T"}}},
|
79 |
+
// o = One<T>()
|
80 |
+
// NOTE: We can also have a Cast<Tin, Tout>(x) instead.
|
81 |
+
{{"o"}, "One", {}, {{"T", "$T"}}},
|
82 |
+
// y = Add<T>(a, o)
|
83 |
+
{{"y"}, "Add", {"a:y", "o:y"}, {{"T", "$T"}}}},
|
84 |
+
// Returns
|
85 |
+
{{"y", "y:z:0"}});
|
86 |
+
|
87 |
+
const char* e = R"P(
|
88 |
+
SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
89 |
+
a = Square[T=$T](x)
|
90 |
+
o = One[T=$T]()
|
91 |
+
y = Add[T=$T](a:y, o:y)
|
92 |
+
return y = y:z:0
|
93 |
+
}
|
94 |
+
)P";
|
95 |
+
EXPECT_EQ(DebugString(fdef), e);
|
96 |
+
|
97 |
+
// Instantiate one with T=float
|
98 |
+
InstantiationResult result;
|
99 |
+
TF_ASSERT_OK(
|
100 |
+
InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result));
|
101 |
+
const char* e2 = R"P(
|
102 |
+
(x:float) -> (y:float) {
|
103 |
+
a = Square[T=float](x)
|
104 |
+
o = One[T=float]()
|
105 |
+
y = Add[T=float](a, o)
|
106 |
+
}
|
107 |
+
)P";
|
108 |
+
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
109 |
+
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
110 |
+
EXPECT_EQ(DebugString(result.nodes), e2);
|
111 |
+
}
|
112 |
+
|
113 |
+
TEST(TFunc, ControlDep) {
|
114 |
+
auto fdef = FDH::Create(
|
115 |
+
// Name
|
116 |
+
"ControlDep",
|
117 |
+
// Inputs
|
118 |
+
{"x: int32"},
|
119 |
+
// Outputs
|
120 |
+
{"y: int32"},
|
121 |
+
// Attrs
|
122 |
+
{},
|
123 |
+
// Nodes
|
124 |
+
{// a = Identity<int32>(x)
|
125 |
+
{{"a"}, "Identity", {"x"}, {{"T", DT_INT32}}},
|
126 |
+
// o = NoOp(^a)
|
127 |
+
{{"o"}, "NoOp", {"^a"}, {}},
|
128 |
+
// y = Identity<int32>(a, ^o)
|
129 |
+
{{"y"}, "Identity", {"a:output:0", "^o"}, {{"T", DT_INT32}}}},
|
130 |
+
// Returns
|
131 |
+
{{"y", "y:output:0"}});
|
132 |
+
|
133 |
+
const char* e = R"P(
|
134 |
+
ControlDep(x:int32) -> (y:int32) {
|
135 |
+
a = Identity[T=int32](x)
|
136 |
+
o = NoOp() @ a
|
137 |
+
y = Identity[T=int32](a:output:0) @ o
|
138 |
+
return y = y:output:0
|
139 |
+
}
|
140 |
+
)P";
|
141 |
+
EXPECT_EQ(DebugString(fdef), e);
|
142 |
+
|
143 |
+
// Instantiate one with T=float
|
144 |
+
InstantiationResult result;
|
145 |
+
TF_ASSERT_OK(
|
146 |
+
InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result));
|
147 |
+
const char* e2 = R"P(
|
148 |
+
(x:int32) -> (y:int32) {
|
149 |
+
a = Identity[T=int32](x)
|
150 |
+
o = NoOp() @ a
|
151 |
+
y = Identity[T=int32](a) @ o
|
152 |
+
}
|
153 |
+
)P";
|
154 |
+
EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32}));
|
155 |
+
EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32}));
|
156 |
+
EXPECT_EQ(DebugString(result.nodes), e2);
|
157 |
+
}
|
158 |
+
|
159 |
+
REGISTER_OP("HasDefaultType")
|
160 |
+
.Output("out: T")
|
161 |
+
.Attr("T: {float, double, int32, int64} = DT_FLOAT");
|
162 |
+
|
163 |
+
// This verifies that a function using an op before a type attr (with
|
164 |
+
// a default) is added, still works. This is important for backwards
|
165 |
+
// compatibility.
|
166 |
+
TEST(TFunc, MissingTypeAttr) {
|
167 |
+
auto fdef = FDH::Create(
|
168 |
+
// Name
|
169 |
+
"BackCompat",
|
170 |
+
// Args
|
171 |
+
{},
|
172 |
+
// Return values
|
173 |
+
{"y: float"},
|
174 |
+
// Attrs
|
175 |
+
{},
|
176 |
+
// Nodes
|
177 |
+
{// y = HasDefaultType(x), T missing, defaults to float
|
178 |
+
{{"a"}, "HasDefaultType", {}, {}}},
|
179 |
+
// Returns
|
180 |
+
{{"y", "a:out:0"}});
|
181 |
+
|
182 |
+
const char* e = R"P(
|
183 |
+
BackCompat() -> (y:float) {
|
184 |
+
a = HasDefaultType()
|
185 |
+
return y = a:out:0
|
186 |
+
}
|
187 |
+
)P";
|
188 |
+
EXPECT_EQ(DebugString(fdef), e);
|
189 |
+
|
190 |
+
InstantiationResult result;
|
191 |
+
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
|
192 |
+
// Should get T=float from Op's default.
|
193 |
+
const char* e2 = R"P(
|
194 |
+
() -> (a:float) {
|
195 |
+
a = HasDefaultType[T=float]()
|
196 |
+
}
|
197 |
+
)P";
|
198 |
+
EXPECT_EQ(result.arg_types, DataTypeVector());
|
199 |
+
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
200 |
+
EXPECT_EQ(DebugString(result.nodes), e2);
|
201 |
+
}
|
202 |
+
|
203 |
+
TEST(TFunc, NTimesT) {
|
204 |
+
auto fdef = FDH::Create(
|
205 |
+
// Name
|
206 |
+
"NTimesT",
|
207 |
+
// Inputs
|
208 |
+
{"x: float", "y: float"},
|
209 |
+
// Outputs
|
210 |
+
{"z: float"},
|
211 |
+
// Attrs
|
212 |
+
{},
|
213 |
+
// Nodes
|
214 |
+
{// a = AddN<N=2>(x, y)
|
215 |
+
{{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 2}}}},
|
216 |
+
// Returns
|
217 |
+
{{"z", "a:sum:0"}});
|
218 |
+
|
219 |
+
const char* e = R"P(
|
220 |
+
NTimesT(x:float, y:float) -> (z:float) {
|
221 |
+
a = AddN[N=2, T=float](x, y)
|
222 |
+
return z = a:sum:0
|
223 |
+
}
|
224 |
+
)P";
|
225 |
+
EXPECT_EQ(DebugString(fdef), e);
|
226 |
+
|
227 |
+
InstantiationResult result;
|
228 |
+
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
|
229 |
+
const char* e2 = R"P(
|
230 |
+
(x:float, y:float) -> (a:float) {
|
231 |
+
a = AddN[N=2, T=float](x, y)
|
232 |
+
}
|
233 |
+
)P";
|
234 |
+
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT}));
|
235 |
+
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
236 |
+
EXPECT_EQ(DebugString(result.nodes), e2);
|
237 |
+
}
|
238 |
+
|
239 |
+
// NOTE: This is the simplest Map op. It takes a f:T->U.
|
240 |
+
REGISTER_OP("Map")
|
241 |
+
.Input("x: N * T")
|
242 |
+
.Output("y: N * U")
|
243 |
+
.Attr("T: type")
|
244 |
+
.Attr("U: type")
|
245 |
+
.Attr("N: int >= 1")
|
246 |
+
// .Attr("func: func_name_with_attr")
|
247 |
+
.Doc(R"doc(
|
248 |
+
Applies the 'func' on every input. I.e.,
|
249 |
+
|
250 |
+
y[i] = func<...>(x[i])
|
251 |
+
|
252 |
+
x: N tensors, each of type T;
|
253 |
+
y: N tensors, each of type U;
|
254 |
+
|
255 |
+
)doc");
|
256 |
+
|
257 |
+
TEST(TFunc, AddSquared) {
|
258 |
+
auto fdef = FDH::Create(
|
259 |
+
// Name
|
260 |
+
"AddSquared",
|
261 |
+
// Args
|
262 |
+
{"x: N*T"},
|
263 |
+
// Return values
|
264 |
+
{"y: T"},
|
265 |
+
// Attrs
|
266 |
+
{"N:int", "T:{float, double, int32, int64}"},
|
267 |
+
// Nodes
|
268 |
+
{// a = Map<func=Square<$T>,T=$T,U=$T,N=$N>(x)
|
269 |
+
{{"a"},
|
270 |
+
"Map",
|
271 |
+
{"x"},
|
272 |
+
{{"func", FDH::FunctionRef("Square", {{"T", "$T"}})},
|
273 |
+
{"T", "$T"},
|
274 |
+
{"U", "$T"},
|
275 |
+
{"N", "$N"}}},
|
276 |
+
// y = AddN<N=$N,T=$T>(a)
|
277 |
+
{{"y"}, "AddN", {"a:y"}, {{"N", "$N"}, {"T", "$T"}}}},
|
278 |
+
{{"y", "y:sum"}});
|
279 |
+
|
280 |
+
const char* e = R"P(
|
281 |
+
AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
|
282 |
+
a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x)
|
283 |
+
y = AddN[N=$N, T=$T](a:y)
|
284 |
+
return y = y:sum
|
285 |
+
}
|
286 |
+
)P";
|
287 |
+
EXPECT_EQ(DebugString(fdef), e);
|
288 |
+
|
289 |
+
// Instantiate one with T=float
|
290 |
+
InstantiationResult result;
|
291 |
+
TF_ASSERT_OK(InstantiateFunction(fdef, Attrs({{"N", 3}, {"T", DT_FLOAT}}),
|
292 |
+
GetOpSig, &result));
|
293 |
+
const char* e2 = R"P(
|
294 |
+
(x_0:float, x_1:float, x_2:float) -> (y:float) {
|
295 |
+
a = Map[N=3, T=float, U=float, func=Square[T=float]](x_0, x_1, x_2)
|
296 |
+
y = AddN[N=3, T=float](a, a:1, a:2)
|
297 |
+
}
|
298 |
+
)P";
|
299 |
+
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT}));
|
300 |
+
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
301 |
+
EXPECT_EQ(DebugString(result.nodes), e2);
|
302 |
+
}
|
303 |
+
|
304 |
+
TEST(TFunc, ControlDeps) {
|
305 |
+
auto fdef = FDH::Define(
|
306 |
+
// Name
|
307 |
+
"ControlDeps",
|
308 |
+
// Args
|
309 |
+
{"x: float"},
|
310 |
+
// Return values
|
311 |
+
{},
|
312 |
+
// Attrs
|
313 |
+
{},
|
314 |
+
// Nodes
|
315 |
+
{
|
316 |
+
{{"a"}, "One", {}, {{"T", DT_FLOAT}}, {"x"}},
|
317 |
+
{{"u"}, "NoOp", {}, {}, {"a"}},
|
318 |
+
{{"b"}, "One", {}, {{"T", DT_FLOAT}}, {"u"}},
|
319 |
+
{{"v"}, "NoOp", {}, {}, {"b"}},
|
320 |
+
{{"c"}, "One", {}, {{"T", DT_FLOAT}}, {"a", "v"}},
|
321 |
+
});
|
322 |
+
const char* e = R"P(
|
323 |
+
ControlDeps(x:float) -> () {
|
324 |
+
a = One[T=float]() @ x
|
325 |
+
u = NoOp() @ a
|
326 |
+
b = One[T=float]() @ u
|
327 |
+
v = NoOp() @ b
|
328 |
+
c = One[T=float]() @ a, v
|
329 |
+
}
|
330 |
+
)P";
|
331 |
+
EXPECT_EQ(DebugString(fdef), e);
|
332 |
+
|
333 |
+
InstantiationResult result;
|
334 |
+
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
|
335 |
+
const char* e2 = R"P(
|
336 |
+
(x:float) -> () {
|
337 |
+
a = One[T=float]() @ x
|
338 |
+
u = NoOp() @ a
|
339 |
+
b = One[T=float]() @ u
|
340 |
+
v = NoOp() @ b
|
341 |
+
c = One[T=float]() @ a, v
|
342 |
+
}
|
343 |
+
)P";
|
344 |
+
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
345 |
+
EXPECT_EQ(result.ret_types, DataTypeVector({}));
|
346 |
+
EXPECT_EQ(DebugString(result.nodes), e2);
|
347 |
+
}
|
348 |
+
|
349 |
+
TEST(TFunc, XTimesTwo) {
|
350 |
+
auto expect = R"P(
|
351 |
+
XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
352 |
+
two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
|
353 |
+
scale = Cast[DstT=$T, SrcT=int64](two:output:0)
|
354 |
+
y = Mul[T=$T](x, scale:y:0)
|
355 |
+
return y = y:z:0
|
356 |
+
}
|
357 |
+
)P";
|
358 |
+
EXPECT_EQ(expect, DebugString(test::function::XTimesTwo()));
|
359 |
+
}
|
360 |
+
|
361 |
+
TEST(TFunc, WXPlusB) {
|
362 |
+
auto expect = R"P(
|
363 |
+
WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
|
364 |
+
mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
|
365 |
+
y = Add[T=$T](mm:product:0, b)
|
366 |
+
return y = y:z:0
|
367 |
+
}
|
368 |
+
)P";
|
369 |
+
EXPECT_EQ(expect, DebugString(test::function::WXPlusB()));
|
370 |
+
}
|
371 |
+
|
372 |
+
TEST(TFunc, Body_TypeList) {
|
373 |
+
const Tensor kZero = test::AsScalar<int32>(0);
|
374 |
+
auto fdef = FDH::Create(
|
375 |
+
// Name
|
376 |
+
"Test",
|
377 |
+
// Args
|
378 |
+
{"i:float"},
|
379 |
+
// Return values
|
380 |
+
{"o:float"},
|
381 |
+
// Attrs
|
382 |
+
{},
|
383 |
+
// Nodes
|
384 |
+
{{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}},
|
385 |
+
{{"s"},
|
386 |
+
"Split",
|
387 |
+
{"zero:output:0", "i"},
|
388 |
+
{{"num_split", 4}, {"T", DT_FLOAT}}},
|
389 |
+
{{"l"}, "Mul", {"s:output:0", "s:output:1"}, {{"T", DT_FLOAT}}},
|
390 |
+
{{"r"}, "Mul", {"s:output:2", "s:output:3"}, {{"T", DT_FLOAT}}},
|
391 |
+
{{"x"},
|
392 |
+
"_ListToArray",
|
393 |
+
{"l:z", "r:z"},
|
394 |
+
{{"N", 2},
|
395 |
+
{"T", DT_FLOAT},
|
396 |
+
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
|
397 |
+
{{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
|
398 |
+
{{"o", "o:sum:0"}});
|
399 |
+
|
400 |
+
const char* e = R"P(
|
401 |
+
Test(i:float) -> (o:float) {
|
402 |
+
zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
403 |
+
s = Split[T=float, num_split=4](zero:output:0, i)
|
404 |
+
l = Mul[T=float](s:output:0, s:output:1)
|
405 |
+
r = Mul[T=float](s:output:2, s:output:3)
|
406 |
+
x = _ListToArray[N=2, T=float, Tin={float, float}](l:z, r:z)
|
407 |
+
o = AddN[N=2, T=float](x:output)
|
408 |
+
return o = o:sum:0
|
409 |
+
}
|
410 |
+
)P";
|
411 |
+
EXPECT_EQ(DebugString(fdef), e);
|
412 |
+
|
413 |
+
InstantiationResult result;
|
414 |
+
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
|
415 |
+
const char* e2 = R"P(
|
416 |
+
(i:float) -> (o:float) {
|
417 |
+
zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
418 |
+
s = Split[T=float, num_split=4](zero, i)
|
419 |
+
l = Mul[T=float](s, s:1)
|
420 |
+
r = Mul[T=float](s:2, s:3)
|
421 |
+
x = _ListToArray[N=2, T=float, Tin={float, float}](l, r)
|
422 |
+
o = AddN[N=2, T=float](x, x:1)
|
423 |
+
}
|
424 |
+
)P";
|
425 |
+
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
426 |
+
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
427 |
+
EXPECT_EQ(DebugString(result.nodes), e2);
|
428 |
+
}
|
429 |
+
|
430 |
+
REGISTER_OP("Cond")
|
431 |
+
.Input("input: Tin")
|
432 |
+
.Output("output: out_types")
|
433 |
+
.Attr("Tin: list(type)")
|
434 |
+
.Attr("out_types: list(type)")
|
435 |
+
.Attr("cond: func")
|
436 |
+
.Attr("then_branch: func")
|
437 |
+
.Attr("else_branch: func")
|
438 |
+
.Doc(R"doc(
|
439 |
+
output = Cond(input) ? then_branch(input) : else_branch(input)
|
440 |
+
|
441 |
+
cond: A function takes 'input' and returns a scalar.
|
442 |
+
then_branch: A function takes 'input' and returns 'output'.
|
443 |
+
else_branch: A function takes 'input' and returns 'output'.
|
444 |
+
)doc");
|
445 |
+
|
446 |
+
TEST(TFunc, Body_Array_List_Converter) {
|
447 |
+
auto fdef = FDH::Define(
|
448 |
+
// Name
|
449 |
+
"MySelect",
|
450 |
+
// Args
|
451 |
+
{"x:float"},
|
452 |
+
// Return values
|
453 |
+
{"z:float"},
|
454 |
+
// Attrs
|
455 |
+
{},
|
456 |
+
// Nodes
|
457 |
+
{
|
458 |
+
{{"y"},
|
459 |
+
"Cond",
|
460 |
+
{"x"},
|
461 |
+
{{"Tin", DataTypeSlice{DT_FLOAT}},
|
462 |
+
{"out_types", DataTypeSlice{DT_FLOAT}},
|
463 |
+
{"cond", FDH::FunctionRef("MyCond")},
|
464 |
+
{"then_branch", FDH::FunctionRef("MyThen")},
|
465 |
+
{"else_branch", FDH::FunctionRef("MyElse")}}},
|
466 |
+
{{"z"},
|
467 |
+
"Cond",
|
468 |
+
{"y", "y"},
|
469 |
+
{{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
|
470 |
+
{"out_types", DataTypeSlice{DT_FLOAT}},
|
471 |
+
{"cond", FDH::FunctionRef("MyCond2")},
|
472 |
+
{"then_branch", FDH::FunctionRef("MyThen2")},
|
473 |
+
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
474 |
+
});
|
475 |
+
|
476 |
+
const char* e = R"P(
|
477 |
+
MySelect(x:float) -> (z:float) {
|
478 |
+
y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
|
479 |
+
z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y:output:0, y:output:0)
|
480 |
+
return z = z:output:0
|
481 |
+
}
|
482 |
+
)P";
|
483 |
+
EXPECT_EQ(DebugString(fdef), e);
|
484 |
+
|
485 |
+
InstantiationResult result;
|
486 |
+
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
|
487 |
+
const char* e2 = R"P(
|
488 |
+
(x:float) -> (z:float) {
|
489 |
+
y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
|
490 |
+
z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y)
|
491 |
+
}
|
492 |
+
)P";
|
493 |
+
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
494 |
+
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
495 |
+
EXPECT_EQ(DebugString(result.nodes), e2);
|
496 |
+
}
|
497 |
+
|
498 |
+
static void HasError(const Status& s, const string& substr) {
|
499 |
+
EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
|
500 |
+
<< ">>" << s << "<<, expected substring >>" << substr << "<<";
|
501 |
+
}
|
502 |
+
|
503 |
+
TEST(InstantiateErrors, Not_Sufficient_Attrs) {
|
504 |
+
auto fdef =
|
505 |
+
FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
|
506 |
+
InstantiationResult result;
|
507 |
+
HasError(
|
508 |
+
InstantiateFunction(fdef, Attrs({{"U", DT_FLOAT}}), GetOpSig, &result),
|
509 |
+
"Attr T is not found from ");
|
510 |
+
}
|
511 |
+
|
512 |
+
#if 0 // TODO(josh11b): Enable this test once having an extra attr is an error.
|
513 |
+
TEST(InstantiateErrors, Too_Many_Attrs) {
|
514 |
+
auto fdef =
|
515 |
+
FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
|
516 |
+
InstantiationResult result;
|
517 |
+
HasError(InstantiateFunction(fdef, Attrs({{"T", DT_INT32}, {"U", DT_FLOAT}}),
|
518 |
+
GetOpSig, &result),
|
519 |
+
"Attr U is not found in ");
|
520 |
+
}
|
521 |
+
#endif
|
522 |
+
|
523 |
+
TEST(InstantiateErrors, AttrValue_Value_Placeholder) {
|
524 |
+
auto fdef =
|
525 |
+
FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
|
526 |
+
InstantiationResult result;
|
527 |
+
HasError(
|
528 |
+
InstantiateFunction(fdef, Attrs({{"T", "$bad"}}), GetOpSig, &result),
|
529 |
+
"AttrValue had value with unexpected type 'placeholder'\n\tfor attr 'T'");
|
530 |
+
}
|
531 |
+
|
532 |
+
TEST(InstantiateErrors, Unbounded_Attr) {
|
533 |
+
auto fdef = FDH::Define("test", {}, {}, {"T:{float, double, int32, int64}"},
|
534 |
+
{
|
535 |
+
{{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}},
|
536 |
+
});
|
537 |
+
InstantiationResult result;
|
538 |
+
HasError(
|
539 |
+
InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result),
|
540 |
+
"Failed to bind all placeholders");
|
541 |
+
}
|
542 |
+
|
543 |
+
TEST(InstantiateErrors, DupArgs) {
|
544 |
+
auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {});
|
545 |
+
InstantiationResult result;
|
546 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
547 |
+
"Duplicated arg name");
|
548 |
+
}
|
549 |
+
|
550 |
+
TEST(InstantiateErrors, Dup_Node_Names) {
|
551 |
+
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
552 |
+
{
|
553 |
+
{{"y"}, "One", {}, {{"T", DT_FLOAT}}},
|
554 |
+
{{"y"}, "One", {}, {{"T", DT_FLOAT}}},
|
555 |
+
});
|
556 |
+
InstantiationResult result;
|
557 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
558 |
+
"Duplicated ret name");
|
559 |
+
}
|
560 |
+
|
561 |
+
TEST(InstantiateErrors, Node_Arg_Notfound) {
|
562 |
+
auto fdef = FDH::Create("test", {"x:float"}, {}, {},
|
563 |
+
{
|
564 |
+
{{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}},
|
565 |
+
},
|
566 |
+
{});
|
567 |
+
InstantiationResult result;
|
568 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
569 |
+
"input z is not found");
|
570 |
+
}
|
571 |
+
|
572 |
+
TEST(InstantiateErrors, Node_Arg_TypeMismatch) {
|
573 |
+
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
574 |
+
{
|
575 |
+
{{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
|
576 |
+
});
|
577 |
+
InstantiationResult result;
|
578 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
579 |
+
"input x[0] expected type int32 != float, the type of x[0]");
|
580 |
+
}
|
581 |
+
|
582 |
+
TEST(InstantiateErrors, Node_Arg_ControlMissing) {
|
583 |
+
auto fdef =
|
584 |
+
FDH::Define("test", {"x:float"}, {}, {},
|
585 |
+
{
|
586 |
+
{{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}},
|
587 |
+
});
|
588 |
+
InstantiationResult result;
|
589 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
590 |
+
"input[2] == '^z', is not found.");
|
591 |
+
}
|
592 |
+
|
593 |
+
TEST(InstantiateErrors, FuncRet_Missing) {
|
594 |
+
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
|
595 |
+
{
|
596 |
+
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
597 |
+
},
|
598 |
+
{});
|
599 |
+
InstantiationResult result;
|
600 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
601 |
+
"Return y missing");
|
602 |
+
}
|
603 |
+
|
604 |
+
TEST(InstantiateErrors, FuncRet_NotFound) {
|
605 |
+
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
|
606 |
+
{
|
607 |
+
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
608 |
+
},
|
609 |
+
{{"y", "z"}});
|
610 |
+
InstantiationResult result;
|
611 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
612 |
+
"Return y -> z is not found");
|
613 |
+
}
|
614 |
+
|
615 |
+
TEST(InstantiateErrors, FuncRet_NameMismatch) {
|
616 |
+
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
|
617 |
+
{
|
618 |
+
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
619 |
+
},
|
620 |
+
{{"z", "x:y:0"}});
|
621 |
+
InstantiationResult result;
|
622 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
623 |
+
"Return y missing");
|
624 |
+
}
|
625 |
+
|
626 |
+
// TODO(josh11b): Make this an error.
|
627 |
+
// TEST(InstantiateErrors, FuncRet_Extra) {
|
628 |
+
// auto fdef = FDH::Create("test", {}, {"y: float"}, {},
|
629 |
+
// {
|
630 |
+
// {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
631 |
+
// },
|
632 |
+
// {{"y", "x:y:0"}, {"z", "x:y:0"}});
|
633 |
+
// InstantiationResult result;
|
634 |
+
// HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
635 |
+
// "ret is not found");
|
636 |
+
// }
|
637 |
+
|
638 |
+
TEST(InstantiateErrors, FuncRet_TypeMismatch) {
|
639 |
+
auto fdef = FDH::Define("test", {}, {"y: float"}, {},
|
640 |
+
{
|
641 |
+
{{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
|
642 |
+
});
|
643 |
+
InstantiationResult result;
|
644 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
645 |
+
"Invalid ret types y : float vs. double\n\tIn function output y");
|
646 |
+
}
|
647 |
+
|
648 |
+
TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
|
649 |
+
auto fdef = FDH::Create(
|
650 |
+
// Name
|
651 |
+
"MySelect",
|
652 |
+
// Args
|
653 |
+
{"x: float"},
|
654 |
+
// Return values
|
655 |
+
{"y: float"},
|
656 |
+
// Attrs
|
657 |
+
{},
|
658 |
+
// Nodes
|
659 |
+
{
|
660 |
+
{{"y"},
|
661 |
+
"Cond",
|
662 |
+
{"x", "x"},
|
663 |
+
{{"tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
|
664 |
+
{"cond", FDH::FunctionRef("MyCond2")},
|
665 |
+
{"then_branch", FDH::FunctionRef("MyThen2")},
|
666 |
+
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
667 |
+
},
|
668 |
+
{{"y", "y:output"}});
|
669 |
+
InstantiationResult result;
|
670 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
671 |
+
"type attr not found: out_types");
|
672 |
+
}
|
673 |
+
|
674 |
+
TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
|
675 |
+
auto fdef = FDH::Create(
|
676 |
+
// Name
|
677 |
+
"MySelect",
|
678 |
+
// Args
|
679 |
+
{"x: float"},
|
680 |
+
// Return values
|
681 |
+
{"y: float"},
|
682 |
+
// Attrs
|
683 |
+
{},
|
684 |
+
// Nodes
|
685 |
+
{
|
686 |
+
{{"y"},
|
687 |
+
"Cond",
|
688 |
+
{"x", "x"},
|
689 |
+
{{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
|
690 |
+
{"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
|
691 |
+
{"cond", FDH::FunctionRef("MyCond2")},
|
692 |
+
{"then_branch", FDH::FunctionRef("MyThen2")},
|
693 |
+
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
694 |
+
},
|
695 |
+
{{"y", "y:output"}});
|
696 |
+
InstantiationResult result;
|
697 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
698 |
+
"Invalid ret types");
|
699 |
+
}
|
700 |
+
|
701 |
+
TEST(InstantiateErrors, TypeList_Missing_Arg) {
|
702 |
+
auto fdef = FDH::Create(
|
703 |
+
// Name
|
704 |
+
"MySelect",
|
705 |
+
// Args
|
706 |
+
{"x: float"},
|
707 |
+
// Return values
|
708 |
+
{"y: float"},
|
709 |
+
// Attrs
|
710 |
+
{},
|
711 |
+
// Nodes
|
712 |
+
{
|
713 |
+
{{"y"},
|
714 |
+
"Cond",
|
715 |
+
{"x", "unknown"},
|
716 |
+
{{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
|
717 |
+
{"out_types", DataTypeSlice{DT_FLOAT}},
|
718 |
+
{"cond", FDH::FunctionRef("MyCond2")},
|
719 |
+
{"then_branch", FDH::FunctionRef("MyThen2")},
|
720 |
+
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
721 |
+
},
|
722 |
+
{{"y", "y:output"}});
|
723 |
+
InstantiationResult result;
|
724 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
725 |
+
"input unknown is not found");
|
726 |
+
}
|
727 |
+
|
728 |
+
TEST(InstantiateErrors, TooManyInputs) {
|
729 |
+
auto fdef = FDH::Create(
|
730 |
+
// Name
|
731 |
+
"TooManyInputs",
|
732 |
+
// Inputs
|
733 |
+
{"x: float", "y: float"},
|
734 |
+
// Outputs
|
735 |
+
{"z: float"},
|
736 |
+
// Attrs
|
737 |
+
{},
|
738 |
+
// Nodes
|
739 |
+
{// a = AddN<N=2>(x, y, x)
|
740 |
+
{{"a"}, "AddN", {"x", "y", "x"}, {{"T", DT_FLOAT}, {"N", 2}}}},
|
741 |
+
// Returns
|
742 |
+
{{"z", "a:sum:0"}});
|
743 |
+
|
744 |
+
InstantiationResult result;
|
745 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
746 |
+
"Expected input[2] == 'x' to be a control input.");
|
747 |
+
}
|
748 |
+
|
749 |
+
TEST(InstantiateErrors, TooFewInputs) {
|
750 |
+
auto fdef = FDH::Create(
|
751 |
+
// Name
|
752 |
+
"TooFewInputs",
|
753 |
+
// Inputs
|
754 |
+
{"x: float", "y: float"},
|
755 |
+
// Outputs
|
756 |
+
{"z: float"},
|
757 |
+
// Attrs
|
758 |
+
{},
|
759 |
+
// Nodes
|
760 |
+
{// a = AddN<N=3>(x, y)
|
761 |
+
{{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 3}}}},
|
762 |
+
// Returns
|
763 |
+
{{"z", "a:sum:0"}});
|
764 |
+
|
765 |
+
InstantiationResult result;
|
766 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
767 |
+
"Attempt to access beyond input size: 2 >= 2");
|
768 |
+
}
|
769 |
+
|
770 |
+
TEST(InstantiateErrors, TooManyInputsFromArray1) {
|
771 |
+
auto fdef = FDH::Create(
|
772 |
+
// Name
|
773 |
+
"TooManyInputsFromArray",
|
774 |
+
// Inputs
|
775 |
+
{"x: float", "y: float"},
|
776 |
+
// Outputs
|
777 |
+
{"z: float"},
|
778 |
+
// Attrs
|
779 |
+
{},
|
780 |
+
// Nodes
|
781 |
+
{// a = _ListToArray(x,y)
|
782 |
+
{{"a"},
|
783 |
+
"_ListToArray",
|
784 |
+
{"x", "y"},
|
785 |
+
{{"N", 2},
|
786 |
+
{"T", DT_FLOAT},
|
787 |
+
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
|
788 |
+
// b = AddN<N=2>(a, y)
|
789 |
+
{{"b"}, "AddN", {"a:output", "y"}, {{"T", DT_FLOAT}, {"N", 2}}}},
|
790 |
+
// Returns
|
791 |
+
{{"z", "a:sum:0"}});
|
792 |
+
|
793 |
+
InstantiationResult result;
|
794 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
795 |
+
"Expected input[1] == 'y' to be a control input.");
|
796 |
+
}
|
797 |
+
|
798 |
+
TEST(InstantiateErrors, TooManyInputsFromArray2) {
|
799 |
+
auto fdef = FDH::Create(
|
800 |
+
// Name
|
801 |
+
"TooManyInputsFromArray",
|
802 |
+
// Inputs
|
803 |
+
{"x: float", "y: float"},
|
804 |
+
// Outputs
|
805 |
+
{"z: float"},
|
806 |
+
// Attrs
|
807 |
+
{},
|
808 |
+
// Nodes
|
809 |
+
{// a = _ListToArray(x,y)
|
810 |
+
{{"a"},
|
811 |
+
"_ListToArray",
|
812 |
+
{"x", "y"},
|
813 |
+
{{"N", 2},
|
814 |
+
{"T", DT_FLOAT},
|
815 |
+
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
|
816 |
+
// b = AddN<N=2>(x, a)
|
817 |
+
{{"b"}, "AddN", {"x", "a:output"}, {{"T", DT_FLOAT}, {"N", 2}}}},
|
818 |
+
// Returns
|
819 |
+
{{"z", "a:sum:0"}});
|
820 |
+
|
821 |
+
InstantiationResult result;
|
822 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
823 |
+
"Input a:output too long for inputs");
|
824 |
+
}
|
825 |
+
|
826 |
+
TEST(InstantiateErrors, TypeMismatch) {
|
827 |
+
auto fdef = FDH::Create(
|
828 |
+
// Name
|
829 |
+
"TypeMismatch",
|
830 |
+
// Inputs
|
831 |
+
{"x: float", "y: int32"},
|
832 |
+
// Outputs
|
833 |
+
{"z: float"},
|
834 |
+
// Attrs
|
835 |
+
{},
|
836 |
+
// Nodes
|
837 |
+
{// a = AddN<N=2>(x, y)
|
838 |
+
{{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 3}}}},
|
839 |
+
// Returns
|
840 |
+
{{"z", "a:sum:0"}});
|
841 |
+
|
842 |
+
InstantiationResult result;
|
843 |
+
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
844 |
+
"input inputs[1] expected type float != int32, the type of y[0]");
|
845 |
+
}
|
846 |
+
|
847 |
+
TEST(FunctionCallFrame, Void_Void) {
|
848 |
+
FunctionCallFrame frame({}, {});
|
849 |
+
TF_EXPECT_OK(frame.SetArgs({}));
|
850 |
+
auto a = test::AsTensor<float>({100});
|
851 |
+
HasError(frame.SetArgs({a}), "Invalid argument");
|
852 |
+
Tensor v;
|
853 |
+
HasError(frame.GetArg(0, &v), "Invalid argument");
|
854 |
+
HasError(frame.SetRetval(0, v), "Invalid argument");
|
855 |
+
std::vector<Tensor> rets;
|
856 |
+
TF_EXPECT_OK(frame.GetRetvals(&rets));
|
857 |
+
EXPECT_EQ(rets.size(), 0);
|
858 |
+
}
|
859 |
+
|
860 |
+
TEST(FunctionCallFrame, Float_Float_Float) {
|
861 |
+
FunctionCallFrame frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT});
|
862 |
+
HasError(frame.SetArgs({}), "Invalid argument: Expects 2 arguments");
|
863 |
+
auto a = test::AsTensor<float>({100});
|
864 |
+
auto b = test::AsTensor<float>({200});
|
865 |
+
auto c = test::AsTensor<int64>({300});
|
866 |
+
HasError(frame.SetArgs({a, c}),
|
867 |
+
"Invalid argument: Expects arg[1] to be float");
|
868 |
+
TF_EXPECT_OK(frame.SetArgs({a, b}));
|
869 |
+
|
870 |
+
Tensor v;
|
871 |
+
HasError(frame.GetArg(-1, &v), "Invalid argument");
|
872 |
+
HasError(frame.GetArg(2, &v), "Invalid argument");
|
873 |
+
TF_EXPECT_OK(frame.GetArg(0, &v));
|
874 |
+
test::ExpectTensorEqual<float>(a, v);
|
875 |
+
TF_EXPECT_OK(frame.GetArg(1, &v));
|
876 |
+
test::ExpectTensorEqual<float>(b, v);
|
877 |
+
|
878 |
+
v = test::AsTensor<float>({-100});
|
879 |
+
HasError(frame.SetRetval(-1, v), "Invalid argument");
|
880 |
+
HasError(frame.SetRetval(1, v), "Invalid argument");
|
881 |
+
HasError(frame.SetRetval(0, test::AsTensor<int64>({-100})),
|
882 |
+
"Invalid argument: Expects ret[0] to be float");
|
883 |
+
|
884 |
+
std::vector<Tensor> rets;
|
885 |
+
HasError(frame.GetRetvals(&rets), "does not have value");
|
886 |
+
TF_EXPECT_OK(frame.SetRetval(0, v));
|
887 |
+
HasError(frame.SetRetval(0, v), "has already been set");
|
888 |
+
|
889 |
+
TF_EXPECT_OK(frame.GetRetvals(&rets));
|
890 |
+
EXPECT_EQ(rets.size(), 1);
|
891 |
+
test::ExpectTensorEqual<float>(rets[0], v);
|
892 |
+
}
|
893 |
+
|
894 |
+
TEST(Canonicalize, Basic) {
|
895 |
+
EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT},
|
896 |
+
{"transpose_a", false},
|
897 |
+
{"transpose_b", false}})),
|
898 |
+
"MatMul[T=float,transpose_a=false,transpose_b=false]");
|
899 |
+
EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT},
|
900 |
+
{"transpose_b", false},
|
901 |
+
{"transpose_a", false}})),
|
902 |
+
"MatMul[T=float,transpose_a=false,transpose_b=false]");
|
903 |
+
EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_DOUBLE},
|
904 |
+
{"transpose_b", true},
|
905 |
+
{"transpose_a", false}})),
|
906 |
+
"MatMul[T=double,transpose_a=false,transpose_b=true]");
|
907 |
+
}
|
908 |
+
|
909 |
+
TEST(FunctionLibraryDefinitionTest, Find) {
|
910 |
+
FunctionDefLibrary proto;
|
911 |
+
*proto.add_function() = test::function::XTimesTwo();
|
912 |
+
FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
|
913 |
+
|
914 |
+
EXPECT_EQ(lib_def.Find("XTimes16"), nullptr);
|
915 |
+
|
916 |
+
auto expect = R"P(
|
917 |
+
XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
918 |
+
two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
|
919 |
+
scale = Cast[DstT=$T, SrcT=int64](two:output:0)
|
920 |
+
y = Mul[T=$T](x, scale:y:0)
|
921 |
+
return y = y:z:0
|
922 |
+
}
|
923 |
+
)P";
|
924 |
+
auto found = lib_def.Find("XTimesTwo");
|
925 |
+
ASSERT_NE(found, nullptr);
|
926 |
+
EXPECT_EQ(expect, DebugString(*found));
|
927 |
+
}
|
928 |
+
|
929 |
+
TEST(FunctionLibraryDefinitionTest, LookUp) {
|
930 |
+
FunctionDefLibrary proto;
|
931 |
+
*proto.add_function() = test::function::XTimesTwo();
|
932 |
+
FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
|
933 |
+
|
934 |
+
const OpDef* op_def;
|
935 |
+
EXPECT_TRUE(!lib_def.LookUpOpDef("XTimes16", &op_def).ok());
|
936 |
+
|
937 |
+
TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &op_def));
|
938 |
+
ASSERT_NE(op_def, nullptr);
|
939 |
+
EXPECT_EQ(op_def->DebugString(),
|
940 |
+
test::function::XTimesTwo().signature().DebugString());
|
941 |
+
|
942 |
+
const OpRegistrationData* op_reg_data;
|
943 |
+
TF_EXPECT_OK(lib_def.LookUp("XTimesTwo", &op_reg_data));
|
944 |
+
ASSERT_NE(op_reg_data, nullptr);
|
945 |
+
// Shape inference function is initialized to UnknownShape.
|
946 |
+
ASSERT_NE(op_reg_data->shape_inference_fn, nullptr);
|
947 |
+
}
|
948 |
+
|
949 |
+
TEST(FunctionLibraryDefinitionTest, AddFunctionDef) {
|
950 |
+
// Add one function to the proto lib before constructing 'lib_def'.
|
951 |
+
FunctionDefLibrary proto;
|
952 |
+
*proto.add_function() = test::function::XTimesTwo();
|
953 |
+
FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
|
954 |
+
|
955 |
+
// Add a new function def to the library.
|
956 |
+
TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB()));
|
957 |
+
|
958 |
+
// Test lookup of first function.
|
959 |
+
const OpDef* first;
|
960 |
+
TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &first));
|
961 |
+
ASSERT_NE(first, nullptr);
|
962 |
+
EXPECT_EQ(first->DebugString(),
|
963 |
+
test::function::XTimesTwo().signature().DebugString());
|
964 |
+
|
965 |
+
// Test lookup of second function.
|
966 |
+
const OpDef* second;
|
967 |
+
TF_EXPECT_OK(lib_def.LookUpOpDef("WXPlusB", &second));
|
968 |
+
ASSERT_NE(second, nullptr);
|
969 |
+
EXPECT_EQ(second->DebugString(),
|
970 |
+
test::function::WXPlusB().signature().DebugString());
|
971 |
+
|
972 |
+
// Can't add function with same name as existing op
|
973 |
+
FunctionDef fdef = test::function::XTimesTwo();
|
974 |
+
fdef.mutable_signature()->set_name("Add");
|
975 |
+
Status s = lib_def.AddFunctionDef(fdef);
|
976 |
+
EXPECT_FALSE(s.ok());
|
977 |
+
EXPECT_EQ(s.error_message(),
|
978 |
+
"Cannot add function 'Add' because an op with the same name "
|
979 |
+
"already exists.");
|
980 |
+
|
981 |
+
// Already-added functions don't produce error
|
982 |
+
TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::XTimesTwo()));
|
983 |
+
TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB()));
|
984 |
+
}
|
985 |
+
|
986 |
+
TEST(FunctionLibraryDefinitionTest, AddGradientDef) {
|
987 |
+
// AddGradientDef() doesn't check that functions referenced exist (yet?)
|
988 |
+
FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary());
|
989 |
+
|
990 |
+
// Test adding a gradient (XTimesFour isn't a valid grad function for
|
991 |
+
// XTimesTwo but that's ok for now)
|
992 |
+
GradientDef grad;
|
993 |
+
grad.set_function_name(test::function::XTimesTwo().signature().name());
|
994 |
+
grad.set_gradient_func(test::function::XTimesFour().signature().name());
|
995 |
+
TF_EXPECT_OK(lib_def.AddGradientDef(grad));
|
996 |
+
|
997 |
+
// Already-added gradients don't produce error
|
998 |
+
TF_EXPECT_OK(lib_def.AddGradientDef(grad));
|
999 |
+
|
1000 |
+
// Test that adding a duplicate gradient fails
|
1001 |
+
grad.set_gradient_func(test::function::XTimes16().signature().name());
|
1002 |
+
Status s = lib_def.AddGradientDef(grad);
|
1003 |
+
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
|
1004 |
+
EXPECT_EQ(s.error_message(),
|
1005 |
+
"Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because "
|
1006 |
+
"it already has gradient function 'XTimesFour'");
|
1007 |
+
}
|
1008 |
+
|
1009 |
+
TEST(FunctionLibraryDefinitionTest, AddLibrary) {
|
1010 |
+
// Create lib def with single function
|
1011 |
+
FunctionDefLibrary proto;
|
1012 |
+
*proto.add_function() = test::function::XTimesTwo();
|
1013 |
+
FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
|
1014 |
+
|
1015 |
+
// Add gradient
|
1016 |
+
GradientDef grad;
|
1017 |
+
grad.set_function_name(test::function::XTimesTwo().signature().name());
|
1018 |
+
grad.set_gradient_func(test::function::XTimesFour().signature().name());
|
1019 |
+
TF_EXPECT_OK(lib_def.AddGradientDef(grad));
|
1020 |
+
|
1021 |
+
// Error if you try to add conflicting function
|
1022 |
+
proto.Clear();
|
1023 |
+
FunctionDef fdef = test::function::XTimesFour();
|
1024 |
+
fdef.mutable_signature()->set_name(
|
1025 |
+
test::function::XTimesTwo().signature().name());
|
1026 |
+
*proto.add_function() = fdef;
|
1027 |
+
FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto);
|
1028 |
+
Status s = lib_def.AddLibrary(lib_def2);
|
1029 |
+
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
|
1030 |
+
EXPECT_EQ(s.error_message(),
|
1031 |
+
"Cannot add function 'XTimesTwo' because a different function with "
|
1032 |
+
"the same name already exists.");
|
1033 |
+
|
1034 |
+
// Error if you try to add conflicting gradient
|
1035 |
+
proto.Clear();
|
1036 |
+
grad.set_gradient_func(test::function::XTimes16().signature().name());
|
1037 |
+
*proto.add_gradient() = grad;
|
1038 |
+
FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto);
|
1039 |
+
s = lib_def.AddLibrary(lib_def3);
|
1040 |
+
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
|
1041 |
+
EXPECT_EQ(s.error_message(),
|
1042 |
+
"Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because "
|
1043 |
+
"it already has gradient function 'XTimesFour'");
|
1044 |
+
|
1045 |
+
// No conflicting functions or gradients OK
|
1046 |
+
proto.Clear();
|
1047 |
+
*proto.add_function() = test::function::XTimesFour();
|
1048 |
+
grad.set_function_name(test::function::XTimes16().signature().name());
|
1049 |
+
*proto.add_gradient() = grad;
|
1050 |
+
FunctionLibraryDefinition lib_def4(OpRegistry::Global(), proto);
|
1051 |
+
TF_EXPECT_OK(lib_def.AddLibrary(lib_def4));
|
1052 |
+
|
1053 |
+
// OK to add the same functions and gradients twice
|
1054 |
+
TF_EXPECT_OK(lib_def.AddLibrary(lib_def));
|
1055 |
+
}
|
1056 |
+
|
1057 |
+
GradientDef MakeGradDef(const string& f, const string& g) {
|
1058 |
+
GradientDef grad;
|
1059 |
+
grad.set_function_name(f);
|
1060 |
+
grad.set_gradient_func(g);
|
1061 |
+
return grad;
|
1062 |
+
}
|
1063 |
+
|
1064 |
+
TEST(FunctionLibraryDefinitionTest, AddLibrary_Atomic) {
|
1065 |
+
// Create lib def containing two functions with equal names
|
1066 |
+
FunctionDefLibrary proto;
|
1067 |
+
const string x2_name = test::function::XTimesTwo().signature().name();
|
1068 |
+
const string x4_name = test::function::XTimesFour().signature().name();
|
1069 |
+
*proto.add_function() = test::function::XTimesTwo();
|
1070 |
+
FunctionDef fdef = test::function::XTimesFour();
|
1071 |
+
fdef.mutable_signature()->set_name(x2_name);
|
1072 |
+
*proto.add_function() = fdef;
|
1073 |
+
FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary());
|
1074 |
+
|
1075 |
+
// Try adding the two functions to lib_def
|
1076 |
+
Status s = lib_def.AddLibrary(proto);
|
1077 |
+
EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
|
1078 |
+
EXPECT_EQ(
|
1079 |
+
"Cannot add function 'XTimesTwo' because a different function with "
|
1080 |
+
"the same name already exists.",
|
1081 |
+
s.error_message());
|
1082 |
+
|
1083 |
+
// Verify that none of the functions are added
|
1084 |
+
EXPECT_TRUE(lib_def.Find(x2_name) == nullptr);
|
1085 |
+
|
1086 |
+
// Fix the name in proto but add two gradient names for it
|
1087 |
+
proto.mutable_function(1)->mutable_signature()->set_name(x4_name);
|
1088 |
+
*proto.add_gradient() = MakeGradDef(x2_name, x4_name);
|
1089 |
+
*proto.add_gradient() = MakeGradDef(x2_name, "SecondGradName");
|
1090 |
+
|
1091 |
+
// Try adding the library and check that nothing was added
|
1092 |
+
s = lib_def.AddLibrary(proto);
|
1093 |
+
EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
|
1094 |
+
EXPECT_EQ(s.error_message(),
|
1095 |
+
"Cannot assign gradient function 'SecondGradName' to 'XTimesTwo' "
|
1096 |
+
"because it already has gradient function 'XTimesFour'");
|
1097 |
+
EXPECT_TRUE(lib_def.Find(x2_name) == nullptr);
|
1098 |
+
EXPECT_EQ(0, lib_def.ToProto().function_size());
|
1099 |
+
EXPECT_EQ(0, lib_def.ToProto().gradient_size());
|
1100 |
+
}
|
1101 |
+
|
1102 |
+
TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_FuncConflict) {
|
1103 |
+
const string x2_name = test::function::XTimesTwo().signature().name();
|
1104 |
+
const string x4_name = test::function::XTimesFour().signature().name();
|
1105 |
+
const string wx_name = test::function::WXPlusB().signature().name();
|
1106 |
+
|
1107 |
+
// Create FunctionLibraryDefinition with
|
1108 |
+
// (func = XTimesTwo, grad = XTimesFour)
|
1109 |
+
FunctionDefLibrary proto;
|
1110 |
+
*proto.add_function() = test::function::XTimesTwo();
|
1111 |
+
*proto.add_gradient() = MakeGradDef(x2_name, x4_name);
|
1112 |
+
FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
|
1113 |
+
EXPECT_EQ(1, lib_def.ToProto().function_size());
|
1114 |
+
EXPECT_EQ(1, lib_def.ToProto().gradient_size());
|
1115 |
+
|
1116 |
+
// Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo)
|
1117 |
+
// and function (name = XTimesTwo, body = XTimeFour)
|
1118 |
+
FunctionDefLibrary proto2;
|
1119 |
+
*proto2.add_function() = test::function::WXPlusB();
|
1120 |
+
*proto2.add_gradient() = MakeGradDef(wx_name, x2_name);
|
1121 |
+
*proto2.add_function() = test::function::XTimesFour();
|
1122 |
+
proto2.mutable_function(1)->mutable_signature()->set_name(x2_name);
|
1123 |
+
FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2);
|
1124 |
+
|
1125 |
+
// Verify that adding lib_def2 will fail because of function conflict
|
1126 |
+
// and WXPlusB is not added.
|
1127 |
+
Status s = lib_def.AddLibrary(lib_def2);
|
1128 |
+
EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
|
1129 |
+
EXPECT_EQ(
|
1130 |
+
"Cannot add function 'XTimesTwo' because a different function "
|
1131 |
+
"with the same name already exists.",
|
1132 |
+
s.error_message());
|
1133 |
+
EXPECT_TRUE(lib_def.Find(wx_name) == nullptr);
|
1134 |
+
EXPECT_EQ(1, lib_def.ToProto().function_size());
|
1135 |
+
EXPECT_EQ(1, lib_def.ToProto().gradient_size());
|
1136 |
+
}
|
1137 |
+
|
1138 |
+
TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_GradConflict) {
|
1139 |
+
const string x2_name = test::function::XTimesTwo().signature().name();
|
1140 |
+
const string x4_name = test::function::XTimesFour().signature().name();
|
1141 |
+
const string wx_name = test::function::WXPlusB().signature().name();
|
1142 |
+
|
1143 |
+
// Create FunctionLibraryDefinition with
|
1144 |
+
// (func = XTimesTwo, grad = XTimesFour)
|
1145 |
+
FunctionDefLibrary proto;
|
1146 |
+
*proto.add_function() = test::function::XTimesTwo();
|
1147 |
+
*proto.add_gradient() = MakeGradDef(x2_name, x4_name);
|
1148 |
+
FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
|
1149 |
+
EXPECT_EQ(1, lib_def.ToProto().function_size());
|
1150 |
+
EXPECT_EQ(1, lib_def.ToProto().gradient_size());
|
1151 |
+
|
1152 |
+
// Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo)
|
1153 |
+
// and (func = XTimesTwo, grad = WXPlusB)
|
1154 |
+
FunctionDefLibrary proto2;
|
1155 |
+
*proto2.add_function() = test::function::WXPlusB();
|
1156 |
+
*proto2.add_gradient() = MakeGradDef(wx_name, x2_name);
|
1157 |
+
*proto2.add_function() = test::function::XTimesTwo();
|
1158 |
+
*proto2.add_gradient() = MakeGradDef(x2_name, wx_name);
|
1159 |
+
FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2);
|
1160 |
+
|
1161 |
+
// Verify that adding lib_def2 will fail because of gradient conflict
|
1162 |
+
// and WXPlusB is not added.
|
1163 |
+
Status s = lib_def.AddLibrary(lib_def2);
|
1164 |
+
EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
|
1165 |
+
EXPECT_EQ(
|
1166 |
+
"Cannot assign gradient function 'WXPlusB' to 'XTimesTwo'"
|
1167 |
+
" because it already has gradient function 'XTimesFour'",
|
1168 |
+
s.error_message());
|
1169 |
+
EXPECT_TRUE(lib_def.Find(wx_name) == nullptr);
|
1170 |
+
EXPECT_EQ(1, lib_def.ToProto().function_size());
|
1171 |
+
EXPECT_EQ(1, lib_def.ToProto().gradient_size());
|
1172 |
+
}
|
1173 |
+
|
1174 |
+
TEST(FunctionLibraryDefinitionTest, ToProto) {
|
1175 |
+
FunctionDefLibrary proto1;
|
1176 |
+
*proto1.add_function() = test::function::XTimesTwo();
|
1177 |
+
*proto1.add_function() = test::function::WXPlusB();
|
1178 |
+
FunctionLibraryDefinition lib_def1(OpRegistry::Global(), proto1);
|
1179 |
+
|
1180 |
+
// Call 'ToProto' and make sure both protos have the same function lib size.
|
1181 |
+
FunctionDefLibrary proto2 = lib_def1.ToProto();
|
1182 |
+
EXPECT_EQ(proto1.function_size(), proto2.function_size());
|
1183 |
+
|
1184 |
+
// Initialize 'lib_def2' with proto returned by 'ToProto' call.
|
1185 |
+
FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2);
|
1186 |
+
|
1187 |
+
// Test that the first function exists in both libraries.
|
1188 |
+
const OpDef *f1, *f2, *f3, *f4;
|
1189 |
+
TF_EXPECT_OK(lib_def1.LookUpOpDef("XTimesTwo", &f1));
|
1190 |
+
TF_EXPECT_OK(lib_def2.LookUpOpDef("XTimesTwo", &f2));
|
1191 |
+
EXPECT_EQ(f1->DebugString(), f2->DebugString());
|
1192 |
+
|
1193 |
+
// Test that the second function exists in both libraries.
|
1194 |
+
TF_EXPECT_OK(lib_def1.LookUpOpDef("WXPlusB", &f3));
|
1195 |
+
TF_EXPECT_OK(lib_def2.LookUpOpDef("WXPlusB", &f4));
|
1196 |
+
EXPECT_EQ(f3->DebugString(), f4->DebugString());
|
1197 |
+
}
|
1198 |
+
|
1199 |
+
TEST(FunctionLibraryDefinitionTest, GetAttr_FuncNoAttr) {
|
1200 |
+
FunctionDefLibrary proto;
|
1201 |
+
*proto.add_function() = test::function::XTimesTwo();
|
1202 |
+
FunctionLibraryDefinition lib(OpRegistry::Global(), proto);
|
1203 |
+
|
1204 |
+
NodeDef ndef;
|
1205 |
+
bool annotation;
|
1206 |
+
|
1207 |
+
// Not a function.
|
1208 |
+
ndef.set_op("Matmul");
|
1209 |
+
EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
|
1210 |
+
|
1211 |
+
// A function. No attr defined.
|
1212 |
+
ndef.set_op("XTimesTwo");
|
1213 |
+
EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
|
1214 |
+
|
1215 |
+
// ndef defines the attr. But we don't care.
|
1216 |
+
AddNodeAttr("annotation", true, &ndef);
|
1217 |
+
EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
|
1218 |
+
}
|
1219 |
+
|
1220 |
+
template <typename T>
|
1221 |
+
void SetAttrValue(FunctionDef* fdef, const string& attr, const T& value) {
|
1222 |
+
AttrValue attr_value;
|
1223 |
+
SetAttrValue(value, &attr_value);
|
1224 |
+
fdef->mutable_attr()->insert({attr, attr_value});
|
1225 |
+
}
|
1226 |
+
|
1227 |
+
TEST(FunctionLibraryDefinitionTest, GetAttr_FuncWithAttr) {
|
1228 |
+
FunctionDefLibrary proto;
|
1229 |
+
auto fdef = proto.add_function();
|
1230 |
+
*fdef = test::function::XTimesTwo();
|
1231 |
+
SetAttrValue(fdef, "annotation", true);
|
1232 |
+
SetAttrValue(fdef, "options", "some string data");
|
1233 |
+
FunctionLibraryDefinition lib(OpRegistry::Global(), proto);
|
1234 |
+
|
1235 |
+
NodeDef ndef;
|
1236 |
+
bool annotation;
|
1237 |
+
|
1238 |
+
// A function. No attr defined in ndef.
|
1239 |
+
ndef.set_op("XTimesTwo");
|
1240 |
+
TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation));
|
1241 |
+
EXPECT_EQ(annotation, true);
|
1242 |
+
|
1243 |
+
string str;
|
1244 |
+
TF_EXPECT_OK(lib.GetAttr(ndef, "options", &str));
|
1245 |
+
EXPECT_EQ(str, "some string data");
|
1246 |
+
}
|
1247 |
+
|
1248 |
+
TEST(FunctionLibraryDefinitionTest, GetAttr_Gradient) {
|
1249 |
+
FunctionDefLibrary proto;
|
1250 |
+
auto fdef = proto.add_function();
|
1251 |
+
*fdef = test::function::XTimesTwo();
|
1252 |
+
SetAttrValue(fdef, "annotation", true);
|
1253 |
+
*fdef = test::function::WXPlusB();
|
1254 |
+
SetAttrValue(fdef, "annotation", false);
|
1255 |
+
auto func_grad = proto.add_gradient();
|
1256 |
+
func_grad->set_function_name("XTimesTwo");
|
1257 |
+
func_grad->set_gradient_func("WXPlusB");
|
1258 |
+
FunctionLibraryDefinition lib(OpRegistry::Global(), proto);
|
1259 |
+
|
1260 |
+
NodeDef ndef;
|
1261 |
+
ndef.set_op(FunctionLibraryDefinition::kGradientOp);
|
1262 |
+
|
1263 |
+
bool annotation;
|
1264 |
+
EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
|
1265 |
+
|
1266 |
+
NameAttrList nal;
|
1267 |
+
nal.set_name("XTimesTwo");
|
1268 |
+
AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef);
|
1269 |
+
TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation));
|
1270 |
+
EXPECT_EQ(annotation, false); // XTimesTwo's gradient is WXPlusB.
|
1271 |
+
|
1272 |
+
nal.set_name("WXPlusB");
|
1273 |
+
ndef.clear_attr();
|
1274 |
+
AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef);
|
1275 |
+
TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation));
|
1276 |
+
EXPECT_EQ(annotation, false); // WXPlusB has no custom gradient.
|
1277 |
+
}
|
1278 |
+
|
1279 |
+
// TODO(skyewm): this could be more thorough
|
1280 |
+
TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) {
|
1281 |
+
// Equal functions
|
1282 |
+
const FunctionDef fdef1 = test::function::XTimesTwo();
|
1283 |
+
FunctionDef fdef2 = test::function::XTimesTwo();
|
1284 |
+
uint64 hash1 = FunctionDefHash(fdef1);
|
1285 |
+
EXPECT_TRUE(FunctionDefsEqual(fdef1, fdef2));
|
1286 |
+
EXPECT_EQ(hash1, FunctionDefHash(fdef2));
|
1287 |
+
|
1288 |
+
// Different functions
|
1289 |
+
fdef2 = test::function::XTimesFour();
|
1290 |
+
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
|
1291 |
+
EXPECT_NE(hash1, FunctionDefHash(fdef2));
|
1292 |
+
|
1293 |
+
// Different signatures
|
1294 |
+
fdef2 = test::function::XTimesTwo();
|
1295 |
+
fdef2.mutable_signature()->mutable_input_arg(0)->set_name("foo");
|
1296 |
+
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
|
1297 |
+
EXPECT_NE(hash1, FunctionDefHash(fdef2));
|
1298 |
+
|
1299 |
+
// Descriptions must be equal
|
1300 |
+
fdef2 = test::function::XTimesTwo();
|
1301 |
+
fdef2.mutable_signature()->mutable_input_arg(0)->set_description("foo");
|
1302 |
+
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
|
1303 |
+
EXPECT_NE(hash1, FunctionDefHash(fdef2));
|
1304 |
+
|
1305 |
+
// Different NodeDefs
|
1306 |
+
fdef2 = test::function::XTimesTwo();
|
1307 |
+
NodeDef* ndef = fdef2.add_node_def();
|
1308 |
+
*ndef = fdef2.node_def(0);
|
1309 |
+
ndef->set_name("new_name");
|
1310 |
+
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
|
1311 |
+
EXPECT_NE(hash1, FunctionDefHash(fdef2));
|
1312 |
+
|
1313 |
+
// Different return values
|
1314 |
+
fdef2 = test::function::XTimesTwo();
|
1315 |
+
(*fdef2.mutable_ret())["y"] = "y:z:1"; // originally is "y:z:0"
|
1316 |
+
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
|
1317 |
+
EXPECT_NE(hash1, FunctionDefHash(fdef2));
|
1318 |
+
|
1319 |
+
// Different attributes
|
1320 |
+
fdef2 = test::function::XTimesTwo();
|
1321 |
+
SetAttrValue(&fdef2, "ExtraAttr", true);
|
1322 |
+
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
|
1323 |
+
EXPECT_NE(hash1, FunctionDefHash(fdef2));
|
1324 |
+
|
1325 |
+
// Multiple equivalent attributes; the two functions should be equal.
|
1326 |
+
fdef2 = test::function::XTimesTwo();
|
1327 |
+
FunctionDef fdef3 = test::function::XTimesTwo();
|
1328 |
+
SetAttrValue(&fdef2, "Foo", true);
|
1329 |
+
SetAttrValue(&fdef3, "Foo", true);
|
1330 |
+
SetAttrValue(&fdef2, "Bar", 123);
|
1331 |
+
SetAttrValue(&fdef3, "Bar", 123);
|
1332 |
+
SetAttrValue(&fdef2, "Baz", "abc");
|
1333 |
+
SetAttrValue(&fdef3, "Baz", "abc");
|
1334 |
+
EXPECT_TRUE(FunctionDefsEqual(fdef2, fdef3));
|
1335 |
+
EXPECT_EQ(FunctionDefHash(fdef2), FunctionDefHash(fdef3));
|
1336 |
+
}
|
1337 |
+
|
1338 |
+
} // end namespace
|
1339 |
+
} // end namespace tensorflow
|
function_testlib.cc
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/function_testlib.h"
|
17 |
+
|
18 |
+
#include "tensorflow/core/framework/function.h"
|
19 |
+
#include "tensorflow/core/framework/node_def.pb.h"
|
20 |
+
#include "tensorflow/core/framework/tensor_testutil.h"
|
21 |
+
#include "tensorflow/core/framework/versions.pb.h"
|
22 |
+
#include "tensorflow/core/lib/core/threadpool.h"
|
23 |
+
#include "tensorflow/core/public/version.h"
|
24 |
+
|
25 |
+
namespace tensorflow {
|
26 |
+
namespace test {
|
27 |
+
namespace function {
|
28 |
+
|
29 |
+
typedef FunctionDefHelper FDH;
|
30 |
+
|
31 |
+
GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
|
32 |
+
gtl::ArraySlice<FunctionDef> funcs) {
|
33 |
+
GraphDef g;
|
34 |
+
VersionDef* versions = g.mutable_versions();
|
35 |
+
versions->set_producer(TF_GRAPH_DEF_VERSION);
|
36 |
+
versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
|
37 |
+
for (const auto& n : nodes) {
|
38 |
+
*(g.add_node()) = n;
|
39 |
+
}
|
40 |
+
auto lib = g.mutable_library();
|
41 |
+
for (const auto& f : funcs) {
|
42 |
+
*(lib->add_function()) = f;
|
43 |
+
}
|
44 |
+
return g;
|
45 |
+
}
|
46 |
+
|
47 |
+
// Helper to construct a NodeDef.
|
48 |
+
NodeDef NDef(const string& name, const string& op,
|
49 |
+
gtl::ArraySlice<string> inputs,
|
50 |
+
gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs,
|
51 |
+
const string& device) {
|
52 |
+
NodeDef n;
|
53 |
+
n.set_name(name);
|
54 |
+
n.set_op(op);
|
55 |
+
for (const auto& in : inputs) n.add_input(in);
|
56 |
+
n.set_device(device);
|
57 |
+
for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
|
58 |
+
return n;
|
59 |
+
}
|
60 |
+
|
61 |
+
FunctionDef NonZero() {
|
62 |
+
return FDH::Define(
|
63 |
+
// Name
|
64 |
+
"NonZero",
|
65 |
+
// Args
|
66 |
+
{"x:T"},
|
67 |
+
// Return values
|
68 |
+
{"y:T"},
|
69 |
+
// Attr def
|
70 |
+
{"T:{float, double, int32, int64, string}"},
|
71 |
+
// Nodes
|
72 |
+
{
|
73 |
+
{{"y"}, "Identity", {"x"}, {{"T", "$T"}}},
|
74 |
+
});
|
75 |
+
}
|
76 |
+
|
77 |
+
FunctionDef XTimesTwo() {
|
78 |
+
const Tensor kTwo = test::AsScalar<int64>(2);
|
79 |
+
return FDH::Define(
|
80 |
+
// Name
|
81 |
+
"XTimesTwo",
|
82 |
+
// Args
|
83 |
+
{"x: T"},
|
84 |
+
// Return values
|
85 |
+
{"y: T"},
|
86 |
+
// Attr def
|
87 |
+
{"T: {float, double, int32, int64}"},
|
88 |
+
// Nodes
|
89 |
+
{
|
90 |
+
{{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
|
91 |
+
{{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
|
92 |
+
{{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
|
93 |
+
});
|
94 |
+
}
|
95 |
+
|
96 |
+
FunctionDef XTimesTwoInt32() {
|
97 |
+
const Tensor kTwo = test::AsScalar<int64>(2);
|
98 |
+
return FDH::Define(
|
99 |
+
// Name
|
100 |
+
"XTimesTwoInt32",
|
101 |
+
// Args
|
102 |
+
{"x: int32"},
|
103 |
+
// Return values
|
104 |
+
{"y: int32"}, {},
|
105 |
+
// Nodes
|
106 |
+
{
|
107 |
+
{{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
|
108 |
+
{{"scale"},
|
109 |
+
"Cast",
|
110 |
+
{"two"},
|
111 |
+
{{"SrcT", DT_INT64}, {"DstT", DT_INT32}}},
|
112 |
+
{{"y"}, "Mul", {"x", "scale"}, {{"T", DT_INT32}}},
|
113 |
+
});
|
114 |
+
}
|
115 |
+
|
116 |
+
FunctionDef XTimesFour() {
|
117 |
+
return FDH::Create(
|
118 |
+
// Name
|
119 |
+
"XTimesFour",
|
120 |
+
// Args
|
121 |
+
{"x: T"},
|
122 |
+
// Return values
|
123 |
+
{"y: T"},
|
124 |
+
// Attr def
|
125 |
+
{"T: {float, double, int32, int64}"},
|
126 |
+
// Nodes
|
127 |
+
{
|
128 |
+
{{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
|
129 |
+
{{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}},
|
130 |
+
},
|
131 |
+
{{"y", "y:y:0"}});
|
132 |
+
}
|
133 |
+
|
134 |
+
FunctionDef XTimes16() {
|
135 |
+
return FDH::Create(
|
136 |
+
// Name
|
137 |
+
"XTimes16",
|
138 |
+
// Args
|
139 |
+
{"x: T"},
|
140 |
+
// Return values
|
141 |
+
{"y: T"},
|
142 |
+
// Attr def
|
143 |
+
{"T: {float, double, int32, int64}"},
|
144 |
+
// Nodes
|
145 |
+
{
|
146 |
+
{{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
|
147 |
+
{{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}},
|
148 |
+
},
|
149 |
+
{{"y", "y:y:0"}});
|
150 |
+
}
|
151 |
+
|
152 |
+
FunctionDef WXPlusB(){return FDH::Define(
|
153 |
+
// Name
|
154 |
+
"WXPlusB",
|
155 |
+
// Args
|
156 |
+
{"w: T", "x: T", "b: T"},
|
157 |
+
// Return values
|
158 |
+
{"y: T"},
|
159 |
+
// Attr def
|
160 |
+
{"T: {float, double}"},
|
161 |
+
// Nodes
|
162 |
+
{
|
163 |
+
{{"mm"},
|
164 |
+
"MatMul",
|
165 |
+
{"w", "x"},
|
166 |
+
{
|
167 |
+
{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false},
|
168 |
+
#ifdef INTEL_MKL
|
169 |
+
}},
|
170 |
+
#else
|
171 |
+
{"_kernel", "eigen"}}},
|
172 |
+
#endif
|
173 |
+
{
|
174 |
+
{"y"}, "Add", {"mm", "b"}, {
|
175 |
+
{ "T", "$T" }
|
176 |
+
}
|
177 |
+
}
|
178 |
+
});
|
179 |
+
}
|
180 |
+
|
181 |
+
FunctionDef Swap() {
|
182 |
+
return FDH::Define(
|
183 |
+
// Name
|
184 |
+
"Swap",
|
185 |
+
// Args
|
186 |
+
{"i0: T", "i1: T"},
|
187 |
+
// Return values
|
188 |
+
{"o0: T", "o1: T"},
|
189 |
+
// Attr def
|
190 |
+
{"T: {float, double}"},
|
191 |
+
// Nodes
|
192 |
+
{{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}},
|
193 |
+
{{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}});
|
194 |
+
}
|
195 |
+
|
196 |
+
void FunctionTestSchedClosure(std::function<void()> fn) {
|
197 |
+
static thread::ThreadPool* w =
|
198 |
+
new thread::ThreadPool(Env::Default(), "Test", 8);
|
199 |
+
w->Schedule(std::move(fn));
|
200 |
+
}
|
201 |
+
|
202 |
+
} // end namespace function
|
203 |
+
} // end namespace test
|
204 |
+
} // end namespace tensorflow
|
function_testlib.h
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_
|
18 |
+
|
19 |
+
#include <string>
|
20 |
+
|
21 |
+
#include "tensorflow/core/framework/attr_value_util.h"
|
22 |
+
#include "tensorflow/core/framework/function.h"
|
23 |
+
#include "tensorflow/core/framework/function.pb.h"
|
24 |
+
#include "tensorflow/core/framework/graph.pb.h"
|
25 |
+
#include "tensorflow/core/framework/node_def.pb.h"
|
26 |
+
#include "tensorflow/core/lib/gtl/array_slice.h"
|
27 |
+
#include "tensorflow/core/platform/types.h"
|
28 |
+
|
29 |
+
namespace tensorflow {
|
30 |
+
namespace test {
|
31 |
+
namespace function {
|
32 |
+
|
33 |
+
// A helper class to make AttrSlice from initializer lists
|
34 |
+
class Attrs {
|
35 |
+
public:
|
36 |
+
Attrs(const std::initializer_list< // NOLINT(runtime/explicit)
|
37 |
+
std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) {
|
38 |
+
for (const auto& aval : attrs) {
|
39 |
+
map_.insert({aval.first, aval.second.proto});
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit)
|
44 |
+
|
45 |
+
private:
|
46 |
+
AttrValueMap map_;
|
47 |
+
};
|
48 |
+
|
49 |
+
// Helper to construct a NodeDef.
|
50 |
+
NodeDef NDef(
|
51 |
+
const string& name, const string& op, gtl::ArraySlice<string> inputs,
|
52 |
+
gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>>
|
53 |
+
attrs = {},
|
54 |
+
const string& device = "");
|
55 |
+
|
56 |
+
// Helper to construct a GraphDef proto.
|
57 |
+
GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
|
58 |
+
gtl::ArraySlice<FunctionDef> funcs = {});
|
59 |
+
|
60 |
+
// For testing convenience, we provide a few simple functions that can
|
61 |
+
// be easily executed and tested.
|
62 |
+
|
63 |
+
// x:T -> x * 2.
|
64 |
+
FunctionDef XTimesTwo();
|
65 |
+
|
66 |
+
// x:T -> x * 2, where x is int32.
|
67 |
+
FunctionDef XTimesTwoInt32();
|
68 |
+
|
69 |
+
// x:T -> (x * 2) * 2.
|
70 |
+
FunctionDef XTimesFour();
|
71 |
+
|
72 |
+
// x:T -> ((x * 2) * 2) * 2.
|
73 |
+
FunctionDef XTimes16();
|
74 |
+
|
75 |
+
// w:T, x:T, b:T -> MatMul(w, x) + b
|
76 |
+
FunctionDef WXPlusB();
|
77 |
+
|
78 |
+
// x:T -> x:T, T is a type which we automatically converts to a bool.
|
79 |
+
FunctionDef NonZero();
|
80 |
+
|
81 |
+
// x:T, y:T -> y:T, x:T
|
82 |
+
FunctionDef Swap();
|
83 |
+
|
84 |
+
void FunctionTestSchedClosure(std::function<void()> fn);
|
85 |
+
|
86 |
+
} // end namespace function
|
87 |
+
} // end namespace test
|
88 |
+
} // end namespace tensorflow
|
89 |
+
|
90 |
+
#endif // TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_
|
graph.proto
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto3";
|
2 |
+
|
3 |
+
package tensorflow;
|
4 |
+
option cc_enable_arenas = true;
|
5 |
+
option java_outer_classname = "GraphProtos";
|
6 |
+
option java_multiple_files = true;
|
7 |
+
option java_package = "org.tensorflow.framework";
|
8 |
+
|
9 |
+
import "tensorflow/core/framework/node_def.proto";
|
10 |
+
import "tensorflow/core/framework/function.proto";
|
11 |
+
import "tensorflow/core/framework/versions.proto";
|
12 |
+
|
13 |
+
// Represents the graph of operations
|
14 |
+
message GraphDef {
|
15 |
+
repeated NodeDef node = 1;
|
16 |
+
|
17 |
+
// Compatibility versions of the graph. See core/public/version.h for version
|
18 |
+
// history. The GraphDef version is distinct from the TensorFlow version, and
|
19 |
+
// each release of TensorFlow will support a range of GraphDef versions.
|
20 |
+
VersionDef versions = 4;
|
21 |
+
|
22 |
+
// Deprecated single version field; use versions above instead. Since all
|
23 |
+
// GraphDef changes before "versions" was introduced were forward
|
24 |
+
// compatible, this field is entirely ignored.
|
25 |
+
int32 version = 3 [deprecated = true];
|
26 |
+
|
27 |
+
// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
|
28 |
+
//
|
29 |
+
// "library" provides user-defined functions.
|
30 |
+
//
|
31 |
+
// Naming:
|
32 |
+
// * library.function.name are in a flat namespace.
|
33 |
+
// NOTE: We may need to change it to be hierarchical to support
|
34 |
+
// different orgs. E.g.,
|
35 |
+
// { "/google/nn", { ... }},
|
36 |
+
// { "/google/vision", { ... }}
|
37 |
+
// { "/org_foo/module_bar", { ... }}
|
38 |
+
// map<string, FunctionDefLib> named_lib;
|
39 |
+
// * If node[i].op is the name of one function in "library",
|
40 |
+
// node[i] is deemed as a function call. Otherwise, node[i].op
|
41 |
+
// must be a primitive operation supported by the runtime.
|
42 |
+
//
|
43 |
+
//
|
44 |
+
// Function call semantics:
|
45 |
+
//
|
46 |
+
// * The callee may start execution as soon as some of its inputs
|
47 |
+
// are ready. The caller may want to use Tuple() mechanism to
|
48 |
+
// ensure all inputs are ready in the same time.
|
49 |
+
//
|
50 |
+
// * The consumer of return values may start executing as soon as
|
51 |
+
// the return values the consumer depends on are ready. The
|
52 |
+
// consumer may want to use Tuple() mechanism to ensure the
|
53 |
+
// consumer does not start until all return values of the callee
|
54 |
+
// function are ready.
|
55 |
+
FunctionDefLibrary library = 2;
|
56 |
+
};
|
graph_def_util.cc
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/graph_def_util.h"
|
17 |
+
|
18 |
+
#include <set>
|
19 |
+
#include <unordered_map>
|
20 |
+
#include <unordered_set>
|
21 |
+
#include <vector>
|
22 |
+
|
23 |
+
#include "tensorflow/core/framework/attr_value.pb.h"
|
24 |
+
#include "tensorflow/core/framework/function.pb.h"
|
25 |
+
#include "tensorflow/core/framework/graph.pb.h"
|
26 |
+
#include "tensorflow/core/framework/node_def.pb.h"
|
27 |
+
#include "tensorflow/core/framework/node_def_util.h"
|
28 |
+
#include "tensorflow/core/framework/op_def_util.h"
|
29 |
+
#include "tensorflow/core/framework/versions.pb_text.h"
|
30 |
+
#include "tensorflow/core/lib/core/errors.h"
|
31 |
+
#include "tensorflow/core/lib/core/status.h"
|
32 |
+
#include "tensorflow/core/lib/strings/strcat.h"
|
33 |
+
|
34 |
+
namespace tensorflow {
|
35 |
+
|
36 |
+
string SummarizeGraphDef(const GraphDef& graph_def) {
|
37 |
+
string ret;
|
38 |
+
strings::StrAppend(&ret, "versions = ",
|
39 |
+
ProtoShortDebugString(graph_def.versions()), ";\n");
|
40 |
+
for (const NodeDef& node : graph_def.node()) {
|
41 |
+
strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
|
42 |
+
}
|
43 |
+
return ret;
|
44 |
+
}
|
45 |
+
|
46 |
+
Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
|
47 |
+
for (const NodeDef& node : graph_def.node()) {
|
48 |
+
TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));
|
49 |
+
}
|
50 |
+
return Status::OK();
|
51 |
+
}
|
52 |
+
|
53 |
+
Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
|
54 |
+
const OpRegistryInterface& op_registry,
|
55 |
+
int node_offset) {
|
56 |
+
if (node_offset > graph_def->node_size()) {
|
57 |
+
return errors::InvalidArgument(
|
58 |
+
"Tried to add default attrs to GraphDef "
|
59 |
+
"starting at offset ",
|
60 |
+
node_offset, " with total nodes in graph: ", graph_def->node_size());
|
61 |
+
}
|
62 |
+
|
63 |
+
for (int i = node_offset; i < graph_def->node_size(); ++i) {
|
64 |
+
NodeDef* node_def = graph_def->mutable_node(i);
|
65 |
+
const OpDef* op_def;
|
66 |
+
TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(node_def->op(), &op_def));
|
67 |
+
AddDefaultsToNodeDef(*op_def, node_def);
|
68 |
+
}
|
69 |
+
|
70 |
+
return Status::OK();
|
71 |
+
}
|
72 |
+
|
73 |
+
static Status RemoveNewDefaultAttrsFromNodeDef(
|
74 |
+
NodeDef* node_def, const OpRegistryInterface& consumer_op_registry,
|
75 |
+
const OpRegistryInterface& producer_op_registry,
|
76 |
+
std::set<std::pair<string, string>>* op_attr_removed) {
|
77 |
+
const OpDef* producer_op_def;
|
78 |
+
const OpDef* consumer_op_def;
|
79 |
+
TF_RETURN_IF_ERROR(
|
80 |
+
producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def));
|
81 |
+
TF_RETURN_IF_ERROR(
|
82 |
+
consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def));
|
83 |
+
|
84 |
+
std::vector<string> to_remove;
|
85 |
+
for (const auto& attr : node_def->attr()) {
|
86 |
+
// If the attr is not in consumer_op_def and doesn't start with '_'...
|
87 |
+
if (!StringPiece(attr.first).starts_with("_") &&
|
88 |
+
FindAttr(attr.first, *consumer_op_def) == nullptr) {
|
89 |
+
const OpDef::AttrDef* producer_attr_def =
|
90 |
+
FindAttr(attr.first, *producer_op_def);
|
91 |
+
if (producer_attr_def == nullptr) {
|
92 |
+
return errors::InvalidArgument(
|
93 |
+
"Attr '", attr.first, "' missing in producer's OpDef: ",
|
94 |
+
SummarizeOpDef(*producer_op_def), " but found in node: ",
|
95 |
+
SummarizeNodeDef(*node_def));
|
96 |
+
}
|
97 |
+
// ...and it has the same value as the default in producer,
|
98 |
+
if (producer_attr_def->has_default_value() &&
|
99 |
+
AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) {
|
100 |
+
// then we will remove it below.
|
101 |
+
to_remove.emplace_back(attr.first);
|
102 |
+
}
|
103 |
+
}
|
104 |
+
}
|
105 |
+
// We separate identifying which attrs should be removed from
|
106 |
+
// actually removing them to avoid invalidating the loop iterators
|
107 |
+
// above.
|
108 |
+
for (const string& attr_name : to_remove) {
|
109 |
+
node_def->mutable_attr()->erase(attr_name);
|
110 |
+
if (op_attr_removed != nullptr) {
|
111 |
+
op_attr_removed->insert(std::make_pair(node_def->op(), attr_name));
|
112 |
+
}
|
113 |
+
}
|
114 |
+
|
115 |
+
return Status::OK();
|
116 |
+
}
|
117 |
+
|
118 |
+
static bool IsFunction(const GraphDef& graph_def, const string& op_name) {
|
119 |
+
for (const auto& func_def : graph_def.library().function()) {
|
120 |
+
if (op_name == func_def.signature().name()) return true;
|
121 |
+
}
|
122 |
+
return false;
|
123 |
+
}
|
124 |
+
|
125 |
+
Status RemoveNewDefaultAttrsFromGraphDef(
|
126 |
+
GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry,
|
127 |
+
const OpRegistryInterface& producer_op_registry,
|
128 |
+
std::set<std::pair<string, string>>* op_attr_removed) {
|
129 |
+
// TODO(joshL): Make IsFunction() faster by collecting the names of
|
130 |
+
// all functions as a preprocessing step.
|
131 |
+
for (int n = 0; n < graph_def->node_size(); ++n) {
|
132 |
+
NodeDef* node_def = graph_def->mutable_node(n);
|
133 |
+
if (!IsFunction(*graph_def, node_def->op())) {
|
134 |
+
TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
|
135 |
+
node_def, consumer_op_registry, producer_op_registry,
|
136 |
+
op_attr_removed));
|
137 |
+
}
|
138 |
+
}
|
139 |
+
for (int f = 0; f < graph_def->library().function_size(); ++f) {
|
140 |
+
FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f);
|
141 |
+
for (int n = 0; n < func_def->node_def_size(); ++n) {
|
142 |
+
NodeDef* node_def = func_def->mutable_node_def(n);
|
143 |
+
if (!IsFunction(*graph_def, node_def->op())) {
|
144 |
+
// TODO(josh11b): Better handling of attrs with placeholder values.
|
145 |
+
TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
|
146 |
+
node_def, consumer_op_registry, producer_op_registry,
|
147 |
+
op_attr_removed));
|
148 |
+
}
|
149 |
+
}
|
150 |
+
}
|
151 |
+
|
152 |
+
return Status::OK();
|
153 |
+
}
|
154 |
+
|
155 |
+
void OpsUsedByGraph(const GraphDef& graph_def,
|
156 |
+
std::set<string>* ops_used_in_graph) {
|
157 |
+
// Map function names to definitions.
|
158 |
+
std::unordered_map<string, const FunctionDef*> name_to_function;
|
159 |
+
for (const auto& function : graph_def.library().function()) {
|
160 |
+
name_to_function.insert(
|
161 |
+
std::make_pair(function.signature().name(), &function));
|
162 |
+
}
|
163 |
+
|
164 |
+
// Collect the sorted list of op names. Since functions can reference
|
165 |
+
// functions, we need a recursive traversal.
|
166 |
+
std::set<string> used_ops; // Includes both primitive ops and functions
|
167 |
+
std::vector<const FunctionDef*> functions_to_process; // A subset of used_ops
|
168 |
+
// Collect the logic to mark an op in a lambda; it'll be used twice below.
|
169 |
+
const auto mark_op_as_used = [&used_ops, &functions_to_process,
|
170 |
+
&name_to_function](const string& op) {
|
171 |
+
if (used_ops.insert(op).second) {
|
172 |
+
// If it's a function, we'll need to process further
|
173 |
+
const auto it = name_to_function.find(op);
|
174 |
+
if (it != name_to_function.end()) {
|
175 |
+
functions_to_process.push_back(it->second);
|
176 |
+
}
|
177 |
+
}
|
178 |
+
};
|
179 |
+
for (const auto& node : graph_def.node()) {
|
180 |
+
mark_op_as_used(node.op());
|
181 |
+
}
|
182 |
+
while (!functions_to_process.empty()) {
|
183 |
+
const FunctionDef* fun = functions_to_process.back();
|
184 |
+
functions_to_process.pop_back();
|
185 |
+
for (const auto& node : fun->node_def()) {
|
186 |
+
mark_op_as_used(node.op());
|
187 |
+
}
|
188 |
+
}
|
189 |
+
|
190 |
+
// Filter out function names to produce output.
|
191 |
+
// TODO(josh11b): Change the above code to produce this directly.
|
192 |
+
ops_used_in_graph->clear();
|
193 |
+
for (const string& op_name : used_ops) {
|
194 |
+
if (name_to_function.find(op_name) == name_to_function.end()) {
|
195 |
+
ops_used_in_graph->insert(op_name);
|
196 |
+
}
|
197 |
+
}
|
198 |
+
}
|
199 |
+
|
200 |
+
Status StrippedOpListForGraph(const GraphDef& graph_def,
|
201 |
+
const OpRegistryInterface& op_registry,
|
202 |
+
OpList* stripped_op_list) {
|
203 |
+
std::set<string> used_ops;
|
204 |
+
OpsUsedByGraph(graph_def, &used_ops);
|
205 |
+
|
206 |
+
// Build the stripped op list in sorted order, ignoring functions.
|
207 |
+
stripped_op_list->clear_op();
|
208 |
+
for (const string& op_name : used_ops) {
|
209 |
+
const OpDef* op_def;
|
210 |
+
TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def));
|
211 |
+
OpDef* stripped_op = stripped_op_list->add_op();
|
212 |
+
stripped_op->CopyFrom(*op_def);
|
213 |
+
RemoveDescriptionsFromOpDef(stripped_op);
|
214 |
+
}
|
215 |
+
return Status::OK();
|
216 |
+
}
|
217 |
+
|
218 |
+
} // namespace tensorflow
|
graph_def_util.h
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
|
18 |
+
|
19 |
+
#include <set>
|
20 |
+
#include "tensorflow/core/framework/op.h"
|
21 |
+
#include "tensorflow/core/lib/core/status.h"
|
22 |
+
|
23 |
+
namespace tensorflow {
|
24 |
+
|
25 |
+
// Forward declare proto so that it's symbols can be removed from .so exports
|
26 |
+
class GraphDef;
|
27 |
+
|
28 |
+
// Produce a human-readable version of a GraphDef that is more concise
|
29 |
+
// than a text-format proto.
|
30 |
+
string SummarizeGraphDef(const GraphDef& graph_def);
|
31 |
+
|
32 |
+
// Validates the syntax of a GraphDef provided externally.
|
33 |
+
//
|
34 |
+
// The following is an EBNF-style syntax for GraphDef objects. Note that
|
35 |
+
// Node objects are actually specified as tensorflow::NodeDef protocol buffers,
|
36 |
+
// which contain many other fields that are not (currently) validated.
|
37 |
+
//
|
38 |
+
// Graph = Node *
|
39 |
+
// Node = NodeName, Inputs
|
40 |
+
// Inputs = ( DataInput * ), ( ControlInput * )
|
41 |
+
// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ?
|
42 |
+
// ControlInput = "^", NodeName
|
43 |
+
// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] *
|
44 |
+
Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def);
|
45 |
+
|
46 |
+
// Adds default attributes to NodeDefs in 'graph_def' starting
|
47 |
+
// from the 'node_offset' node in 'graph_def'.
|
48 |
+
//
|
49 |
+
// Default attributes are defined by 'op_registry'.
|
50 |
+
//
|
51 |
+
// Returns OK on success, an error if 'graph_def' has a NodeDef
|
52 |
+
// that cannot be found in 'op_registry'.
|
53 |
+
//
|
54 |
+
// REQUIRES: 'graph_def' and 'op_registry' are not nullptr.
|
55 |
+
Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
|
56 |
+
const OpRegistryInterface& op_registry,
|
57 |
+
int node_offset);
|
58 |
+
|
59 |
+
// Remove attrs from 'graph_def' that have the default value according
|
60 |
+
// to 'producer_op_registry', but don't exist according to
|
61 |
+
// 'consumer_op_registry'. This can allow 'graph_def' to run on the
|
62 |
+
// consumer even if consumer was built at an earlier CL (before an
|
63 |
+
// attr with a default was added). Note that this will not affect
|
64 |
+
// attrs with non-default values, so you must run a
|
65 |
+
// ValidateGraphDef...() function to see if the result is in fact
|
66 |
+
// compatible. If not nullptr, the op/attr pairs that were removed
|
67 |
+
// are added to '*op_attr_removed'.
|
68 |
+
//
|
69 |
+
// Expected usage, for a producer that wants to prepare a graph for
|
70 |
+
// a consumer:
|
71 |
+
// // For each consumer, update 'graph_def':
|
72 |
+
// OpListOpRegistry consumer_op_registry(consumer_server_op_list);
|
73 |
+
// std::unordered_set<std::pair<string, string>> op_attr_removed;
|
74 |
+
// TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromGraphDef(
|
75 |
+
// &graph_def, consumer_op_registry, *OpRegistry::Global(),
|
76 |
+
// &op_attr_removed));
|
77 |
+
// // Validate that each consumer can understand the resulting 'graph_def'
|
78 |
+
// TF_RETURN_IF_ERROR(graph::ValidateGraphDefAgainstOpRegistry(
|
79 |
+
// graph_def, consumer_op_registry));
|
80 |
+
// // Consumer can use 'graph_def', and 'op_attr_removed' summarizes
|
81 |
+
// // what changes had to be made to 'graph_def' for it to work.
|
82 |
+
//
|
83 |
+
// Expected usage, for a consumer that has a graph and a
|
84 |
+
// (optionally-stripped) op_list from a producer (say from a call to
|
85 |
+
// StrippedOpListForGraph(), or in the MetaGraphDef):
|
86 |
+
// OpListOpRegistry producer_op_registry(producer_stripped_op_list);
|
87 |
+
// TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromGraphDef(
|
88 |
+
// &graph_def, *OpRegistry::Global(), producer_op_registry, nullptr));
|
89 |
+
Status RemoveNewDefaultAttrsFromGraphDef(
|
90 |
+
GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry,
|
91 |
+
const OpRegistryInterface& producer_op_registry,
|
92 |
+
std::set<std::pair<string, string>>* op_attr_removed);
|
93 |
+
|
94 |
+
// Two functions that collect the ops used by a graph.
|
95 |
+
//
|
96 |
+
// This returns the ops used as a set of strings.
|
97 |
+
void OpsUsedByGraph(const GraphDef& graph_def,
|
98 |
+
std::set<string>* ops_used_in_graph);
|
99 |
+
|
100 |
+
// This function computes the stripped_op_list field of MetaGraphDef
|
101 |
+
// and similar protos. The op_registry should contain the ops used to
|
102 |
+
// produce graph_def. The resulting stripped_op_list can be
|
103 |
+
// communicated from the producer to the consumer, which can use
|
104 |
+
// RemoveNewDefaultAttrsFromGraphDef() to improve forwards compatibility
|
105 |
+
// (using an OpListOpRegistry as indicated in the example above).
|
106 |
+
//
|
107 |
+
// Most users will pass *OpRegistry::Global() for op_registry to strip against
|
108 |
+
// the list of ops registered in this process.
|
109 |
+
Status StrippedOpListForGraph(const GraphDef& graph_def,
|
110 |
+
const OpRegistryInterface& op_registry,
|
111 |
+
OpList* stripped_op_list);
|
112 |
+
|
113 |
+
} // namespace tensorflow
|
114 |
+
|
115 |
+
#endif // TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
|
graph_def_util_test.cc
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/graph_def_util.h"
|
17 |
+
|
18 |
+
#include "tensorflow/core/framework/function.h"
|
19 |
+
#include "tensorflow/core/framework/graph.pb.h"
|
20 |
+
#include "tensorflow/core/framework/node_def_builder.h"
|
21 |
+
#include "tensorflow/core/framework/op.h"
|
22 |
+
#include "tensorflow/core/framework/op_def.pb.h"
|
23 |
+
#include "tensorflow/core/framework/op_def_builder.h"
|
24 |
+
#include "tensorflow/core/lib/core/status_test_util.h"
|
25 |
+
#include "tensorflow/core/platform/test.h"
|
26 |
+
#include "tensorflow/core/util/equal_graph_def.h"
|
27 |
+
|
28 |
+
namespace tensorflow {
|
29 |
+
namespace {
|
30 |
+
|
31 |
+
Status FinalizeOpDef(const OpDefBuilder& b, OpDef* op_def) {
|
32 |
+
OpRegistrationData op_reg_data;
|
33 |
+
const Status s = b.Finalize(&op_reg_data);
|
34 |
+
*op_def = op_reg_data.op_def;
|
35 |
+
return s;
|
36 |
+
}
|
37 |
+
|
38 |
+
// Producer and consumer have default for an attr -> graph unchanged.
|
39 |
+
TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) {
|
40 |
+
OpList op_list;
|
41 |
+
TF_ASSERT_OK(
|
42 |
+
FinalizeOpDef(OpDefBuilder("NoChangeWithDefault").Attr("a: int = 12"),
|
43 |
+
op_list.add_op()));
|
44 |
+
OpListOpRegistry registry(&op_list);
|
45 |
+
|
46 |
+
GraphDef graph_def;
|
47 |
+
TF_ASSERT_OK(NodeDefBuilder("ncwd", "NoChangeWithDefault", ®istry)
|
48 |
+
.Finalize(graph_def.add_node()));
|
49 |
+
GraphDef expected_graph_def = graph_def;
|
50 |
+
|
51 |
+
std::set<std::pair<string, string>> op_attr_removed;
|
52 |
+
TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
|
53 |
+
&op_attr_removed));
|
54 |
+
|
55 |
+
TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
|
56 |
+
EXPECT_TRUE(op_attr_removed.empty());
|
57 |
+
}
|
58 |
+
|
59 |
+
// Producer and consumer both have an attr -> graph unchanged.
|
60 |
+
TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) {
|
61 |
+
OpList op_list;
|
62 |
+
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("NoChangeNoDefault").Attr("a: int"),
|
63 |
+
op_list.add_op()));
|
64 |
+
OpListOpRegistry registry(&op_list);
|
65 |
+
|
66 |
+
GraphDef graph_def;
|
67 |
+
TF_ASSERT_OK(NodeDefBuilder("ncnd", "NoChangeNoDefault", ®istry)
|
68 |
+
.Attr("a", 42)
|
69 |
+
.Finalize(graph_def.add_node()));
|
70 |
+
GraphDef expected_graph_def = graph_def;
|
71 |
+
|
72 |
+
std::set<std::pair<string, string>> op_attr_removed;
|
73 |
+
TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
|
74 |
+
&op_attr_removed));
|
75 |
+
|
76 |
+
TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
|
77 |
+
EXPECT_TRUE(op_attr_removed.empty());
|
78 |
+
}
|
79 |
+
|
80 |
+
// Producer has default for an attr that the consumer does not know
|
81 |
+
// about, and the produced graph has the default value for the attr ->
|
82 |
+
// attr removed from graph (and so able to be consumed).
|
83 |
+
TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) {
|
84 |
+
OpList consumer_op_list;
|
85 |
+
TF_ASSERT_OK(
|
86 |
+
FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
|
87 |
+
OpListOpRegistry consumer_registry(&consumer_op_list);
|
88 |
+
|
89 |
+
OpList producer_op_list;
|
90 |
+
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
|
91 |
+
producer_op_list.add_op()));
|
92 |
+
OpListOpRegistry producer_registry(&producer_op_list);
|
93 |
+
|
94 |
+
GraphDef produced_graph_def;
|
95 |
+
TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &producer_registry)
|
96 |
+
.Finalize(produced_graph_def.add_node()));
|
97 |
+
|
98 |
+
std::set<std::pair<string, string>> op_attr_removed;
|
99 |
+
TF_ASSERT_OK(
|
100 |
+
RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
|
101 |
+
producer_registry, &op_attr_removed));
|
102 |
+
|
103 |
+
GraphDef expected_graph_def;
|
104 |
+
TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &consumer_registry)
|
105 |
+
.Finalize(expected_graph_def.add_node()));
|
106 |
+
TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
|
107 |
+
|
108 |
+
std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
|
109 |
+
EXPECT_EQ(expected_removed, op_attr_removed);
|
110 |
+
}
|
111 |
+
|
112 |
+
// Producer has default for an attr that the consumer does not know
|
113 |
+
// about, graph sets the attr to a value different from the default ->
|
114 |
+
// graph unchanged (but not able to be consumed by consumer).
|
115 |
+
TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) {
|
116 |
+
OpList consumer_op_list;
|
117 |
+
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
|
118 |
+
consumer_op_list.add_op()));
|
119 |
+
OpListOpRegistry consumer_registry(&consumer_op_list);
|
120 |
+
|
121 |
+
OpList producer_op_list;
|
122 |
+
TF_ASSERT_OK(
|
123 |
+
FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
|
124 |
+
producer_op_list.add_op()));
|
125 |
+
OpListOpRegistry producer_registry(&producer_op_list);
|
126 |
+
|
127 |
+
GraphDef produced_graph_def;
|
128 |
+
TF_ASSERT_OK(NodeDefBuilder("changed_from_default", "ChangedFromDefault",
|
129 |
+
&producer_registry)
|
130 |
+
.Attr("a", 9)
|
131 |
+
.Finalize(produced_graph_def.add_node()));
|
132 |
+
GraphDef expected_graph_def = produced_graph_def;
|
133 |
+
|
134 |
+
std::set<std::pair<string, string>> op_attr_removed;
|
135 |
+
TF_ASSERT_OK(
|
136 |
+
RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
|
137 |
+
producer_registry, &op_attr_removed));
|
138 |
+
|
139 |
+
TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
|
140 |
+
EXPECT_TRUE(op_attr_removed.empty());
|
141 |
+
}
|
142 |
+
|
143 |
+
// Attrs starting with underscores should not be removed.
|
144 |
+
TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) {
|
145 |
+
OpList consumer_op_list;
|
146 |
+
TF_ASSERT_OK(
|
147 |
+
FinalizeOpDef(OpDefBuilder("Underscore"), consumer_op_list.add_op()));
|
148 |
+
OpListOpRegistry consumer_registry(&consumer_op_list);
|
149 |
+
|
150 |
+
OpList producer_op_list;
|
151 |
+
TF_ASSERT_OK(
|
152 |
+
FinalizeOpDef(OpDefBuilder("Underscore"), producer_op_list.add_op()));
|
153 |
+
// Add the _underscore attr manually since OpDefBuilder would complain
|
154 |
+
OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr();
|
155 |
+
attr->set_name("_underscore");
|
156 |
+
attr->set_type("int");
|
157 |
+
attr->mutable_default_value()->set_i(17);
|
158 |
+
OpListOpRegistry producer_registry(&producer_op_list);
|
159 |
+
|
160 |
+
GraphDef produced_graph_def;
|
161 |
+
TF_ASSERT_OK(NodeDefBuilder("node", "Underscore", &producer_registry)
|
162 |
+
.Attr("_underscore", 17)
|
163 |
+
.Finalize(produced_graph_def.add_node()));
|
164 |
+
GraphDef expected_graph_def = produced_graph_def;
|
165 |
+
|
166 |
+
std::set<std::pair<string, string>> op_attr_removed;
|
167 |
+
TF_ASSERT_OK(
|
168 |
+
RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
|
169 |
+
producer_registry, &op_attr_removed));
|
170 |
+
|
171 |
+
TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
|
172 |
+
EXPECT_EQ(op_attr_removed.size(), 0);
|
173 |
+
}
|
174 |
+
|
175 |
+
TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) {
|
176 |
+
OpList consumer_op_list;
|
177 |
+
TF_ASSERT_OK(
|
178 |
+
FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
|
179 |
+
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
|
180 |
+
consumer_op_list.add_op()));
|
181 |
+
OpListOpRegistry consumer_registry(&consumer_op_list);
|
182 |
+
|
183 |
+
OpList producer_op_list;
|
184 |
+
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
|
185 |
+
producer_op_list.add_op()));
|
186 |
+
TF_ASSERT_OK(
|
187 |
+
FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
|
188 |
+
producer_op_list.add_op()));
|
189 |
+
OpListOpRegistry producer_registry(&producer_op_list);
|
190 |
+
|
191 |
+
GraphDef produced_graph_def;
|
192 |
+
*produced_graph_def.mutable_library()->add_function() =
|
193 |
+
FunctionDefHelper::Create(
|
194 |
+
"my_func", {}, {}, {},
|
195 |
+
{{{"x"}, "UsesDefault", {}, {{"a", 17}}},
|
196 |
+
{{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
|
197 |
+
{});
|
198 |
+
OpList function_op_list;
|
199 |
+
*function_op_list.add_op() =
|
200 |
+
produced_graph_def.library().function(0).signature();
|
201 |
+
OpListOpRegistry function_registry(&function_op_list);
|
202 |
+
TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
|
203 |
+
.Finalize(produced_graph_def.add_node()));
|
204 |
+
|
205 |
+
std::set<std::pair<string, string>> op_attr_removed;
|
206 |
+
TF_ASSERT_OK(
|
207 |
+
RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
|
208 |
+
producer_registry, &op_attr_removed));
|
209 |
+
|
210 |
+
GraphDef expected_graph_def;
|
211 |
+
*expected_graph_def.mutable_library()->add_function() =
|
212 |
+
FunctionDefHelper::Create(
|
213 |
+
"my_func", {}, {}, {},
|
214 |
+
{{{"x"}, "UsesDefault", {}, {}},
|
215 |
+
{{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
|
216 |
+
{});
|
217 |
+
TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
|
218 |
+
.Finalize(expected_graph_def.add_node()));
|
219 |
+
TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
|
220 |
+
EXPECT_EQ(expected_graph_def.library().DebugString(),
|
221 |
+
produced_graph_def.library().DebugString());
|
222 |
+
|
223 |
+
std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
|
224 |
+
EXPECT_EQ(expected_removed, op_attr_removed);
|
225 |
+
}
|
226 |
+
|
227 |
+
TEST(StrippedOpListForGraphTest, FlatTest) {
|
228 |
+
// Make four ops
|
229 |
+
OpList op_list;
|
230 |
+
for (const string& op : {"A", "B", "C", "D"}) {
|
231 |
+
OpDef* op_def = op_list.add_op();
|
232 |
+
op_def->set_name(op);
|
233 |
+
op_def->set_summary("summary");
|
234 |
+
op_def->set_description("description");
|
235 |
+
op_def->set_is_commutative(op == "B");
|
236 |
+
}
|
237 |
+
|
238 |
+
// Make a graph which uses two ops once and twice, respectively.
|
239 |
+
// The result should be independent of the ordering.
|
240 |
+
const string graph_ops[4][3] = {
|
241 |
+
{"C", "B", "B"}, {"B", "C", "B"}, {"B", "B", "C"}, {"C", "C", "B"}};
|
242 |
+
for (const bool use_function : {false, true}) {
|
243 |
+
for (int order = 0; order < 4; order++) {
|
244 |
+
GraphDef graph_def;
|
245 |
+
if (use_function) {
|
246 |
+
FunctionDef* function_def = graph_def.mutable_library()->add_function();
|
247 |
+
function_def->mutable_signature()->set_name("F");
|
248 |
+
for (const string& op : graph_ops[order]) {
|
249 |
+
function_def->add_node_def()->set_op(op);
|
250 |
+
}
|
251 |
+
graph_def.add_node()->set_op("F");
|
252 |
+
} else {
|
253 |
+
for (const string& op : graph_ops[order]) {
|
254 |
+
string name = strings::StrCat("name", graph_def.node_size());
|
255 |
+
NodeDef* node = graph_def.add_node();
|
256 |
+
node->set_name(name);
|
257 |
+
node->set_op(op);
|
258 |
+
}
|
259 |
+
}
|
260 |
+
|
261 |
+
// Strip the op list
|
262 |
+
OpList stripped_op_list;
|
263 |
+
TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
|
264 |
+
&stripped_op_list));
|
265 |
+
|
266 |
+
// We should have exactly two ops: B and C.
|
267 |
+
ASSERT_EQ(stripped_op_list.op_size(), 2);
|
268 |
+
for (int i = 0; i < 2; i++) {
|
269 |
+
const OpDef& op = stripped_op_list.op(i);
|
270 |
+
EXPECT_EQ(op.name(), i ? "C" : "B");
|
271 |
+
EXPECT_EQ(op.summary(), "");
|
272 |
+
EXPECT_EQ(op.description(), "");
|
273 |
+
EXPECT_EQ(op.is_commutative(), !i);
|
274 |
+
}
|
275 |
+
|
276 |
+
// Should get the same result using OpsUsedByGraph().
|
277 |
+
std::set<string> used_ops;
|
278 |
+
OpsUsedByGraph(graph_def, &used_ops);
|
279 |
+
ASSERT_EQ(std::set<string>({"B", "C"}), used_ops);
|
280 |
+
}
|
281 |
+
}
|
282 |
+
}
|
283 |
+
|
284 |
+
TEST(StrippedOpListForGraphTest, NestedFunctionTest) {
|
285 |
+
// Make a primitive op A.
|
286 |
+
OpList op_list;
|
287 |
+
op_list.add_op()->set_name("A");
|
288 |
+
|
289 |
+
for (const bool recursive : {false, true}) {
|
290 |
+
// Call A from function B, and B from function C.
|
291 |
+
GraphDef graph_def;
|
292 |
+
FunctionDef* b = graph_def.mutable_library()->add_function();
|
293 |
+
FunctionDef* c = graph_def.mutable_library()->add_function();
|
294 |
+
b->mutable_signature()->set_name("B");
|
295 |
+
c->mutable_signature()->set_name("C");
|
296 |
+
b->add_node_def()->set_op("A");
|
297 |
+
c->add_node_def()->set_op("B");
|
298 |
+
if (recursive) {
|
299 |
+
b->add_node_def()->set_op("B");
|
300 |
+
c->add_node_def()->set_op("C");
|
301 |
+
}
|
302 |
+
|
303 |
+
// Use C in the graph.
|
304 |
+
graph_def.add_node()->set_op("C");
|
305 |
+
|
306 |
+
// The stripped op list should contain just A.
|
307 |
+
OpList stripped_op_list;
|
308 |
+
TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
|
309 |
+
&stripped_op_list));
|
310 |
+
ASSERT_EQ(stripped_op_list.op_size(), 1);
|
311 |
+
ASSERT_EQ(stripped_op_list.op(0).name(), "A");
|
312 |
+
|
313 |
+
// Should get the same result using OpsUsedByGraph().
|
314 |
+
std::set<string> used_ops;
|
315 |
+
OpsUsedByGraph(graph_def, &used_ops);
|
316 |
+
ASSERT_EQ(std::set<string>({"A"}), used_ops);
|
317 |
+
}
|
318 |
+
}
|
319 |
+
|
320 |
+
} // namespace
|
321 |
+
} // namespace tensorflow
|
graph_transfer_info.proto
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto3";
|
2 |
+
|
3 |
+
package tensorflow;
|
4 |
+
option cc_enable_arenas = true;
|
5 |
+
option java_outer_classname = "GraphTransferInfoProto";
|
6 |
+
option java_multiple_files = true;
|
7 |
+
option java_package = "org.tensorflow.framework";
|
8 |
+
|
9 |
+
import "tensorflow/core/framework/types.proto";
|
10 |
+
|
11 |
+
// Protocol buffer representing a handle to a tensorflow resource. Handles are
|
12 |
+
// not valid across executions, but can be serialized back and forth from within
|
13 |
+
// a single run.
|
14 |
+
message GraphTransferInfo {
|
15 |
+
enum Destination {
|
16 |
+
NOP = 0;
|
17 |
+
HEXAGON = 1;
|
18 |
+
}
|
19 |
+
message NodeInput {
|
20 |
+
int32 node_id = 1;
|
21 |
+
int32 output_port = 2;
|
22 |
+
}
|
23 |
+
message NodeInfo {
|
24 |
+
string name = 1;
|
25 |
+
int32 node_id = 2;
|
26 |
+
string type_name = 3;
|
27 |
+
int32 soc_op_id = 4;
|
28 |
+
int32 padding_id = 5;
|
29 |
+
int32 input_count = 6;
|
30 |
+
int32 output_count = 7;
|
31 |
+
};
|
32 |
+
message ConstNodeInfo {
|
33 |
+
string name = 1;
|
34 |
+
int32 node_id = 2;
|
35 |
+
repeated int64 shape = 3;
|
36 |
+
bytes data = 4;
|
37 |
+
DataType dtype = 5;
|
38 |
+
};
|
39 |
+
message NodeInputInfo {
|
40 |
+
int32 node_id = 1;
|
41 |
+
repeated NodeInput node_input = 2;
|
42 |
+
};
|
43 |
+
message NodeOutputInfo {
|
44 |
+
int32 node_id = 1;
|
45 |
+
repeated int32 max_byte_size = 2;
|
46 |
+
};
|
47 |
+
message GraphInputNodeInfo {
|
48 |
+
string name = 1;
|
49 |
+
repeated int64 shape = 2;
|
50 |
+
DataType dtype = 3;
|
51 |
+
}
|
52 |
+
|
53 |
+
message GraphOutputNodeInfo {
|
54 |
+
string name = 1;
|
55 |
+
repeated int64 shape = 2;
|
56 |
+
DataType dtype = 3;
|
57 |
+
}
|
58 |
+
|
59 |
+
repeated NodeInfo node_info = 1;
|
60 |
+
repeated ConstNodeInfo const_node_info = 2;
|
61 |
+
repeated NodeInputInfo node_input_info = 3;
|
62 |
+
repeated NodeOutputInfo node_output_info = 4;
|
63 |
+
// Input Node parameters of transferred graph
|
64 |
+
repeated GraphInputNodeInfo graph_input_node_info = 5;
|
65 |
+
repeated GraphOutputNodeInfo graph_output_node_info = 6;
|
66 |
+
// Destination of graph transfer
|
67 |
+
Destination destination = 7;
|
68 |
+
};
|
iterator.proto
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto3";
|
2 |
+
|
3 |
+
package tensorflow;
|
4 |
+
option cc_enable_arenas = true;
|
5 |
+
option java_outer_classname = "IteratorProtos";
|
6 |
+
option java_multiple_files = true;
|
7 |
+
option java_package = "org.tensorflow.util";
|
8 |
+
|
9 |
+
// Protocol buffer representing the metadata for an iterator's state stored
|
10 |
+
// as a Variant tensor.
|
11 |
+
message IteratorStateMetadata {
|
12 |
+
// A user-specified version string.
|
13 |
+
string version = 1;
|
14 |
+
|
15 |
+
// Keys for tensors in the VariantTensorDataProto.
|
16 |
+
repeated string keys = 2;
|
17 |
+
}
|
kernel_def.proto
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto3";
|
2 |
+
|
3 |
+
package tensorflow;
|
4 |
+
option cc_enable_arenas = true;
|
5 |
+
option java_outer_classname = "KernelDefProtos";
|
6 |
+
option java_multiple_files = true;
|
7 |
+
option java_package = "org.tensorflow.framework";
|
8 |
+
|
9 |
+
import "tensorflow/core/framework/attr_value.proto";
|
10 |
+
|
11 |
+
message KernelDef {
|
12 |
+
// Must match the name of an Op.
|
13 |
+
string op = 1;
|
14 |
+
|
15 |
+
// Type of device this kernel runs on.
|
16 |
+
string device_type = 2;
|
17 |
+
|
18 |
+
message AttrConstraint {
|
19 |
+
// Name of an attr from the Op.
|
20 |
+
string name = 1;
|
21 |
+
|
22 |
+
// A list of values that this kernel supports for this attr.
|
23 |
+
// Like OpDef.AttrDef.allowed_values, except for kernels instead of Ops.
|
24 |
+
AttrValue allowed_values = 2;
|
25 |
+
}
|
26 |
+
repeated AttrConstraint constraint = 3;
|
27 |
+
|
28 |
+
// Names of the Op's input_/output_args that reside in host memory
|
29 |
+
// instead of device memory.
|
30 |
+
repeated string host_memory_arg = 4;
|
31 |
+
|
32 |
+
// This allows experimental kernels to be registered for an op that
|
33 |
+
// won't be used unless the user specifies a "_kernel" attr with
|
34 |
+
// value matching this.
|
35 |
+
string label = 5;
|
36 |
+
}
|
kernel_def_builder.cc
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/kernel_def_builder.h"
|
17 |
+
#include "tensorflow/core/framework/attr_value.pb.h"
|
18 |
+
#include "tensorflow/core/framework/kernel_def.pb_text.h"
|
19 |
+
#include "tensorflow/core/framework/kernel_def.pb.h"
|
20 |
+
|
21 |
+
namespace tensorflow {
|
22 |
+
|
23 |
+
KernelDefBuilder::KernelDefBuilder(const char* op_name) {
|
24 |
+
kernel_def_ = new KernelDef;
|
25 |
+
kernel_def_->set_op(op_name);
|
26 |
+
}
|
27 |
+
|
28 |
+
KernelDefBuilder::~KernelDefBuilder() {
|
29 |
+
DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
|
30 |
+
}
|
31 |
+
|
32 |
+
KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
|
33 |
+
kernel_def_->set_device_type(device_type);
|
34 |
+
return *this;
|
35 |
+
}
|
36 |
+
|
37 |
+
KernelDefBuilder& KernelDefBuilder::TypeConstraint(
|
38 |
+
const char* attr_name, gtl::ArraySlice<DataType> allowed) {
|
39 |
+
auto* constraint = kernel_def_->add_constraint();
|
40 |
+
constraint->set_name(attr_name);
|
41 |
+
auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
|
42 |
+
for (DataType dt : allowed) {
|
43 |
+
allowed_values->add_type(dt);
|
44 |
+
}
|
45 |
+
return *this;
|
46 |
+
}
|
47 |
+
|
48 |
+
KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name,
|
49 |
+
DataType allowed) {
|
50 |
+
auto* constraint = kernel_def_->add_constraint();
|
51 |
+
constraint->set_name(attr_name);
|
52 |
+
constraint->mutable_allowed_values()->mutable_list()->add_type(allowed);
|
53 |
+
return *this;
|
54 |
+
}
|
55 |
+
|
56 |
+
KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) {
|
57 |
+
kernel_def_->add_host_memory_arg(arg_name);
|
58 |
+
return *this;
|
59 |
+
}
|
60 |
+
|
61 |
+
KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
|
62 |
+
CHECK_EQ(kernel_def_->label(), "")
|
63 |
+
<< "Trying to set a kernel's label a second time: '" << label
|
64 |
+
<< "' in: " << ProtoShortDebugString(*kernel_def_);
|
65 |
+
kernel_def_->set_label(label);
|
66 |
+
return *this;
|
67 |
+
}
|
68 |
+
|
69 |
+
const KernelDef* KernelDefBuilder::Build() {
|
70 |
+
KernelDef* r = kernel_def_;
|
71 |
+
kernel_def_ = nullptr;
|
72 |
+
return r;
|
73 |
+
}
|
74 |
+
|
75 |
+
} // namespace tensorflow
|
kernel_def_builder.h
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
|
18 |
+
|
19 |
+
#include "tensorflow/core/framework/types.h"
|
20 |
+
#include "tensorflow/core/lib/gtl/array_slice.h"
|
21 |
+
#include "tensorflow/core/platform/macros.h"
|
22 |
+
#include "tensorflow/core/platform/types.h"
|
23 |
+
|
24 |
+
namespace tensorflow {
|
25 |
+
|
26 |
+
// Forward declare proto so that kernels don't need to depend on it
|
27 |
+
class KernelDef;
|
28 |
+
|
29 |
+
// Builder class passed to the REGISTER_KERNEL_BUILDER() macro.
|
30 |
+
class KernelDefBuilder {
|
31 |
+
public:
|
32 |
+
// Starts with just the name field set.
|
33 |
+
// Caller MUST call Build() and take ownership of the result.
|
34 |
+
explicit KernelDefBuilder(const char* op_name);
|
35 |
+
~KernelDefBuilder();
|
36 |
+
|
37 |
+
// Required: specify the type of device this kernel supports.
|
38 |
+
// Returns *this.
|
39 |
+
KernelDefBuilder& Device(const char* device_type);
|
40 |
+
// KernelDefBuilder& Device(DeviceType device_type);
|
41 |
+
|
42 |
+
// Specify that this kernel supports a limited set of values for a
|
43 |
+
// particular type or list(type) attr (a further restriction than
|
44 |
+
// what the Op allows).
|
45 |
+
// Returns *this.
|
46 |
+
KernelDefBuilder& TypeConstraint(const char* attr_name,
|
47 |
+
gtl::ArraySlice<DataType> allowed);
|
48 |
+
|
49 |
+
// Like TypeConstraint but supports just a single type.
|
50 |
+
KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
|
51 |
+
|
52 |
+
// Like TypeConstraint, but (a) gets the type from a template parameter
|
53 |
+
// and (b) only supports a constraint to a single type.
|
54 |
+
template <class T>
|
55 |
+
KernelDefBuilder& TypeConstraint(const char* attr_name);
|
56 |
+
// TODO(josh11b): Support other types of attr constraints as needed.
|
57 |
+
|
58 |
+
// Specify that this kernel requires/provides an input/output arg
|
59 |
+
// in host memory (instead of the default, device memory).
|
60 |
+
// Returns *this.
|
61 |
+
KernelDefBuilder& HostMemory(const char* arg_name);
|
62 |
+
|
63 |
+
// Specify that this kernel requires a particular value for the
|
64 |
+
// "_kernel" attr. May only be specified once. Returns *this.
|
65 |
+
KernelDefBuilder& Label(const char* label);
|
66 |
+
|
67 |
+
// Returns a pointer to a KernelDef with fields set based on the
|
68 |
+
// above calls to this instance.
|
69 |
+
// Caller takes ownership of the result.
|
70 |
+
const KernelDef* Build();
|
71 |
+
|
72 |
+
private:
|
73 |
+
KernelDef* kernel_def_;
|
74 |
+
|
75 |
+
TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
|
76 |
+
};
|
77 |
+
|
78 |
+
// IMPLEMENTATION
|
79 |
+
|
80 |
+
template <class T>
|
81 |
+
KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name) {
|
82 |
+
return this->TypeConstraint(attr_name, DataTypeToEnum<T>::v());
|
83 |
+
}
|
84 |
+
|
85 |
+
} // namespace tensorflow
|
86 |
+
|
87 |
+
#endif // TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
|
kernel_def_builder_test.cc
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/kernel_def_builder.h"
|
17 |
+
|
18 |
+
#include "tensorflow/core/framework/kernel_def.pb.h"
|
19 |
+
#include "tensorflow/core/platform/protobuf.h"
|
20 |
+
#include "tensorflow/core/platform/test.h"
|
21 |
+
|
22 |
+
namespace tensorflow {
|
23 |
+
namespace {
|
24 |
+
|
25 |
+
TEST(KernelDefBuilderTest, Basic) {
|
26 |
+
const KernelDef* def = KernelDefBuilder("A").Device(DEVICE_CPU).Build();
|
27 |
+
KernelDef expected;
|
28 |
+
protobuf::TextFormat::ParseFromString("op: 'A' device_type: 'CPU'",
|
29 |
+
&expected);
|
30 |
+
EXPECT_EQ(def->DebugString(), expected.DebugString());
|
31 |
+
delete def;
|
32 |
+
}
|
33 |
+
|
34 |
+
TEST(KernelDefBuilderTest, TypeConstraint) {
|
35 |
+
const KernelDef* def = KernelDefBuilder("B")
|
36 |
+
.Device(DEVICE_GPU)
|
37 |
+
.TypeConstraint<float>("T")
|
38 |
+
.Build();
|
39 |
+
KernelDef expected;
|
40 |
+
protobuf::TextFormat::ParseFromString(R"proto(
|
41 |
+
op: 'B' device_type: 'GPU'
|
42 |
+
constraint { name: 'T' allowed_values { list { type: DT_FLOAT } } } )proto",
|
43 |
+
&expected);
|
44 |
+
|
45 |
+
EXPECT_EQ(def->DebugString(), expected.DebugString());
|
46 |
+
delete def;
|
47 |
+
|
48 |
+
def = KernelDefBuilder("C")
|
49 |
+
.Device(DEVICE_GPU)
|
50 |
+
.TypeConstraint<int32>("U")
|
51 |
+
.TypeConstraint<bool>("V")
|
52 |
+
.Build();
|
53 |
+
|
54 |
+
protobuf::TextFormat::ParseFromString(R"proto(
|
55 |
+
op: 'C' device_type: 'GPU'
|
56 |
+
constraint { name: 'U' allowed_values { list { type: DT_INT32 } } }
|
57 |
+
constraint { name: 'V' allowed_values { list { type: DT_BOOL } } } )proto",
|
58 |
+
&expected);
|
59 |
+
EXPECT_EQ(def->DebugString(), expected.DebugString());
|
60 |
+
delete def;
|
61 |
+
|
62 |
+
def = KernelDefBuilder("D")
|
63 |
+
.Device(DEVICE_CPU)
|
64 |
+
.TypeConstraint("W", {DT_DOUBLE, DT_STRING})
|
65 |
+
.Build();
|
66 |
+
protobuf::TextFormat::ParseFromString(R"proto(
|
67 |
+
op: 'D' device_type: 'CPU'
|
68 |
+
constraint { name: 'W'
|
69 |
+
allowed_values { list { type: [DT_DOUBLE, DT_STRING] } } } )proto",
|
70 |
+
&expected);
|
71 |
+
EXPECT_EQ(def->DebugString(), expected.DebugString());
|
72 |
+
delete def;
|
73 |
+
}
|
74 |
+
|
75 |
+
TEST(KernelDefBuilderTest, HostMemory) {
|
76 |
+
const KernelDef* def = KernelDefBuilder("E")
|
77 |
+
.Device(DEVICE_GPU)
|
78 |
+
.HostMemory("in")
|
79 |
+
.HostMemory("out")
|
80 |
+
.Build();
|
81 |
+
KernelDef expected;
|
82 |
+
protobuf::TextFormat::ParseFromString(
|
83 |
+
"op: 'E' device_type: 'GPU' "
|
84 |
+
"host_memory_arg: ['in', 'out']",
|
85 |
+
&expected);
|
86 |
+
EXPECT_EQ(def->DebugString(), expected.DebugString());
|
87 |
+
delete def;
|
88 |
+
}
|
89 |
+
|
90 |
+
} // namespace
|
91 |
+
} // namespace tensorflow
|
load_library.cc
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include <memory>
|
17 |
+
#include <unordered_set>
|
18 |
+
|
19 |
+
#include "tensorflow/core/framework/op.h"
|
20 |
+
#include "tensorflow/core/framework/op_kernel.h"
|
21 |
+
#include "tensorflow/core/lib/core/errors.h"
|
22 |
+
#include "tensorflow/core/platform/env.h"
|
23 |
+
#include "tensorflow/core/platform/mem.h"
|
24 |
+
|
25 |
+
namespace tensorflow {
|
26 |
+
|
27 |
+
namespace {
|
28 |
+
|
29 |
+
struct Library {
|
30 |
+
void* handle = nullptr;
|
31 |
+
OpList op_list;
|
32 |
+
};
|
33 |
+
|
34 |
+
} // namespace
|
35 |
+
|
36 |
+
// Load a dynamic library.
|
37 |
+
// On success, returns the handle to library in result, copies the serialized
|
38 |
+
// OpList of OpDefs registered in the library to *buf and the length to *len,
|
39 |
+
// and returns OK from the function. Otherwise return nullptr in result
|
40 |
+
// and an error status from the function, leaving buf and len untouched.
|
41 |
+
//
|
42 |
+
// If `library_filename` has already been loaded, we return a cached handle
|
43 |
+
// and OpList. Ops and kernels are registered as globals when a library is
|
44 |
+
// loaded for the first time. Without caching, every subsequent load would not
|
45 |
+
// perform initialization again, so the OpList would be empty.
|
46 |
+
Status LoadLibrary(const char* library_filename, void** result,
|
47 |
+
const void** buf, size_t* len) {
|
48 |
+
static mutex mu(LINKER_INITIALIZED);
|
49 |
+
static std::unordered_map<string, Library> loaded_libs;
|
50 |
+
Env* env = Env::Default();
|
51 |
+
Library library;
|
52 |
+
std::unordered_set<string> seen_op_names;
|
53 |
+
{
|
54 |
+
mutex_lock lock(mu);
|
55 |
+
if (loaded_libs.find(library_filename) != loaded_libs.end()) {
|
56 |
+
library = loaded_libs[library_filename];
|
57 |
+
} else {
|
58 |
+
Status s = OpRegistry::Global()->ProcessRegistrations();
|
59 |
+
if (!s.ok()) {
|
60 |
+
return s;
|
61 |
+
}
|
62 |
+
TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(
|
63 |
+
[&library, &seen_op_names](const Status& s,
|
64 |
+
const OpDef& opdef) -> Status {
|
65 |
+
if (errors::IsAlreadyExists(s)) {
|
66 |
+
if (seen_op_names.find(opdef.name()) == seen_op_names.end()) {
|
67 |
+
// Over writing a registration of an op not in this custom op
|
68 |
+
// library. Treat this as not an error.
|
69 |
+
return Status::OK();
|
70 |
+
}
|
71 |
+
}
|
72 |
+
if (s.ok()) {
|
73 |
+
*library.op_list.add_op() = opdef;
|
74 |
+
seen_op_names.insert(opdef.name());
|
75 |
+
}
|
76 |
+
return s;
|
77 |
+
}));
|
78 |
+
OpRegistry::Global()->DeferRegistrations();
|
79 |
+
s = env->LoadLibrary(library_filename, &library.handle);
|
80 |
+
if (s.ok()) {
|
81 |
+
s = OpRegistry::Global()->ProcessRegistrations();
|
82 |
+
}
|
83 |
+
if (!s.ok()) {
|
84 |
+
OpRegistry::Global()->ClearDeferredRegistrations();
|
85 |
+
TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr));
|
86 |
+
return s;
|
87 |
+
}
|
88 |
+
TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr));
|
89 |
+
|
90 |
+
loaded_libs[library_filename] = library;
|
91 |
+
}
|
92 |
+
}
|
93 |
+
string str;
|
94 |
+
library.op_list.SerializeToString(&str);
|
95 |
+
char* str_buf = reinterpret_cast<char*>(port::Malloc(str.length()));
|
96 |
+
memcpy(str_buf, str.data(), str.length());
|
97 |
+
*buf = str_buf;
|
98 |
+
*len = str.length();
|
99 |
+
|
100 |
+
*result = library.handle;
|
101 |
+
return Status::OK();
|
102 |
+
}
|
103 |
+
|
104 |
+
} // namespace tensorflow
|
log_memory.cc
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/log_memory.h"
|
17 |
+
|
18 |
+
#include "tensorflow/core/framework/log_memory.pb_text.h"
|
19 |
+
#include "tensorflow/core/framework/log_memory.pb.h"
|
20 |
+
|
21 |
+
namespace tensorflow {
|
22 |
+
|
23 |
+
const string LogMemory::kLogMemoryLabel = "__LOG_MEMORY__";
|
24 |
+
|
25 |
+
bool LogMemory::IsEnabled() { return VLOG_IS_ON(1); }
|
26 |
+
|
27 |
+
namespace {
|
28 |
+
|
29 |
+
// Write the proto entry to LOG(INFO).
|
30 |
+
template <typename T>
|
31 |
+
void OutputToLog(const T& proto) {
|
32 |
+
string type_name = proto.GetTypeName();
|
33 |
+
const size_t index = type_name.find_last_of(".");
|
34 |
+
if (index != string::npos) type_name = type_name.substr(index + 1);
|
35 |
+
LOG(INFO) << LogMemory::kLogMemoryLabel << " " << type_name << " { "
|
36 |
+
<< ProtoShortDebugString(proto) << " }";
|
37 |
+
}
|
38 |
+
|
39 |
+
} // namespace
|
40 |
+
|
41 |
+
void LogMemory::RecordStep(const int64 step_id, const string& handle) {
|
42 |
+
MemoryLogStep step;
|
43 |
+
step.set_step_id(step_id);
|
44 |
+
step.set_handle(handle);
|
45 |
+
OutputToLog(step);
|
46 |
+
}
|
47 |
+
|
48 |
+
void LogMemory::RecordTensorAllocation(const string& kernel_name,
|
49 |
+
const int64 step_id,
|
50 |
+
const Tensor& tensor) {
|
51 |
+
MemoryLogTensorAllocation allocation;
|
52 |
+
allocation.set_step_id(step_id);
|
53 |
+
allocation.set_kernel_name(kernel_name);
|
54 |
+
tensor.FillDescription(allocation.mutable_tensor());
|
55 |
+
OutputToLog(allocation);
|
56 |
+
}
|
57 |
+
|
58 |
+
void LogMemory::RecordTensorDeallocation(const int64 allocation_id,
|
59 |
+
const string& allocator_name) {
|
60 |
+
MemoryLogTensorDeallocation deallocation;
|
61 |
+
deallocation.set_allocation_id(allocation_id);
|
62 |
+
deallocation.set_allocator_name(allocator_name);
|
63 |
+
OutputToLog(deallocation);
|
64 |
+
}
|
65 |
+
|
66 |
+
void LogMemory::RecordTensorOutput(const string& kernel_name,
|
67 |
+
const int64 step_id, const int index,
|
68 |
+
const Tensor& tensor) {
|
69 |
+
MemoryLogTensorOutput output;
|
70 |
+
output.set_step_id(step_id);
|
71 |
+
output.set_kernel_name(kernel_name);
|
72 |
+
output.set_index(index);
|
73 |
+
tensor.FillDescription(output.mutable_tensor());
|
74 |
+
OutputToLog(output);
|
75 |
+
}
|
76 |
+
|
77 |
+
void LogMemory::RecordRawAllocation(const string& operation,
|
78 |
+
const int64 step_id, size_t num_bytes,
|
79 |
+
void* ptr, Allocator* allocator) {
|
80 |
+
MemoryLogRawAllocation allocation;
|
81 |
+
allocation.set_step_id(step_id);
|
82 |
+
allocation.set_operation(operation);
|
83 |
+
allocation.set_num_bytes(static_cast<int64>(num_bytes));
|
84 |
+
allocation.set_ptr(reinterpret_cast<uintptr_t>(ptr));
|
85 |
+
allocation.set_allocation_id(allocator->AllocationId(ptr));
|
86 |
+
allocation.set_allocator_name(allocator->Name());
|
87 |
+
OutputToLog(allocation);
|
88 |
+
}
|
89 |
+
|
90 |
+
void LogMemory::RecordRawDeallocation(const string& operation,
|
91 |
+
const int64 step_id, void* ptr,
|
92 |
+
Allocator* allocator, bool deferred) {
|
93 |
+
MemoryLogRawDeallocation deallocation;
|
94 |
+
deallocation.set_step_id(step_id);
|
95 |
+
deallocation.set_operation(operation);
|
96 |
+
deallocation.set_allocation_id(allocator->AllocationId(ptr));
|
97 |
+
deallocation.set_allocator_name(allocator->Name());
|
98 |
+
deallocation.set_deferred(deferred);
|
99 |
+
OutputToLog(deallocation);
|
100 |
+
}
|
101 |
+
|
102 |
+
} // namespace tensorflow
|
log_memory.h
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
|
18 |
+
|
19 |
+
#include "tensorflow/core/framework/tensor.h"
|
20 |
+
#include "tensorflow/core/platform/protobuf.h"
|
21 |
+
|
22 |
+
namespace tensorflow {
|
23 |
+
|
24 |
+
// LogMemory contains methods for recording memory allocations and
|
25 |
+
// frees, associating each allocation with a step identified by a
|
26 |
+
// process-wide id. For now, logging is enabled whenever VLOG_IS_ON(1)
|
27 |
+
// for the log_memory module.
|
28 |
+
//
|
29 |
+
// Limitations: We don't log memory allocations by Eigen on the CPU
|
30 |
+
// since that would require major changes to plumb through to the
|
31 |
+
// Eigen::{DefaultDevice,ThreadPoolDevice} allocate and deallocate
|
32 |
+
// methods. We do log Eigen allocations on GPU since the plumbing was
|
33 |
+
// already in place.
|
34 |
+
class LogMemory {
|
35 |
+
public:
|
36 |
+
// Allocations sometimes happen outside any computation step, and
|
37 |
+
// SpecialStepIds lists the ids used for those steps.
|
38 |
+
enum SpecialStepIds {
|
39 |
+
// Used when performing a just-in-time constant folding optimization.
|
40 |
+
CONSTANT_FOLDING_STEP_ID = -1,
|
41 |
+
// Used when constructing an Op kernel before executing a step.
|
42 |
+
OP_KERNEL_CONSTRUCTION_STEP_ID = -2,
|
43 |
+
// Used when allocating a tensor buffer from external code, e.g.,
|
44 |
+
// the C API.
|
45 |
+
EXTERNAL_TENSOR_ALLOCATION_STEP_ID = -3,
|
46 |
+
// Used when allocating a buffer for network transfer.
|
47 |
+
NETWORK_BUFFER_STEP_ID = -4,
|
48 |
+
// Used when allocating a buffer to fill a Proto from the GPU.
|
49 |
+
PROTO_BUFFER_STEP_ID = -5,
|
50 |
+
// Used when allocating a Tensor where the caller has not indicated
|
51 |
+
// the step.
|
52 |
+
UNKNOWN_STEP_ID = -6,
|
53 |
+
};
|
54 |
+
|
55 |
+
static const string kLogMemoryLabel;
|
56 |
+
|
57 |
+
// Test to see if memory logging is enabled. For now, logging is
|
58 |
+
// enabled whenever VLOG_IS_ON(1) for the log_memory module.
|
59 |
+
static bool IsEnabled();
|
60 |
+
|
61 |
+
// Log the beginning of a step.
|
62 |
+
static void RecordStep(int64 step_id, const string& handle);
|
63 |
+
|
64 |
+
// Log a tensor buffer allocation. The name indicates which kernel
|
65 |
+
// made the allocation. If the allocation is made through an
|
66 |
+
// OpKernelContext the step_id indicates which step is executing,
|
67 |
+
// otherwise step_id is one of the SpecialStepIds defined in
|
68 |
+
// op_kernel.h, e.g. Op Kernel construction or an optimization pass
|
69 |
+
// such as constant folding.
|
70 |
+
static void RecordTensorAllocation(const string& kernel_name, int64 step_id,
|
71 |
+
const Tensor& tensor);
|
72 |
+
|
73 |
+
// Log a tensor buffer deallocation. The deallocation is triggered
|
74 |
+
// when the buffer's refcount falls to zero, and the tracking
|
75 |
+
// mechanism does not associate it with a particular step or
|
76 |
+
// kernel. The allocation_id/allocator_name should match a
|
77 |
+
// corresponding tensor previously passed in to
|
78 |
+
// RecordTensorAllocation.
|
79 |
+
static void RecordTensorDeallocation(int64 allocation_id,
|
80 |
+
const string& allocator_name);
|
81 |
+
|
82 |
+
// Log the use of a tensor as an output from a kernel.
|
83 |
+
static void RecordTensorOutput(const string& kernel_name, int64 step_id,
|
84 |
+
int index, const Tensor& tensor);
|
85 |
+
|
86 |
+
// Log a "raw" allocation, which is just a buffer sized in
|
87 |
+
// bytes. The Eigen allocator, and memory copies, record their
|
88 |
+
// allocations this way, since they do not allocate TensorFlow
|
89 |
+
// tensors. The operation is set to the OpKernel name if this is
|
90 |
+
// called from within an Op execution, otherwise it indicates an
|
91 |
+
// operation such as memcpy. The step_id if >=0 indicates which step
|
92 |
+
// is executing, otherwise step_id is one of the SpecialStepIds
|
93 |
+
// defined in op_kernel.h, e.g. Op Kernel construction or an
|
94 |
+
// optimization pass such as constant folding.
|
95 |
+
static void RecordRawAllocation(const string& operation, int64 step_id,
|
96 |
+
size_t num_bytes, void* ptr,
|
97 |
+
Allocator* allocator);
|
98 |
+
|
99 |
+
// Log a "raw" deallocation of a buffer. When deferred is true, the
|
100 |
+
// buffer won't be used again, but a GPU kernel may still be
|
101 |
+
// enqueued using the buffer. A deferred deallocation should always
|
102 |
+
// be followed by a matching non-deferred deallocation when the
|
103 |
+
// buffer is actually returned and can be reused.
|
104 |
+
static void RecordRawDeallocation(const string& operation, int64 step_id,
|
105 |
+
void* ptr, Allocator* allocator,
|
106 |
+
bool deferred);
|
107 |
+
};
|
108 |
+
|
109 |
+
} // namespace tensorflow
|
110 |
+
|
111 |
+
#endif // TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
|
log_memory.proto
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto3";
|
2 |
+
|
3 |
+
package tensorflow;
|
4 |
+
option cc_enable_arenas = true;
|
5 |
+
option java_outer_classname = "LogMemoryProtos";
|
6 |
+
option java_multiple_files = true;
|
7 |
+
option java_package = "org.tensorflow.framework";
|
8 |
+
|
9 |
+
import "tensorflow/core/framework/tensor_description.proto";
|
10 |
+
|
11 |
+
message MemoryLogStep {
|
12 |
+
// Process-unique step id.
|
13 |
+
int64 step_id = 1;
|
14 |
+
|
15 |
+
// Handle describing the feeds and fetches of the step.
|
16 |
+
string handle = 2;
|
17 |
+
};
|
18 |
+
|
19 |
+
message MemoryLogTensorAllocation {
|
20 |
+
// Process-unique step id.
|
21 |
+
int64 step_id = 1;
|
22 |
+
|
23 |
+
// Name of the kernel making the allocation as set in GraphDef,
|
24 |
+
// e.g., "affine2/weights/Assign".
|
25 |
+
string kernel_name = 2;
|
26 |
+
|
27 |
+
// Allocated tensor details.
|
28 |
+
TensorDescription tensor = 3;
|
29 |
+
};
|
30 |
+
|
31 |
+
message MemoryLogTensorDeallocation {
|
32 |
+
// Id of the tensor buffer being deallocated, used to match to a
|
33 |
+
// corresponding allocation.
|
34 |
+
int64 allocation_id = 1;
|
35 |
+
|
36 |
+
// Name of the allocator used.
|
37 |
+
string allocator_name = 2;
|
38 |
+
};
|
39 |
+
|
40 |
+
message MemoryLogTensorOutput {
|
41 |
+
// Process-unique step id.
|
42 |
+
int64 step_id = 1;
|
43 |
+
|
44 |
+
// Name of the kernel producing an output as set in GraphDef, e.g.,
|
45 |
+
// "affine2/weights/Assign".
|
46 |
+
string kernel_name = 2;
|
47 |
+
|
48 |
+
// Index of the output being set.
|
49 |
+
int32 index = 3;
|
50 |
+
|
51 |
+
// Output tensor details.
|
52 |
+
TensorDescription tensor = 4;
|
53 |
+
}
|
54 |
+
|
55 |
+
message MemoryLogRawAllocation {
|
56 |
+
// Process-unique step id.
|
57 |
+
int64 step_id = 1;
|
58 |
+
|
59 |
+
// Name of the operation making the allocation.
|
60 |
+
string operation = 2;
|
61 |
+
|
62 |
+
// Number of bytes in the allocation.
|
63 |
+
int64 num_bytes = 3;
|
64 |
+
|
65 |
+
// Address of the allocation.
|
66 |
+
uint64 ptr = 4;
|
67 |
+
|
68 |
+
// Id of the tensor buffer being allocated, used to match to a
|
69 |
+
// corresponding deallocation.
|
70 |
+
int64 allocation_id = 5;
|
71 |
+
|
72 |
+
// Name of the allocator used.
|
73 |
+
string allocator_name = 6;
|
74 |
+
};
|
75 |
+
|
76 |
+
message MemoryLogRawDeallocation {
|
77 |
+
// Process-unique step id.
|
78 |
+
int64 step_id = 1;
|
79 |
+
|
80 |
+
// Name of the operation making the deallocation.
|
81 |
+
string operation = 2;
|
82 |
+
|
83 |
+
// Id of the tensor buffer being deallocated, used to match to a
|
84 |
+
// corresponding allocation.
|
85 |
+
int64 allocation_id = 3;
|
86 |
+
|
87 |
+
// Name of the allocator used.
|
88 |
+
string allocator_name = 4;
|
89 |
+
|
90 |
+
// True if the deallocation is queued and will be performed later,
|
91 |
+
// e.g. for GPU lazy freeing of buffers.
|
92 |
+
bool deferred = 5;
|
93 |
+
};
|
lookup_interface.cc
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/lookup_interface.h"
|
17 |
+
|
18 |
+
#include "tensorflow/core/framework/tensor_shape.h"
|
19 |
+
#include "tensorflow/core/lib/core/errors.h"
|
20 |
+
|
21 |
+
namespace tensorflow {
|
22 |
+
namespace lookup {
|
23 |
+
|
24 |
+
Status LookupInterface::CheckKeyShape(const TensorShape& shape) {
|
25 |
+
if (!TensorShapeUtils::EndsWith(shape, key_shape())) {
|
26 |
+
return errors::InvalidArgument("Input key shape ", shape.DebugString(),
|
27 |
+
" must end with the table's key shape ",
|
28 |
+
key_shape().DebugString());
|
29 |
+
}
|
30 |
+
return Status::OK();
|
31 |
+
}
|
32 |
+
|
33 |
+
Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys,
|
34 |
+
const Tensor& values) {
|
35 |
+
if (keys.dtype() != key_dtype()) {
|
36 |
+
return errors::InvalidArgument("Key must be type ", key_dtype(),
|
37 |
+
" but got ", keys.dtype());
|
38 |
+
}
|
39 |
+
if (values.dtype() != value_dtype()) {
|
40 |
+
return errors::InvalidArgument("Value must be type ", value_dtype(),
|
41 |
+
" but got ", values.dtype());
|
42 |
+
}
|
43 |
+
return Status::OK();
|
44 |
+
}
|
45 |
+
|
46 |
+
Status LookupInterface::CheckKeyAndValueTensorsHelper(const Tensor& keys,
|
47 |
+
const Tensor& values) {
|
48 |
+
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values));
|
49 |
+
TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape()));
|
50 |
+
|
51 |
+
TensorShape expected_value_shape = keys.shape();
|
52 |
+
for (int i = 0; i < key_shape().dims(); ++i) {
|
53 |
+
expected_value_shape.RemoveDim(expected_value_shape.dims() - 1);
|
54 |
+
}
|
55 |
+
expected_value_shape.AppendShape(value_shape());
|
56 |
+
if (values.shape() != expected_value_shape) {
|
57 |
+
return errors::InvalidArgument(
|
58 |
+
"Expected shape ", expected_value_shape.DebugString(),
|
59 |
+
" for value, got ", values.shape().DebugString());
|
60 |
+
}
|
61 |
+
return Status::OK();
|
62 |
+
}
|
63 |
+
|
64 |
+
Status LookupInterface::CheckKeyAndValueTensorsForInsert(const Tensor& keys,
|
65 |
+
const Tensor& values) {
|
66 |
+
return CheckKeyAndValueTensorsHelper(keys, values);
|
67 |
+
}
|
68 |
+
|
69 |
+
Status LookupInterface::CheckKeyAndValueTensorsForImport(const Tensor& keys,
|
70 |
+
const Tensor& values) {
|
71 |
+
return CheckKeyAndValueTensorsHelper(keys, values);
|
72 |
+
}
|
73 |
+
|
74 |
+
Status LookupInterface::CheckFindArguments(const Tensor& key,
|
75 |
+
const Tensor& default_value) {
|
76 |
+
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
|
77 |
+
TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
|
78 |
+
if (default_value.shape() != value_shape()) {
|
79 |
+
return errors::InvalidArgument(
|
80 |
+
"Expected shape ", value_shape().DebugString(),
|
81 |
+
" for default value, got ", default_value.shape().DebugString());
|
82 |
+
}
|
83 |
+
return Status::OK();
|
84 |
+
}
|
85 |
+
|
86 |
+
} // namespace lookup
|
87 |
+
} // namespace tensorflow
|
lookup_interface.h
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#ifndef TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
|
17 |
+
#define TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
|
18 |
+
|
19 |
+
#include "tensorflow/core/framework/resource_mgr.h"
|
20 |
+
#include "tensorflow/core/framework/tensor.h"
|
21 |
+
#include "tensorflow/core/lib/core/status.h"
|
22 |
+
|
23 |
+
namespace tensorflow {
|
24 |
+
|
25 |
+
class OpKernelContext;
|
26 |
+
|
27 |
+
namespace lookup {
|
28 |
+
|
29 |
+
// Forward declaration so we can define GetInitializableLookupTable() in
|
30 |
+
// LookupInterface.
|
31 |
+
class InitializableLookupTable;
|
32 |
+
|
33 |
+
// Lookup interface for batch lookups used by table lookup ops.
|
34 |
+
class LookupInterface : public ResourceBase {
|
35 |
+
public:
|
36 |
+
// Performs batch lookups, for every element in the key tensor, Find returns
|
37 |
+
// the corresponding value into the values tensor.
|
38 |
+
// If an element is not present in the table, the given default value is used.
|
39 |
+
|
40 |
+
// For tables that require initialization, Find is available once the table
|
41 |
+
// is marked as initialized.
|
42 |
+
|
43 |
+
// Returns the following statuses:
|
44 |
+
// - OK: when the find finishes successfully.
|
45 |
+
// - FailedPrecondition: if the table is not initialized.
|
46 |
+
// - InvalidArgument: if any of the preconditions on the lookup key or value
|
47 |
+
// fails.
|
48 |
+
// - In addition, other implementations may provide another non-OK status
|
49 |
+
// specific to their failure modes.
|
50 |
+
virtual Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values,
|
51 |
+
const Tensor& default_value) = 0;
|
52 |
+
|
53 |
+
// Inserts elements into the table. Each element of the key tensor is
|
54 |
+
// associated with the corresponding element in the value tensor.
|
55 |
+
// This method is only implemented in mutable tables that can be updated over
|
56 |
+
// the execution of the graph. It returns Status::NotImplemented for read-only
|
57 |
+
// tables that are initialized once before they can be looked up.
|
58 |
+
|
59 |
+
// Returns the following statuses:
|
60 |
+
// - OK: when the insert finishes successfully.
|
61 |
+
// - InvalidArgument: if any of the preconditions on the lookup key or value
|
62 |
+
// fails.
|
63 |
+
// - Unimplemented: if the table does not support insertions.
|
64 |
+
virtual Status Insert(OpKernelContext* ctx, const Tensor& keys,
|
65 |
+
const Tensor& values) = 0;
|
66 |
+
|
67 |
+
// Returns the number of elements in the table.
|
68 |
+
virtual size_t size() const = 0;
|
69 |
+
|
70 |
+
// Exports the values of the table to two tensors named keys and values.
|
71 |
+
// Note that the shape of the tensors is completely up to the implementation
|
72 |
+
// of the table and can be different than the tensors used for the Insert
|
73 |
+
// function above.
|
74 |
+
virtual Status ExportValues(OpKernelContext* ctx) = 0;
|
75 |
+
|
76 |
+
// Imports previously exported keys and values.
|
77 |
+
// As mentioned above, the shape of the keys and values tensors are determined
|
78 |
+
// by the ExportValues function above and can be different than for the
|
79 |
+
// Insert function.
|
80 |
+
virtual Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
|
81 |
+
const Tensor& values) = 0;
|
82 |
+
|
83 |
+
// Returns the data type of the key.
|
84 |
+
virtual DataType key_dtype() const = 0;
|
85 |
+
|
86 |
+
// Returns the data type of the value.
|
87 |
+
virtual DataType value_dtype() const = 0;
|
88 |
+
|
89 |
+
// Returns the shape of a key in the table.
|
90 |
+
virtual TensorShape key_shape() const = 0;
|
91 |
+
|
92 |
+
// Returns the shape of a value in the table.
|
93 |
+
virtual TensorShape value_shape() const = 0;
|
94 |
+
|
95 |
+
// Check format of the key and value tensors for the Insert function.
|
96 |
+
// Returns OK if all the following requirements are satisfied, otherwise it
|
97 |
+
// returns InvalidArgument:
|
98 |
+
// - DataType of the tensor keys equals to the table key_dtype
|
99 |
+
// - DataType of the tensor values equals to the table value_dtype
|
100 |
+
// - the values tensor has the required shape given keys and the tables's
|
101 |
+
// value shape.
|
102 |
+
virtual Status CheckKeyAndValueTensorsForInsert(const Tensor& keys,
|
103 |
+
const Tensor& values);
|
104 |
+
|
105 |
+
// Similar to the function above but instead checks eligibility for the Import
|
106 |
+
// function.
|
107 |
+
virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys,
|
108 |
+
const Tensor& values);
|
109 |
+
|
110 |
+
// Check the arguments of a find operation. Returns OK if all the following
|
111 |
+
// requirements are satisfied, otherwise it returns InvalidArgument:
|
112 |
+
// - DataType of the tensor keys equals to the table key_dtype
|
113 |
+
// - DataType of the tensor default_value equals to the table value_dtype
|
114 |
+
// - the default_value tensor shape matches the table's value shape.
|
115 |
+
Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
|
116 |
+
|
117 |
+
string DebugString() override {
|
118 |
+
return strings::StrCat("A lookup table of size: ", size());
|
119 |
+
}
|
120 |
+
|
121 |
+
// Returns an InitializableLookupTable, a subclass of LookupInterface, if the
|
122 |
+
// current object is an InitializableLookupTable. Otherwise, returns nullptr.
|
123 |
+
virtual InitializableLookupTable* GetInitializableLookupTable() {
|
124 |
+
return nullptr;
|
125 |
+
}
|
126 |
+
|
127 |
+
protected:
|
128 |
+
virtual ~LookupInterface() = default;
|
129 |
+
|
130 |
+
// Makes sure that the key and value tensor DataType's match the table
|
131 |
+
// key_dtype and value_dtype.
|
132 |
+
Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values);
|
133 |
+
|
134 |
+
// Makes sure that the provided shape is consistent with the table keys shape.
|
135 |
+
Status CheckKeyShape(const TensorShape& shape);
|
136 |
+
|
137 |
+
private:
|
138 |
+
Status CheckKeyAndValueTensorsHelper(const Tensor& keys,
|
139 |
+
const Tensor& values);
|
140 |
+
};
|
141 |
+
|
142 |
+
} // namespace lookup
|
143 |
+
} // namespace tensorflow
|
144 |
+
|
145 |
+
#endif // TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
|
memory_types.cc
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
==============================================================================*/
|
15 |
+
|
16 |
+
#include "tensorflow/core/framework/memory_types.h"
|
17 |
+
|
18 |
+
#include <utility>
|
19 |
+
|
20 |
+
#include "tensorflow/core/framework/kernel_def.pb.h"
|
21 |
+
#include "tensorflow/core/framework/node_def.pb.h"
|
22 |
+
#include "tensorflow/core/framework/node_def_util.h"
|
23 |
+
#include "tensorflow/core/framework/op_kernel.h"
|
24 |
+
#include "tensorflow/core/framework/types.h"
|
25 |
+
#include "tensorflow/core/lib/core/errors.h"
|
26 |
+
#include "tensorflow/core/platform/types.h"
|
27 |
+
|
28 |
+
namespace tensorflow {
|
29 |
+
|
30 |
+
namespace {
|
31 |
+
// Returns the largest endpoint of anything in the name_map.
|
32 |
+
int GetTotal(const NameRangeMap& name_map) {
|
33 |
+
int total = 0;
|
34 |
+
for (const auto& item : name_map) {
|
35 |
+
total = std::max(total, item.second.second);
|
36 |
+
}
|
37 |
+
return total;
|
38 |
+
}
|
39 |
+
|
40 |
+
// Fills memory_types for either input or output, setting everything
|
41 |
+
// to DEVICE_MEMORY except those args in host_memory_args. Removes
|
42 |
+
// elements of host_memory_args that were used.
|
43 |
+
void MemoryTypesHelper(const NameRangeMap& name_map,
|
44 |
+
std::vector<string>* host_memory_args,
|
45 |
+
MemoryTypeVector* memory_types) {
|
46 |
+
// Update args that have been marked as in "HOST_MEMORY".
|
47 |
+
size_t keep = 0;
|
48 |
+
for (size_t i = 0; i < host_memory_args->size(); ++i) {
|
49 |
+
auto iter = name_map.find((*host_memory_args)[i]);
|
50 |
+
if (iter != name_map.end()) {
|
51 |
+
for (int j = iter->second.first; j < iter->second.second; ++j) {
|
52 |
+
(*memory_types)[j] = HOST_MEMORY;
|
53 |
+
}
|
54 |
+
} else {
|
55 |
+
// (*host_memory_args)[i] not found, save it for the next pass.
|
56 |
+
if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i];
|
57 |
+
++keep;
|
58 |
+
}
|
59 |
+
}
|
60 |
+
host_memory_args->resize(keep);
|
61 |
+
}
|
62 |
+
|
63 |
+
MemoryType MTypeFromDType(const DataType dtype) {
|
64 |
+
return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY
|
65 |
+
: DEVICE_MEMORY;
|
66 |
+
}
|
67 |
+
|
68 |
+
} // namespace
|
69 |
+
|
70 |
+
Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
|
71 |
+
const DeviceType& device_type, const NodeDef& ndef,
|
72 |
+
MemoryTypeVector* inp_mtypes,
|
73 |
+
MemoryTypeVector* out_mtypes) {
|
74 |
+
// Look up the Op registered for this op name.
|
75 |
+
const OpDef* op_def;
|
76 |
+
TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(ndef.op(), &op_def));
|
77 |
+
|
78 |
+
// Look up the Kernel registered for this node def.
|
79 |
+
const KernelDef* kdef = nullptr;
|
80 |
+
Status status =
|
81 |
+
FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */);
|
82 |
+
|
83 |
+
DataTypeVector inp_dtypes;
|
84 |
+
DataTypeVector out_dtypes;
|
85 |
+
TF_RETURN_IF_ERROR(
|
86 |
+
InOutTypesForNode(ndef, *op_def, &inp_dtypes, &out_dtypes));
|
87 |
+
|
88 |
+
inp_mtypes->clear();
|
89 |
+
out_mtypes->clear();
|
90 |
+
|
91 |
+
// For functions (which have no KernelDef) and their gradients, we can only
|
92 |
+
// best-effort derive the memory type from the data type. For now, we assume
|
93 |
+
// int32 is always on host memory and other types are always on device memory.
|
94 |
+
// TODO(zhifengc,phawkins): We should do type inference over function bodies
|
95 |
+
// to derive the correct input/output memory types. We should also split
|
96 |
+
// host-memory and non host-memory arguments into separate type lists.
|
97 |
+
if (!status.ok() || ndef.op() == "SymbolicGradient") {
|
98 |
+
for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t));
|
99 |
+
for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t));
|
100 |
+
return Status::OK();
|
101 |
+
}
|
102 |
+
|
103 |
+
// Gets the input/output names and their corresponding endpoint ranges.
|
104 |
+
NameRangeMap inp_names;
|
105 |
+
NameRangeMap out_names;
|
106 |
+
TF_RETURN_IF_ERROR(NameRangesForNode(ndef, *op_def, &inp_names, &out_names));
|
107 |
+
|
108 |
+
// Now that we know the size, fill with the default 'DEVICE_MEMORY'.
|
109 |
+
inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY);
|
110 |
+
out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY);
|
111 |
+
|
112 |
+
// Fills in host memory types based on the kernel def.
|
113 |
+
const auto& from_proto = kdef->host_memory_arg();
|
114 |
+
std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
|
115 |
+
MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes);
|
116 |
+
MemoryTypesHelper(out_names, &host_memory_args, out_mtypes);
|
117 |
+
if (!host_memory_args.empty()) {
|
118 |
+
return errors::InvalidArgument(
|
119 |
+
"HostMemory args '", str_util::Join(host_memory_args, "', '"),
|
120 |
+
"' not found in OpDef: ", SummarizeOpDef(*op_def));
|
121 |
+
}
|
122 |
+
CHECK_LE(inp_mtypes->size(), inp_dtypes.size());
|
123 |
+
CHECK_LE(out_mtypes->size(), out_dtypes.size());
|
124 |
+
|
125 |
+
// Mark e.g. all resource and string types as host memory.
|
126 |
+
for (int i = 0; i < inp_mtypes->size(); ++i) {
|
127 |
+
if (DataTypeAlwaysOnHost(inp_dtypes[i])) {
|
128 |
+
(*inp_mtypes)[i] = HOST_MEMORY;
|
129 |
+
}
|
130 |
+
}
|
131 |
+
for (int i = 0; i < out_mtypes->size(); ++i) {
|
132 |
+
if (DataTypeAlwaysOnHost(out_dtypes[i])) {
|
133 |
+
(*out_mtypes)[i] = HOST_MEMORY;
|
134 |
+
}
|
135 |
+
}
|
136 |
+
|
137 |
+
std::vector<int32> hostmem_attr;
|
138 |
+
if (GetNodeAttr(ndef, "_input_hostmem", &hostmem_attr).ok()) {
|
139 |
+
for (int32 i : hostmem_attr) {
|
140 |
+
if (0 <= i && i < inp_mtypes->size()) {
|
141 |
+
(*inp_mtypes)[i] = HOST_MEMORY;
|
142 |
+
}
|
143 |
+
}
|
144 |
+
}
|
145 |
+
if (GetNodeAttr(ndef, "_output_hostmem", &hostmem_attr).ok()) {
|
146 |
+
for (int32 i : hostmem_attr) {
|
147 |
+
if (0 <= i && i < out_mtypes->size()) {
|
148 |
+
(*out_mtypes)[i] = HOST_MEMORY;
|
149 |
+
}
|
150 |
+
}
|
151 |
+
}
|
152 |
+
|
153 |
+
return Status::OK();
|
154 |
+
}
|
155 |
+
|
156 |
+
} // namespace tensorflow
|