sssdtgvg commited on
Commit
5178306
1 Parent(s): 8e5dd95

Upload 161 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
allocation_description.proto ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package tensorflow;
4
+ option cc_enable_arenas = true;
5
+ option java_outer_classname = "AllocationDescriptionProtos";
6
+ option java_multiple_files = true;
7
+ option java_package = "org.tensorflow.framework";
8
+
9
+ message AllocationDescription {
10
+ // Total number of bytes requested
11
+ int64 requested_bytes = 1;
12
+
13
+ // Total number of bytes allocated if known
14
+ int64 allocated_bytes = 2;
15
+
16
+ // Name of the allocator used
17
+ string allocator_name = 3;
18
+
19
+ // Identifier of the allocated buffer if known
20
+ int64 allocation_id = 4;
21
+
22
+ // Set if this tensor only has one remaining reference
23
+ bool has_single_reference = 5;
24
+
25
+ // Address of the allocation.
26
+ uint64 ptr = 6;
27
+ };
allocator.cc ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/allocator.h"
17
+
18
+ #include "tensorflow/core/framework/allocator_registry.h"
19
+ #include "tensorflow/core/framework/log_memory.h"
20
+ #include "tensorflow/core/framework/tracking_allocator.h"
21
+ #include "tensorflow/core/lib/strings/stringprintf.h"
22
+ #include "tensorflow/core/platform/mem.h"
23
+ #include "tensorflow/core/platform/mutex.h"
24
+ #include "tensorflow/core/platform/types.h"
25
+
26
+ namespace tensorflow {
27
+
28
+ void AllocatorStats::Clear() {
29
+ this->num_allocs = 0;
30
+ this->bytes_in_use = 0;
31
+ this->max_bytes_in_use = 0;
32
+ this->max_alloc_size = 0;
33
+ this->bytes_limit = 0;
34
+ }
35
+
36
+ string AllocatorStats::DebugString() const {
37
+ return strings::Printf(
38
+ "Limit: %20lld\n"
39
+ "InUse: %20lld\n"
40
+ "MaxInUse: %20lld\n"
41
+ "NumAllocs: %20lld\n"
42
+ "MaxAllocSize: %20lld\n",
43
+ this->bytes_limit, this->bytes_in_use, this->max_bytes_in_use,
44
+ this->num_allocs, this->max_alloc_size);
45
+ }
46
+
47
+ constexpr size_t Allocator::kAllocatorAlignment;
48
+
49
+ Allocator::~Allocator() {}
50
+
51
+ void RunResourceCtor(ResourceHandle* p, size_t n) {
52
+ for (size_t i = 0; i < n; ++p, ++i) new (p) ResourceHandle();
53
+ }
54
+
55
+ void RunResourceDtor(ResourceHandle* p, size_t n) {
56
+ for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
57
+ }
58
+
59
+ // If true, cpu allocator collects more stats.
60
+ static bool cpu_allocator_collect_stats = false;
61
+ // If true, cpu allocator collects full stats.
62
+ static bool cpu_allocator_collect_full_stats = false;
63
+
64
+ void EnableCPUAllocatorStats(bool enable) {
65
+ cpu_allocator_collect_stats = enable;
66
+ }
67
+ void EnableCPUAllocatorFullStats(bool enable) {
68
+ cpu_allocator_collect_full_stats = enable;
69
+ }
70
+
71
+ class CPUAllocator : public Allocator {
72
+ public:
73
+ CPUAllocator() {}
74
+
75
+ ~CPUAllocator() override {}
76
+
77
+ string Name() override { return "cpu"; }
78
+
79
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
80
+ void* p = port::AlignedMalloc(num_bytes, alignment);
81
+ if (cpu_allocator_collect_stats) {
82
+ const std::size_t alloc_size = port::MallocExtension_GetAllocatedSize(p);
83
+ mutex_lock l(mu_);
84
+ ++stats_.num_allocs;
85
+ stats_.bytes_in_use += alloc_size;
86
+ stats_.max_bytes_in_use =
87
+ std::max<int64>(stats_.max_bytes_in_use, stats_.bytes_in_use);
88
+ stats_.max_alloc_size =
89
+ std::max<int64>(stats_.max_alloc_size, alloc_size);
90
+ }
91
+ return p;
92
+ }
93
+
94
+ void DeallocateRaw(void* ptr) override {
95
+ if (cpu_allocator_collect_stats) {
96
+ const std::size_t alloc_size =
97
+ port::MallocExtension_GetAllocatedSize(ptr);
98
+ mutex_lock l(mu_);
99
+ stats_.bytes_in_use -= alloc_size;
100
+ }
101
+ port::AlignedFree(ptr);
102
+ }
103
+
104
+ void GetStats(AllocatorStats* stats) override {
105
+ mutex_lock l(mu_);
106
+ *stats = stats_;
107
+ }
108
+
109
+ size_t AllocatedSizeSlow(void* ptr) override {
110
+ return port::MallocExtension_GetAllocatedSize(ptr);
111
+ }
112
+
113
+ private:
114
+ mutex mu_;
115
+ AllocatorStats stats_ GUARDED_BY(mu_);
116
+
117
+ TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator);
118
+ };
119
+
120
+ Allocator* cpu_allocator() {
121
+ static Allocator* cpu_alloc = AllocatorRegistry::Global()->GetAllocator();
122
+ if (cpu_allocator_collect_full_stats && !cpu_alloc->TracksAllocationSizes()) {
123
+ cpu_alloc = new TrackingAllocator(cpu_alloc, true);
124
+ }
125
+ return cpu_alloc;
126
+ }
127
+
128
+ REGISTER_MEM_ALLOCATOR("DefaultCPUAllocator", 100, CPUAllocator);
129
+
130
+ } // namespace tensorflow
allocator.h ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
17
+ #define TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
18
+
19
+ #include <stdlib.h>
20
+
21
+ #include <limits>
22
+
23
+ #include "tensorflow/core/framework/numeric_types.h"
24
+ #include "tensorflow/core/framework/resource_handle.h"
25
+ #include "tensorflow/core/framework/type_traits.h"
26
+ #include "tensorflow/core/framework/variant.h"
27
+ #include "tensorflow/core/platform/logging.h"
28
+ #include "tensorflow/core/platform/types.h"
29
+
30
+ namespace tensorflow {
31
+
32
+ // Attributes for a single allocation call. Different calls to the same
33
+ // allocator could potentially have different allocation attributes.
34
+ struct AllocationAttributes {
35
+ // If the first attempt to allocate the memory fails, the allocation
36
+ // should return immediately without retrying.
37
+ // An example use case is optional scratch spaces where a failure
38
+ // has only performance impact.
39
+ bool no_retry_on_failure = false;
40
+ // If a Tensor is allocated without the following set to true, then
41
+ // it is logged as an unknown allocation. During execution Tensors
42
+ // should be allocated through the OpKernelContext which records
43
+ // which Op is performing the allocation, and sets this flag to
44
+ // true.
45
+ bool allocation_will_be_logged = false;
46
+ };
47
+
48
+ // Runtime statistics collected by an allocator.
49
+ struct AllocatorStats {
50
+ int64 num_allocs; // Number of allocations.
51
+ int64 bytes_in_use; // Number of bytes in use.
52
+ int64 max_bytes_in_use; // The maximum bytes in use.
53
+ int64 max_alloc_size; // The max single allocation seen.
54
+
55
+ // The upper limit what the allocator can allocate, if such a limit
56
+ // is known. Certain allocator may return 0 to indicate the limit is
57
+ // unknown.
58
+ int64 bytes_limit;
59
+
60
+ AllocatorStats() { Clear(); }
61
+
62
+ void Clear();
63
+ string DebugString() const;
64
+ };
65
+
66
+ // Allocator is an abstract interface for allocating and deallocating
67
+ // device memory.
68
+ class Allocator {
69
+ public:
70
+ #ifdef EIGEN_VECTORIZE_AVX512
71
+ // Align to 64 byte boundary.
72
+ static constexpr size_t kAllocatorAlignment = 64;
73
+ #else
74
+ // Align to 32 byte boundary.
75
+ static constexpr size_t kAllocatorAlignment = 32;
76
+ #endif
77
+
78
+ virtual ~Allocator();
79
+
80
+ // Return a string identifying this allocator
81
+ virtual string Name() = 0;
82
+
83
+ // Return an uninitialized block of memory that is "num_bytes" bytes
84
+ // in size. The returned pointer is guaranteed to be aligned to a
85
+ // multiple of "alignment" bytes.
86
+ // REQUIRES: "alignment" is a power of 2.
87
+ virtual void* AllocateRaw(size_t alignment, size_t num_bytes) = 0;
88
+
89
+ // Return an uninitialized block of memory that is "num_bytes" bytes
90
+ // in size with specified allocation attributes. The returned pointer is
91
+ // guaranteed to be aligned to a multiple of "alignment" bytes.
92
+ // REQUIRES: "alignment" is a power of 2.
93
+ virtual void* AllocateRaw(size_t alignment, size_t num_bytes,
94
+ const AllocationAttributes& allocation_attr) {
95
+ // The default behavior is to use the implementation without any allocation
96
+ // attributes.
97
+ return AllocateRaw(alignment, num_bytes);
98
+ }
99
+
100
+ // Deallocate a block of memory pointer to by "ptr"
101
+ // REQUIRES: "ptr" was previously returned by a call to AllocateRaw
102
+ virtual void DeallocateRaw(void* ptr) = 0;
103
+
104
+ // Convenience functions to do typed allocation. C++ constructors
105
+ // and destructors are invoked for complex types if necessary,
106
+ // depending on the concrete Allocator implementation. May return
107
+ // NULL if the tensor has too many elements to represent in a single
108
+ // allocation.
109
+ template <typename T>
110
+ T* Allocate(size_t num_elements) {
111
+ return Allocate<T>(num_elements, AllocationAttributes());
112
+ }
113
+
114
+ template <typename T>
115
+ T* Allocate(size_t num_elements,
116
+ const AllocationAttributes& allocation_attr) {
117
+ // TODO(jeff): Do we need to allow clients to pass in alignment
118
+ // requirements?
119
+
120
+ if (num_elements > (std::numeric_limits<size_t>::max() / sizeof(T))) {
121
+ return NULL;
122
+ }
123
+
124
+ void* p = AllocateRaw(kAllocatorAlignment, sizeof(T) * num_elements,
125
+ allocation_attr);
126
+ T* typed_p = reinterpret_cast<T*>(p);
127
+ if (typed_p) RunCtor<T>(typed_p, num_elements);
128
+ return typed_p;
129
+ }
130
+
131
+ template <typename T>
132
+ void Deallocate(T* ptr, size_t num_elements) {
133
+ if (ptr) {
134
+ RunDtor<T>(ptr, num_elements);
135
+ DeallocateRaw(ptr);
136
+ }
137
+ }
138
+
139
+ // Returns true if this allocator tracks the sizes of allocations.
140
+ // RequestedSize and AllocatedSize must be overridden if
141
+ // TracksAllocationSizes is overridden to return true.
142
+ virtual bool TracksAllocationSizes() { return false; }
143
+
144
+ // Returns true if this allocator requires tensors with 0 elements
145
+ // to allocate buffers. This is false for most allocators, but may
146
+ // be used by special-case allocators that want to track tensor
147
+ // usage.
148
+ virtual bool ShouldAllocateEmptyTensors() { return false; }
149
+
150
+ // Returns the user-requested size of the data allocated at
151
+ // 'ptr'. Note that the actual buffer allocated might be larger
152
+ // than requested, but this function returns the size requested by
153
+ // the user.
154
+ //
155
+ // REQUIRES: TracksAllocationSizes() is true.
156
+ //
157
+ // REQUIRES: 'ptr!=nullptr' and points to a buffer previously
158
+ // allocated by this allocator.
159
+ virtual size_t RequestedSize(void* ptr) {
160
+ CHECK(false) << "allocator doesn't track sizes";
161
+ return size_t(0);
162
+ }
163
+
164
+ // Returns the allocated size of the buffer at 'ptr' if known,
165
+ // otherwise returns RequestedSize(ptr). AllocatedSize(ptr) is
166
+ // guaranteed to be >= RequestedSize(ptr).
167
+ //
168
+ // REQUIRES: TracksAllocationSizes() is true.
169
+ //
170
+ // REQUIRES: 'ptr!=nullptr' and points to a buffer previously
171
+ // allocated by this allocator.
172
+ virtual size_t AllocatedSize(void* ptr) { return RequestedSize(ptr); }
173
+
174
+ // Returns either 0 or an identifier assigned to the buffer at 'ptr'
175
+ // when the buffer was returned by AllocateRaw. If non-zero, the
176
+ // identifier differs from every other ID assigned by this
177
+ // allocator.
178
+ //
179
+ // REQUIRES: TracksAllocationSizes() is true.
180
+ //
181
+ // REQUIRES: 'ptr!=nullptr' and points to a buffer previously
182
+ // allocated by this allocator.
183
+ virtual int64 AllocationId(void* ptr) { return 0; }
184
+
185
+ // Returns the allocated size of the buffer at 'ptr' if known,
186
+ // otherwise returns 0. This method can be called when
187
+ // TracksAllocationSizes() is false, but can be extremely slow.
188
+ //
189
+ // REQUIRES: 'ptr!=nullptr' and points to a buffer previously
190
+ // allocated by this allocator.
191
+ virtual size_t AllocatedSizeSlow(void* ptr) {
192
+ if (TracksAllocationSizes()) {
193
+ return AllocatedSize(ptr);
194
+ }
195
+ return 0;
196
+ }
197
+
198
+ // Fills in 'stats' with statistics collected by this allocator.
199
+ virtual void GetStats(AllocatorStats* stats) { stats->Clear(); }
200
+
201
+ private:
202
+ // No constructors or destructors are run for simple types
203
+ template <typename T>
204
+ void RunCtor(T* p, size_t n) {
205
+ static_assert(is_simple_type<T>::value, "T is not a simple type.");
206
+ }
207
+
208
+ template <typename T>
209
+ void RunDtor(T* p, size_t n) {}
210
+
211
+ // custom constructors and destructors that can be overridden for
212
+ // non-standard allocators
213
+
214
+ // Runs string's default constructor for p[0], p[1], ..., p[n-1].
215
+ virtual void RunStringCtor(string* p, size_t n) {
216
+ for (size_t i = 0; i < n; ++p, ++i) new (p) string();
217
+ }
218
+
219
+ // Runs string's default destructor for p[0], p[1], ..., p[n-1].
220
+ virtual void RunStringDtor(string* p, size_t n) {
221
+ for (size_t i = 0; i < n; ++p, ++i) p->~string();
222
+ }
223
+
224
+ virtual void RunResourceCtor(ResourceHandle* p, size_t n) {
225
+ for (size_t i = 0; i < n; ++p, ++i) new (p) ResourceHandle();
226
+ }
227
+
228
+ // Runs string's default destructor for p[0], p[1], ..., p[n-1].
229
+ virtual void RunResourceDtor(ResourceHandle* p, size_t n) {
230
+ for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
231
+ }
232
+
233
+ virtual void RunVariantCtor(Variant* p, size_t n) {
234
+ for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
235
+ }
236
+
237
+ virtual void RunVariantDtor(Variant* p, size_t n) {
238
+ for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
239
+ }
240
+
241
+ // TODO(jeff): Maybe provide some interface to give info about
242
+ // current allocation state (total number of bytes available for
243
+ // allocation, number of bytes free on device, etc.)
244
+ };
245
+
246
+ // Allocator-specific constructors and destructors are used for
247
+ // strings
248
+ template <>
249
+ inline void Allocator::RunCtor(string* p, size_t n) {
250
+ RunStringCtor(p, n);
251
+ }
252
+
253
+ template <>
254
+ inline void Allocator::RunDtor(string* p, size_t n) {
255
+ RunStringDtor(p, n);
256
+ }
257
+
258
+ template <>
259
+ inline void Allocator::RunCtor(ResourceHandle* p, size_t n) {
260
+ RunResourceCtor(p, n);
261
+ }
262
+
263
+ template <>
264
+ inline void Allocator::RunDtor(ResourceHandle* p, size_t n) {
265
+ RunResourceDtor(p, n);
266
+ }
267
+
268
+ template <>
269
+ inline void Allocator::RunCtor(Variant* p, size_t n) {
270
+ RunVariantCtor(p, n);
271
+ }
272
+
273
+ template <>
274
+ inline void Allocator::RunDtor(Variant* p, size_t n) {
275
+ RunVariantDtor(p, n);
276
+ }
277
+
278
+ // An implementation of Allocator that delegates all calls to another Allocator.
279
+ //
280
+ // Useful to clients who want to override part of the functionality of another
281
+ // allocator.
282
+ class AllocatorWrapper : public Allocator {
283
+ public:
284
+ explicit AllocatorWrapper(Allocator* wrapped) : wrapped_(wrapped) {}
285
+
286
+ ~AllocatorWrapper() override {}
287
+
288
+ // Returns the wrapped allocator to which all calls are delegated.
289
+ Allocator* wrapped() const { return wrapped_; }
290
+
291
+ string Name() override { return wrapped_->Name(); }
292
+
293
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
294
+ return wrapped_->AllocateRaw(alignment, num_bytes);
295
+ }
296
+
297
+ void* AllocateRaw(size_t alignment, size_t num_bytes,
298
+ const AllocationAttributes& allocation_attr) override {
299
+ return wrapped_->AllocateRaw(alignment, num_bytes, allocation_attr);
300
+ }
301
+
302
+ void DeallocateRaw(void* ptr) override { wrapped_->DeallocateRaw(ptr); }
303
+
304
+ bool TracksAllocationSizes() override {
305
+ return wrapped_->TracksAllocationSizes();
306
+ }
307
+
308
+ bool ShouldAllocateEmptyTensors() override {
309
+ return wrapped_->TracksAllocationSizes();
310
+ }
311
+
312
+ size_t RequestedSize(void* ptr) override {
313
+ return wrapped_->RequestedSize(ptr);
314
+ }
315
+
316
+ size_t AllocatedSize(void* ptr) override {
317
+ return wrapped_->AllocatedSize(ptr);
318
+ }
319
+
320
+ int64 AllocationId(void* ptr) override { return wrapped_->AllocationId(ptr); }
321
+
322
+ size_t AllocatedSizeSlow(void* ptr) override {
323
+ return wrapped_->AllocatedSizeSlow(ptr);
324
+ }
325
+
326
+ private:
327
+ Allocator* const wrapped_;
328
+ };
329
+
330
+ // A tensorflow Op may need access to different kinds of memory that
331
+ // are not simply a function of the device to which the Op has been
332
+ // assigned. For example, an Op executing on a GPU may still need
333
+ // to allocate CPU RAM for some purpose. Internal to the tensorflow
334
+ // runtime we may choose to allocate CPU ram from special regions
335
+ // that have been prepared for higher performance in some use
336
+ // contexts, e.g. doing DMA with particular devices. For these
337
+ // reasons, the Device interface does not expose just one memory
338
+ // Allocator, but instead provides an accessor that takes a
339
+ // specification of the desired memory attributes in order to select
340
+ // an Allocator.
341
+ //
342
+ // Example use:
343
+ // // Allocator for ordinary device memory:
344
+ // Allocator* a = allocator(AllocatorAttributes());
345
+ // ...
346
+ // // Allocator for CPU RAM, regardless of where Op is executing:
347
+ // AllocatorAttributes attr;
348
+ // attr.set_on_host(true);
349
+ // Allocator* a = allocator(attr);
350
+ struct AllocatorAttributes {
351
+ void set_on_host(bool v) { value |= (static_cast<int>(v)); }
352
+ bool on_host() const { return value & 0x1; }
353
+ void set_nic_compatible(bool v) { value |= (static_cast<int>(v) << 1); }
354
+ bool nic_compatible() const { return value & (0x1 << 1); }
355
+ void set_gpu_compatible(bool v) { value |= (static_cast<int>(v) << 2); }
356
+ bool gpu_compatible() const { return value & (0x1 << 2); }
357
+ void Merge(AllocatorAttributes other) { value |= other.value; }
358
+ // Returns true if the fields set in *this is a subset of or equal to
359
+ // those set in other.
360
+ bool IsEqualOrLessRestrictiveThan(const AllocatorAttributes& other) const {
361
+ return (value | other.value) == other.value;
362
+ }
363
+
364
+ // NOTE: The upper 8 bits of the value are reserved for
365
+ // device-specific uses. Implementors of a device can interpret these
366
+ // upper 8 bits in device-specific ways, and ops implemented for those
367
+ // devices are responsible for setting those 8 bits appropriately.
368
+ uint32 value = 0;
369
+ };
370
+
371
+ // Returns a trivial implementation of Allocator which uses the system
372
+ // default malloc. The returned allocator is a process singleton.
373
+ Allocator* cpu_allocator();
374
+
375
+ // If 'enable' is true, the process-wide cpu allocator collects
376
+ // AllocatorStats. By default, it's disabled.
377
+ void EnableCPUAllocatorStats(bool enable);
378
+
379
+ // If 'enable' is true, the process-wide cpu allocator collects full
380
+ // statistics. By default, it's disabled.
381
+ void EnableCPUAllocatorFullStats(bool enable);
382
+
383
+ // Abstract interface of an object that does the underlying suballoc/free of
384
+ // memory for a higher-level allocator.
385
+ class SubAllocator {
386
+ public:
387
+ virtual ~SubAllocator() {}
388
+ virtual void* Alloc(size_t alignment, size_t num_bytes) = 0;
389
+ virtual void Free(void* ptr, size_t num_bytes) = 0;
390
+ };
391
+
392
+ } // namespace tensorflow
393
+
394
+ #endif // TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
allocator_registry.cc ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include <string>
17
+
18
+ #include "tensorflow/core/framework/allocator_registry.h"
19
+ #include "tensorflow/core/platform/logging.h"
20
+
21
+ namespace tensorflow {
22
+
23
+ // static
24
+ AllocatorRegistry* AllocatorRegistry::Global() {
25
+ static AllocatorRegistry* global_allocator_registry = new AllocatorRegistry;
26
+ return global_allocator_registry;
27
+ }
28
+
29
+ Allocator* AllocatorRegistry::GetRegisteredAllocator(const string& name,
30
+ int priority) {
31
+ for (auto entry : allocators_) {
32
+ if (!name.compare(entry.name) && priority == entry.priority) {
33
+ return entry.allocator;
34
+ }
35
+ }
36
+ return nullptr;
37
+ }
38
+
39
+ void AllocatorRegistry::Register(const string& name, int priority,
40
+ Allocator* allocator) {
41
+ CHECK(!name.empty()) << "Need a valid name for Allocator";
42
+ CHECK_GE(priority, 0) << "Priority needs to be non-negative";
43
+
44
+ Allocator* existing = GetRegisteredAllocator(name, priority);
45
+ if (existing != nullptr) {
46
+ // A duplicate is if the registration name and priority match
47
+ // but the Allocator::Name()'s don't match.
48
+ CHECK_EQ(existing->Name(), allocator->Name())
49
+ << "Allocator with name: [" << name << "], type [" << existing->Name()
50
+ << "], priority: [" << priority
51
+ << "] already registered. Choose a different name to register "
52
+ << "an allocator of type " << allocator->Name();
53
+
54
+ // The allocator names match, so we can just return.
55
+ // It should be safe to delete the allocator since the caller
56
+ // gives up ownership of it.
57
+ delete allocator;
58
+ return;
59
+ }
60
+
61
+ AllocatorRegistryEntry tmp_entry;
62
+ tmp_entry.name = name;
63
+ tmp_entry.priority = priority;
64
+ tmp_entry.allocator = allocator;
65
+
66
+ allocators_.push_back(tmp_entry);
67
+ int high_pri = -1;
68
+ for (auto entry : allocators_) {
69
+ if (high_pri < entry.priority) {
70
+ m_curr_allocator_ = entry.allocator;
71
+ high_pri = entry.priority;
72
+ }
73
+ }
74
+ }
75
+
76
+ Allocator* AllocatorRegistry::GetAllocator() {
77
+ return CHECK_NOTNULL(m_curr_allocator_);
78
+ }
79
+
80
+ } // namespace tensorflow
allocator_registry.h ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ // Classes to maintain a static registry of memory allocators
17
+ #ifndef TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_
18
+ #define TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_
19
+
20
+ #include <string>
21
+ #include <vector>
22
+
23
+ #include "tensorflow/core/framework/allocator.h"
24
+
25
+ namespace tensorflow {
26
+
27
+ // A global AllocatorRegistry is used to hold allocators for CPU backends
28
+ class AllocatorRegistry {
29
+ public:
30
+ // Add an allocator to the registry. Caller releases ownership of
31
+ // 'allocator'.
32
+ void Register(const string& name, int priority, Allocator* allocator);
33
+
34
+ // Return allocator with highest priority
35
+ // If multiple allocators have the same high priority, return one of them
36
+ Allocator* GetAllocator();
37
+
38
+ // Returns the global registry of allocators.
39
+ static AllocatorRegistry* Global();
40
+
41
+ private:
42
+ typedef struct {
43
+ string name;
44
+ int priority;
45
+ Allocator* allocator; // not owned
46
+ } AllocatorRegistryEntry;
47
+
48
+ // Returns the Allocator registered for 'name' and 'priority',
49
+ // or 'nullptr' if not found.
50
+ Allocator* GetRegisteredAllocator(const string& name, int priority);
51
+
52
+ std::vector<AllocatorRegistryEntry> allocators_;
53
+ Allocator* m_curr_allocator_; // not owned
54
+ };
55
+
56
+ namespace allocator_registration {
57
+
58
+ class AllocatorRegistration {
59
+ public:
60
+ AllocatorRegistration(const string& name, int priority,
61
+ Allocator* allocator) {
62
+ AllocatorRegistry::Global()->Register(name, priority, allocator);
63
+ }
64
+ };
65
+
66
+ } // namespace allocator_registration
67
+
68
+ #define REGISTER_MEM_ALLOCATOR(name, priority, allocator) \
69
+ REGISTER_MEM_ALLOCATOR_UNIQ_HELPER(__COUNTER__, name, priority, allocator)
70
+
71
+ #define REGISTER_MEM_ALLOCATOR_UNIQ_HELPER(ctr, name, priority, allocator) \
72
+ REGISTER_MEM_ALLOCATOR_UNIQ(ctr, name, priority, allocator)
73
+
74
+ #define REGISTER_MEM_ALLOCATOR_UNIQ(ctr, name, priority, allocator) \
75
+ static allocator_registration::AllocatorRegistration \
76
+ register_allocator_##ctr(name, priority, new allocator)
77
+
78
+ } // namespace tensorflow
79
+
80
+ #endif // TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_
allocator_test.cc ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/allocator.h"
17
+
18
+ #include <algorithm>
19
+ #include <vector>
20
+
21
+ #include "tensorflow/core/platform/logging.h"
22
+ #include "tensorflow/core/platform/test.h"
23
+ #include "tensorflow/core/platform/test_benchmark.h"
24
+
25
+ namespace tensorflow {
26
+
27
+ static void CheckStats(Allocator* a, int64 num_allocs, int64 bytes_in_use,
28
+ int64 max_bytes_in_use, int64 max_alloc_size) {
29
+ AllocatorStats stats;
30
+ a->GetStats(&stats);
31
+ LOG(INFO) << "Alloc stats: \n" << stats.DebugString();
32
+ #if defined(PLATFORM_GOOGLE) && defined(NDEBUG)
33
+ // NOTE: allocator stats expectation depends on the system malloc,
34
+ // and can vary as that changes.
35
+ static const int64 kSlop = 5 * 1024;
36
+ EXPECT_GT(stats.bytes_in_use, bytes_in_use - kSlop);
37
+ EXPECT_LT(stats.bytes_in_use, bytes_in_use + kSlop);
38
+ EXPECT_GT(stats.max_bytes_in_use, max_bytes_in_use - kSlop);
39
+ EXPECT_LT(stats.max_bytes_in_use, max_bytes_in_use + kSlop);
40
+ EXPECT_EQ(stats.num_allocs, num_allocs);
41
+ EXPECT_EQ(stats.max_alloc_size, max_alloc_size);
42
+ #endif
43
+ }
44
+
45
+ TEST(AllocatorAttributesTest, AllCombos) {
46
+ for (bool on_host : {false, true}) {
47
+ for (bool nic_compatible : {false, true}) {
48
+ for (bool gpu_compatible : {false, true}) {
49
+ AllocatorAttributes aa;
50
+ aa.set_on_host(on_host);
51
+ aa.set_nic_compatible(nic_compatible);
52
+ aa.set_gpu_compatible(gpu_compatible);
53
+ EXPECT_EQ(on_host, aa.on_host());
54
+ EXPECT_EQ(nic_compatible, aa.nic_compatible());
55
+ EXPECT_EQ(gpu_compatible, aa.gpu_compatible());
56
+ }
57
+ }
58
+ }
59
+ }
60
+
61
+ TEST(AllocatorAttributesTest, IsEqualOrLessRestrictiveThan) {
62
+ AllocatorAttributes a, b;
63
+ EXPECT_TRUE(a.IsEqualOrLessRestrictiveThan(b));
64
+ EXPECT_TRUE(a.IsEqualOrLessRestrictiveThan(a));
65
+ EXPECT_TRUE(b.IsEqualOrLessRestrictiveThan(b));
66
+
67
+ b.set_gpu_compatible(true);
68
+ // The set of flags in b is not a subset of those in a.
69
+ EXPECT_TRUE(a.IsEqualOrLessRestrictiveThan(b));
70
+ EXPECT_FALSE(b.IsEqualOrLessRestrictiveThan(a));
71
+ EXPECT_TRUE(a.IsEqualOrLessRestrictiveThan(a));
72
+ EXPECT_TRUE(b.IsEqualOrLessRestrictiveThan(b));
73
+
74
+ a.set_nic_compatible(true);
75
+ // Neither a nor b is a subset of the other.
76
+ EXPECT_FALSE(a.IsEqualOrLessRestrictiveThan(b));
77
+ EXPECT_FALSE(b.IsEqualOrLessRestrictiveThan(a));
78
+
79
+ a.set_gpu_compatible(true);
80
+ // The set of flags in b is a proper subset of those in a.
81
+ EXPECT_TRUE(b.IsEqualOrLessRestrictiveThan(a));
82
+ EXPECT_FALSE(a.IsEqualOrLessRestrictiveThan(b));
83
+ }
84
+
85
+ TEST(CPUAllocatorTest, Simple) {
86
+ EnableCPUAllocatorStats(true);
87
+ Allocator* a = cpu_allocator();
88
+ std::vector<void*> ptrs;
89
+ for (int s = 1; s < 1024; s++) {
90
+ void* raw = a->AllocateRaw(1, s);
91
+ ptrs.push_back(raw);
92
+ }
93
+ std::sort(ptrs.begin(), ptrs.end());
94
+ CheckStats(a, 1023, 552640, 552640, 1024);
95
+ for (size_t i = 0; i < ptrs.size(); i++) {
96
+ if (i > 0) {
97
+ CHECK_NE(ptrs[i], ptrs[i - 1]); // No dups
98
+ }
99
+ a->DeallocateRaw(ptrs[i]);
100
+ }
101
+ CheckStats(a, 1023, 0, 552640, 1024);
102
+ float* t1 = a->Allocate<float>(1024);
103
+ double* t2 = a->Allocate<double>(1048576);
104
+ CheckStats(a, 1025, 1048576 * sizeof(double) + 1024 * sizeof(float),
105
+ 1048576 * sizeof(double) + 1024 * sizeof(float),
106
+ 1048576 * sizeof(double));
107
+
108
+ a->Deallocate(t1, 1024);
109
+ a->Deallocate(t2, 1048576);
110
+
111
+ CheckStats(a, 1025, 0, 1048576 * sizeof(double) + 1024 * sizeof(float),
112
+ 1048576 * sizeof(double));
113
+ EnableCPUAllocatorStats(false);
114
+ }
115
+
116
+ // Define a struct that we will use to observe behavior in the unit tests
117
+ struct TestStruct {
118
+ int x; // not used just want to make sure sizeof(TestStruct) > 1
119
+ };
120
+
121
+ TEST(CPUAllocatorTest, CheckStructSize) { CHECK_GT(sizeof(TestStruct), 1); }
122
+
123
+ TEST(CPUAllocatorTest, AllocateOverflowMaxSizeT) {
124
+ Allocator* a = cpu_allocator();
125
+
126
+ // The maximum size_t value will definitely overflow.
127
+ size_t count_to_allocate = std::numeric_limits<size_t>::max();
128
+ TestStruct* const test_pointer = a->Allocate<TestStruct>(count_to_allocate);
129
+
130
+ CHECK_EQ(test_pointer, reinterpret_cast<TestStruct*>(NULL));
131
+ }
132
+
133
+ TEST(CPUAllocatorTest, AllocateOverflowSmallest) {
134
+ Allocator* a = cpu_allocator();
135
+
136
+ // count_to_allocate is the smallest count that will cause overflow.
137
+ const size_t count_to_allocate =
138
+ (std::numeric_limits<size_t>::max() / sizeof(TestStruct)) + 1;
139
+ TestStruct* const test_pointer = a->Allocate<TestStruct>(count_to_allocate);
140
+
141
+ CHECK_EQ(test_pointer, reinterpret_cast<TestStruct*>(NULL));
142
+ }
143
+
144
+ TEST(CPUAllocatorTest, Sizes) {
145
+ Allocator* a = cpu_allocator();
146
+
147
+ EXPECT_EQ(false, a->TracksAllocationSizes());
148
+ }
149
+
150
+ namespace {
151
+
152
+ AllocatorAttributes DeviceAllocatorAttribute() {
153
+ AllocatorAttributes attr;
154
+ attr.value |= (0x1 << 24);
155
+ return attr;
156
+ }
157
+
158
+ bool HasDeviceAllocatorAttribute(const AllocatorAttributes& attr) {
159
+ return attr.value & (0x1 << 24);
160
+ }
161
+
162
+ } // namespace
163
+
164
+ TEST(CustomAllocatorAttributes, TestSetterAndGetter) {
165
+ AllocatorAttributes attr = DeviceAllocatorAttribute();
166
+ EXPECT_TRUE(HasDeviceAllocatorAttribute(attr));
167
+ EXPECT_FALSE(HasDeviceAllocatorAttribute(AllocatorAttributes()));
168
+ }
169
+
170
+ static void BM_Allocation(int iters, int arg) {
171
+ Allocator* a = cpu_allocator();
172
+ // Exercise a few different allocation sizes
173
+ std::vector<int> sizes = {256, 4096, 16384, 524288, 512, 1048576};
174
+ int size_index = 0;
175
+
176
+ if (arg) EnableCPUAllocatorStats(true);
177
+ while (--iters > 0) {
178
+ int bytes = sizes[size_index++ % sizes.size()];
179
+ void* p = a->AllocateRaw(1, bytes);
180
+ a->DeallocateRaw(p);
181
+ }
182
+ if (arg) EnableCPUAllocatorStats(false);
183
+ }
184
+ BENCHMARK(BM_Allocation)->Arg(0)->Arg(1);
185
+
186
+ } // namespace tensorflow
api_def.proto ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Defines the text format for including per-op API definition and
2
+ // overrides for client language op code generators.
3
+
4
+ syntax = "proto3";
5
+
6
+ package tensorflow;
7
+ option cc_enable_arenas = true;
8
+ option java_outer_classname = "ApiDefProtos";
9
+ option java_multiple_files = true;
10
+ option java_package = "org.tensorflow.framework";
11
+ import "tensorflow/core/framework/attr_value.proto";
12
+
13
+ // Used to specify and override the default API & behavior in the
14
+ // generated code for client languages, from what you would get from
15
+ // the OpDef alone. There will be a set of ApiDefs that are common
16
+ // to all client languages, and another set per client language.
17
+ // The per-client-language ApiDefs will inherit values from the
18
+ // common ApiDefs which it can either replace or modify.
19
+ //
20
+ // We separate the API definition from the OpDef so we can evolve the
21
+ // API while remaining backwards compatible when interpretting old
22
+ // graphs. Overrides go in an "api_def.pbtxt" file with a text-format
23
+ // ApiDefs message.
24
+ //
25
+ // WARNING: Be *very* careful changing the API for any existing op --
26
+ // you can change the semantics of existing code. These changes may
27
+ // need to wait until a major release of TensorFlow to avoid breaking
28
+ // our compatibility promises.
29
+ message ApiDef {
30
+ // Name of the op (in the OpDef) to specify the API for.
31
+ string graph_op_name = 1;
32
+
33
+ enum Visibility {
34
+ // Normally this is "VISIBLE" unless you are inheriting a
35
+ // different value from another ApiDef.
36
+ DEFAULT_VISIBILITY = 0;
37
+ // Publicly visible in the API.
38
+ VISIBLE = 1;
39
+ // Do not include this op in the generated API. If visibility is
40
+ // set to 'SKIP', other fields are ignored for this op.
41
+ SKIP = 2;
42
+ // Hide this op by putting it into an internal namespace (or whatever
43
+ // is appropriate in the target language).
44
+ HIDDEN = 3;
45
+ }
46
+ Visibility visibility = 2;
47
+
48
+ // If you specify any endpoint, this will replace all of the
49
+ // inherited endpoints. The first endpoint should be the
50
+ // "canonical" endpoint, and should not be deprecated (unless all
51
+ // endpoints are deprecated).
52
+ message Endpoint {
53
+ // Name should be either like "CamelCaseName" or
54
+ // "Package.CamelCaseName". Client-language-specific ApiDefs may
55
+ // use a snake_case convention instead of CamelCase.
56
+ string name = 1;
57
+
58
+ // First GraphDef version at which the op is disallowed.
59
+ int32 deprecation_version = 2;
60
+ }
61
+ repeated Endpoint endpoint = 3;
62
+
63
+ message Arg {
64
+ string name = 1;
65
+
66
+ // Change the name used to access this arg in the API from what
67
+ // is used in the GraphDef. Note that these names in `backticks`
68
+ // will also be replaced in the summary & description fields.
69
+ string rename_to = 2;
70
+
71
+ // Note: this will replace any inherited arg doc. There is no
72
+ // current way of modifying arg descriptions (other than replacing
73
+ // them entirely) as can be done with op descriptions.
74
+ string description = 3;
75
+ }
76
+ repeated Arg in_arg = 4;
77
+ repeated Arg out_arg = 5;
78
+ // List of original in_arg names to specify new argument order.
79
+ // Length of arg_order should be either empty to keep current order
80
+ // or match size of in_arg.
81
+ repeated string arg_order = 11;
82
+
83
+ // Description of the graph-construction-time configuration of this
84
+ // Op. That is to say, this describes the attr fields that will
85
+ // be specified in the NodeDef.
86
+ message Attr {
87
+ string name = 1;
88
+
89
+ // Change the name used to access this attr in the API from what
90
+ // is used in the GraphDef. Note that these names in `backticks`
91
+ // will also be replaced in the summary & description fields.
92
+ string rename_to = 2;
93
+
94
+ // Specify a new default value to use for this attr. This default
95
+ // will be used when creating new graphs, as opposed to the
96
+ // default in the OpDef, which will be used when interpreting old
97
+ // GraphDefs.
98
+ AttrValue default_value = 3;
99
+
100
+ // Note: this will replace any inherited attr doc, there is no current
101
+ // way of modifying attr descriptions as can be done with op descriptions.
102
+ string description = 4;
103
+ }
104
+ repeated Attr attr = 6;
105
+
106
+ // One-line human-readable description of what the Op does.
107
+ string summary = 7;
108
+
109
+ // Additional, longer human-readable description of what the Op does.
110
+ string description = 8;
111
+
112
+ // Modify an existing/inherited description by adding text to the beginning
113
+ // or end.
114
+ string description_prefix = 9;
115
+ string description_suffix = 10;
116
+ }
117
+
118
+ message ApiDefs {
119
+ repeated ApiDef op = 1;
120
+ }
attr_value.proto ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package tensorflow;
4
+ option cc_enable_arenas = true;
5
+ option java_outer_classname = "AttrValueProtos";
6
+ option java_multiple_files = true;
7
+ option java_package = "org.tensorflow.framework";
8
+
9
+ import "tensorflow/core/framework/tensor.proto";
10
+ import "tensorflow/core/framework/tensor_shape.proto";
11
+ import "tensorflow/core/framework/types.proto";
12
+
13
+ // Protocol buffer representing the value for an attr used to configure an Op.
14
+ // Comment indicates the corresponding attr type. Only the field matching the
15
+ // attr type may be filled.
16
+ message AttrValue {
17
+ // LINT.IfChange
18
+ message ListValue {
19
+ repeated bytes s = 2; // "list(string)"
20
+ repeated int64 i = 3 [packed = true]; // "list(int)"
21
+ repeated float f = 4 [packed = true]; // "list(float)"
22
+ repeated bool b = 5 [packed = true]; // "list(bool)"
23
+ repeated DataType type = 6 [packed = true]; // "list(type)"
24
+ repeated TensorShapeProto shape = 7; // "list(shape)"
25
+ repeated TensorProto tensor = 8; // "list(tensor)"
26
+ repeated NameAttrList func = 9; // "list(attr)"
27
+ }
28
+ // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc)
29
+
30
+ oneof value {
31
+ bytes s = 2; // "string"
32
+ int64 i = 3; // "int"
33
+ float f = 4; // "float"
34
+ bool b = 5; // "bool"
35
+ DataType type = 6; // "type"
36
+ TensorShapeProto shape = 7; // "shape"
37
+ TensorProto tensor = 8; // "tensor"
38
+ ListValue list = 1; // any "list(...)"
39
+
40
+ // "func" represents a function. func.name is a function's name or
41
+ // a primitive op's name. func.attr.first is the name of an attr
42
+ // defined for that function. func.attr.second is the value for
43
+ // that attr in the instantiation.
44
+ NameAttrList func = 10;
45
+
46
+ // This is a placeholder only used in nodes defined inside a
47
+ // function. It indicates the attr value will be supplied when
48
+ // the function is instantiated. For example, let us suppose a
49
+ // node "N" in function "FN". "N" has an attr "A" with value
50
+ // placeholder = "foo". When FN is instantiated with attr "foo"
51
+ // set to "bar", the instantiated node N's attr A will have been
52
+ // given the value "bar".
53
+ string placeholder = 9;
54
+ }
55
+ }
56
+
57
+ // A list of attr names and their values. The whole list is attached
58
+ // with a string name. E.g., MatMul[T=float].
59
+ message NameAttrList {
60
+ string name = 1;
61
+ map<string, AttrValue> attr = 2;
62
+ }
attr_value_util.cc ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/attr_value_util.h"
17
+
18
+ #include <string>
19
+ #include <vector>
20
+
21
+ #include "tensorflow/core/framework/attr_value.pb_text.h"
22
+ #include "tensorflow/core/framework/tensor.pb_text.h"
23
+ #include "tensorflow/core/framework/tensor_shape.pb.h"
24
+ #include "tensorflow/core/framework/types.h"
25
+ #include "tensorflow/core/framework/types.pb_text.h"
26
+ #include "tensorflow/core/lib/core/errors.h"
27
+ #include "tensorflow/core/lib/core/stringpiece.h"
28
+ #include "tensorflow/core/lib/hash/hash.h"
29
+ #include "tensorflow/core/lib/strings/str_util.h"
30
+ #include "tensorflow/core/platform/protobuf.h"
31
+
32
+ namespace tensorflow {
33
+ namespace {
34
+
35
+ string SummarizeString(const string& str) {
36
+ return strings::StrCat("\"", str_util::CEscape(str), "\"");
37
+ }
38
+
39
+ string SummarizeTensor(const TensorProto& tensor_proto) {
40
+ Tensor t;
41
+ if (!t.FromProto(tensor_proto)) {
42
+ return strings::StrCat(
43
+ "<Invalid TensorProto: ", ProtoShortDebugString(tensor_proto), ">");
44
+ }
45
+ return t.DebugString();
46
+ }
47
+
48
+ string SummarizeFunc(const NameAttrList& func) {
49
+ std::vector<string> entries;
50
+ for (auto p : func.attr()) {
51
+ entries.push_back(
52
+ strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
53
+ }
54
+ std::sort(entries.begin(), entries.end());
55
+ return strings::StrCat(func.name(), "[", str_util::Join(entries, ", "), "]");
56
+ }
57
+
58
+ } // namespace
59
+
60
+ string SummarizeAttrValue(const AttrValue& attr_value) {
61
+ switch (attr_value.value_case()) {
62
+ case AttrValue::kS:
63
+ return SummarizeString(attr_value.s());
64
+ case AttrValue::kI:
65
+ return strings::StrCat(attr_value.i());
66
+ case AttrValue::kF:
67
+ return strings::StrCat(attr_value.f());
68
+ case AttrValue::kB:
69
+ return attr_value.b() ? "true" : "false";
70
+ case AttrValue::kType:
71
+ return EnumName_DataType(attr_value.type());
72
+ case AttrValue::kShape:
73
+ return PartialTensorShape::DebugString(attr_value.shape());
74
+ case AttrValue::kTensor:
75
+ return SummarizeTensor(attr_value.tensor());
76
+ case AttrValue::kList: {
77
+ string ret = "[";
78
+ if (attr_value.list().s_size() > 0) {
79
+ for (int i = 0; i < attr_value.list().s_size(); ++i) {
80
+ if (i > 0) strings::StrAppend(&ret, ", ");
81
+ strings::StrAppend(&ret, SummarizeString(attr_value.list().s(i)));
82
+ }
83
+ } else if (attr_value.list().i_size() > 0) {
84
+ for (int i = 0; i < attr_value.list().i_size(); ++i) {
85
+ if (i > 0) strings::StrAppend(&ret, ", ");
86
+ strings::StrAppend(&ret, attr_value.list().i(i));
87
+ }
88
+ } else if (attr_value.list().f_size() > 0) {
89
+ for (int i = 0; i < attr_value.list().f_size(); ++i) {
90
+ if (i > 0) strings::StrAppend(&ret, ", ");
91
+ strings::StrAppend(&ret, attr_value.list().f(i));
92
+ }
93
+ } else if (attr_value.list().b_size() > 0) {
94
+ for (int i = 0; i < attr_value.list().b_size(); ++i) {
95
+ if (i > 0) strings::StrAppend(&ret, ", ");
96
+ strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false");
97
+ }
98
+ } else if (attr_value.list().type_size() > 0) {
99
+ for (int i = 0; i < attr_value.list().type_size(); ++i) {
100
+ if (i > 0) strings::StrAppend(&ret, ", ");
101
+ strings::StrAppend(&ret,
102
+ EnumName_DataType(attr_value.list().type(i)));
103
+ }
104
+ } else if (attr_value.list().shape_size() > 0) {
105
+ for (int i = 0; i < attr_value.list().shape_size(); ++i) {
106
+ if (i > 0) strings::StrAppend(&ret, ", ");
107
+ strings::StrAppend(
108
+ &ret, TensorShape::DebugString(attr_value.list().shape(i)));
109
+ }
110
+ } else if (attr_value.list().tensor_size() > 0) {
111
+ for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
112
+ if (i > 0) strings::StrAppend(&ret, ", ");
113
+ strings::StrAppend(&ret,
114
+ SummarizeTensor(attr_value.list().tensor(i)));
115
+ }
116
+ } else if (attr_value.list().func_size() > 0) {
117
+ for (int i = 0; i < attr_value.list().func_size(); ++i) {
118
+ if (i > 0) strings::StrAppend(&ret, ", ");
119
+ strings::StrAppend(&ret, SummarizeFunc(attr_value.list().func(i)));
120
+ }
121
+ }
122
+
123
+ strings::StrAppend(&ret, "]");
124
+ return ret;
125
+ }
126
+ case AttrValue::kFunc: {
127
+ return SummarizeFunc(attr_value.func());
128
+ }
129
+ case AttrValue::kPlaceholder:
130
+ return strings::StrCat("$", attr_value.placeholder());
131
+ case AttrValue::VALUE_NOT_SET:
132
+ return "<Unknown AttrValue type>";
133
+ }
134
+ return "<Unknown AttrValue type>"; // Prevent missing return warning
135
+ }
136
+
137
+ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
138
+ int num_set = 0;
139
+
140
+ #define VALIDATE_FIELD(name, type_string, oneof_case) \
141
+ do { \
142
+ if (attr_value.has_list()) { \
143
+ if (attr_value.list().name##_size() > 0) { \
144
+ if (type != "list(" type_string ")") { \
145
+ return errors::InvalidArgument( \
146
+ "AttrValue had value with type 'list(" type_string ")' when '", \
147
+ type, "' expected"); \
148
+ } \
149
+ ++num_set; \
150
+ } \
151
+ } else if (attr_value.value_case() == AttrValue::oneof_case) { \
152
+ if (type != type_string) { \
153
+ return errors::InvalidArgument( \
154
+ "AttrValue had value with type '" type_string "' when '", type, \
155
+ "' expected"); \
156
+ } \
157
+ ++num_set; \
158
+ } \
159
+ } while (false)
160
+
161
+ VALIDATE_FIELD(s, "string", kS);
162
+ VALIDATE_FIELD(i, "int", kI);
163
+ VALIDATE_FIELD(f, "float", kF);
164
+ VALIDATE_FIELD(b, "bool", kB);
165
+ VALIDATE_FIELD(type, "type", kType);
166
+ VALIDATE_FIELD(shape, "shape", kShape);
167
+ VALIDATE_FIELD(tensor, "tensor", kTensor);
168
+ VALIDATE_FIELD(func, "func", kFunc);
169
+
170
+ #undef VALIDATE_FIELD
171
+
172
+ if (attr_value.value_case() == AttrValue::kPlaceholder) {
173
+ return errors::InvalidArgument(
174
+ "AttrValue had value with unexpected type 'placeholder'");
175
+ }
176
+
177
+ // If the attr type is 'list', we expect attr_value.has_list() to be
178
+ // true. However, proto3's attr_value.has_list() can be false when
179
+ // set to an empty list for GraphDef versions <= 4. So we simply
180
+ // check if has_list is false and some other field in attr_value is
181
+ // set to flag the error. This test can be made more strict once
182
+ // support for GraphDef versions <= 4 is dropped.
183
+ if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) {
184
+ if (num_set) {
185
+ return errors::InvalidArgument(
186
+ "AttrValue missing value with expected type '", type, "'");
187
+ } else {
188
+ // Indicate that we have a list, but an empty one.
189
+ ++num_set;
190
+ }
191
+ }
192
+
193
+ // Okay to have an empty list, but not to be missing a non-list value.
194
+ if (num_set == 0 && !StringPiece(type).starts_with("list(")) {
195
+ return errors::InvalidArgument(
196
+ "AttrValue missing value with expected type '", type, "'");
197
+ }
198
+
199
+ // Ref types and DT_INVALID are illegal, and DataTypes must
200
+ // be a valid enum type.
201
+ if (type == "type") {
202
+ if (!DataType_IsValid(attr_value.type())) {
203
+ return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
204
+ attr_value.type());
205
+ }
206
+ if (IsRefType(attr_value.type())) {
207
+ return errors::InvalidArgument(
208
+ "AttrValue must not have reference type value of ",
209
+ DataTypeString(attr_value.type()));
210
+ }
211
+ if (attr_value.type() == DT_INVALID) {
212
+ return errors::InvalidArgument("AttrValue has invalid DataType");
213
+ }
214
+ } else if (type == "list(type)") {
215
+ for (auto as_int : attr_value.list().type()) {
216
+ const DataType dtype = static_cast<DataType>(as_int);
217
+ if (!DataType_IsValid(dtype)) {
218
+ return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
219
+ as_int);
220
+ }
221
+ if (IsRefType(dtype)) {
222
+ return errors::InvalidArgument(
223
+ "AttrValue must not have reference type value of ",
224
+ DataTypeString(dtype));
225
+ }
226
+ if (dtype == DT_INVALID) {
227
+ return errors::InvalidArgument("AttrValue contains invalid DataType");
228
+ }
229
+ }
230
+ }
231
+
232
+ return Status::OK();
233
+ }
234
+
235
+ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
236
+ // Parse type.
237
+ string field_name;
238
+ bool is_list = type.Consume("list(");
239
+ if (type.Consume("string")) {
240
+ field_name = "s";
241
+ } else if (type.Consume("int")) {
242
+ field_name = "i";
243
+ } else if (type.Consume("float")) {
244
+ field_name = "f";
245
+ } else if (type.Consume("bool")) {
246
+ field_name = "b";
247
+ } else if (type.Consume("type")) {
248
+ field_name = "type";
249
+ } else if (type.Consume("shape")) {
250
+ field_name = "shape";
251
+ } else if (type.Consume("tensor")) {
252
+ field_name = "tensor";
253
+ } else if (type.Consume("func")) {
254
+ field_name = "func";
255
+ } else if (type.Consume("placeholder")) {
256
+ field_name = "placeholder";
257
+ } else {
258
+ return false;
259
+ }
260
+ if (is_list && !type.Consume(")")) {
261
+ return false;
262
+ }
263
+
264
+ // Construct a valid text proto message to parse.
265
+ string to_parse;
266
+ if (is_list) {
267
+ // TextFormat parser considers "i: 7" to be the same as "i: [7]",
268
+ // but we only want to allow list values with [].
269
+ StringPiece cleaned = text;
270
+ str_util::RemoveLeadingWhitespace(&cleaned);
271
+ str_util::RemoveTrailingWhitespace(&cleaned);
272
+ if (cleaned.size() < 2 || cleaned[0] != '[' ||
273
+ cleaned[cleaned.size() - 1] != ']') {
274
+ return false;
275
+ }
276
+ cleaned.remove_prefix(1);
277
+ str_util::RemoveLeadingWhitespace(&cleaned);
278
+ if (cleaned.size() == 1) {
279
+ // User wrote "[]", so return empty list without invoking the TextFormat
280
+ // parse which returns an error for "i: []".
281
+ out->Clear();
282
+ out->mutable_list();
283
+ return true;
284
+ }
285
+ to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
286
+ } else {
287
+ to_parse = strings::StrCat(field_name, ": ", text);
288
+ }
289
+
290
+ return ProtoParseFromString(to_parse, out);
291
+ }
292
+
293
+ void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
294
+
295
+ #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
296
+ void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
297
+
298
+ #define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \
299
+ void SetAttrValue(ARG_TYPE value, AttrValue* out) { \
300
+ out->mutable_list()->Clear(); /* create list() even if value empty */ \
301
+ for (const auto& v : value) { \
302
+ out->mutable_list()->add_##FIELD(v); \
303
+ } \
304
+ }
305
+
306
+ #define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
307
+ DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
308
+ DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
309
+
310
+ DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
311
+ DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
312
+ DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
313
+ DEFINE_SET_ATTR_VALUE_BOTH(int64, i)
314
+ DEFINE_SET_ATTR_VALUE_BOTH(int32, i)
315
+ DEFINE_SET_ATTR_VALUE_BOTH(float, f)
316
+ DEFINE_SET_ATTR_VALUE_BOTH(double, f)
317
+ DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
318
+ DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
319
+ DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
320
+ DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
321
+
322
+ void SetAttrValue(StringPiece value, AttrValue* out) {
323
+ out->set_s(value.data(), value.size());
324
+ }
325
+
326
+ void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) {
327
+ out->mutable_list()->Clear(); // Create list() even if value empty.
328
+ for (const auto& v : value) {
329
+ out->mutable_list()->add_s(v.data(), v.size());
330
+ }
331
+ }
332
+
333
+ void SetAttrValue(const TensorShape& value, AttrValue* out) {
334
+ value.AsProto(out->mutable_shape());
335
+ }
336
+
337
+ void SetAttrValue(const TensorShapeProto& value, AttrValue* out) {
338
+ *out->mutable_shape() = value;
339
+ }
340
+
341
+ void SetAttrValue(const PartialTensorShape& value, AttrValue* out) {
342
+ value.AsProto(out->mutable_shape());
343
+ }
344
+
345
+ void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
346
+ out->mutable_list()->Clear(); // Create list() even if value empty.
347
+ for (const auto& v : value) {
348
+ v.AsProto(out->mutable_list()->add_shape());
349
+ }
350
+ }
351
+
352
+ void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) {
353
+ out->mutable_list()->Clear(); // Create list() even if value empty.
354
+ for (const auto& v : value) {
355
+ *out->mutable_list()->add_shape() = v;
356
+ }
357
+ }
358
+
359
+ void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,
360
+ AttrValue* out) {
361
+ out->mutable_list()->Clear(); // Create list() even if value empty.
362
+ for (const auto& v : value) {
363
+ v.AsProto(out->mutable_list()->add_shape());
364
+ }
365
+ }
366
+
367
+ void SetAttrValue(const Tensor& value, AttrValue* out) {
368
+ if (value.NumElements() > 1) {
369
+ value.AsProtoTensorContent(out->mutable_tensor());
370
+ } else {
371
+ value.AsProtoField(out->mutable_tensor());
372
+ }
373
+ }
374
+
375
+ void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
376
+ out->mutable_list()->Clear(); // Create list() even if value empty.
377
+ for (const auto& v : value) {
378
+ if (v.NumElements() > 1) {
379
+ v.AsProtoTensorContent(out->mutable_list()->add_tensor());
380
+ } else {
381
+ v.AsProtoField(out->mutable_list()->add_tensor());
382
+ }
383
+ }
384
+ }
385
+
386
+ void SetAttrValue(const TensorProto& value, AttrValue* out) {
387
+ *out->mutable_tensor() = value;
388
+ }
389
+
390
+ void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
391
+ out->mutable_list()->Clear(); // Create list() even if value empty.
392
+ for (const auto& v : value) {
393
+ *out->mutable_list()->add_tensor() = v;
394
+ }
395
+ }
396
+
397
+ void SetAttrValue(const NameAttrList& value, AttrValue* out) {
398
+ *out->mutable_func() = value;
399
+ }
400
+
401
+ void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
402
+ out->mutable_list()->Clear(); // Create list() even if value empty.
403
+ for (const auto& v : value) {
404
+ *out->mutable_list()->add_func() = v;
405
+ }
406
+ }
407
+
408
+ bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
409
+ // There are multiple equivalent representations of attr values containing
410
+ // TensorProtos. Compare them by constructing Tensors and serializing them
411
+ // back. Comparing Tensor objects is pretty tricky.
412
+ if (a.has_tensor() != b.has_tensor()) {
413
+ return false;
414
+ } else if (a.has_tensor() && b.has_tensor()) {
415
+ Tensor at(a.tensor().dtype());
416
+ bool success = at.FromProto(a.tensor());
417
+ DCHECK(success);
418
+
419
+ Tensor bt(b.tensor().dtype());
420
+ success = bt.FromProto(b.tensor());
421
+ DCHECK(success);
422
+
423
+ TensorProto ap;
424
+ at.AsProtoTensorContent(&ap);
425
+
426
+ TensorProto bp;
427
+ bt.AsProtoTensorContent(&bp);
428
+
429
+ string a_str, b_str;
430
+ SerializeToStringDeterministic(ap, &a_str);
431
+ SerializeToStringDeterministic(bp, &b_str);
432
+ return a_str == b_str;
433
+ }
434
+
435
+ // `func` field contains a nested AttrValue. Compare such AttrValues
436
+ // recursively.
437
+ if (a.has_func() != b.has_func()) {
438
+ return false;
439
+ } else if (a.has_func() && b.has_func()) {
440
+ const NameAttrList& af = a.func();
441
+ const NameAttrList& bf = b.func();
442
+ if (af.name() != bf.name()) return false;
443
+ std::unordered_map<string, AttrValue> am(af.attr().begin(),
444
+ af.attr().end());
445
+ for (const auto& bm_pair : bf.attr()) {
446
+ const auto& iter = am.find(bm_pair.first);
447
+ if (iter == am.end()) return false;
448
+ if (!AreAttrValuesEqual(iter->second, bm_pair.second)) return false;
449
+ am.erase(iter);
450
+ }
451
+ if (!am.empty()) return false;
452
+ return true;
453
+ }
454
+
455
+ // All other fields in AttrValue have deterministic representations.
456
+ // It is safe to compare their serialized strings.
457
+ string a_str, b_str;
458
+ SerializeToStringDeterministic(a, &a_str);
459
+ SerializeToStringDeterministic(b, &b_str);
460
+ return a_str == b_str;
461
+ }
462
+
463
+ uint64 AttrValueHash(const AttrValue& a) {
464
+ if (a.has_tensor()) {
465
+ // Deal with multiple representations by parsing TensorProto to
466
+ // Tensor and serializing it back. This is slow, but current use case
467
+ // don't need high efficiency.
468
+ Tensor tensor(a.tensor().dtype());
469
+ bool success = tensor.FromProto(a.tensor());
470
+ DCHECK(success);
471
+ TensorProto p;
472
+ tensor.AsProtoTensorContent(&p);
473
+ string s;
474
+ SerializeToStringDeterministic(p, &s);
475
+ return Hash64(s);
476
+ }
477
+ if (a.has_func()) {
478
+ const NameAttrList& func = a.func();
479
+ uint64 h = Hash64(func.name());
480
+ std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
481
+ for (const auto& pair : map) {
482
+ h = Hash64(pair.first.data(), pair.first.size(), h);
483
+ h = Hash64Combine(AttrValueHash(pair.second), h);
484
+ }
485
+ return h;
486
+ }
487
+
488
+ // If `a` is not a tensor or func, get a hash of serialized string.
489
+ string s;
490
+ SerializeToStringDeterministic(a, &s);
491
+ return Hash64(s);
492
+ }
493
+
494
+ bool HasPlaceHolder(const AttrValue& val) {
495
+ switch (val.value_case()) {
496
+ case AttrValue::kList: {
497
+ for (const NameAttrList& func : val.list().func()) {
498
+ for (const auto& p : func.attr()) {
499
+ if (HasPlaceHolder(p.second)) {
500
+ return true;
501
+ }
502
+ }
503
+ }
504
+ break;
505
+ }
506
+ case AttrValue::kFunc:
507
+ for (const auto& p : val.func().attr()) {
508
+ if (HasPlaceHolder(p.second)) {
509
+ return true;
510
+ }
511
+ }
512
+ break;
513
+ case AttrValue::kPlaceholder:
514
+ return true;
515
+ default:
516
+ break;
517
+ }
518
+ return false;
519
+ }
520
+
521
+ bool SubstitutePlaceholders(const SubstituteFunc& substitute,
522
+ AttrValue* value) {
523
+ switch (value->value_case()) {
524
+ case AttrValue::kList: {
525
+ for (NameAttrList& func : *value->mutable_list()->mutable_func()) {
526
+ for (auto& p : *func.mutable_attr()) {
527
+ if (!SubstitutePlaceholders(substitute, &p.second)) {
528
+ return false;
529
+ }
530
+ }
531
+ }
532
+ break;
533
+ }
534
+ case AttrValue::kFunc:
535
+ for (auto& p : *(value->mutable_func()->mutable_attr())) {
536
+ if (!SubstitutePlaceholders(substitute, &p.second)) {
537
+ return false;
538
+ }
539
+ }
540
+ break;
541
+ case AttrValue::kPlaceholder:
542
+ return substitute(value->placeholder(), value);
543
+ case AttrValue::VALUE_NOT_SET:
544
+ return false;
545
+ default:
546
+ break;
547
+ }
548
+ return true;
549
+ }
550
+
551
+ } // namespace tensorflow
attr_value_util.h ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
17
+ #define TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
18
+
19
+ #include <functional>
20
+ #include <string>
21
+ #include <vector>
22
+
23
+ #include "tensorflow/core/framework/partial_tensor_shape.h"
24
+ #include "tensorflow/core/framework/tensor.h"
25
+ #include "tensorflow/core/framework/tensor_shape.h"
26
+ #include "tensorflow/core/framework/types.h"
27
+ #include "tensorflow/core/lib/core/status.h"
28
+ #include "tensorflow/core/lib/core/stringpiece.h"
29
+ #include "tensorflow/core/lib/gtl/array_slice.h"
30
+
31
+ namespace tensorflow {
32
+
33
+ // Forward declare protos so their symbols can be removed from .so exports
34
+ class AttrValue;
35
+ class NameAttrList;
36
+
37
+ // A human-readable rendering of attr_value, that is more concise than a
38
+ // text-format proto.
39
+ string SummarizeAttrValue(const AttrValue& attr_value);
40
+
41
+ // Generates an error if attr_value doesn't have the indicated attr type.
42
+ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type);
43
+
44
+ // Converts a text proto value from "text" into the field of *out
45
+ // indicated by "type" (e.g. from the type field of an AttrDef).
46
+ // Examples:
47
+ // * If type:"int" and text:"-14", then *out is set to "i: -14"
48
+ // * If type:"list(string)" and text:"['foo', 'bar']",
49
+ // then *out is set to "list { s: ['foo', 'bar'] }"
50
+ // Returns true on success.
51
+ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out);
52
+
53
+ // Sets *out based on the type of value.
54
+ void SetAttrValue(const string& value, AttrValue* out);
55
+ void SetAttrValue(const char* value, AttrValue* out);
56
+ void SetAttrValue(StringPiece value, AttrValue* out);
57
+ void SetAttrValue(int64 value, AttrValue* out);
58
+ void SetAttrValue(int32 value, AttrValue* out);
59
+ void SetAttrValue(float value, AttrValue* out);
60
+ void SetAttrValue(double value, AttrValue* out);
61
+ void SetAttrValue(bool value, AttrValue* out);
62
+ void SetAttrValue(DataType value, AttrValue* out);
63
+ void SetAttrValue(const TensorShape& value, AttrValue* out);
64
+ void SetAttrValue(const TensorShapeProto& value, AttrValue* out);
65
+ void SetAttrValue(const PartialTensorShape& value, AttrValue* out);
66
+ void SetAttrValue(const Tensor& value, AttrValue* out);
67
+ void SetAttrValue(const TensorProto& value, AttrValue* out);
68
+ void SetAttrValue(const NameAttrList& value, AttrValue* out);
69
+
70
+ void SetAttrValue(gtl::ArraySlice<string> value, AttrValue* out);
71
+ void SetAttrValue(gtl::ArraySlice<const char*> value, AttrValue* out);
72
+ void SetAttrValue(gtl::ArraySlice<StringPiece> value, AttrValue* out);
73
+ void SetAttrValue(gtl::ArraySlice<int64> value, AttrValue* out);
74
+ void SetAttrValue(gtl::ArraySlice<int32> value, AttrValue* out);
75
+ void SetAttrValue(gtl::ArraySlice<float> value, AttrValue* out);
76
+ void SetAttrValue(gtl::ArraySlice<double> value, AttrValue* out);
77
+ void SetAttrValue(gtl::ArraySlice<bool> value, AttrValue* out);
78
+ void SetAttrValue(const std::vector<bool>& value, AttrValue* out);
79
+ void SetAttrValue(std::initializer_list<bool> value, AttrValue* out);
80
+ void SetAttrValue(DataTypeSlice value, AttrValue* out);
81
+ void SetAttrValue(gtl::ArraySlice<TensorShape> value, AttrValue* out);
82
+ void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out);
83
+ void SetAttrValue(gtl::ArraySlice<PartialTensorShape> value, AttrValue* out);
84
+ void SetAttrValue(gtl::ArraySlice<Tensor> value, AttrValue* out);
85
+ void SetAttrValue(gtl::ArraySlice<TensorProto> value, AttrValue* out);
86
+ void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out);
87
+
88
+ void SetAttrValue(const AttrValue& value, AttrValue* out);
89
+
90
+ // Returns true if a and b have the same value.
91
+ bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b);
92
+
93
+ // Returns a hash of `a` that is consistent with AreAttrValuesEqual. In other
94
+ // words, if two AttrValues compare equal according to AreAttrValuesEqual,
95
+ // they will have the same hash value.
96
+ // Similarly to protobuf deterministic serialization, hash value is
97
+ // guaranteed to be stable only for a given binary. In particular, one should
98
+ // probably not persist the returned value.
99
+ uint64 AttrValueHash(const AttrValue& a);
100
+
101
+ // Returns true if "val" has a placeholder.
102
+ bool HasPlaceHolder(const AttrValue& val);
103
+
104
+ // SubstitutePlaceholders recursively replaces placeholders in 'value'
105
+ // with an attr value by calling SubstituteFunc. Returns true iff all
106
+ // placeholders in "value" are replaced with a value.
107
+ //
108
+ // SubstituteFunc is given a placeholder string. If the placeholder is
109
+ // unknown, SubstituteFunc returns false. Otherwise, overwrites the
110
+ // attr value and returns true.
111
+ using SubstituteFunc = std::function<bool(const string&, AttrValue*)>;
112
+ bool SubstitutePlaceholders(const SubstituteFunc& substitute, AttrValue* value);
113
+
114
+ } // namespace tensorflow
115
+
116
+ #endif // TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
attr_value_util_test.cc ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/attr_value_util.h"
17
+
18
+ #include <vector>
19
+ #include "tensorflow/core/framework/attr_value.pb.h"
20
+ #include "tensorflow/core/lib/core/status_test_util.h"
21
+ #include "tensorflow/core/platform/protobuf.h"
22
+ #include "tensorflow/core/platform/test.h"
23
+
24
+ namespace tensorflow {
25
+
26
+ // A few helpers to construct AttrValue protos.
27
+ template <typename T>
28
+ AttrValue V(T value) {
29
+ AttrValue ret;
30
+ SetAttrValue(value, &ret);
31
+ return ret;
32
+ }
33
+
34
+ AttrValue P(const string& p) {
35
+ AttrValue ret;
36
+ ret.set_placeholder(p);
37
+ return ret;
38
+ }
39
+
40
+ AttrValue F(const string& name,
41
+ std::vector<std::pair<string, AttrValue>> pairs) {
42
+ AttrValue ret;
43
+ ret.mutable_func()->set_name(name);
44
+ ret.mutable_func()->mutable_attr()->insert(pairs.begin(), pairs.end());
45
+ return ret;
46
+ }
47
+
48
+ AttrValue Fs(
49
+ std::vector<std::pair<string, std::vector<std::pair<string, AttrValue>>>>
50
+ funcs) {
51
+ AttrValue ret;
52
+ for (const auto& func : funcs) {
53
+ NameAttrList* entry = ret.mutable_list()->add_func();
54
+ entry->set_name(func.first);
55
+ entry->mutable_attr()->insert(func.second.begin(), func.second.end());
56
+ }
57
+ return ret;
58
+ }
59
+
60
+ TEST(AttrValueUtil, HasType) {
61
+ // OK
62
+ EXPECT_TRUE(AttrValueHasType(V(123), "int").ok());
63
+ EXPECT_TRUE(AttrValueHasType(V(1.2), "float").ok());
64
+ EXPECT_TRUE(AttrValueHasType(V(DT_FLOAT), "type").ok());
65
+ EXPECT_TRUE(AttrValueHasType(F("f", {}), "func").ok());
66
+ EXPECT_TRUE(AttrValueHasType(Fs({{"f", {}}, {"g", {}}}), "list(func)").ok());
67
+
68
+ // not OK.
69
+ EXPECT_FALSE(AttrValueHasType(V(123), "func").ok());
70
+ EXPECT_FALSE(AttrValueHasType(V(1.2), "int").ok());
71
+ EXPECT_FALSE(AttrValueHasType(V(DT_FLOAT), "shape").ok());
72
+ EXPECT_FALSE(AttrValueHasType(F("f", {}), "string").ok());
73
+ EXPECT_FALSE(AttrValueHasType(P("T"), "float").ok());
74
+ EXPECT_FALSE(AttrValueHasType(V(static_cast<DataType>(1000)), "type").ok());
75
+ std::vector<DataType> list_type({static_cast<DataType>(1000)});
76
+ EXPECT_FALSE(AttrValueHasType(V(list_type), "list(type)").ok());
77
+ }
78
+
79
+ SubstituteFunc ReplaceTWith(const AttrValue& val) {
80
+ return [val](const string& placeholder, AttrValue* target) {
81
+ if (placeholder == "T") {
82
+ *target = val;
83
+ return true;
84
+ } else {
85
+ return false;
86
+ }
87
+ };
88
+ }
89
+
90
+ TEST(AttrValueUtil, Basic) {
91
+ auto v = F("MatMul", {{"dtype", P("T")},
92
+ {"transpose_a", V(false)},
93
+ {"transpose_b", V(true)},
94
+ {"use_cublas", V(true)}});
95
+ TF_EXPECT_OK(AttrValueHasType(v, "func"));
96
+ EXPECT_TRUE(HasPlaceHolder(v));
97
+
98
+ EXPECT_EQ(
99
+ SummarizeAttrValue(v),
100
+ "MatMul[dtype=$T, transpose_a=false, transpose_b=true, use_cublas=true]");
101
+
102
+ SubstitutePlaceholders(ReplaceTWith(V(DT_FLOAT)), &v);
103
+ EXPECT_TRUE(!HasPlaceHolder(v));
104
+ EXPECT_EQ(SummarizeAttrValue(v),
105
+ "MatMul[dtype=DT_FLOAT, transpose_a=false, transpose_b=true, "
106
+ "use_cublas=true]");
107
+ }
108
+
109
+ TEST(AttrValueUtil, Shaped) {
110
+ auto v =
111
+ F("OpRequiresShape", {{"shape_full", V(TensorShape({1, 0}))},
112
+ {"shape_part", V(PartialTensorShape({-1, 1, 0}))}});
113
+ TF_EXPECT_OK(AttrValueHasType(v, "func"));
114
+ EXPECT_FALSE(HasPlaceHolder(v));
115
+
116
+ EXPECT_EQ(SummarizeAttrValue(v),
117
+ "OpRequiresShape[shape_full=[1,0], shape_part=[?,1,0]]");
118
+ }
119
+
120
+ TEST(AttrValueUtil, DeepAttr) {
121
+ auto v = Fs({{"f", {{"T", P("T")}}}, {"g", {{"T", P("T")}}}});
122
+ TF_EXPECT_OK(AttrValueHasType(v, "list(func)"));
123
+ EXPECT_TRUE(HasPlaceHolder(v));
124
+
125
+ for (int i = 0; i < 3; ++i) {
126
+ v = F("f", {{"T", P("T")}, {"F", v}});
127
+ EXPECT_TRUE(HasPlaceHolder(v));
128
+ }
129
+ EXPECT_EQ(SummarizeAttrValue(v),
130
+ "f[F=f[F=f[F=[f[T=$T], g[T=$T]], T=$T], T=$T], T=$T]");
131
+
132
+ SubstitutePlaceholders(ReplaceTWith(F("x", {})), &v);
133
+ EXPECT_TRUE(!HasPlaceHolder(v));
134
+ EXPECT_EQ(SummarizeAttrValue(v),
135
+ "f[F=f[F=f[F=[f[T=x[]], g[T=x[]]], T=x[]], T=x[]], T=x[]]");
136
+ }
137
+
138
+ AttrValue FromText(const string& text) {
139
+ AttrValue attr;
140
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &attr));
141
+ return attr;
142
+ }
143
+
144
+ void ExpectDifferent(const AttrValue& a1, const AttrValue& a2) {
145
+ EXPECT_FALSE(AreAttrValuesEqual(a1, a2));
146
+ EXPECT_FALSE(AreAttrValuesEqual(a2, a1));
147
+ EXPECT_NE(AttrValueHash(a1), AttrValueHash(a2));
148
+ }
149
+
150
+ TEST(AttrValueEquality, StringAndFuncTensors) {
151
+ AttrValue a = FromText(R"(
152
+ tensor {
153
+ dtype: DT_STRING
154
+ tensor_shape {
155
+ dim {
156
+ size: 2
157
+ }
158
+ }
159
+ string_val: 'reader_dataset_ops_test/tmphtXHks/text_line.0.txt'
160
+ string_val: 'reader_dataset_ops_test/tmphtXHks/text_line.1.txt'
161
+ })");
162
+ EXPECT_TRUE(AreAttrValuesEqual(a, a));
163
+ EXPECT_EQ(AttrValueHash(a), AttrValueHash(a));
164
+
165
+ AttrValue b = a;
166
+ (*b.mutable_tensor()->mutable_string_val(0))[3] = '1';
167
+ ExpectDifferent(a, b);
168
+
169
+ AttrValue c1;
170
+ c1.mutable_func()->set_name("func_name");
171
+ (*c1.mutable_func()->mutable_attr())["attr1"] = a;
172
+ (*c1.mutable_func()->mutable_attr())["attr2"] = b;
173
+ EXPECT_TRUE(AreAttrValuesEqual(c1, c1));
174
+ EXPECT_EQ(AttrValueHash(c1), AttrValueHash(c1));
175
+
176
+ ExpectDifferent(c1, a);
177
+
178
+ AttrValue c2 = c1;
179
+ c2.mutable_func()->set_name("func_name2");
180
+ ExpectDifferent(c1, c2);
181
+
182
+ c2 = c1;
183
+ (*c2.mutable_func()->mutable_attr())["attr3"] = b;
184
+ ExpectDifferent(c1, c2);
185
+
186
+ c2 = c1;
187
+ (*c2.mutable_func()->mutable_attr())["attr2"] = a;
188
+ ExpectDifferent(c1, c2);
189
+
190
+ c2 = c1;
191
+ c2.mutable_func()->mutable_attr()->erase("attr2");
192
+ ExpectDifferent(c1, c2);
193
+ }
194
+
195
+ } // namespace tensorflow
bfloat16.cc ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/bfloat16.h"
17
+
18
+ namespace tensorflow {
19
+
20
+ void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
21
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
22
+ uint16_t* q = reinterpret_cast<uint16_t*>(dst);
23
+ #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
24
+ for (; size != 0; p += 2, q++, size--) {
25
+ *q = p[0];
26
+ }
27
+ #else
28
+ for (; size != 0; p += 2, q++, size--) {
29
+ *q = p[1];
30
+ }
31
+ #endif
32
+ }
33
+
34
+ void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {
35
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
36
+ uint16_t* q = reinterpret_cast<uint16_t*>(dst);
37
+ #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
38
+ for (; size != 0; p++, q += 2, size--) {
39
+ q[0] = *p;
40
+ q[1] = 0;
41
+ }
42
+ #else
43
+ for (; size != 0; p++, q += 2, size--) {
44
+ q[0] = 0;
45
+ q[1] = *p;
46
+ }
47
+ #endif
48
+ }
49
+
50
+ } // end namespace tensorflow
bfloat16.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_BFLOAT16_H_
17
+ #define TENSORFLOW_FRAMEWORK_BFLOAT16_H_
18
+
19
+ #include "tensorflow/core/framework/numeric_types.h"
20
+ #include "tensorflow/core/platform/types.h"
21
+
22
+ #if defined(PLATFORM_WINDOWS)
23
+ #include "tensorflow/core/platform/windows/cpu_info.h"
24
+ #endif
25
+
26
+ // Compact 16-bit encoding of floating point numbers. This representation uses
27
+ // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. It
28
+ // is assumed that floats are in IEEE 754 format so the representation is just
29
+ // bits 16-31 of a single precision float.
30
+ //
31
+ // NOTE: The IEEE floating point standard defines a float16 format that
32
+ // is different than this format (it has fewer bits of exponent and more
33
+ // bits of mantissa). We don't use that format here because conversion
34
+ // to/from 32-bit floats is more complex for that format, and the
35
+ // conversion for this format is very simple.
36
+ //
37
+ // Because of the existing IEEE float16 type, we do not name our representation
38
+ // "float16" but just use "uint16".
39
+ //
40
+ // <-----our 16bits float------->
41
+ // s e e e e e e e e f f f f f f f f f f f f f f f f f f f f f f f
42
+ // <------------------------------float-------------------------->
43
+ // 3 3 2 2 1 1 0
44
+ // 1 0 3 2 5 4 0
45
+ //
46
+ //
47
+ // This type only supports conversion back and forth with float.
48
+ //
49
+ // This file must be compilable by nvcc.
50
+ //
51
+ // The type is defined in framework/numeric_types.h.
52
+
53
+ namespace tensorflow {
54
+
55
+ // Conversion routines between an array of float and bfloat16 of
56
+ // "size".
57
+ void FloatToBFloat16(const float* src, bfloat16* dst, int64 size);
58
+ void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size);
59
+
60
+ } // namespace tensorflow
61
+
62
+ #endif // TENSORFLOW_FRAMEWORK_BFLOAT16_H_
bfloat16_test.cc ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/bfloat16.h"
17
+
18
+ #include "tensorflow/core/framework/numeric_types.h"
19
+ #include "tensorflow/core/lib/core/casts.h"
20
+ #include "tensorflow/core/platform/test.h"
21
+ #include "tensorflow/core/platform/test_benchmark.h"
22
+
23
+ namespace tensorflow {
24
+ namespace {
25
+
26
+ TEST(Bfloat16Test, Simple) {
27
+ bfloat16 a(12);
28
+ // Floating point representation of 12: 0x41400000
29
+ EXPECT_EQ(0x4140, a.value);
30
+ }
31
+
32
+ float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
33
+ uint32_t low_mantissa) {
34
+ return bit_cast<float>((sign << 31) + (exponent << 23) +
35
+ (high_mantissa << 16) + low_mantissa);
36
+ }
37
+
38
+ struct Bfloat16TestParam {
39
+ float input;
40
+ float expected;
41
+ };
42
+
43
+ class Bfloat16Test : public ::testing::Test,
44
+ public ::testing::WithParamInterface<Bfloat16TestParam> {};
45
+
46
+ TEST_P(Bfloat16Test, TruncateTest) {
47
+ bfloat16 a(GetParam().input);
48
+ if (std::isnan(GetParam().input)) {
49
+ EXPECT_TRUE(std::isnan(float(a)) || std::isinf(float(a)));
50
+ return;
51
+ }
52
+ EXPECT_EQ(GetParam().expected, float(a));
53
+ }
54
+
55
+ INSTANTIATE_TEST_CASE_P(
56
+ Bfloat16Test_Instantiation, Bfloat16Test,
57
+ ::testing::Values(
58
+ Bfloat16TestParam{
59
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b1111010111000011),
60
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
61
+ Bfloat16TestParam{
62
+ BinaryToFloat(1, 0b10000000, 0b1001000, 0b1111010111000011),
63
+ BinaryToFloat(1, 0b10000000, 0b1001000, 0b0000000000000000)},
64
+ Bfloat16TestParam{
65
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
66
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
67
+ Bfloat16TestParam{
68
+ BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000001),
69
+ BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000000)},
70
+ Bfloat16TestParam{
71
+ BinaryToFloat(0, 0b11111111, 0b1111111, 0b1111111111111111),
72
+ BinaryToFloat(0, 0b11111111, 0b1111111, 0b0000000000000000)},
73
+ Bfloat16TestParam{
74
+ BinaryToFloat(1, 0b10000000, 0b1001000, 0b1100000000000000),
75
+ BinaryToFloat(1, 0b10000000, 0b1001000, 0b0000000000000000)},
76
+ Bfloat16TestParam{
77
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
78
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
79
+ Bfloat16TestParam{
80
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0100000000000000),
81
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
82
+ Bfloat16TestParam{
83
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
84
+ BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
85
+ Bfloat16TestParam{
86
+ BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000),
87
+ BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)},
88
+ Bfloat16TestParam{
89
+ BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000),
90
+ BinaryToFloat(0, 0b00000000, 0b1111111, 0b0000000000000000)}));
91
+
92
+ TEST(Bfloat16Test, Conversion) {
93
+ float a[100];
94
+ for (int i = 0; i < 100; ++i) {
95
+ a[i] = i + 1.25;
96
+ }
97
+ bfloat16 b[100];
98
+ float c[100];
99
+ FloatToBFloat16(a, b, 100);
100
+ BFloat16ToFloat(b, c, 100);
101
+ for (int i = 0; i < 100; ++i) {
102
+ // The relative error should be less than 1/(2^7) since bfloat16
103
+ // has 7 bits mantissa.
104
+ EXPECT_LE(fabs(c[i] - a[i]) / a[i], 1.0 / 128);
105
+ }
106
+ }
107
+
108
+ TEST(Bfloat16Test, Epsilon) {
109
+ EXPECT_LT(1.0f, static_cast<float>(bfloat16::epsilon() + bfloat16(1.0f)));
110
+ EXPECT_EQ(1.0f, static_cast<float>((bfloat16::epsilon() / bfloat16(2.0f)) +
111
+ bfloat16(1.0f)));
112
+ }
113
+
114
+ TEST(Bfloat16Test, Negate) {
115
+ EXPECT_EQ(-3.0f, static_cast<float>(-bfloat16(3.0f)));
116
+ EXPECT_EQ(4.5f, static_cast<float>(-bfloat16(-4.5f)));
117
+ }
118
+
119
+ static void BM_FloatToBFloat16(int iters) {
120
+ testing::StopTiming();
121
+ static const int N = 32 << 20;
122
+ const int64 tot = static_cast<int64>(iters) * N;
123
+ testing::ItemsProcessed(tot);
124
+ testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
125
+
126
+ float* inp = new float[N];
127
+ bfloat16* out = new bfloat16[N];
128
+
129
+ testing::StartTiming();
130
+ while (iters--) {
131
+ FloatToBFloat16(inp, out, N);
132
+ }
133
+ delete[] inp;
134
+ delete[] out;
135
+ }
136
+ BENCHMARK(BM_FloatToBFloat16);
137
+
138
+ static void BM_BFloat16ToFloat(int iters) {
139
+ testing::StopTiming();
140
+ static const int N = 32 << 20;
141
+ const int64 tot = static_cast<int64>(iters) * N;
142
+ testing::ItemsProcessed(tot);
143
+ testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
144
+
145
+ bfloat16* inp = new bfloat16[N];
146
+ float* out = new float[N];
147
+
148
+ testing::StartTiming();
149
+ while (iters--) {
150
+ BFloat16ToFloat(inp, out, N);
151
+ }
152
+ delete[] inp;
153
+ delete[] out;
154
+ }
155
+ BENCHMARK(BM_BFloat16ToFloat);
156
+
157
+ } // namespace
158
+ } // namespace tensorflow
cancellation.cc ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/cancellation.h"
17
+
18
+ #include "tensorflow/core/lib/core/errors.h"
19
+ #include "tensorflow/core/platform/logging.h"
20
+
21
+ namespace tensorflow {
22
+
23
+ const CancellationToken CancellationManager::kInvalidToken = -1;
24
+
25
+ CancellationManager::CancellationManager()
26
+ : is_cancelling_(false),
27
+ is_cancelled_(false),
28
+ next_cancellation_token_(0) {}
29
+
30
+ void CancellationManager::StartCancel() {
31
+ gtl::FlatMap<CancellationToken, CancelCallback> callbacks_to_run;
32
+ {
33
+ mutex_lock l(mu_);
34
+ if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
35
+ return;
36
+ }
37
+ is_cancelling_ = true;
38
+ std::swap(callbacks_, callbacks_to_run);
39
+ }
40
+ // We call these callbacks without holding mu_, so that concurrent
41
+ // calls to DeregisterCallback, which can happen asynchronously, do
42
+ // not block. The callbacks remain valid because any concurrent call
43
+ // to DeregisterCallback will block until the
44
+ // cancelled_notification_ is notified.
45
+ for (auto key_and_value : callbacks_to_run) {
46
+ key_and_value.second();
47
+ }
48
+ {
49
+ mutex_lock l(mu_);
50
+ is_cancelling_ = false;
51
+ is_cancelled_.store(true, std::memory_order_release);
52
+ }
53
+ cancelled_notification_.Notify();
54
+ }
55
+
56
+ CancellationToken CancellationManager::get_cancellation_token() {
57
+ mutex_lock l(mu_);
58
+ return next_cancellation_token_++;
59
+ }
60
+
61
+ bool CancellationManager::RegisterCallback(CancellationToken token,
62
+ CancelCallback callback) {
63
+ mutex_lock l(mu_);
64
+ CHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token";
65
+ bool should_register = !is_cancelled_ && !is_cancelling_;
66
+ if (should_register) {
67
+ std::swap(callbacks_[token], callback);
68
+ }
69
+ return should_register;
70
+ }
71
+
72
+ bool CancellationManager::DeregisterCallback(CancellationToken token) {
73
+ mu_.lock();
74
+ if (is_cancelled_) {
75
+ mu_.unlock();
76
+ return false;
77
+ } else if (is_cancelling_) {
78
+ mu_.unlock();
79
+ // Wait for all of the cancellation callbacks to be called. This
80
+ // wait ensures that the caller of DeregisterCallback does not
81
+ // return immediately and free objects that may be used in the
82
+ // execution of any currently pending callbacks in StartCancel.
83
+ cancelled_notification_.WaitForNotification();
84
+ return false;
85
+ } else {
86
+ callbacks_.erase(token);
87
+ mu_.unlock();
88
+ return true;
89
+ }
90
+ }
91
+
92
+ CancellationManager::~CancellationManager() { StartCancel(); }
93
+
94
+ } // end namespace tensorflow
cancellation.h ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_
17
+ #define TENSORFLOW_FRAMEWORK_CANCELLATION_H_
18
+
19
+ #include <atomic>
20
+ #include <functional>
21
+
22
+ #include "tensorflow/core/lib/core/notification.h"
23
+ #include "tensorflow/core/lib/core/status.h"
24
+ #include "tensorflow/core/lib/gtl/flatmap.h"
25
+ #include "tensorflow/core/lib/hash/hash.h"
26
+ #include "tensorflow/core/platform/mutex.h"
27
+ #include "tensorflow/core/platform/thread_annotations.h"
28
+ #include "tensorflow/core/platform/types.h"
29
+
30
+ namespace tensorflow {
31
+
32
+ // A token that can be used to register and deregister a
33
+ // CancelCallback with a CancellationManager.
34
+ //
35
+ // CancellationToken values must be created by a call to
36
+ // CancellationManager::get_cancellation_token.
37
+ typedef int64 CancellationToken;
38
+
39
+ // A callback that is invoked when a step is canceled.
40
+ //
41
+ // NOTE(mrry): See caveats about CancelCallback implementations in the
42
+ // comment for CancellationManager::RegisterCallback.
43
+ typedef std::function<void()> CancelCallback;
44
+
45
+ class CancellationManager {
46
+ public:
47
+ // A value that won't be returned by get_cancellation_token().
48
+ static const CancellationToken kInvalidToken;
49
+
50
+ CancellationManager();
51
+ ~CancellationManager();
52
+
53
+ // Run all callbacks associated with this manager.
54
+ void StartCancel();
55
+
56
+ // Returns true iff StartCancel() has been called.
57
+ bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
58
+
59
+ // Returns a token that must be used in calls to RegisterCallback
60
+ // and DeregisterCallback.
61
+ CancellationToken get_cancellation_token();
62
+
63
+ // Attempts to register the given callback to be invoked when this
64
+ // manager is cancelled. Returns true if the callback was
65
+ // registered; returns false if this manager was already cancelled,
66
+ // and the callback was not registered.
67
+ //
68
+ // If this method returns false, it is the caller's responsibility
69
+ // to perform any cancellation cleanup.
70
+ //
71
+ // This method is tricky to use correctly. The following usage pattern
72
+ // is recommended:
73
+ //
74
+ // class ObjectWithCancellableOperation {
75
+ // mutex mu_;
76
+ // void CancellableOperation(CancellationManager* cm,
77
+ // std::function<void(Status)> callback) {
78
+ // bool already_cancelled;
79
+ // CancellationToken token = cm->get_cancellation_token();
80
+ // {
81
+ // mutex_lock(mu_);
82
+ // already_cancelled = !cm->RegisterCallback(
83
+ // [this, token]() { Cancel(token); });
84
+ // if (!already_cancelled) {
85
+ // // Issue asynchronous operation. Associate the pending operation
86
+ // // with `token` in some object state, or provide another way for
87
+ // // the Cancel method to look up the operation for cancellation.
88
+ // // Ensure that `cm->DeregisterCallback(token)` is called without
89
+ // // holding `mu_`, before `callback` is invoked.
90
+ // // ...
91
+ // }
92
+ // }
93
+ // if (already_cancelled) {
94
+ // callback(errors::Cancelled("Operation was cancelled"));
95
+ // }
96
+ // }
97
+ //
98
+ // void Cancel(CancellationToken token) {
99
+ // mutex_lock(mu_);
100
+ // // Take action to cancel the operation with the given cancellation
101
+ // // token.
102
+ // }
103
+ //
104
+ // NOTE(mrry): The caller should take care that (i) the calling code
105
+ // is robust to `callback` being invoked asynchronously (e.g. from
106
+ // another thread), (ii) `callback` is deregistered by a call to
107
+ // this->DeregisterCallback(token) when the operation completes
108
+ // successfully, and (iii) `callback` does not invoke any method
109
+ // on this cancellation manager. Furthermore, it is important that
110
+ // the eventual caller of the complementary DeregisterCallback does not
111
+ // hold any mutexes that are required by `callback`.
112
+ bool RegisterCallback(CancellationToken token, CancelCallback callback);
113
+
114
+ // Deregister the callback that, when registered, was associated
115
+ // with the given cancellation token. Returns true iff the callback
116
+ // was deregistered and will not be invoked; otherwise returns false
117
+ // after the callback has been invoked, blocking if necessary.
118
+ //
119
+ // NOTE(mrry): This method may block if cancellation is in progress.
120
+ // The caller of this method must not hold any mutexes that are required
121
+ // to invoke any cancellation callback that has been registered with this
122
+ // cancellation manager.
123
+ bool DeregisterCallback(CancellationToken token);
124
+
125
+ private:
126
+ bool is_cancelling_;
127
+ std::atomic_bool is_cancelled_;
128
+
129
+ mutex mu_;
130
+ Notification cancelled_notification_;
131
+ CancellationToken next_cancellation_token_ GUARDED_BY(mu_);
132
+ gtl::FlatMap<CancellationToken, CancelCallback> callbacks_ GUARDED_BY(mu_);
133
+ };
134
+
135
+ } // namespace tensorflow
136
+
137
+ #endif // TENSORFLOW_FRAMEWORK_CANCELLATION_H_
cancellation_test.cc ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/cancellation.h"
17
+
18
+ #include <vector>
19
+ #include "tensorflow/core/lib/core/notification.h"
20
+ #include "tensorflow/core/lib/core/threadpool.h"
21
+ #include "tensorflow/core/platform/test.h"
22
+
23
+ namespace tensorflow {
24
+
25
+ TEST(Cancellation, SimpleNoCancel) {
26
+ bool is_cancelled = false;
27
+ CancellationManager* manager = new CancellationManager();
28
+ auto token = manager->get_cancellation_token();
29
+ bool registered = manager->RegisterCallback(
30
+ token, [&is_cancelled]() { is_cancelled = true; });
31
+ EXPECT_TRUE(registered);
32
+ bool deregistered = manager->DeregisterCallback(token);
33
+ EXPECT_TRUE(deregistered);
34
+ delete manager;
35
+ EXPECT_FALSE(is_cancelled);
36
+ }
37
+
38
+ TEST(Cancellation, SimpleCancel) {
39
+ bool is_cancelled = false;
40
+ CancellationManager* manager = new CancellationManager();
41
+ auto token = manager->get_cancellation_token();
42
+ bool registered = manager->RegisterCallback(
43
+ token, [&is_cancelled]() { is_cancelled = true; });
44
+ EXPECT_TRUE(registered);
45
+ manager->StartCancel();
46
+ EXPECT_TRUE(is_cancelled);
47
+ delete manager;
48
+ }
49
+
50
+ TEST(Cancellation, CancelBeforeRegister) {
51
+ CancellationManager* manager = new CancellationManager();
52
+ auto token = manager->get_cancellation_token();
53
+ manager->StartCancel();
54
+ bool registered = manager->RegisterCallback(token, nullptr);
55
+ EXPECT_FALSE(registered);
56
+ delete manager;
57
+ }
58
+
59
+ TEST(Cancellation, DeregisterAfterCancel) {
60
+ bool is_cancelled = false;
61
+ CancellationManager* manager = new CancellationManager();
62
+ auto token = manager->get_cancellation_token();
63
+ bool registered = manager->RegisterCallback(
64
+ token, [&is_cancelled]() { is_cancelled = true; });
65
+ EXPECT_TRUE(registered);
66
+ manager->StartCancel();
67
+ EXPECT_TRUE(is_cancelled);
68
+ bool deregistered = manager->DeregisterCallback(token);
69
+ EXPECT_FALSE(deregistered);
70
+ delete manager;
71
+ }
72
+
73
+ TEST(Cancellation, CancelMultiple) {
74
+ bool is_cancelled_1 = false, is_cancelled_2 = false, is_cancelled_3 = false;
75
+ CancellationManager* manager = new CancellationManager();
76
+ auto token_1 = manager->get_cancellation_token();
77
+ bool registered_1 = manager->RegisterCallback(
78
+ token_1, [&is_cancelled_1]() { is_cancelled_1 = true; });
79
+ EXPECT_TRUE(registered_1);
80
+ auto token_2 = manager->get_cancellation_token();
81
+ bool registered_2 = manager->RegisterCallback(
82
+ token_2, [&is_cancelled_2]() { is_cancelled_2 = true; });
83
+ EXPECT_TRUE(registered_2);
84
+ EXPECT_FALSE(is_cancelled_1);
85
+ EXPECT_FALSE(is_cancelled_2);
86
+ manager->StartCancel();
87
+ EXPECT_TRUE(is_cancelled_1);
88
+ EXPECT_TRUE(is_cancelled_2);
89
+ EXPECT_FALSE(is_cancelled_3);
90
+ auto token_3 = manager->get_cancellation_token();
91
+ bool registered_3 = manager->RegisterCallback(
92
+ token_3, [&is_cancelled_3]() { is_cancelled_3 = true; });
93
+ EXPECT_FALSE(registered_3);
94
+ EXPECT_FALSE(is_cancelled_3);
95
+ delete manager;
96
+ }
97
+
98
+ TEST(Cancellation, IsCancelled) {
99
+ CancellationManager* cm = new CancellationManager();
100
+ thread::ThreadPool w(Env::Default(), "test", 4);
101
+ std::vector<Notification> done(8);
102
+ for (size_t i = 0; i < done.size(); ++i) {
103
+ Notification* n = &done[i];
104
+ w.Schedule([n, cm]() {
105
+ while (!cm->IsCancelled()) {
106
+ }
107
+ n->Notify();
108
+ });
109
+ }
110
+ Env::Default()->SleepForMicroseconds(1000000 /* 1 second */);
111
+ cm->StartCancel();
112
+ for (size_t i = 0; i < done.size(); ++i) {
113
+ done[i].WaitForNotification();
114
+ }
115
+ delete cm;
116
+ }
117
+
118
+ } // namespace tensorflow
common_shape_fns.cc ADDED
@@ -0,0 +1,1399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+ #include "tensorflow/core/framework/common_shape_fns.h"
16
+ #include "tensorflow/core/framework/attr_value.pb.h"
17
+
18
+ namespace tensorflow {
19
+
20
+ Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
21
+ int64 dilation_rate, int64 stride,
22
+ Padding padding_type, int64* output_size,
23
+ int64* padding_before,
24
+ int64* padding_after) {
25
+ if (stride <= 0) {
26
+ return errors::InvalidArgument("Stride must be > 0, but got ", stride);
27
+ }
28
+ if (dilation_rate < 1) {
29
+ return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
30
+ dilation_rate);
31
+ }
32
+
33
+ // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2.
34
+ int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1;
35
+ switch (padding_type) {
36
+ case Padding::VALID:
37
+ *output_size = (input_size - effective_filter_size + stride) / stride;
38
+ *padding_before = *padding_after = 0;
39
+ break;
40
+ case Padding::SAME:
41
+ *output_size = (input_size + stride - 1) / stride;
42
+ const int64 padding_needed =
43
+ std::max(0LL, (*output_size - 1) * stride + effective_filter_size -
44
+ input_size);
45
+ // For odd values of total padding, add more padding at the 'right'
46
+ // side of the given dimension.
47
+ *padding_before = padding_needed / 2;
48
+ *padding_after = padding_needed - *padding_before;
49
+ break;
50
+ }
51
+ if (*output_size < 0) {
52
+ return errors::InvalidArgument("computed output size would be negative");
53
+ }
54
+ return Status::OK();
55
+ }
56
+
57
+ Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
58
+ int64 stride, Padding padding_type,
59
+ int64* output_size, int64* padding_before,
60
+ int64* padding_after) {
61
+ return GetWindowedOutputSizeVerboseV2(input_size, filter_size,
62
+ /*dilation_rate=*/1, stride,
63
+ padding_type, output_size,
64
+ padding_before, padding_after);
65
+ }
66
+
67
+ Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
68
+ Padding padding_type, int64* output_size,
69
+ int64* padding_size) {
70
+ int64 padding_after_unused;
71
+ return GetWindowedOutputSizeVerbose(input_size, filter_size, stride,
72
+ padding_type, output_size, padding_size,
73
+ &padding_after_unused);
74
+ }
75
+
76
+ Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
77
+ int64 dilation_rate, int64 stride,
78
+ Padding padding_type, int64* output_size,
79
+ int64* padding_size) {
80
+ int64 padding_after_unused;
81
+ return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate,
82
+ stride, padding_type, output_size,
83
+ padding_size, &padding_after_unused);
84
+ }
85
+
86
+ Status Get3dOutputSize(const std::array<int64, 3>& input,
87
+ const std::array<int64, 3>& window,
88
+ const std::array<int64, 3>& strides,
89
+ Padding padding_type, std::array<int64, 3>* output_ptr,
90
+ std::array<int64, 3>* padding_ptr) {
91
+ for (size_t i = 0; i < input.size(); ++i) {
92
+ TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i],
93
+ padding_type, &(*output_ptr)[i],
94
+ &(*padding_ptr)[i]));
95
+ }
96
+ return Status::OK();
97
+ }
98
+
99
+ Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
100
+ const std::array<int64, 3>& window,
101
+ const std::array<int64, 3>& dilations,
102
+ const std::array<int64, 3>& strides,
103
+ Padding padding_type, std::array<int64, 3>* output_ptr,
104
+ std::array<int64, 3>* padding_ptr) {
105
+ for (size_t i = 0; i < input.size(); ++i) {
106
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
107
+ input[i], window[i], dilations[i], strides[i], padding_type,
108
+ &(*output_ptr)[i], &(*padding_ptr)[i]));
109
+ }
110
+ return Status::OK();
111
+ }
112
+
113
+ namespace shape_inference {
114
+
115
+ // The V2 version computes windowed output size with arbitrary dilation_rate,
116
+ // while the original version only handles the cases where dilation_rates equal
117
+ // to 1.
118
+ Status GetWindowedOutputSizeFromDimsV2(
119
+ shape_inference::InferenceContext* c,
120
+ shape_inference::DimensionHandle input_size,
121
+ shape_inference::DimensionOrConstant filter_size, int64 dilation_rate,
122
+ int64 stride, Padding padding_type,
123
+ shape_inference::DimensionHandle* output_size) {
124
+ if (stride <= 0) {
125
+ return errors::InvalidArgument("Stride must be > 0, but got ", stride);
126
+ }
127
+
128
+ if (dilation_rate < 1) {
129
+ return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
130
+ dilation_rate);
131
+ }
132
+
133
+ // See also the parallel implementation in GetWindowedOutputSizeVerbose.
134
+ switch (padding_type) {
135
+ case Padding::VALID:
136
+ if (dilation_rate > 1) {
137
+ DimensionHandle window_size;
138
+ TF_RETURN_IF_ERROR(
139
+ c->Subtract(c->MakeDim(filter_size), 1, &window_size));
140
+ TF_RETURN_IF_ERROR(
141
+ c->Multiply(window_size, dilation_rate, &window_size));
142
+ TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
143
+ TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
144
+ } else {
145
+ TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
146
+ }
147
+ TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
148
+ TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
149
+ /*evenly_divisible=*/false, output_size));
150
+ break;
151
+ case Padding::SAME:
152
+ TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
153
+ TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
154
+ /*evenly_divisible=*/false, output_size));
155
+ break;
156
+ }
157
+ return Status::OK();
158
+ }
159
+
160
+ Status GetWindowedOutputSizeFromDims(
161
+ shape_inference::InferenceContext* c,
162
+ shape_inference::DimensionHandle input_size,
163
+ shape_inference::DimensionOrConstant filter_size, int64 stride,
164
+ Padding padding_type, shape_inference::DimensionHandle* output_size) {
165
+ return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
166
+ /*dilation_rate=*/1, stride,
167
+ padding_type, output_size);
168
+ }
169
+
170
+ Status UnchangedShape(shape_inference::InferenceContext* c) {
171
+ c->set_output(0, c->input(0));
172
+ return Status::OK();
173
+ }
174
+
175
+ Status MatMulShape(shape_inference::InferenceContext* c) {
176
+ ShapeHandle a;
177
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
178
+
179
+ ShapeHandle b;
180
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
181
+
182
+ bool transpose_a, transpose_b;
183
+ TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
184
+ TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
185
+ DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
186
+ DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
187
+
188
+ // Validate that the inner shapes are compatible.
189
+ DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
190
+ DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
191
+ DimensionHandle merged;
192
+ TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
193
+
194
+ c->set_output(0, c->Matrix(output_rows, output_cols));
195
+ return Status::OK();
196
+ }
197
+
198
+ Status BiasAddShape(shape_inference::InferenceContext* c) {
199
+ ShapeHandle input_shape;
200
+
201
+ // Fetch the data_format attribute, which may not exist.
202
+ string data_format;
203
+ Status s = c->GetAttr("data_format", &data_format);
204
+
205
+ if (s.ok() && data_format == "NCHW") {
206
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
207
+ } else {
208
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
209
+ }
210
+
211
+ ShapeHandle bias_shape;
212
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
213
+ DimensionHandle bias_dim = c->Dim(bias_shape, 0);
214
+
215
+ // If rank unknown, return unknown shape.
216
+ if (!c->RankKnown(input_shape)) {
217
+ c->set_output(0, c->UnknownShape());
218
+ return Status::OK();
219
+ }
220
+
221
+ // Output has the same shape as the input, and matches the length of
222
+ // the bias in its bias dimension.
223
+ ShapeHandle output_shape;
224
+ if (s.ok() && data_format == "NCHW") {
225
+ // Merge the length of bias_shape into the third to last dimension
226
+ ShapeHandle first;
227
+ TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -3, &first));
228
+
229
+ ShapeHandle last;
230
+ TF_RETURN_IF_ERROR(c->Subshape(input_shape, -2, &last));
231
+
232
+ DimensionHandle input_bias_dim = c->Dim(input_shape, -3);
233
+ DimensionHandle merged_bias_dim;
234
+ TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
235
+ ShapeHandle merged_bias = c->Vector(merged_bias_dim);
236
+
237
+ ShapeHandle temp;
238
+ TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
239
+ TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
240
+ } else {
241
+ ShapeHandle all_but_bias;
242
+ TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
243
+
244
+ DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
245
+ DimensionHandle merged_bias_dim;
246
+ TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
247
+
248
+ ShapeHandle merged_bias = c->Vector(merged_bias_dim);
249
+ TF_RETURN_IF_ERROR(
250
+ c->Concatenate(all_but_bias, merged_bias, &output_shape));
251
+ }
252
+
253
+ c->set_output(0, output_shape);
254
+ return Status::OK();
255
+ }
256
+
257
+ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
258
+ ShapeHandle input_shape;
259
+ // Fetch the data_format attribute, which may not exist.
260
+ string data_format;
261
+ Status s = c->GetAttr("data_format", &data_format);
262
+
263
+ if (s.ok() && data_format == "NCHW") {
264
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
265
+ c->set_output(0, c->Vector(c->Dim(input_shape, -3)));
266
+ } else {
267
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
268
+ c->set_output(0, c->Vector(c->Dim(input_shape, -1)));
269
+ }
270
+
271
+ return Status::OK();
272
+ }
273
+
274
+ Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
275
+ const ShapeHandle shape_handle,
276
+ const string& tensor_name,
277
+ shape_inference::InferenceContext* c) {
278
+ if (tensor_format == FORMAT_NCHW_VECT_C) {
279
+ // Check that the vect dim has size 4.
280
+ const int num_dims = c->Rank(shape_handle);
281
+ DimensionHandle vect_dim = c->Dim(
282
+ shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format));
283
+ DimensionHandle unused_vect_dim;
284
+ TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim));
285
+ }
286
+
287
+ return Status::OK();
288
+ }
289
+
290
+ Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
291
+ const std::vector<DimensionOrConstant>& spatial,
292
+ DimensionOrConstant C, ShapeHandle* out,
293
+ shape_inference::InferenceContext* context) {
294
+ const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
295
+ std::vector<DimensionHandle> dims_actual(num_dims);
296
+ dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N);
297
+ int outer_c_index = GetTensorFeatureDimIndex(num_dims, format);
298
+ dims_actual[outer_c_index] = context->MakeDim(C);
299
+ if (format == FORMAT_NCHW_VECT_C) {
300
+ dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
301
+ context->MakeDim(4);
302
+ }
303
+ for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) {
304
+ dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
305
+ context->MakeDim(spatial[spatial_dim]);
306
+ }
307
+ *out = context->MakeShape(dims_actual);
308
+ return Status::OK();
309
+ }
310
+
311
+ Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
312
+ DimensionHandle* batch_dim,
313
+ gtl::MutableArraySlice<DimensionHandle> spatial_dims,
314
+ DimensionHandle* filter_dim,
315
+ InferenceContext* context) {
316
+ const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
317
+ // Batch.
318
+ *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
319
+ // Spatial.
320
+ for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
321
+ ++spatial_dim_index) {
322
+ spatial_dims[spatial_dim_index] = context->Dim(
323
+ shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
324
+ }
325
+ // Channel.
326
+ *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
327
+ if (format == FORMAT_NCHW_VECT_C) {
328
+ TF_RETURN_IF_ERROR(context->Multiply(
329
+ *filter_dim,
330
+ context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
331
+ filter_dim));
332
+ }
333
+ return Status::OK();
334
+ }
335
+
336
+ Status ShapeFromDimensions(DimensionHandle batch_dim,
337
+ gtl::ArraySlice<DimensionHandle> spatial_dims,
338
+ DimensionHandle filter_dim, TensorFormat format,
339
+ InferenceContext* context, ShapeHandle* shape) {
340
+ const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
341
+ std::vector<DimensionHandle> out_dims(rank);
342
+
343
+ // Batch.
344
+ out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
345
+ // Spatial.
346
+ for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
347
+ ++spatial_dim_index) {
348
+ out_dims[tensorflow::GetTensorSpatialDimIndex(
349
+ rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
350
+ }
351
+ // Channel.
352
+ if (format == tensorflow::FORMAT_NCHW_VECT_C) {
353
+ // When format is NCHW_VECT_C, factor the feature map count
354
+ // into the outer feature count and the inner feature count (=4).
355
+ TF_RETURN_IF_ERROR(context->Divide(
356
+ filter_dim, 4, /*evenly_divisible=*/true,
357
+ &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
358
+ out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4);
359
+ } else {
360
+ out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
361
+ }
362
+
363
+ *shape = context->MakeShape(out_dims);
364
+ return tensorflow::Status::OK();
365
+ }
366
+
367
+ Status Conv2DShape(shape_inference::InferenceContext* c) {
368
+ string data_format_str, filter_format_str;
369
+ if (!c->GetAttr("data_format", &data_format_str).ok()) {
370
+ data_format_str = "NHWC";
371
+ }
372
+ if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
373
+ filter_format_str = "HWIO";
374
+ }
375
+
376
+ TensorFormat data_format;
377
+ if (!FormatFromString(data_format_str, &data_format)) {
378
+ return errors::InvalidArgument("Invalid data format string: ",
379
+ data_format_str);
380
+ }
381
+ FilterTensorFormat filter_format;
382
+ if (!FilterFormatFromString(filter_format_str, &filter_format)) {
383
+ return errors::InvalidArgument("Invalid filter format string: ",
384
+ filter_format_str);
385
+ }
386
+
387
+ constexpr int num_spatial_dims = 2;
388
+ const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
389
+ ShapeHandle conv_input_shape;
390
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
391
+ TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
392
+ data_format, conv_input_shape, "conv_input", c));
393
+
394
+ // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
395
+ ShapeHandle filter_shape;
396
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
397
+ TF_RETURN_IF_ERROR(
398
+ CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
399
+
400
+ std::vector<int32> dilations;
401
+ TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
402
+
403
+ if (dilations.size() != 4) {
404
+ return errors::InvalidArgument(
405
+ "Conv2D requires the dilation attribute to contain 4 values, but got: ",
406
+ dilations.size());
407
+ }
408
+
409
+ std::vector<int32> strides;
410
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
411
+
412
+ // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
413
+ if (strides.size() != 4) {
414
+ return errors::InvalidArgument("Conv2D on data format ", data_format_str,
415
+ " requires the stride attribute to contain"
416
+ " 4 values, but got: ",
417
+ strides.size());
418
+ }
419
+
420
+ const int32 stride_rows = GetTensorDim(strides, data_format, 'H');
421
+ const int32 stride_cols = GetTensorDim(strides, data_format, 'W');
422
+ const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H');
423
+ const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W');
424
+
425
+ DimensionHandle batch_size_dim;
426
+ DimensionHandle input_depth_dim;
427
+ gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
428
+ TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format,
429
+ &batch_size_dim, &input_spatial_dims,
430
+ &input_depth_dim, c));
431
+
432
+ DimensionHandle output_depth_dim = c->Dim(
433
+ filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
434
+ DimensionHandle filter_rows_dim = c->Dim(
435
+ filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
436
+ DimensionHandle filter_cols_dim = c->Dim(
437
+ filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
438
+ DimensionHandle filter_input_depth_dim;
439
+ if (filter_format == FORMAT_OIHW_VECT_I) {
440
+ TF_RETURN_IF_ERROR(c->Multiply(
441
+ c->Dim(filter_shape,
442
+ GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
443
+ c->Dim(filter_shape,
444
+ GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
445
+ &filter_input_depth_dim));
446
+ } else {
447
+ filter_input_depth_dim = c->Dim(
448
+ filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
449
+ }
450
+
451
+ // Check that the input tensor and the filter tensor agree on the input
452
+ // channel count.
453
+ DimensionHandle unused;
454
+ TF_RETURN_IF_ERROR(
455
+ c->Merge(input_depth_dim, filter_input_depth_dim, &unused));
456
+
457
+ Padding padding;
458
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
459
+
460
+ DimensionHandle output_rows, output_cols;
461
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
462
+ c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
463
+ padding, &output_rows));
464
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
465
+ c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
466
+ padding, &output_cols));
467
+
468
+ ShapeHandle output_shape;
469
+ TF_RETURN_IF_ERROR(
470
+ ShapeFromDimensions(batch_size_dim, {output_rows, output_cols},
471
+ output_depth_dim, data_format, c, &output_shape));
472
+ c->set_output(0, output_shape);
473
+ return Status::OK();
474
+ }
475
+
476
+ // TODO(mjanusz): Unify all conv/pooling shape functions.
477
+ Status Conv3DShape(shape_inference::InferenceContext* c) {
478
+ ShapeHandle input_shape;
479
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
480
+ ShapeHandle filter_shape;
481
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
482
+
483
+ string data_format;
484
+ Status s = c->GetAttr("data_format", &data_format);
485
+
486
+ std::vector<int32> strides;
487
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
488
+ if (strides.size() != 5) {
489
+ return errors::InvalidArgument(
490
+ "Conv3D requires the stride attribute to contain 5 values, but got: ",
491
+ strides.size());
492
+ }
493
+
494
+ int32 stride_planes, stride_rows, stride_cols;
495
+ if (s.ok() && data_format == "NCDHW") {
496
+ // Convert input_shape to NDHWC.
497
+ auto dim = [&](char dimension) {
498
+ return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
499
+ };
500
+ input_shape =
501
+ c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
502
+ stride_planes = strides[2];
503
+ stride_cols = strides[3];
504
+ stride_rows = strides[4];
505
+ } else {
506
+ stride_planes = strides[1];
507
+ stride_rows = strides[2];
508
+ stride_cols = strides[3];
509
+ }
510
+
511
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
512
+ DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
513
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
514
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
515
+
516
+ DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
517
+ DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
518
+ DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
519
+ DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
520
+
521
+ DimensionHandle unused;
522
+ TF_RETURN_IF_ERROR(
523
+ c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused));
524
+
525
+ Padding padding;
526
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
527
+ DimensionHandle output_planes, output_rows, output_cols;
528
+
529
+ TF_RETURN_IF_ERROR(
530
+ GetWindowedOutputSizeFromDims(c, in_planes_dim, filter_planes_dim,
531
+ stride_planes, padding, &output_planes));
532
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
533
+ c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
534
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
535
+ c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
536
+
537
+ ShapeHandle output_shape;
538
+ if (data_format == "NCDHW") {
539
+ output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
540
+ output_planes, output_rows, output_cols});
541
+ } else {
542
+ output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
543
+ output_cols, output_depth_dim});
544
+ }
545
+ c->set_output(0, output_shape);
546
+ return Status::OK();
547
+ }
548
+
549
+ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
550
+ ShapeHandle input_shape;
551
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
552
+ ShapeHandle filter_shape;
553
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
554
+
555
+ std::vector<int32> strides;
556
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
557
+
558
+ if (strides.size() != 4) {
559
+ return errors::InvalidArgument(
560
+ "DepthwiseConv2D requires the stride attribute to contain 4 values, "
561
+ "but got: ",
562
+ strides.size());
563
+ }
564
+
565
+ string data_format;
566
+ Status s = c->GetAttr("data_format", &data_format);
567
+ int32 stride_rows;
568
+ int32 stride_cols;
569
+ if (s.ok() && data_format == "NCHW") {
570
+ // Canonicalize input shape to NHWC so the shape inference code below can
571
+ // process it.
572
+ input_shape =
573
+ c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
574
+ c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
575
+ stride_rows = strides[2];
576
+ stride_cols = strides[3];
577
+ } else {
578
+ stride_rows = strides[1];
579
+ stride_cols = strides[2];
580
+ }
581
+
582
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
583
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
584
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
585
+
586
+ DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
587
+ DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
588
+ DimensionHandle input_depth = c->Dim(filter_shape, 2);
589
+ DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
590
+
591
+ // Check that the input depths are compatible.
592
+ TF_RETURN_IF_ERROR(
593
+ c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
594
+
595
+ DimensionHandle output_depth;
596
+ TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
597
+
598
+ Padding padding;
599
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
600
+
601
+ // TODO(mrry,shlens): Raise an error if the stride would cause
602
+ // information in the input to be ignored. This will require a change
603
+ // in the kernel implementation.
604
+ DimensionHandle output_rows, output_cols;
605
+
606
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
607
+ c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
608
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
609
+ c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
610
+
611
+ ShapeHandle output_shape;
612
+ if (data_format == "NCHW") {
613
+ output_shape =
614
+ c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
615
+ } else {
616
+ output_shape =
617
+ c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
618
+ }
619
+ c->set_output(0, output_shape);
620
+ return Status::OK();
621
+ }
622
+
623
+ Status AvgPoolShape(shape_inference::InferenceContext* c) {
624
+ string data_format_str;
625
+ TensorFormat data_format;
626
+ Status s = c->GetAttr("data_format", &data_format_str);
627
+ if (s.ok()) {
628
+ FormatFromString(data_format_str, &data_format);
629
+ } else {
630
+ data_format = FORMAT_NHWC;
631
+ }
632
+
633
+ const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
634
+ ShapeHandle input_shape;
635
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
636
+
637
+ TF_RETURN_IF_ERROR(
638
+ CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
639
+
640
+ std::vector<int32> strides;
641
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
642
+ if (strides.size() != 4) {
643
+ return errors::InvalidArgument(
644
+ "AvgPool requires the stride attribute to contain 4 values, but got: ",
645
+ strides.size());
646
+ }
647
+
648
+ std::vector<int32> kernel_sizes;
649
+ TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
650
+ if (kernel_sizes.size() != 4) {
651
+ return errors::InvalidArgument(
652
+ "AvgPool requires the ksize attribute to contain 4 values, but got: ",
653
+ kernel_sizes.size());
654
+ }
655
+
656
+ int32 stride_rows = GetTensorDim(strides, data_format, 'H');
657
+ int32 stride_cols = GetTensorDim(strides, data_format, 'W');
658
+ int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
659
+ int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
660
+
661
+ constexpr int num_spatial_dims = 2;
662
+ DimensionHandle batch_size_dim = c->Dim(
663
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
664
+ DimensionHandle in_rows_dim = c->Dim(
665
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
666
+ DimensionHandle in_cols_dim = c->Dim(
667
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
668
+ DimensionHandle depth_dim = c->Dim(
669
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
670
+
671
+ Padding padding;
672
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
673
+
674
+ // TODO(mrry,shlens): Raise an error if the stride would cause
675
+ // information in the input to be ignored. This will require a change
676
+ // in the kernel implementation.
677
+
678
+ DimensionHandle output_rows, output_cols;
679
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
680
+ c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
681
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
682
+ c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
683
+
684
+ ShapeHandle output_shape;
685
+ TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
686
+ {output_rows, output_cols}, depth_dim,
687
+ &output_shape, c));
688
+ c->set_output(0, output_shape);
689
+ return Status::OK();
690
+ }
691
+
692
+ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
693
+ ShapeHandle x;
694
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
695
+
696
+ bool is_training;
697
+ TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
698
+ int number_inputs = (is_training) ? 3 : 5;
699
+ string data_format;
700
+ TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
701
+ DimensionHandle channel_dim =
702
+ (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1);
703
+
704
+ // covers scale, offset, and if is_training is false, mean, variance
705
+ for (int i = 1; i < number_inputs; ++i) {
706
+ ShapeHandle vec;
707
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
708
+ TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
709
+ }
710
+
711
+ ShapeHandle y;
712
+ if (data_format == "NHWC") {
713
+ TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y));
714
+ } else {
715
+ TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y));
716
+ }
717
+ c->set_output(0, y);
718
+ ShapeHandle vector_shape = c->Vector(channel_dim);
719
+ c->set_output(1, vector_shape);
720
+ c->set_output(2, vector_shape);
721
+ c->set_output(3, vector_shape);
722
+ c->set_output(4, vector_shape);
723
+ return Status::OK();
724
+ }
725
+
726
+ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
727
+ ShapeHandle y_backprop;
728
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
729
+ ShapeHandle x;
730
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
731
+
732
+ bool is_training;
733
+ string data_format;
734
+ TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
735
+ TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
736
+ DimensionHandle channel_dim =
737
+ (data_format == "NHWC") ? c->Dim(y_backprop, 3) : c->Dim(y_backprop, 1);
738
+ if (data_format == "NHWC") {
739
+ TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim));
740
+ } else {
741
+ TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim));
742
+ }
743
+
744
+ // covers scale, mean (reserve_space_1), variance (reserve_space_2)
745
+ for (int i = 2; i < 5; ++i) {
746
+ ShapeHandle vec;
747
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
748
+ TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
749
+ }
750
+
751
+ ShapeHandle x_backprop;
752
+ if (data_format == "NHWC") {
753
+ TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop));
754
+ } else {
755
+ TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop));
756
+ }
757
+ c->set_output(0, x_backprop);
758
+ c->set_output(1, c->Vector(channel_dim));
759
+ c->set_output(2, c->Vector(channel_dim));
760
+ // Set the correct shapes for reserve_spaces
761
+ // so that gradients can be performed when
762
+ // the op is in a symbolic condition.
763
+ if (is_training) {
764
+ c->set_output(3, c->Vector(0));
765
+ c->set_output(4, c->Vector(0));
766
+ } else {
767
+ c->set_output(3, c->Vector(channel_dim));
768
+ c->set_output(4, c->Vector(channel_dim));
769
+ }
770
+ return Status::OK();
771
+ }
772
+
773
+ Status MaxPoolShape(shape_inference::InferenceContext* c) {
774
+ string data_format_str;
775
+ TensorFormat data_format;
776
+ Status s = c->GetAttr("data_format", &data_format_str);
777
+ if (s.ok()) {
778
+ FormatFromString(data_format_str, &data_format);
779
+ } else {
780
+ data_format = FORMAT_NHWC;
781
+ }
782
+
783
+ const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
784
+ ShapeHandle input_shape;
785
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
786
+
787
+ TF_RETURN_IF_ERROR(
788
+ CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
789
+
790
+ std::vector<int32> strides;
791
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
792
+ if (strides.size() != 4) {
793
+ return errors::InvalidArgument(
794
+ "MaxPool requires the stride attribute to contain 4 values, but got: ",
795
+ strides.size());
796
+ }
797
+
798
+ std::vector<int32> kernel_sizes;
799
+ TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
800
+ if (kernel_sizes.size() != 4) {
801
+ return errors::InvalidArgument(
802
+ "MaxPool requires the ksize attribute to contain 4 values, but got: ",
803
+ kernel_sizes.size());
804
+ }
805
+
806
+ int32 stride_depth = GetTensorDim(strides, data_format, 'C');
807
+ int32 stride_rows = GetTensorDim(strides, data_format, 'H');
808
+ int32 stride_cols = GetTensorDim(strides, data_format, 'W');
809
+ int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
810
+ int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
811
+ int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
812
+
813
+ constexpr int num_spatial_dims = 2;
814
+ DimensionHandle batch_size_dim = c->Dim(
815
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
816
+ DimensionHandle in_rows_dim = c->Dim(
817
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
818
+ DimensionHandle in_cols_dim = c->Dim(
819
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
820
+ DimensionHandle in_depth_dim = c->Dim(
821
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
822
+
823
+ Padding padding;
824
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
825
+
826
+ ShapeHandle output_shape;
827
+ DimensionHandle output_rows, output_cols, output_depth;
828
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
829
+ c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
830
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
831
+ c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
832
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
833
+ c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
834
+
835
+ TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
836
+ {output_rows, output_cols},
837
+ output_depth, &output_shape, c));
838
+
839
+ c->set_output(0, output_shape);
840
+ return Status::OK();
841
+ }
842
+
843
+ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
844
+ string data_format_str;
845
+ TensorFormat data_format;
846
+ Status s = c->GetAttr("data_format", &data_format_str);
847
+ if (s.ok()) {
848
+ FormatFromString(data_format_str, &data_format);
849
+ } else {
850
+ data_format = FORMAT_NHWC;
851
+ }
852
+
853
+ const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
854
+ ShapeHandle input_shape;
855
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
856
+
857
+ TF_RETURN_IF_ERROR(
858
+ CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
859
+
860
+ std::vector<int32> kernel_sizes;
861
+ std::vector<int32> strides;
862
+
863
+ if (c->num_inputs() + 2 == num_inputs) {
864
+ TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
865
+
866
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
867
+ } else {
868
+ // Verify shape of ksize and strides input.
869
+ ShapeHandle size;
870
+ DimensionHandle unused;
871
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
872
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
873
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
874
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
875
+
876
+ const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
877
+ if (kernel_sizes_tensor == nullptr) {
878
+ c->set_output(0, c->UnknownShape());
879
+ return Status::OK();
880
+ }
881
+ kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
882
+ auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
883
+ std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
884
+ kernel_sizes.begin());
885
+
886
+ const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
887
+ if (strides_tensor == nullptr) {
888
+ c->set_output(0, c->UnknownShape());
889
+ return Status::OK();
890
+ }
891
+ strides.resize(strides_tensor->shape().num_elements());
892
+ auto strides_vec = strides_tensor->flat<int32>();
893
+ std::copy_n(&strides_vec(0), strides.size(), strides.begin());
894
+ }
895
+
896
+ if (strides.size() != 4) {
897
+ return errors::InvalidArgument(
898
+ "MaxPool requires the stride attribute to contain 4 values, but "
899
+ "got: ",
900
+ strides.size());
901
+ }
902
+ if (kernel_sizes.size() != 4) {
903
+ return errors::InvalidArgument(
904
+ "MaxPool requires the ksize attribute to contain 4 values, but got: ",
905
+ kernel_sizes.size());
906
+ }
907
+
908
+ int32 stride_depth = GetTensorDim(strides, data_format, 'C');
909
+ int32 stride_rows = GetTensorDim(strides, data_format, 'H');
910
+ int32 stride_cols = GetTensorDim(strides, data_format, 'W');
911
+ int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
912
+ int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
913
+ int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
914
+
915
+ constexpr int num_spatial_dims = 2;
916
+ DimensionHandle batch_size_dim = c->Dim(
917
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
918
+ DimensionHandle in_rows_dim = c->Dim(
919
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
920
+ DimensionHandle in_cols_dim = c->Dim(
921
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
922
+ DimensionHandle in_depth_dim = c->Dim(
923
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
924
+
925
+ Padding padding;
926
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
927
+
928
+ ShapeHandle output_shape;
929
+ DimensionHandle output_rows, output_cols, output_depth;
930
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
931
+ c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
932
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
933
+ c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
934
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
935
+ c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
936
+
937
+ TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
938
+ {output_rows, output_cols},
939
+ output_depth, &output_shape, c));
940
+
941
+ c->set_output(0, output_shape);
942
+ return Status::OK();
943
+ }
944
+
945
+ Status Pool3DShape(shape_inference::InferenceContext* c) {
946
+ ShapeHandle input_shape;
947
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
948
+
949
+ string data_format;
950
+ Status s = c->GetAttr("data_format", &data_format);
951
+
952
+ std::vector<int32> strides;
953
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
954
+ if (strides.size() != 5) {
955
+ return errors::InvalidArgument(
956
+ "Pool3D ops require the stride attribute to contain 5 values, but "
957
+ "got: ",
958
+ strides.size());
959
+ }
960
+
961
+ std::vector<int32> kernel_sizes;
962
+ TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
963
+ if (kernel_sizes.size() != 5) {
964
+ return errors::InvalidArgument(
965
+ "Pool3D requires the ksize attribute to contain 5 values, but got: ",
966
+ kernel_sizes.size());
967
+ }
968
+
969
+ int32 stride_planes, stride_rows, stride_cols;
970
+ int32 kernel_planes, kernel_rows, kernel_cols;
971
+
972
+ if (s.ok() && data_format == "NCDHW") {
973
+ // Convert input_shape to NDHWC.
974
+ auto dim = [&](char dimension) {
975
+ return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
976
+ };
977
+ input_shape =
978
+ c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
979
+ stride_planes = strides[2];
980
+ stride_rows = strides[3];
981
+ stride_cols = strides[4];
982
+ kernel_planes = kernel_sizes[2];
983
+ kernel_rows = kernel_sizes[3];
984
+ kernel_cols = kernel_sizes[4];
985
+ } else {
986
+ stride_planes = strides[1];
987
+ stride_rows = strides[2];
988
+ stride_cols = strides[3];
989
+ kernel_planes = kernel_sizes[1];
990
+ kernel_rows = kernel_sizes[2];
991
+ kernel_cols = kernel_sizes[3];
992
+ }
993
+
994
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
995
+ DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
996
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
997
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
998
+ DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
999
+
1000
+ Padding padding;
1001
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1002
+
1003
+ // TODO(mrry,shlens): Raise an error if the stride would cause
1004
+ // information in the input to be ignored. This will require a change
1005
+ // in the kernel implementation.
1006
+ DimensionHandle output_planes, output_rows, output_cols;
1007
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1008
+ c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes));
1009
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1010
+ c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1011
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1012
+ c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1013
+
1014
+ ShapeHandle output_shape;
1015
+ if (data_format == "NCDHW") {
1016
+ output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
1017
+ output_planes, output_rows, output_cols});
1018
+ } else {
1019
+ output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
1020
+ output_cols, output_depth_dim});
1021
+ }
1022
+
1023
+ c->set_output(0, output_shape);
1024
+ return Status::OK();
1025
+ }
1026
+
1027
+ Status UnknownShape(shape_inference::InferenceContext* c) {
1028
+ for (int i = 0; i < c->num_outputs(); ++i) {
1029
+ c->set_output(i, c->UnknownShape());
1030
+ }
1031
+ return Status::OK();
1032
+ }
1033
+
1034
+ template <typename T>
1035
+ Status ReductionShapeHelper(const Tensor* reduction_indices_t,
1036
+ const int32 input_rank,
1037
+ std::set<int64>& true_indices) {
1038
+ auto reduction_indices = reduction_indices_t->flat<T>();
1039
+ for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
1040
+ const T reduction_index = reduction_indices(i);
1041
+ if (reduction_index < -input_rank || reduction_index >= input_rank) {
1042
+ return errors::InvalidArgument("Invalid reduction dimension ",
1043
+ reduction_index, " for input with ",
1044
+ input_rank, " dimensions.");
1045
+ }
1046
+
1047
+ auto wrapped_index = reduction_index;
1048
+ if (wrapped_index < 0) {
1049
+ wrapped_index += input_rank;
1050
+ }
1051
+
1052
+ true_indices.insert(wrapped_index);
1053
+ }
1054
+ return Status::OK();
1055
+ }
1056
+
1057
+ Status ReductionShape(InferenceContext* c) {
1058
+ ShapeHandle input = c->input(0);
1059
+
1060
+ ShapeHandle indices;
1061
+ // Older versions of TensorFlow accidentally allowed higher rank tensors like
1062
+ // [[1,2]] or [[1],[2]] to represent axis=[1,2].
1063
+ if (c->graph_def_version() < 21) {
1064
+ indices = c->input(1);
1065
+ } else {
1066
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
1067
+ }
1068
+
1069
+ bool keep_dims;
1070
+ TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
1071
+
1072
+ const Tensor* reduction_indices_t = c->input_tensor(1);
1073
+ if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
1074
+ // If we do not have the reduction values at runtime, or the
1075
+ // rank of the input, we don't know the output shape.
1076
+
1077
+ if (keep_dims && c->RankKnown(input)) {
1078
+ // output rank matches input input if <keep_dims>.
1079
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
1080
+ return Status::OK();
1081
+ } else {
1082
+ return shape_inference::UnknownShape(c);
1083
+ }
1084
+ }
1085
+
1086
+ const int32 input_rank = c->Rank(input);
1087
+ std::set<int64> true_indices;
1088
+ if (reduction_indices_t->dtype() == DataType::DT_INT32) {
1089
+ TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
1090
+ input_rank, true_indices));
1091
+ } else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
1092
+ TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t,
1093
+ input_rank, true_indices));
1094
+ } else {
1095
+ return errors::InvalidArgument(
1096
+ "reduction_indices can only be int32 or int64");
1097
+ }
1098
+
1099
+ std::vector<DimensionHandle> dims;
1100
+ for (int i = 0; i < input_rank; ++i) {
1101
+ if (true_indices.count(i) > 0) {
1102
+ if (keep_dims) {
1103
+ dims.emplace_back(c->MakeDim(1));
1104
+ }
1105
+ } else {
1106
+ dims.emplace_back(c->Dim(input, i));
1107
+ }
1108
+ }
1109
+
1110
+ c->set_output(0, c->MakeShape(dims));
1111
+ return Status::OK();
1112
+ }
1113
+
1114
+ Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
1115
+ int end_value_index, int dim_index) {
1116
+ ShapeHandle unused;
1117
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
1118
+ const Tensor* concat_dim_t = c->input_tensor(dim_index);
1119
+ if (concat_dim_t == nullptr) {
1120
+ // Return an unknown shape with same rank as inputs, or an unknown rank
1121
+ // if no input's rank is known.
1122
+
1123
+ // Find rank.
1124
+ int32 rank = InferenceContext::kUnknownRank;
1125
+ for (int i = start_value_index; i < end_value_index; ++i) {
1126
+ if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
1127
+ if (rank != InferenceContext::kUnknownRank) {
1128
+ break;
1129
+ }
1130
+ }
1131
+ if (rank == InferenceContext::kUnknownRank) {
1132
+ c->set_output(0, c->UnknownShape());
1133
+ return Status::OK();
1134
+ } else if (rank == 0) {
1135
+ return errors::InvalidArgument(
1136
+ "Can't concatenate scalars (use tf.stack instead)");
1137
+ } else {
1138
+ for (int i = start_value_index; i < end_value_index; ++i) {
1139
+ // Check that all the inputs are of the correct rank.
1140
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
1141
+ }
1142
+ }
1143
+ // Build result of <rank> different unknown dims.
1144
+ std::vector<DimensionHandle> dims;
1145
+ dims.reserve(rank);
1146
+ for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
1147
+ c->set_output(0, c->MakeShape(dims));
1148
+ return Status::OK();
1149
+ }
1150
+
1151
+ // Merge all the non-concat dims, and sum the concat dim to make an output
1152
+ // shape.
1153
+ const int32 concat_dim = concat_dim_t->scalar<int32>()();
1154
+
1155
+ // Minimum required number of dimensions.
1156
+ const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
1157
+
1158
+ ShapeHandle output_before;
1159
+ ShapeHandle output_after;
1160
+
1161
+ ShapeHandle input = c->input(end_value_index - 1);
1162
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1163
+ TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
1164
+ DimensionHandle output_middle = c->Dim(input, concat_dim);
1165
+ if (concat_dim == -1) {
1166
+ output_after = c->Scalar(); // no dimensions.
1167
+ } else {
1168
+ TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
1169
+ }
1170
+
1171
+ for (int i = end_value_index - 2; i >= start_value_index; --i) {
1172
+ ShapeHandle before;
1173
+ ShapeHandle after;
1174
+ input = c->input(i);
1175
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1176
+ TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
1177
+ DimensionHandle middle = c->Dim(input, concat_dim);
1178
+ if (concat_dim == -1) {
1179
+ after = c->Scalar();
1180
+ } else {
1181
+ TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
1182
+ }
1183
+
1184
+ TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
1185
+ TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
1186
+ TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
1187
+ }
1188
+
1189
+ ShapeHandle s;
1190
+ TF_RETURN_IF_ERROR(
1191
+ c->Concatenate(output_before, c->Vector(output_middle), &s));
1192
+ TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
1193
+ c->set_output(0, s);
1194
+ return Status::OK();
1195
+ }
1196
+
1197
+ Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
1198
+ return ConcatShapeHelper(c, 1 /* start_value_index */,
1199
+ 1 + num_inputs_to_concat /* end_value_index */,
1200
+ 0 /* dim_index */);
1201
+ }
1202
+
1203
+ Status ConcatV2Shape(InferenceContext* c) {
1204
+ return ConcatShapeHelper(c, 0 /* start_value_index */,
1205
+ c->num_inputs() - 1 /* end_value_index */,
1206
+ c->num_inputs() - 1 /* dim_index */);
1207
+ }
1208
+
1209
+ Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
1210
+ ShapeHandle shape_x = c->input(0);
1211
+ ShapeHandle shape_y = c->input(1);
1212
+ if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
1213
+ c->set_output(0, c->UnknownShape());
1214
+ return Status::OK();
1215
+ }
1216
+ const int32 rank_x = c->Rank(shape_x);
1217
+ const int32 rank_y = c->Rank(shape_y);
1218
+ const int32 rank_out = std::max(rank_x, rank_y);
1219
+
1220
+ // To compute the broadcast dimensions, we zip together shape_x and shape_y
1221
+ // and
1222
+ // pad with 1 to make them the same length.
1223
+ std::vector<DimensionHandle> dims;
1224
+ DimensionHandle dim_one;
1225
+ if (rank_x != rank_y) dim_one = c->MakeDim(1);
1226
+ for (int i = 0; i < rank_out; ++i) {
1227
+ const auto dim_x = i < (rank_out - rank_x)
1228
+ ? dim_one
1229
+ : c->Dim(shape_x, i - (rank_out - rank_x));
1230
+ const bool dim_y_is_one = (i < (rank_out - rank_y));
1231
+ const auto dim_y =
1232
+ dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y));
1233
+ if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
1234
+ // One or both dimensions is unknown.
1235
+ //
1236
+ // - If either dimension is greater than 1, we assume that the program is
1237
+ // correct, and the other dimension will be broadcast to match it.
1238
+ // TODO(cwhipkey): For shape inference, if we eliminate the shape checks
1239
+ // in C++ op code, we must still assert that the unknown dim is either 1
1240
+ // or the same as the known dim.
1241
+ // - If either dimension is 1, the other dimension is the output.
1242
+ if (c->Value(dim_x) > 1) {
1243
+ dims.push_back(dim_x);
1244
+ } else if (c->Value(dim_y) > 1) {
1245
+ dims.push_back(dim_y);
1246
+ } else if (c->Value(dim_x) == 1) {
1247
+ dims.push_back(dim_y);
1248
+ } else if (c->Value(dim_y) == 1) {
1249
+ dims.push_back(dim_x);
1250
+ } else if (dim_y.SameHandle(dim_x)) {
1251
+ dims.push_back(dim_x);
1252
+ } else {
1253
+ dims.push_back(c->UnknownDim());
1254
+ }
1255
+ } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
1256
+ if (c->Value(dim_x) == 1 && !dim_y_is_one) {
1257
+ // We will broadcast dim_x to dim_y.
1258
+ dims.push_back(dim_y);
1259
+ } else {
1260
+ DCHECK_EQ(c->Value(dim_y), 1);
1261
+ // We will broadcast dim_y to dim_x.
1262
+ dims.push_back(dim_x);
1263
+ }
1264
+ } else {
1265
+ DimensionHandle dim;
1266
+ TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
1267
+ dims.push_back(dim);
1268
+ }
1269
+ }
1270
+
1271
+ c->set_output(0, c->MakeShape(dims));
1272
+ return Status::OK();
1273
+ }
1274
+
1275
+ Status RandomShape(shape_inference::InferenceContext* c) {
1276
+ shape_inference::ShapeHandle out;
1277
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1278
+ c->set_output(0, out);
1279
+ return Status::OK();
1280
+ }
1281
+
1282
+ Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
1283
+ ShapeHandle values_shape, ShapeHandle shape_shape) {
1284
+ // Validate ranks.
1285
+ ShapeHandle unused_shape;
1286
+ TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
1287
+ TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
1288
+ TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
1289
+
1290
+ // Number of elements in indices and values must match.
1291
+ DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
1292
+ if (c->ValueKnown(num_index_elements_dim)) {
1293
+ DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
1294
+ if (c->ValueKnown(num_values_elements_dim)) {
1295
+ int64 num_index_elements = c->Value(num_index_elements_dim);
1296
+ int64 num_values_elements = c->Value(num_values_elements_dim);
1297
+ if (num_index_elements != num_values_elements) {
1298
+ return errors::InvalidArgument("Number of elements in index (",
1299
+ num_index_elements, ") and values (",
1300
+ num_values_elements, ") do not match.");
1301
+ }
1302
+ }
1303
+ }
1304
+
1305
+ // Rank embedded in indices must match shape.
1306
+ DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
1307
+ if (c->ValueKnown(index_rank_dim)) {
1308
+ DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
1309
+ if (c->ValueKnown(shape_rank_dim)) {
1310
+ int64 index_rank = c->Value(index_rank_dim);
1311
+ int32 shape_rank = c->Value(shape_rank_dim);
1312
+ if (index_rank != shape_rank) {
1313
+ return errors::InvalidArgument("Index rank (", index_rank,
1314
+ ") and shape rank (", shape_rank,
1315
+ ") do not match.");
1316
+ }
1317
+ }
1318
+ }
1319
+
1320
+ return Status::OK();
1321
+ }
1322
+
1323
+ Status ScatterNdUpdateShape(InferenceContext* c) {
1324
+ ShapeHandle input_shape = c->input(0);
1325
+ if (c->input_handle_shapes_and_types(0) != nullptr) {
1326
+ input_shape = (*c->input_handle_shapes_and_types(0))[0].shape;
1327
+ }
1328
+ ShapeHandle indices_shape;
1329
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
1330
+ ShapeHandle updates_shape;
1331
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
1332
+
1333
+ if (c->Value(c->NumElements(input_shape)) == 0 &&
1334
+ (c->Value(c->NumElements(indices_shape)) > 0 ||
1335
+ c->Value(c->NumElements(updates_shape)) > 0)) {
1336
+ return errors::InvalidArgument(
1337
+ "Indices and updates specified for empty output shape");
1338
+ }
1339
+
1340
+ if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
1341
+ const int64 num_outer_dims = c->Rank(indices_shape) - 1;
1342
+ const DimensionHandle index_size = c->Dim(indices_shape, -1);
1343
+
1344
+ // We can only do more validation if the last dimension of indices
1345
+ // is a known value.
1346
+ if (c->ValueKnown(index_size)) {
1347
+ const int64 ix = c->Value(index_size);
1348
+ ShapeHandle unused;
1349
+ ShapeHandle prefix_indices;
1350
+ TF_RETURN_IF_ERROR(
1351
+ c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices));
1352
+ ShapeHandle prefix_updates;
1353
+ TF_RETURN_IF_ERROR(
1354
+ c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates));
1355
+
1356
+ Status s = c->Merge(prefix_indices, prefix_updates, &unused);
1357
+ if (!s.ok()) {
1358
+ return errors::InvalidArgument(
1359
+ "The outer ", num_outer_dims, " dimensions of indices.shape=",
1360
+ c->DebugString(indices_shape), " must match the outer ",
1361
+ num_outer_dims, " dimensions of updates.shape=",
1362
+ c->DebugString(updates_shape), ": ", s.error_message());
1363
+ }
1364
+
1365
+ ShapeHandle input_suffix;
1366
+ TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix));
1367
+ ShapeHandle suffix_updates;
1368
+ TF_RETURN_IF_ERROR(
1369
+ c->Subshape(updates_shape, num_outer_dims, &suffix_updates));
1370
+ s = c->Merge(input_suffix, suffix_updates, &unused);
1371
+ if (!s.ok()) {
1372
+ return errors::InvalidArgument(
1373
+ "The inner ", c->Rank(input_shape) - ix,
1374
+ " dimensions of input.shape=", c->DebugString(input_shape),
1375
+ " must match the inner ", c->Rank(updates_shape) - num_outer_dims,
1376
+ " dimensions of updates.shape=", c->DebugString(updates_shape),
1377
+ ": ", s.error_message());
1378
+ }
1379
+ }
1380
+ }
1381
+
1382
+ if (c->input_handle_shapes_and_types(0) == nullptr) {
1383
+ c->set_output(0, input_shape);
1384
+ }
1385
+ return Status::OK();
1386
+ }
1387
+
1388
+ Status ExplicitShape(InferenceContext* c) {
1389
+ PartialTensorShape shape;
1390
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
1391
+ ShapeHandle output_shape;
1392
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape));
1393
+ c->set_output(0, output_shape);
1394
+ return Status::OK();
1395
+ }
1396
+
1397
+ } // namespace shape_inference
1398
+
1399
+ } // namespace tensorflow
common_shape_fns.h ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+ #ifndef THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
16
+ #define THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
17
+
18
+ #include <array>
19
+
20
+ #include "tensorflow/core/framework/shape_inference.h"
21
+ #include "tensorflow/core/util/padding.h"
22
+ #include "tensorflow/core/util/tensor_format.h"
23
+
24
+ namespace tensorflow {
25
+
26
+ // GetWindowedOutputSize(): Given an input tensor, kernel, stride and padding
27
+ // type, the function computes the output and padding dimensions.
28
+ //
29
+ // For example, ignoring batches or multiple features, a 1D convolution
30
+ // takes as input a 1D tensor of shape (H), and convolves it with a filter of
31
+ // shape (K).
32
+ //
33
+ // It also takes in a few additional parameters:
34
+ //
35
+ // Stride (S): the stride with which we apply the filters. This is the offset
36
+ // between locations where we apply the filters. A larger stride
37
+ // means that the output will be spatially smaller.
38
+ //
39
+ // Padding (P): the padding we apply to the input tensor along each
40
+ // dimension. This is usually used to make sure that the spatial dimensions
41
+ // do not shrink when we progress with convolutions. Two types of padding are
42
+ // often used:
43
+ // SAME: the pad value is computed so that the output will have size H/S.
44
+ // VALID: no padding is carried out.
45
+ // The padded area is zero-filled.
46
+ //
47
+ // The output dimensions for convolution and many other operations, when given
48
+ // all the parameters above, are as follows:
49
+ // - When Padding = SAME: the output size is (H'), where
50
+ // H' = ceil(float(H) / float(S))
51
+ // where ceil is the ceiling function. The number of padded cells
52
+ // is computed as:
53
+ // Pc = ((H' - 1) * S + K - H) / 2
54
+ // When the stride is 1, the expression simplifies to
55
+ // H' = H, Pc = (K-1)/2.
56
+ // This is where SAME comes from - the output has the same size as the input
57
+ // has.
58
+ //
59
+ // - When Padding = VALID: the output size is computed as
60
+ // H' = ceil(float(H - K + 1) / float(S))
61
+ // and the number of padded cells is always zero.
62
+ // When the stride is 1, the expression simplifies to
63
+ // H' = H-K+1.
64
+ //
65
+ // For convolution, mathematically, the output value at location (r')
66
+ // is the inner product of two vectors: the chunk of input at
67
+ // ((r'*S-Pr) : (r'*S-Pr+K)),
68
+ // and the filter.
69
+ //
70
+ // For 2D and 3D convolutions, the spatial dimensions are orthogonal, so the
71
+ // size and padding of each spatial dimension can be computed by calling
72
+ // GetWindowedOutputSize separately for each dimension.
73
+ //
74
+ Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
75
+ Padding padding_type, int64* output_size,
76
+ int64* padding_size);
77
+
78
+ // The V2 version computes the same outputs with arbitrary dilation_rate.
79
+ // The output dimensions are computed as follows:
80
+ // - When adding dilation_rate (D), we compute an effective filter size (K'):
81
+ // K' = (K - 1) * D + 1
82
+ // - When Padding = SAME: the output size is (H'), where
83
+ // H' = ceil(float(H) / float(S))
84
+ // where ceil is the ceiling function. The number of padded cells
85
+ // is computed as:
86
+ // Pc = ((H' - 1) * S + K' - H) / 2
87
+ // When the stride is 1, the expression simplifies to
88
+ // H' = H, Pc = (K'-1)/2.
89
+ // This is where SAME comes from - the output has the same size as the input
90
+ // has.
91
+ //
92
+ // - When Padding = VALID: the output size is computed as
93
+ // H' = ceil(float(H - K' + 1) / float(S))
94
+ // and the number of padded cells is always zero.
95
+ // When the stride is 1, the expression simplifies to
96
+ // H' = H-K'+1.
97
+ //
98
+ // TODO(b/67112639): Merge V2 versions and the original versions eventually.
99
+ Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
100
+ int64 dilation_rate, int64 stride,
101
+ Padding padding_type, int64* output_size,
102
+ int64* padding_size);
103
+
104
+ // Returns the same output dimensions as in GetWindowedOutputSize, but returns
105
+ // verbose padding dimensions (before/after). Any excess padding
106
+ // (caused by an odd padding size value) is added to the 'padding_after'
107
+ // dimension.
108
+ Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
109
+ int64 stride, Padding padding_type,
110
+ int64* output_size, int64* padding_before,
111
+ int64* padding_after);
112
+
113
+ // The V2 version computes the same outputs with arbitrary dilation_rate. For
114
+ // detailed equations, refer to the comments for GetWindowedOutputSizeV2().
115
+ Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
116
+ int64 dilation_rate, int64 stride,
117
+ Padding padding_type, int64* output_size,
118
+ int64* padding_before,
119
+ int64* padding_after);
120
+
121
+ // Given an input tensor, kernel, stride and padding type, populates the 3D size
122
+ // of the output tensor and padding to be applied to the input tensor at the
123
+ // lower end of every dimension. Use for 3D convolutions, where the input data
124
+ // is padded with zeros, as well as for 3D avg/max pooling, where the input data
125
+ // is padded with invalid values that are not considered for pooling.
126
+ Status Get3dOutputSize(const std::array<int64, 3>& input,
127
+ const std::array<int64, 3>& window,
128
+ const std::array<int64, 3>& strides,
129
+ Padding padding_type, std::array<int64, 3>* output_ptr,
130
+ std::array<int64, 3>* padding_ptr);
131
+
132
+ // The V2 version computes the same outputs with arbitrary dilation_rate. For
133
+ // detailed equations, refer to the comments for GetWindowedOutputSizeV2().
134
+ Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
135
+ const std::array<int64, 3>& window,
136
+ const std::array<int64, 3>& dilations,
137
+ const std::array<int64, 3>& strides,
138
+ Padding padding_type, std::array<int64, 3>* output_ptr,
139
+ std::array<int64, 3>* padding_ptr);
140
+
141
+ namespace shape_inference {
142
+
143
+ // Like GetWindowedOutputSize, but deals with DimensionHandles.
144
+ Status GetWindowedOutputSizeFromDims(InferenceContext* c,
145
+ DimensionHandle input_size,
146
+ DimensionOrConstant filter_size,
147
+ int64 stride, Padding padding_type,
148
+ DimensionHandle* output_size);
149
+
150
+ // The V2 version computes the same outputs with arbitrary dilation_rate. For
151
+ // detailed equations, refer to the comments for GetWindowedOutputSizeV2().
152
+ Status GetWindowedOutputSizeFromDimsV2(InferenceContext* c,
153
+ DimensionHandle input_size,
154
+ DimensionOrConstant filter_size,
155
+ int64 dilation_rate, int64 stride,
156
+ Padding padding_type,
157
+ DimensionHandle* output_size);
158
+
159
+ // Transfers shape of input(0) to output(0).
160
+ Status UnchangedShape(shape_inference::InferenceContext* c);
161
+
162
+ // Transfers shape of input(0) to output(0), after asserting its rank is <rank>.
163
+ inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c,
164
+ int32 rank) {
165
+ ShapeHandle out;
166
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out));
167
+ c->set_output(0, out);
168
+ return Status::OK();
169
+ }
170
+
171
+ // Transfers shape of input(0) to output(0), after asserting its rank >= <rank>.
172
+ inline Status UnchangedShapeWithRankAtLeast(
173
+ shape_inference::InferenceContext* c, int32 rank) {
174
+ ShapeHandle out;
175
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out));
176
+ c->set_output(0, out);
177
+ return Status::OK();
178
+ }
179
+
180
+ // Transfers shape of input(0) to output(0), after asserting its rank <= <rank>.
181
+ inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c,
182
+ int32 rank) {
183
+ ShapeHandle out;
184
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out));
185
+ c->set_output(0, out);
186
+ return Status::OK();
187
+ }
188
+
189
+ // Shape function for use with ops no outputs.
190
+ inline Status NoOutputs(shape_inference::InferenceContext* c) {
191
+ return Status::OK();
192
+ }
193
+
194
+ // Shape function for ops that output a single scalar value.
195
+ inline Status ScalarShape(shape_inference::InferenceContext* c) {
196
+ c->set_output(0, c->Scalar());
197
+ return Status::OK();
198
+ }
199
+
200
+ // Shape function for binary ops where both inputs and the output match.
201
+ inline Status MergeBothInputsShapeFn(InferenceContext* c) {
202
+ ShapeHandle out;
203
+ TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out));
204
+ c->set_output(0, out);
205
+ return Status::OK();
206
+ }
207
+
208
+ // Returns a new shape with the specified dims arranged in the specified
209
+ // format. The returned value is owned by this context.
210
+ // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth.
211
+ Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
212
+ const std::vector<DimensionOrConstant>& spatial,
213
+ DimensionOrConstant C, ShapeHandle* out,
214
+ shape_inference::InferenceContext* context);
215
+
216
+ // Shape function for MatMul-like operations.
217
+ Status MatMulShape(shape_inference::InferenceContext* c);
218
+
219
+ // Shape function for BiasAdd-like operations.
220
+ Status BiasAddShape(shape_inference::InferenceContext* c);
221
+
222
+ // Shape function for BiasAddGrad-like operations.
223
+ Status BiasAddGradShape(shape_inference::InferenceContext* c);
224
+
225
+ // Shape function for Conv2D-like operations.
226
+ Status Conv2DShape(shape_inference::InferenceContext* c);
227
+
228
+ // Shape function for Conv3D-like operations.
229
+ Status Conv3DShape(shape_inference::InferenceContext* c);
230
+
231
+ // Shape function for DepthwiseConv2D-like operations.
232
+ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
233
+
234
+ // Shape function for AvgPool-like operations.
235
+ Status AvgPoolShape(shape_inference::InferenceContext* c);
236
+
237
+ // Shape function for FusedBatchNorm and FusedBatchNormV2 operations.
238
+ Status FusedBatchNormShape(shape_inference::InferenceContext* c);
239
+
240
+ // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations.
241
+ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c);
242
+
243
+ // Shape function for MaxPool-like operations.
244
+ Status MaxPoolShape(shape_inference::InferenceContext* c);
245
+
246
+ // Shape function for MaxPoolV2-like operations.
247
+ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
248
+
249
+ // Shape function for 3D Pooling operations.
250
+ Status Pool3DShape(shape_inference::InferenceContext* c);
251
+
252
+ // Shape function for use with ops whose output shapes are unknown.
253
+ Status UnknownShape(shape_inference::InferenceContext* c);
254
+
255
+ // Shape function for reduction operations.
256
+ Status ReductionShape(shape_inference::InferenceContext* c);
257
+
258
+ // Shape function for concat operations.
259
+ // <num_inputs_to_concat> is the number of inputs to concatenate and are taken
260
+ // from inputs
261
+ // [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input.
262
+ Status ConcatShape(shape_inference::InferenceContext* c,
263
+ int num_inputs_to_concat);
264
+
265
+ // Shape function for concat operations.
266
+ Status ConcatV2Shape(shape_inference::InferenceContext* c);
267
+
268
+ // Shape function for binary operators that broadcast their inputs.
269
+ // Tested by ops/math_ops_test.cc.
270
+ Status BroadcastBinaryOpShapeFn(InferenceContext* c);
271
+
272
+ // Shape function for random operations.
273
+ Status RandomShape(shape_inference::InferenceContext* c);
274
+
275
+ // Validates the 3 component tensors of a sparse tensor have the proper
276
+ // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
277
+ Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
278
+ ShapeHandle values_shape, ShapeHandle shape_shape);
279
+
280
+ // Shape function for ScatterNd update/add/sub/... operations.
281
+ Status ScatterNdUpdateShape(InferenceContext* c);
282
+
283
+ // Shape function for ops with an explicit "shape" attribute.
284
+ Status ExplicitShape(InferenceContext* c);
285
+
286
+ } // namespace shape_inference
287
+
288
+ } // namespace tensorflow
289
+
290
+ #endif // THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
common_shape_fns_test.cc ADDED
@@ -0,0 +1,1131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+ #include "tensorflow/core/framework/common_shape_fns.h"
16
+
17
+ #include "tensorflow/core/framework/fake_input.h"
18
+ #include "tensorflow/core/framework/node_def_builder.h"
19
+ #include "tensorflow/core/framework/op_def_builder.h"
20
+ #include "tensorflow/core/framework/shape_inference_testutil.h"
21
+ #include "tensorflow/core/framework/tensor_testutil.h"
22
+ #include "tensorflow/core/lib/core/status_test_util.h"
23
+ #include "tensorflow/core/platform/test.h"
24
+
25
+ namespace tensorflow {
26
+ namespace shape_inference {
27
+
28
+ namespace {
29
+
30
+ PartialTensorShape S(std::initializer_list<int64> dims) {
31
+ return PartialTensorShape(dims);
32
+ }
33
+
34
+ PartialTensorShape Unknown() { return PartialTensorShape(); }
35
+
36
+ OpDef MakeOpDef(int num_inputs, int num_outputs) {
37
+ OpRegistrationData op_reg_data;
38
+ OpDefBuilder b("dummy");
39
+ for (int i = 0; i < num_inputs; ++i) {
40
+ b.Input(strings::StrCat("i", i, ": float"));
41
+ }
42
+ for (int i = 0; i < num_outputs; ++i) {
43
+ b.Output(strings::StrCat("o", i, ": float"));
44
+ }
45
+ CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
46
+ return op_reg_data.op_def;
47
+ }
48
+
49
+ } // namespace
50
+
51
+ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
52
+ OpRegistrationData op_reg_data;
53
+ TF_CHECK_OK(OpDefBuilder("Assert")
54
+ .Input("condition: bool")
55
+ .Input("data: float")
56
+ .Finalize(&op_reg_data));
57
+ OpDef op_def = op_reg_data.op_def;
58
+
59
+ NodeDef def;
60
+ TF_CHECK_OK(NodeDefBuilder("test", "Assert")
61
+ .Input("condition", 0, DT_BOOL)
62
+ .Input({{"data", 0, DT_FLOAT}})
63
+ .Finalize(&def));
64
+
65
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({}), S({10})}, {},
66
+ {}, {});
67
+ TF_EXPECT_OK(NoOutputs(&c));
68
+ EXPECT_EQ(0, c.num_outputs());
69
+ }
70
+
71
+ TEST(CommonShapeFnsTest, ScalarShapeTest) {
72
+ OpRegistrationData op_reg_data;
73
+ TF_CHECK_OK(OpDefBuilder("L2Loss")
74
+ .Input("t: float")
75
+ .Output("t: float")
76
+ .Finalize(&op_reg_data));
77
+ OpDef op_def = op_reg_data.op_def;
78
+
79
+ NodeDef def;
80
+ TF_CHECK_OK(
81
+ NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
82
+
83
+ {
84
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {});
85
+ TF_EXPECT_OK(ScalarShape(&c));
86
+ ShapeHandle output = c.output(0);
87
+ EXPECT_EQ(0, c.Rank(output));
88
+ }
89
+
90
+ {
91
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
92
+ {S({1, 23, 4, 4, 2})}, {}, {}, {});
93
+ TF_EXPECT_OK(ScalarShape(&c));
94
+ ShapeHandle output = c.output(0);
95
+ EXPECT_EQ(0, c.Rank(output));
96
+ }
97
+ }
98
+
99
+ TEST(CommonShapeFnsTest, MatMulShapeTest) {
100
+ OpRegistrationData op_reg_data;
101
+ TF_CHECK_OK(OpDefBuilder("MatMul")
102
+ .Input("a: float")
103
+ .Input("b: float")
104
+ .Output("c: float")
105
+ .Attr("transpose_a:bool=false")
106
+ .Attr("transpose_b:bool=false")
107
+ .Finalize(&op_reg_data));
108
+ OpDef op_def = op_reg_data.op_def;
109
+
110
+ NodeDef def;
111
+ TF_CHECK_OK(NodeDefBuilder("test", "MatMul")
112
+ .Input("a", 0, DT_FLOAT)
113
+ .Input("b", 0, DT_FLOAT)
114
+ .Attr("transpose_a", false)
115
+ .Attr("transpose_b", false)
116
+ .Finalize(&def));
117
+
118
+ {
119
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
120
+ {S({2, 3}), S({3, 4})}, {}, {}, {});
121
+ TF_EXPECT_OK(MatMulShape(&c));
122
+ ShapeHandle output = c.output(0);
123
+ EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
124
+ EXPECT_EQ(4, c.Value(c.Dim(output, 1)));
125
+ }
126
+
127
+ {
128
+ // Unknown inner dimension for one
129
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
130
+ {S({2, -1}), S({3, 4})}, {}, {}, {});
131
+ TF_EXPECT_OK(MatMulShape(&c));
132
+ ShapeHandle output = c.output(0);
133
+ EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
134
+ EXPECT_EQ(4, c.Value(c.Dim(output, 1)));
135
+ }
136
+
137
+ {
138
+ // Invalid rank.
139
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2}), S({3, 4})},
140
+ {}, {}, {});
141
+ auto s = MatMulShape(&c);
142
+ EXPECT_FALSE(s.ok());
143
+ EXPECT_TRUE(
144
+ StringPiece(s.ToString())
145
+ .contains("Invalid argument: Shape must be rank 2 but is rank 1"));
146
+ }
147
+
148
+ {
149
+ // Unknown outer dimension
150
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
151
+ {S({2, 3}), S({3, -1})}, {}, {}, {});
152
+ TF_EXPECT_OK(MatMulShape(&c));
153
+ ShapeHandle output = c.output(0);
154
+ EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
155
+ EXPECT_FALSE(c.ValueKnown(c.Dim(output, 1)));
156
+ }
157
+
158
+ {
159
+ // Inner shapes not compatible
160
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
161
+ {S({2, 5}), S({3, 4})}, {}, {}, {});
162
+ auto s = MatMulShape(&c);
163
+ EXPECT_FALSE(s.ok());
164
+ EXPECT_TRUE(
165
+ StringPiece(s.ToString())
166
+ .contains(
167
+ "Invalid argument: Dimensions must be equal, but are 5 and 3"));
168
+ }
169
+
170
+ {
171
+ // Inner shapes not compatible
172
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
173
+ {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {});
174
+ auto s = MatMulShape(&c);
175
+ EXPECT_FALSE(s.ok());
176
+ EXPECT_TRUE(
177
+ StringPiece(s.ToString())
178
+ .contains("Invalid argument: Shape must be rank 2 but is rank 3"));
179
+ }
180
+
181
+ {
182
+ // transpose_a
183
+ TF_CHECK_OK(NodeDefBuilder("test", "MatMul")
184
+ .Input("a", 0, DT_FLOAT)
185
+ .Input("b", 0, DT_FLOAT)
186
+ .Attr("transpose_a", true)
187
+ .Attr("transpose_b", false)
188
+ .Attr("type", DT_FLOAT)
189
+ .Finalize(&def));
190
+
191
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
192
+ {S({3, 2}), S({3, 4})}, {}, {}, {});
193
+ auto s = MatMulShape(&c);
194
+ ShapeHandle output = c.output(0);
195
+ EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
196
+ EXPECT_EQ(4, c.Value(c.Dim(output, 1)));
197
+ }
198
+
199
+ {
200
+ // transpose_b
201
+ TF_CHECK_OK(NodeDefBuilder("test", "MatMul")
202
+ .Input("a", 0, DT_FLOAT)
203
+ .Input("b", 0, DT_FLOAT)
204
+ .Attr("transpose_a", false)
205
+ .Attr("transpose_b", true)
206
+ .Attr("type", DT_FLOAT)
207
+ .Finalize(&def));
208
+
209
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
210
+ {S({2, 3}), S({4, 3})}, {}, {}, {});
211
+ auto s = MatMulShape(&c);
212
+ ShapeHandle output = c.output(0);
213
+ EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
214
+ EXPECT_EQ(4, c.Value(c.Dim(output, 1)));
215
+ }
216
+ }
217
+
218
+ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
219
+ OpRegistrationData op_reg_data;
220
+ TF_CHECK_OK(OpDefBuilder("BiasAdd")
221
+ .Input("a: float")
222
+ .Input("b: float")
223
+ .Output("c: float")
224
+ .Finalize(&op_reg_data));
225
+
226
+ OpDef op_def = op_reg_data.op_def;
227
+ NodeDef def;
228
+ TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
229
+ .Input("a", 0, DT_FLOAT)
230
+ .Input("b", 0, DT_FLOAT)
231
+ .Finalize(&def));
232
+
233
+ {
234
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
235
+ {S({2, 10}), S({10})}, {}, {}, {});
236
+ TF_EXPECT_OK(BiasAddShape(&c));
237
+ ShapeHandle output = c.output(0);
238
+ EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
239
+ EXPECT_EQ(10, c.Value(c.Dim(output, 1)));
240
+ }
241
+
242
+ {
243
+ // Unknown ranks.
244
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
245
+ {Unknown(), Unknown()}, {}, {}, {});
246
+ TF_EXPECT_OK(BiasAddShape(&c));
247
+ ShapeHandle output = c.output(0);
248
+ EXPECT_FALSE(c.RankKnown(output));
249
+ }
250
+
251
+ {
252
+ // Rank > 2
253
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
254
+ {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {});
255
+ TF_EXPECT_OK(BiasAddShape(&c));
256
+ ShapeHandle output = c.output(0);
257
+ EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
258
+ }
259
+
260
+ {
261
+ // NCHW format
262
+ TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
263
+ .Input("a", 0, DT_FLOAT)
264
+ .Input("b", 0, DT_FLOAT)
265
+ .Attr("data_format", "NCHW")
266
+ .Finalize(&def));
267
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
268
+ {S({2, 3, 4, 5}), S({3})}, {}, {}, {});
269
+ TF_EXPECT_OK(BiasAddShape(&c));
270
+ ShapeHandle output = c.output(0);
271
+ EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
272
+ }
273
+
274
+ {
275
+ // NCHW format with high input rank
276
+ TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
277
+ .Input("a", 0, DT_FLOAT)
278
+ .Input("b", 0, DT_FLOAT)
279
+ .Attr("data_format", "NCHW")
280
+ .Finalize(&def));
281
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
282
+ {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {});
283
+ TF_EXPECT_OK(BiasAddShape(&c));
284
+ ShapeHandle output = c.output(0);
285
+ EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
286
+ }
287
+
288
+ {
289
+ // NCHW format with input rank 3
290
+ TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
291
+ .Input("a", 0, DT_FLOAT)
292
+ .Input("b", 0, DT_FLOAT)
293
+ .Attr("data_format", "NCHW")
294
+ .Finalize(&def));
295
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
296
+ {S({10, 11, 12}), S({10})}, {}, {}, {});
297
+ TF_EXPECT_OK(BiasAddShape(&c));
298
+ ShapeHandle output = c.output(0);
299
+ EXPECT_EQ("[10,11,12]", c.DebugString(output));
300
+ }
301
+
302
+ {
303
+ // Input rank not high enough
304
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3}), S({3})}, {},
305
+ {}, {});
306
+ EXPECT_FALSE(BiasAddShape(&c).ok());
307
+ }
308
+
309
+ {
310
+ // NCHW rank not high enough
311
+ TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
312
+ .Input("a", 0, DT_FLOAT)
313
+ .Input("b", 0, DT_FLOAT)
314
+ .Attr("data_format", "NCHW")
315
+ .Finalize(&def));
316
+ // NCHW format
317
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3}), S({3})},
318
+ {}, {}, {});
319
+ EXPECT_FALSE(BiasAddShape(&c).ok());
320
+ }
321
+ }
322
+
323
+ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
324
+ OpRegistrationData op_reg_data;
325
+ TF_CHECK_OK(OpDefBuilder("BiasAddGrad")
326
+ .Input("a: float")
327
+ .Output("b: float")
328
+ .Finalize(&op_reg_data));
329
+
330
+ OpDef op_def = op_reg_data.op_def;
331
+ NodeDef def;
332
+ TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
333
+ .Input("a", 0, DT_FLOAT)
334
+ .Finalize(&def));
335
+
336
+ {
337
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 10})}, {}, {},
338
+ {});
339
+ TF_EXPECT_OK(BiasAddGradShape(&c));
340
+ ShapeHandle output = c.output(0);
341
+ EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
342
+ }
343
+
344
+ {
345
+ // Rank > 2
346
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({5, 7, 2, 10})},
347
+ {}, {}, {});
348
+ TF_EXPECT_OK(BiasAddGradShape(&c));
349
+ ShapeHandle output = c.output(0);
350
+ EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
351
+ }
352
+
353
+ {
354
+ // NCHW format
355
+ TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
356
+ .Input("a", 0, DT_FLOAT)
357
+ .Attr("data_format", "NCHW")
358
+ .Finalize(&def));
359
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3, 4, 5})},
360
+ {}, {}, {});
361
+ TF_EXPECT_OK(BiasAddGradShape(&c));
362
+ ShapeHandle output = c.output(0);
363
+ EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
364
+ }
365
+
366
+ {
367
+ // NCHW format with high input rank
368
+ TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
369
+ .Input("a", 0, DT_FLOAT)
370
+ .Attr("data_format", "NCHW")
371
+ .Finalize(&def));
372
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
373
+ {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {});
374
+ TF_EXPECT_OK(BiasAddGradShape(&c));
375
+ ShapeHandle output = c.output(0);
376
+ EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
377
+ }
378
+
379
+ {
380
+ // NCHW format with input rank 3
381
+ TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
382
+ .Input("a", 0, DT_FLOAT)
383
+ .Attr("data_format", "NCHW")
384
+ .Finalize(&def));
385
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({10, 11, 12})},
386
+ {}, {}, {});
387
+ TF_EXPECT_OK(BiasAddGradShape(&c));
388
+ ShapeHandle output = c.output(0);
389
+ EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
390
+ }
391
+
392
+ {
393
+ // Input rank not high enough
394
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {},
395
+ {});
396
+ EXPECT_FALSE(BiasAddGradShape(&c).ok());
397
+ }
398
+
399
+ {
400
+ // NCHW rank not high enough
401
+ TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
402
+ .Input("a", 0, DT_FLOAT)
403
+ .Attr("data_format", "NCHW")
404
+ .Finalize(&def));
405
+ // NCHW format
406
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3})}, {}, {},
407
+ {});
408
+ EXPECT_FALSE(BiasAddGradShape(&c).ok());
409
+ }
410
+ }
411
+
412
+ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
413
+ ShapeInferenceTestOp op("Conv2D");
414
+ auto set_op = [&op](const std::vector<int32>& strides, const string& padding,
415
+ const string& data_format, const string& filter_format) {
416
+ TF_CHECK_OK(NodeDefBuilder("test", "Conv2D")
417
+ .Input("input", 0, DT_FLOAT)
418
+ .Input("filter", 0, DT_FLOAT)
419
+ .Attr("strides", strides)
420
+ .Attr("padding", padding)
421
+ .Attr("data_format", data_format)
422
+ .Attr("filter_format", filter_format)
423
+ .Finalize(&op.node_def));
424
+ };
425
+
426
+ // Invalid rank for input
427
+ INFER_ERROR("must be rank 4", op, "[4,4];[2,1,1,1]");
428
+ // Invalid rank for filter
429
+ INFER_ERROR("must be rank 4", op, "[1,4,4,1];[2,1,1]");
430
+
431
+ // Invalid value for strides
432
+ set_op({{1, 1, 0, 1}}, "VALID", "NHWC", "HWIO");
433
+ INFER_ERROR("must be > 0", op, "[1,2,2,1];[1,1,1,1]");
434
+
435
+ // 1x1 filter
436
+ set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO");
437
+ INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
438
+
439
+ // 2x2 filter
440
+ set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO");
441
+ INFER_OK(op, "[1,2,2,1];[2,2,1,1]", "[d0_0,1,1,d1_3]");
442
+
443
+ // 3x3 input, 1x1 filter, 2x2 stride
444
+ set_op({{1, 2, 2, 1}}, "VALID", "NHWC", "HWIO");
445
+ INFER_OK(op, "[1,3,3,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
446
+
447
+ // 3x3 input, 1x1 filter, 2x1 stride
448
+ set_op({{1, 2, 1, 1}}, "VALID", "NHWC", "HWIO");
449
+ INFER_OK(op, "[1,3,3,1];[1,1,1,1]", "[d0_0,2,3,d1_3]");
450
+
451
+ // 4x4 input, 2x1 filter, 1x2 stride
452
+ set_op({{1, 1, 2, 1}}, "VALID", "NHWC", "HWIO");
453
+ INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]");
454
+
455
+ // Unknown dims in the critical fields lead to partial inference.
456
+ INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]");
457
+ INFER_OK(op, "[1,?,4,1];[2,1,1,1]", "[d0_0,?,2,d1_3]");
458
+ INFER_OK(op, "[1,4,?,1];[2,1,1,1]", "[d0_0,3,?,d1_3]");
459
+ INFER_OK(op, "[1,4,4,?];[2,1,1,1]", "[d0_0,3,2,d1_3]");
460
+ INFER_OK(op, "[1,4,4,1];[?,1,1,1]", "[d0_0,?,2,d1_3]");
461
+ INFER_OK(op, "[1,4,4,1];[2,?,1,1]", "[d0_0,3,?,d1_3]");
462
+
463
+ // input depths must match.
464
+ INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op,
465
+ "[1,2,2,10];[1,1,10000,20]");
466
+
467
+ // Tests for NCHW
468
+ // 1x1 filter
469
+ set_op({{1, 1, 1, 1}}, "VALID", "NCHW", "HWIO");
470
+ INFER_OK(op, "[1,1,2,2];[1,1,1,1]", "[d0_0,d1_3,2,2]");
471
+
472
+ // 2x2 filter
473
+ set_op({{1, 1, 1, 1}}, "VALID", "NCHW", "HWIO");
474
+ INFER_OK(op, "[1,1,2,2];[2,2,1,1]", "[d0_0,d1_3,1,1]");
475
+
476
+ // 3x3 input, 1x1 filter, 2x2 stride
477
+ set_op({{1, 1, 2, 2}}, "VALID", "NCHW", "HWIO");
478
+ INFER_OK(op, "[1,1,3,3];[1,1,1,1]", "[d0_0,d1_3,2,2]");
479
+
480
+ // 3x3 input, 1x1 filter, 2x1 stride
481
+ set_op({{1, 1, 2, 1}}, "VALID", "NCHW", "HWIO");
482
+ INFER_OK(op, "[1,1,3,3];[1,1,1,1]", "[d0_0,d1_3,2,3]");
483
+
484
+ // 4x4 input, 2x1 filter, 1x2 stride
485
+ set_op({{1, 1, 1, 2}}, "VALID", "NCHW", "HWIO");
486
+ INFER_OK(op, "[1,1,4,4];[2,1,1,1]", "[d0_0,d1_3,3,2]");
487
+
488
+ // Tests for NCHW_VECT_C
489
+ // 1x1 filter
490
+ set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
491
+ INFER_OK(op, "[1,1,2,2,4];[4,1,1,1,4]", "[d0_0,1,2,2,4]");
492
+
493
+ // 2x2 filter
494
+ set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
495
+ INFER_OK(op, "[1,1,2,2,4];[4,1,2,2,4]", "[d0_0,1,1,1,4]");
496
+
497
+ // 3x3 input, 1x1 filter, 2x2 stride
498
+ set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
499
+ INFER_OK(op, "[1,1,3,3,4];[8,1,1,1,4]", "[d0_0,2,2,2,4]");
500
+
501
+ // 3x3 input, 1x1 filter, 2x1 stride
502
+ set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
503
+ INFER_OK(op, "[1,1,3,3,4];[4,1,1,1,4]", "[d0_0,1,2,3,4]");
504
+
505
+ // 4x4 input, 2x1 filter, 1x2 stride
506
+ set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
507
+ INFER_OK(op, "[1,1,4,4,4];[4,1,2,1,4]", "[d0_0,1,3,2,4]");
508
+
509
+ // Some tests for "SAME" padding
510
+
511
+ // 4x4 input, 1x1 filter, 1x1 stride
512
+ set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO");
513
+ INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
514
+
515
+ // 3x3 input, 2x2 filter, 1x1 stride
516
+ set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO");
517
+ INFER_OK(op, "[1,3,3,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
518
+
519
+ // 4x4 input, 2x2 filter, 2x2 stride
520
+ set_op({{1, 2, 2, 1}}, "SAME", "NHWC", "HWIO");
521
+ INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,2,2,d1_3]");
522
+
523
+ // 4x4 input, 2x2 filter, 1x1 stride
524
+ set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO");
525
+ INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
526
+
527
+ // With stride 1x1 and SAME, unknown dims don't matter - filter dims except
528
+ // for output channels are ignored for output, so all inputs are carried
529
+ // through to output.
530
+ set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO");
531
+ INFER_OK(op, "[1,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
532
+ INFER_OK(op, "[1,?,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
533
+ INFER_OK(op, "[1,4,?,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
534
+ INFER_OK(op, "[1,4,4,?];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
535
+ INFER_OK(op, "[?,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
536
+
537
+ // With stride != 1, the input HW dims are divided to produce output dims.
538
+ set_op({{1, 2, 2, 1}}, "SAME", "NHWC", "HWIO");
539
+ INFER_OK(op, "[?,4,4,1];[?,?,?,?]", "[d0_0,2,2,d1_3]");
540
+ INFER_OK(op, "[1,?,4,1];[?,?,?,?]", "[d0_0,?,2,d1_3]");
541
+ INFER_OK(op, "[1,4,?,1];[?,?,?,?]", "[d0_0,2,?,d1_3]");
542
+ INFER_OK(op, "[1,4,4,?];[?,?,?,?]", "[d0_0,2,2,d1_3]");
543
+ }
544
+
545
+ TEST(CommonShapeFnsTest, Conv2DDilatedShapeTest) {
546
+ ShapeInferenceTestOp op("Conv2D");
547
+ auto set_op = [&op](const std::vector<int32>& dilations,
548
+ const std::vector<int32>& strides, const string& padding,
549
+ const string& data_format) {
550
+ TF_CHECK_OK(NodeDefBuilder("test", "Conv2D")
551
+ .Input("input", 0, DT_FLOAT)
552
+ .Input("filter", 0, DT_FLOAT)
553
+ .Attr("dilations", dilations)
554
+ .Attr("strides", strides)
555
+ .Attr("padding", padding)
556
+ .Attr("data_format", data_format)
557
+ .Finalize(&op.node_def));
558
+ };
559
+
560
+ // Invalid rank for dilation
561
+ set_op({{1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
562
+ INFER_ERROR("contain 4 values", op, "[1,2,2,1];[1,1,1,1]");
563
+
564
+ // Invalid value for dilation
565
+ set_op({{1, 0, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
566
+ INFER_ERROR("must be >= 1", op, "[1,2,2,1];[1,1,1,1]");
567
+
568
+ // Tests for NHWC
569
+ // 1x1 filter, 2x1 dilations, 1x1 strides
570
+ set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
571
+ INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
572
+
573
+ // 1x1 filter, 2x1 dilations, 2x1 strides
574
+ set_op({{1, 2, 1, 1}}, {{1, 2, 1, 1}}, "VALID", "NHWC");
575
+ INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,2,4,d1_3]");
576
+
577
+ // 1x1 filter, 2x1 dilations, 2x2 strides
578
+ set_op({{1, 2, 1, 1}}, {{1, 2, 2, 1}}, "VALID", "NHWC");
579
+ INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
580
+
581
+ // 3x3 filter, 2x1 dilations, 1x1 strides
582
+ set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
583
+ INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,1,3,d1_3]");
584
+
585
+ // 3x3 filter, 2x1 dilations, 2x1 strides
586
+ set_op({{1, 2, 1, 1}}, {{1, 2, 1, 1}}, "VALID", "NHWC");
587
+ INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,1,3,d1_3]");
588
+
589
+ // 3x3 filter, 1x2 dilations, 2x2 strides
590
+ set_op({{1, 1, 2, 1}}, {{1, 2, 2, 1}}, "VALID", "NHWC");
591
+ INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,2,1,d1_3]");
592
+
593
+ // Tests for NCHW
594
+ // 1x1 filter, 2x1 dilations, 1x1 strides
595
+ set_op({{1, 1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NCHW");
596
+ INFER_OK(op, "[1,1,2,2];[1,1,1,1]", "[d0_0,d1_3,2,2]");
597
+
598
+ // 1x1 filter, 2x1 dilations, 2x1 strides
599
+ set_op({{1, 1, 2, 1}}, {{1, 1, 2, 1}}, "VALID", "NCHW");
600
+ INFER_OK(op, "[1,1,4,4];[1,1,1,1]", "[d0_0,d1_3,2,4]");
601
+
602
+ // 1x1 filter, 2x1 dilations, 2x2 strides
603
+ set_op({{1, 1, 2, 1}}, {{1, 1, 2, 2}}, "VALID", "NCHW");
604
+ INFER_OK(op, "[1,1,4,4];[1,1,1,1]", "[d0_0,d1_3,2,2]");
605
+
606
+ // 3x3 filter, 2x1 dilations, 1x1 strides
607
+ set_op({{1, 1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NCHW");
608
+ INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,1,3]");
609
+
610
+ // 3x3 filter, 2x1 dilations, 2x1 strides
611
+ set_op({{1, 1, 2, 1}}, {{1, 1, 2, 1}}, "VALID", "NCHW");
612
+ INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,1,3]");
613
+
614
+ // 3x3 filter, 1x2 dilations, 2x2 strides
615
+ set_op({{1, 1, 1, 2}}, {{1, 1, 2, 2}}, "VALID", "NCHW");
616
+ INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,2,1]");
617
+
618
+ // Some tests for "SAME" padding
619
+
620
+ // 4x4 input, 1x1 filter, 2x1 dilations, 1x1 stride
621
+ set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC");
622
+ INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
623
+
624
+ // 3x3 input, 2x2 filter, 2x2 dilations, 1x1 stride
625
+ set_op({{1, 2, 2, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC");
626
+ INFER_OK(op, "[1,3,3,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
627
+
628
+ // 4x4 input, 2x2 filter, 1x2 dilations, 2x2 stride
629
+ set_op({{1, 1, 2, 1}}, {{1, 2, 2, 1}}, "SAME", "NHWC");
630
+ INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,2,2,d1_3]");
631
+
632
+ // 4x4 input, 2x2 filter, 2x2 dilations, 1x1 stride
633
+ set_op({{1, 2, 2, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC");
634
+ INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
635
+ }
636
+
637
+ TEST(CommonShapeFnsTest, Conv3DShapeTest) {
638
+ ShapeInferenceTestOp op("Conv3D");
639
+ auto set_op = [&op](const std::vector<int32>& strides,
640
+ const string& padding) {
641
+ TF_CHECK_OK(NodeDefBuilder("test", "Conv3D")
642
+ .Input("input", 0, DT_FLOAT)
643
+ .Input("filter", 0, DT_FLOAT)
644
+ .Attr("strides", strides)
645
+ .Attr("padding", padding)
646
+ .Finalize(&op.node_def));
647
+ };
648
+
649
+ // 1x1x1 filter
650
+ set_op({{1, 1, 1, 1, 1}}, "VALID");
651
+ INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]");
652
+
653
+ // Invalid rank for input
654
+ INFER_ERROR("must be rank 5", op, "[4,4];[2,1,1,1]");
655
+ // Invalid rank for filter
656
+ INFER_ERROR("must be rank 5", op, "[1,4,4,1];[2,1,1]");
657
+
658
+ // unknown dims in the critical fields give partial inference.
659
+ INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]");
660
+ INFER_OK(op, "[1,?,2,2,1];[1,1,1,1,1]", "[d0_0,?,2,2,d1_4]");
661
+ INFER_OK(op, "[1,2,?,2,1];[1,1,1,1,1]", "[d0_0,2,?,2,d1_4]");
662
+ INFER_OK(op, "[1,2,2,?,1];[1,1,1,1,1]", "[d0_0,2,2,?,d1_4]");
663
+ INFER_OK(op, "[1,2,2,2,1];[?,1,1,1,1]", "[d0_0,?,2,2,d1_4]");
664
+ INFER_OK(op, "[1,2,2,2,1];[1,?,1,1,1]", "[d0_0,2,?,2,d1_4]");
665
+ INFER_OK(op, "[1,2,2,2,1];[1,1,?,1,1]", "[d0_0,2,2,?,d1_4]");
666
+ INFER_OK(op, "[1,2,2,2,1];[1,1,1,?,1]", "[d0_0,2,2,2,d1_4]");
667
+ INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,?]", "[d0_0,2,2,2,d1_4]");
668
+
669
+ // input depths must match.
670
+ INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op,
671
+ "[1,2,2,2,10];[1,1,1,10000,20]");
672
+
673
+ // 2x2x2 filter
674
+ set_op({{1, 1, 1, 1, 1}}, "VALID");
675
+ INFER_OK(op, "[1,2,2,2,1];[2,2,2,1,1]", "[d0_0,1,1,1,d1_4]");
676
+
677
+ // 3x3 input, 1x1 filter, 2x2 stride
678
+ set_op({{1, 2, 2, 2, 1}}, "VALID");
679
+ INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]");
680
+
681
+ // 3x3 input, 1x1 filter, 2x1x1 stride
682
+ set_op({{1, 2, 1, 1, 1}}, "VALID");
683
+ INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,3,3,d1_4]");
684
+
685
+ // 4x4 input, 2x2 filter, 1x1 stride
686
+ set_op({{1, 1, 1, 1, 1}}, "SAME");
687
+ INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
688
+
689
+ // with SAME, filter doesn't matter except for last dim.
690
+ set_op({{1, 1, 1, 1, 1}}, "SAME");
691
+ INFER_OK(op, "[?,4,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
692
+ INFER_OK(op, "[1,?,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
693
+ INFER_OK(op, "[1,4,?,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
694
+ INFER_OK(op, "[1,4,4,?,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
695
+ INFER_OK(op, "[1,4,4,4,?];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
696
+ INFER_OK(op, "[1,4,4,4,1];[?,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
697
+ INFER_OK(op, "[1,4,4,4,1];[2,?,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
698
+ INFER_OK(op, "[1,4,4,4,1];[2,2,?,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
699
+ INFER_OK(op, "[1,4,4,4,1];[2,2,2,?,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
700
+ INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,?]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
701
+
702
+ // with SAME, and stride != 1, division happens to produce output.
703
+ set_op({{1, 2, 3, 4, 1}}, "SAME");
704
+ INFER_OK(op, "[1,4,9,4,1];[2,2,2,1,1]", "[d0_0,2,3,1,d1_4]");
705
+ INFER_OK(op, "[?,4,9,4,1];[2,2,2,1,1]", "[d0_0,2,3,1,d1_4]");
706
+ INFER_OK(op, "[1,?,9,4,1];[2,2,2,1,1]", "[d0_0,?,3,1,d1_4]");
707
+ INFER_OK(op, "[1,4,?,4,1];[2,2,2,1,1]", "[d0_0,2,?,1,d1_4]");
708
+ INFER_OK(op, "[1,4,9,?,1];[2,2,2,1,1]", "[d0_0,2,3,?,d1_4]");
709
+ INFER_OK(op, "[1,4,9,4,?];[2,2,2,1,1]", "[d0_0,2,3,1,d1_4]");
710
+ INFER_OK(op, "[1,4,9,4,1];[?,2,2,1,1]", "[d0_0,2,3,1,d1_4]");
711
+ INFER_OK(op, "[1,4,9,4,1];[2,?,2,1,1]", "[d0_0,2,3,1,d1_4]");
712
+ INFER_OK(op, "[1,4,9,4,1];[2,2,?,1,1]", "[d0_0,2,3,1,d1_4]");
713
+ INFER_OK(op, "[1,4,9,4,1];[2,2,2,?,1]", "[d0_0,2,3,1,d1_4]");
714
+ INFER_OK(op, "[1,4,9,4,1];[2,2,2,1,?]", "[d0_0,2,3,1,d1_4]");
715
+ }
716
+
717
+ TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) {
718
+ ShapeInferenceTestOp op("DepthwiseConv2dNative");
719
+ std::vector<int32> strides = {{1, 1, 1, 1}};
720
+ TF_CHECK_OK(NodeDefBuilder("test", "DepthwiseConv2dNative")
721
+ .Input("input", 0, DT_FLOAT)
722
+ .Input("filter", 0, DT_FLOAT)
723
+ .Attr("strides", strides)
724
+ .Attr("padding", "VALID")
725
+ .Attr("data_format", "NHWC")
726
+ .Finalize(&op.node_def));
727
+
728
+ // Most of DepthwiseConv2D is implicitly tested by Conv2D, so
729
+ // we test only the very-specific differences here.
730
+
731
+ // 1x1 filter, depth multiplication
732
+ INFER_OK(op, "[1,2,2,3];[1,1,3,4]", "[d0_0,2,2,12]");
733
+
734
+ // Input depths not compatible
735
+ INFER_ERROR("Dimensions must be equal, but are 3 and 12", op,
736
+ "[1,2,2,3];[1,1,12,4]");
737
+
738
+ // No unknown dims in the critical fields.
739
+ INFER_OK(op, "[1,2,2,3];[1,1,3,4]", "[d0_0,2,2,12]");
740
+ INFER_OK(op, "[1,?,2,3];[1,1,3,4]", "[d0_0,?,2,12]");
741
+ INFER_OK(op, "[1,2,?,3];[1,1,3,4]", "[d0_0,2,?,12]");
742
+ INFER_OK(op, "[1,2,2,3];[?,1,3,4]", "[d0_0,?,2,12]");
743
+ INFER_OK(op, "[1,2,2,3];[1,?,3,4]", "[d0_0,2,?,12]");
744
+ INFER_OK(op, "[1,2,2,3];[1,1,?,4]", "[d0_0,2,2,12]");
745
+ INFER_OK(op, "[1,2,2,?];[1,1,?,4]", "[d0_0,2,2,?]");
746
+ INFER_OK(op, "[1,2,2,3];[1,1,3,?]", "[d0_0,2,2,?]");
747
+
748
+ // Test for NCHW format.
749
+ TF_CHECK_OK(NodeDefBuilder("test", "DepthwiseConv2dNative")
750
+ .Input("input", 0, DT_FLOAT)
751
+ .Input("filter", 0, DT_FLOAT)
752
+ .Attr("strides", strides)
753
+ .Attr("padding", "VALID")
754
+ .Attr("data_format", "NCHW")
755
+ .Finalize(&op.node_def));
756
+
757
+ // 1x1 filter, depth multiplication
758
+ INFER_OK(op, "[1,3,2,2];[1,1,3,4]", "[d0_0,12,2,2]");
759
+ }
760
+
761
+ TEST(CommonShapeFnsTest, AvgPool2DShapeTest) {
762
+ ShapeInferenceTestOp op("AvgPool");
763
+ auto set_op = [&op](const std::vector<int32>& strides,
764
+ const std::vector<int32>& ksizes, const string& padding,
765
+ const string& data_format) {
766
+ TF_CHECK_OK(NodeDefBuilder("test", "AvgPool")
767
+ .Input("input", 0, DT_FLOAT)
768
+ .Attr("strides", strides)
769
+ .Attr("ksize", ksizes)
770
+ .Attr("padding", padding)
771
+ .Attr("data_format", data_format)
772
+ .Finalize(&op.node_def));
773
+ };
774
+
775
+ // Most of the functionality is tested by conv-like shapes,
776
+ // so we check the very-specific avgpooling features here.
777
+
778
+ // 1x1 filter, 1x1 stride
779
+ set_op({1, 1, 1, 1}, {1, 1, 1, 1}, "VALID", "NHWC");
780
+ INFER_OK(op, "[1,2,2,1]", "[d0_0,2,2,d0_3]");
781
+
782
+ // 4x4 input, 2x1 ksize, 1x2 stride
783
+ set_op({1, 1, 2, 1}, {1, 2, 1, 1}, "VALID", "NHWC");
784
+ INFER_OK(op, "[1,4,4,1]", "[d0_0,3,2,d0_3]");
785
+
786
+ // 4x4 input, 2x1 ksize, 1x2 stride
787
+ // unknown dims in the critical fields lead to partial inference.
788
+ // Assumes NHWC format.
789
+ INFER_OK(op, "[1,?,4,1]", "[d0_0,?,2,d0_3]");
790
+ INFER_OK(op, "[1,4,?,1]", "[d0_0,3,?,d0_3]");
791
+
792
+ // 4x4 input, 2x1 ksize, 1x2 stride, NCHW format
793
+ set_op({{1, 1, 1, 2}}, {1, 1, 2, 1}, "VALID", "NCHW");
794
+ INFER_OK(op, "[1,1,4,4]", "[d0_0,d0_1,3,2]");
795
+
796
+ // 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C test
797
+ set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "VALID", "NCHW_VECT_C");
798
+ INFER_OK(op, "[2,3,5,7,4]", "[d0_0,d0_1,4,6,4]");
799
+ INFER_OK(op, "[5,7,?,?,4]", "[d0_0,d0_1,?,?,4]");
800
+ INFER_OK(op, "[?,?,?,?,4]", "[d0_0,d0_1,?,?,4]");
801
+ INFER_ERROR("Dimension must be 4 but is 3", op, "[2,5,7,11,3]");
802
+
803
+ // Invalid rank for input
804
+ INFER_ERROR("Shape must be rank", op, "[4,4]");
805
+ }
806
+
807
+ TEST(CommonShapeFnsTest, MaxPool2DShapeTest) {
808
+ ShapeInferenceTestOp op("MaxPool");
809
+ auto set_op = [&op](const std::vector<int32>& strides,
810
+ const std::vector<int32>& ksizes, const string& padding,
811
+ const string& data_format) {
812
+ TF_CHECK_OK(NodeDefBuilder("test", "MaxPool")
813
+ .Input("input", 0, DT_FLOAT)
814
+ .Attr("strides", strides)
815
+ .Attr("ksize", ksizes)
816
+ .Attr("padding", padding)
817
+ .Attr("data_format", data_format)
818
+ .Finalize(&op.node_def));
819
+ };
820
+
821
+ // Most of the functionality is tested by conv-like shapes,
822
+ // so we check the very-specific maxpooling features here,
823
+ // namely depthwise kernel and striding.
824
+
825
+ // all 1 strides, depth 2 filter
826
+ set_op({1, 1, 1, 1}, {1, 1, 1, 2}, "VALID", "NHWC");
827
+ INFER_OK(op, "[1,2,2,2]", "[d0_0,2,2,1]");
828
+
829
+ // depth 3 stride, 1x1x1 filter, NCHW
830
+ set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW");
831
+ INFER_OK(op, "[1,7,5,5]", "[d0_0,3,5,5]");
832
+
833
+ // 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C tests
834
+ set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "SAME", "NCHW_VECT_C");
835
+ INFER_OK(op, "[2,3,5,7,4]", "[d0_0,d0_1,d0_2,d0_3,4]");
836
+ INFER_OK(op, "[5,7,?,?,4]", "[d0_0,d0_1,d0_2,d0_3,4]");
837
+ INFER_OK(op, "[?,?,?,?,4]", "[d0_0,d0_1,d0_2,d0_3,4]");
838
+ INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8]");
839
+ }
840
+
841
+ TEST(CommonShapeFnsTest, MaxPoolV22DShapeTest) {
842
+ ShapeInferenceTestOp op("MaxPoolV2");
843
+ Tensor ksizes_tensor, strides_tensor;
844
+ auto set_op = [&op, &ksizes_tensor, &strides_tensor](
845
+ const std::vector<int32>& strides,
846
+ const std::vector<int32>& ksizes, const string& padding,
847
+ const string& data_format) {
848
+ TF_CHECK_OK(NodeDefBuilder("test", "MaxPoolV2")
849
+ .Input("input", 0, DT_FLOAT)
850
+ .Input("ksize", 1, DT_INT32)
851
+ .Input("strides", 2, DT_INT32)
852
+ .Attr("padding", padding)
853
+ .Attr("data_format", data_format)
854
+ .Finalize(&op.node_def));
855
+ ksizes_tensor = test::AsTensor<int32>(ksizes);
856
+ op.input_tensors.resize(3);
857
+ op.input_tensors[0] = nullptr;
858
+ op.input_tensors[1] = &ksizes_tensor;
859
+ strides_tensor = test::AsTensor<int32>(strides);
860
+ op.input_tensors[2] = &strides_tensor;
861
+ };
862
+
863
+ // Most of the functionality is tested by conv-like shapes,
864
+ // so we check the very-specific maxpooling features here,
865
+ // namely depthwise kernel and striding.
866
+
867
+ // all 1 strides, depth 2 filter
868
+ set_op({1, 1, 1, 1}, {1, 1, 1, 2}, "VALID", "NHWC");
869
+ INFER_OK(op, "[1,2,2,2];[4];[4]", "[d0_0,2,2,1]");
870
+
871
+ // depth 3 stride, 1x1x1 filter, NCHW
872
+ set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW");
873
+ INFER_OK(op, "[1,7,5,5];[4];[4]", "[d0_0,3,5,5]");
874
+
875
+ // 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C tests
876
+ set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "SAME", "NCHW_VECT_C");
877
+ INFER_OK(op, "[2,3,5,7,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]");
878
+ INFER_OK(op, "[5,7,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]");
879
+ INFER_OK(op, "[?,?,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]");
880
+ INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8];[4];[4]");
881
+ }
882
+
883
+ TEST(CommonShapeFnsTest, Pool3DShapeTest) {
884
+ ShapeInferenceTestOp op("MaxPool3D");
885
+ auto set_op = [&op](const std::vector<int32>& strides,
886
+ const std::vector<int32>& ksizes, const string& padding) {
887
+ TF_CHECK_OK(NodeDefBuilder("test", "MaxPool3D")
888
+ .Input("input", 0, DT_FLOAT)
889
+ .Attr("strides", strides)
890
+ .Attr("ksize", ksizes)
891
+ .Attr("padding", padding)
892
+ .Finalize(&op.node_def));
893
+ };
894
+
895
+ // Most of the functionality is tested by conv-like shapes,
896
+ // so we check that we handle the extra dimension properly.
897
+
898
+ // 2x3x4 stride, 1x1x1 filter.
899
+ set_op({1, 2, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID");
900
+ INFER_OK(op, "[1,24,24,24,1]", "[d0_0,12,8,6,d0_4]");
901
+
902
+ // Test partially known dimensions
903
+ set_op({1, 1, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID");
904
+ INFER_OK(op, "[1,?,24,24,1]", "[d0_0,?,8,6,d0_4]");
905
+ }
906
+
907
+ TEST(CommonShapeFnsTest, UnknownShapeTest) {
908
+ {
909
+ // Single output
910
+ ShapeInferenceTestOp op("QueueDequeue");
911
+ TF_CHECK_OK(NodeDefBuilder("test", "QueueDequeue")
912
+ .Input("handle", 0, DT_STRING_REF)
913
+ .Attr("component_types", {DT_FLOAT})
914
+ .Finalize(&op.node_def));
915
+ INFER_OK(op, "[1]", "?");
916
+ }
917
+
918
+ {
919
+ // Multiple outputs
920
+ ShapeInferenceTestOp op("QueueDequeue");
921
+ TF_CHECK_OK(NodeDefBuilder("test", "QueueDequeue")
922
+ .Input("handle", 0, DT_STRING_REF)
923
+ .Attr("component_types", {DT_FLOAT, DT_FLOAT, DT_STRING})
924
+ .Finalize(&op.node_def));
925
+ INFER_OK(op, "[1]", "?;?;?");
926
+ }
927
+ }
928
+
929
+ TEST(CommonShapeFnsTest, Reduce_ShapeFn) {
930
+ ShapeInferenceTestOp op("Sum");
931
+ op.input_tensors.resize(2);
932
+
933
+ TF_ASSERT_OK(NodeDefBuilder("test", "Sum")
934
+ .Input("input", 0, DT_FLOAT)
935
+ .Input("reduction_indices", 1, DT_INT32)
936
+ .Attr("keep_dims", false)
937
+ .Finalize(&op.node_def));
938
+
939
+ // Reduction indices not available, so output is unknown.
940
+ INFER_OK(op, "[2,4,5];[2]", "?");
941
+ INFER_OK(op, "?;[2]", "?");
942
+
943
+ Tensor indices = test::AsTensor<int32>({1, 2});
944
+ op.input_tensors[1] = &indices;
945
+
946
+ // Reduction indices available
947
+ INFER_OK(op, "[2,4,5];[2]", "[d0_0]");
948
+
949
+ // Wrapped indices
950
+ indices = test::AsTensor<int32>({-1, -2});
951
+ op.input_tensors[1] = &indices;
952
+ INFER_OK(op, "[2,4,5];[2]", "[d0_0]");
953
+
954
+ // Scalar
955
+ indices = test::AsScalar<int32>(0);
956
+ op.input_tensors[1] = &indices;
957
+ INFER_OK(op, "[2,4,5];[]", "[d0_1,d0_2]");
958
+
959
+ indices = test::AsScalar<int32>(-4);
960
+ op.input_tensors[1] = &indices;
961
+ INFER_ERROR("Invalid reduction dimension", op, "[2,4,5];[]");
962
+
963
+ // Empty reduction indices
964
+ indices = test::AsTensor<int32>({});
965
+ op.input_tensors[1] = &indices;
966
+ INFER_OK(op, "[2,4,5];[0]", "[d0_0,d0_1,d0_2]");
967
+
968
+ // Keep dims = true
969
+ TF_ASSERT_OK(NodeDefBuilder("test", "Sum")
970
+ .Input("input", 0, DT_FLOAT)
971
+ .Input("reduction_indices", 1, DT_INT32)
972
+ .Attr("keep_dims", true)
973
+ .Finalize(&op.node_def));
974
+ indices = test::AsTensor<int32>({-1, -2});
975
+ op.input_tensors[1] = &indices;
976
+ INFER_OK(op, "[2,4,5];[2]", "[d0_0, 1, 1]");
977
+
978
+ // input rank is known, but reduction indices are not (with keep_dim=true).
979
+ // The output rank matches input rank (because of keep_dims=true).
980
+ op.input_tensors[1] = nullptr;
981
+ INFER_OK(op, "[?,?,?];?", "[?,?,?]");
982
+ INFER_OK(op, "[?,?,?];[2]", "[?,?,?]");
983
+
984
+ // Reduction indices with too many dimensions.
985
+ INFER_ERROR("must be at most rank 1 but is rank 2", op, "[?,?,?];[?,?]");
986
+ // With older graph-def version, this is allowed.
987
+ op.graph_def_version = 20;
988
+ INFER_OK(op, "[?,?,?];[?,?]", "[?,?,?]");
989
+ // And when the tensor is specified, it's still allowed.
990
+ op.input_tensors[1] = &indices;
991
+ indices = test::AsTensor<int32>({-1, -2}, TensorShape({2, 1}));
992
+ INFER_OK(op, "[2,4,5];[2,1]", "[d0_0, 1, 1]");
993
+ indices = test::AsTensor<int32>({-1, -2}, TensorShape({1, 2}));
994
+ INFER_OK(op, "[2,4,5];[1,2]", "[d0_0, 1, 1]");
995
+ }
996
+
997
+ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
998
+ NodeDef def;
999
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
1000
+ {Unknown(), Unknown(), Unknown()}, {}, {}, {});
1001
+ EXPECT_EQ(3, c.num_inputs());
1002
+ EXPECT_EQ(1, c.num_outputs());
1003
+
1004
+ auto indices = c.input(0);
1005
+ auto values = c.input(1);
1006
+ auto shape = c.input(2);
1007
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
1008
+ }
1009
+
1010
+ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
1011
+ NodeDef def;
1012
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
1013
+ {S({-1, -1}), S({-1}), S({-1})}, {}, {}, {});
1014
+ EXPECT_EQ(3, c.num_inputs());
1015
+ EXPECT_EQ(1, c.num_outputs());
1016
+
1017
+ auto indices = c.input(0);
1018
+ auto values = c.input(1);
1019
+ auto shape = c.input(2);
1020
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
1021
+ }
1022
+
1023
+ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
1024
+ NodeDef def;
1025
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
1026
+ {S({-1}), S({-1}), S({-1})}, {}, {}, {});
1027
+ EXPECT_EQ(3, c.num_inputs());
1028
+ EXPECT_EQ(1, c.num_outputs());
1029
+
1030
+ auto indices = c.input(0);
1031
+ auto values = c.input(1);
1032
+ auto shape = c.input(2);
1033
+ EXPECT_EQ(error::INVALID_ARGUMENT,
1034
+ ValidateSparseTensor(&c, indices, values, shape).code());
1035
+ }
1036
+
1037
+ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
1038
+ NodeDef def;
1039
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
1040
+ {S({5, 3}), S({4}), S({3})}, {}, {}, {});
1041
+ EXPECT_EQ(3, c.num_inputs());
1042
+ EXPECT_EQ(1, c.num_outputs());
1043
+
1044
+ auto indices = c.input(0);
1045
+ auto values = c.input(1);
1046
+ auto shape = c.input(2);
1047
+ EXPECT_EQ(error::INVALID_ARGUMENT,
1048
+ ValidateSparseTensor(&c, indices, values, shape).code());
1049
+ }
1050
+
1051
+ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
1052
+ NodeDef def;
1053
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
1054
+ {S({5, 3}), S({5}), S({4})}, {}, {}, {});
1055
+ EXPECT_EQ(3, c.num_inputs());
1056
+ EXPECT_EQ(1, c.num_outputs());
1057
+
1058
+ auto indices = c.input(0);
1059
+ auto values = c.input(1);
1060
+ auto shape = c.input(2);
1061
+ EXPECT_EQ(error::INVALID_ARGUMENT,
1062
+ ValidateSparseTensor(&c, indices, values, shape).code());
1063
+ }
1064
+
1065
+ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
1066
+ NodeDef def;
1067
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
1068
+ {S({-1, 3}), S({5}), S({3})}, {}, {}, {});
1069
+ EXPECT_EQ(3, c.num_inputs());
1070
+ EXPECT_EQ(1, c.num_outputs());
1071
+
1072
+ auto indices = c.input(0);
1073
+ auto values = c.input(1);
1074
+ auto shape = c.input(2);
1075
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
1076
+ }
1077
+
1078
+ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
1079
+ NodeDef def;
1080
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
1081
+ {S({5, 3}), S({-1}), S({3})}, {}, {}, {});
1082
+ EXPECT_EQ(3, c.num_inputs());
1083
+ EXPECT_EQ(1, c.num_outputs());
1084
+
1085
+ auto indices = c.input(0);
1086
+ auto values = c.input(1);
1087
+ auto shape = c.input(2);
1088
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
1089
+ }
1090
+
1091
+ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
1092
+ NodeDef def;
1093
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
1094
+ {S({5, -1}), S({5}), S({3})}, {}, {}, {});
1095
+ EXPECT_EQ(3, c.num_inputs());
1096
+ EXPECT_EQ(1, c.num_outputs());
1097
+
1098
+ auto indices = c.input(0);
1099
+ auto values = c.input(1);
1100
+ auto shape = c.input(2);
1101
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
1102
+ }
1103
+
1104
+ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
1105
+ NodeDef def;
1106
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
1107
+ {S({5, 3}), S({5}), S({-1})}, {}, {}, {});
1108
+ EXPECT_EQ(3, c.num_inputs());
1109
+ EXPECT_EQ(1, c.num_outputs());
1110
+
1111
+ auto indices = c.input(0);
1112
+ auto values = c.input(1);
1113
+ auto shape = c.input(2);
1114
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
1115
+ }
1116
+
1117
+ TEST(CommonShapeFnsTest, ValidateSparseTensor) {
1118
+ NodeDef def;
1119
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
1120
+ {S({5, 3}), S({5}), S({3})}, {}, {}, {});
1121
+ EXPECT_EQ(3, c.num_inputs());
1122
+ EXPECT_EQ(1, c.num_outputs());
1123
+
1124
+ auto indices = c.input(0);
1125
+ auto values = c.input(1);
1126
+ auto shape = c.input(2);
1127
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
1128
+ }
1129
+
1130
+ } // namespace shape_inference
1131
+ } // namespace tensorflow
control_flow.h ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
17
+ #define TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
18
+
19
+ #include "tensorflow/core/lib/hash/hash.h"
20
+ #include "tensorflow/core/platform/logging.h"
21
+ #include "tensorflow/core/platform/types.h"
22
+
23
+ namespace tensorflow {
24
+
25
+ const uint64 kIllegalFrameId = ~0uLL;
26
+ const int64 kIllegalIterId = -1;
27
+
28
+ // For the purpose of control flow, every tensor produced by TensorFlow is
29
+ // conceptually tagged by a 'FrameAndIter'. FrameAndIter consists of a
30
+ // 'frame_id' and an 'iter_id'. The tensor value it represents is produced
31
+ // in the frame with frame_id at the iteration of iter_id.
32
+ struct FrameAndIter {
33
+ uint64 frame_id = kIllegalFrameId;
34
+ int64 iter_id = kIllegalIterId;
35
+
36
+ FrameAndIter() {}
37
+
38
+ FrameAndIter(uint64 frame, int64 iter) {
39
+ frame_id = frame;
40
+ iter_id = iter;
41
+ }
42
+
43
+ bool operator==(const FrameAndIter& other) const {
44
+ return (frame_id == other.frame_id && iter_id == other.iter_id);
45
+ }
46
+ };
47
+
48
+ struct FrameAndIterHash {
49
+ size_t operator()(const FrameAndIter& key) const {
50
+ // Make sure there are no padding bytes that we don't want
51
+ CHECK_EQ(sizeof(uint64) + sizeof(int64), sizeof(FrameAndIter));
52
+ return Hash64(reinterpret_cast<const char*>(&key), sizeof(FrameAndIter));
53
+ }
54
+ };
55
+
56
+ } // namespace tensorflow
57
+
58
+ #endif // TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
cost_graph.proto ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package tensorflow;
4
+ option cc_enable_arenas = true;
5
+ option java_outer_classname = "CostGraphProtos";
6
+ option java_multiple_files = true;
7
+ option java_package = "org.tensorflow.framework";
8
+
9
+ import "tensorflow/core/framework/tensor_shape.proto";
10
+ import "tensorflow/core/framework/types.proto";
11
+
12
+ message CostGraphDef {
13
+ message Node {
14
+ // The name of the node. Names are globally unique.
15
+ string name = 1;
16
+
17
+ // The device of the node. Can be empty if the node is mapped to the
18
+ // default partition or partitioning hasn't been run yet.
19
+ string device = 2;
20
+
21
+ // The id of the node. Node ids are only unique inside a partition.
22
+ int32 id = 3;
23
+
24
+ // Inputs of this node. They must be executed before this node can be
25
+ // executed. An input is a particular output of another node, specified
26
+ // by the node id and the output index.
27
+ message InputInfo {
28
+ int32 preceding_node = 1;
29
+ int32 preceding_port = 2;
30
+ }
31
+ repeated InputInfo input_info = 4;
32
+
33
+ // Outputs of this node.
34
+ message OutputInfo {
35
+ int64 size = 1;
36
+ // If >= 0, the output is an alias of an input. Note that an alias input
37
+ // may itself be an alias. The algorithm will therefore need to follow
38
+ // those pointers.
39
+ int64 alias_input_port = 2;
40
+ TensorShapeProto shape = 3;
41
+ DataType dtype = 4;
42
+ }
43
+ repeated OutputInfo output_info = 5;
44
+
45
+ // Temporary memory used by this node.
46
+ int64 temporary_memory_size = 6;
47
+
48
+ int64 host_temp_memory_size = 10;
49
+ int64 device_temp_memory_size = 11;
50
+ int64 host_persistent_memory_size = 12;
51
+ int64 device_persistent_memory_size = 16;
52
+
53
+ // Estimate of the computational cost of this node, in microseconds.
54
+ int64 compute_cost = 9;
55
+
56
+ // Analytical estimate of the computational cost of this node, in
57
+ // microseconds.
58
+ int64 compute_time = 14;
59
+
60
+ // Analytical estimate of the memory access cost of this node, in
61
+ // microseconds.
62
+ int64 memory_time = 15;
63
+
64
+ // If true, the output is permanent: it can't be discarded, because this
65
+ // node is part of the "final output". Nodes may depend on final nodes.
66
+ bool is_final = 7;
67
+
68
+ // Ids of the control inputs for this node.
69
+ repeated int32 control_input = 8;
70
+ }
71
+ repeated Node node = 1;
72
+ }
device_attributes.proto ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package tensorflow;
4
+ option cc_enable_arenas = true;
5
+ option java_outer_classname = "DeviceAttributesProtos";
6
+ option java_multiple_files = true;
7
+ option java_package = "org.tensorflow.framework";
8
+
9
+ message DeviceLocality {
10
+ // Optional bus locality of device. Default value of 0 means
11
+ // no specific locality. Specific localities are indexed from 1.
12
+ int32 bus_id = 1;
13
+ };
14
+
15
+ message DeviceAttributes {
16
+ // Fully specified name of the device within a cluster.
17
+ string name = 1;
18
+
19
+ // String representation of device_type.
20
+ string device_type = 2;
21
+
22
+ // Memory capacity of device in bytes.
23
+ int64 memory_limit = 4;
24
+
25
+ // Platform-specific data about device that may be useful
26
+ // for supporting efficient data transfers.
27
+ DeviceLocality locality = 5;
28
+
29
+ // A device is assigned a global unique number each time it is
30
+ // initialized. "incarnation" should never be 0.
31
+ fixed64 incarnation = 6;
32
+
33
+ // String representation of the physical device that this device maps to.
34
+ string physical_device_desc = 7;
35
+ }
device_base.cc ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/device_base.h"
17
+
18
+ namespace tensorflow {
19
+
20
+ DeviceBase::~DeviceBase() {}
21
+
22
+ const DeviceAttributes& DeviceBase::attributes() const {
23
+ LOG(FATAL) << "Device does not implement attributes()";
24
+ }
25
+
26
+ const string& DeviceBase::name() const {
27
+ LOG(FATAL) << "Device does not implement name()";
28
+ }
29
+
30
+ } // namespace tensorflow
device_base.h ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
17
+ #define TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
18
+
19
+ #include <memory>
20
+ #include <string>
21
+ #include <unordered_map>
22
+
23
+ #include "tensorflow/core/framework/tensor.h"
24
+ #include "tensorflow/core/lib/core/errors.h"
25
+ #include "tensorflow/core/lib/core/refcount.h"
26
+ #include "tensorflow/core/lib/core/status.h"
27
+ #include "tensorflow/core/lib/core/stringpiece.h"
28
+ #include "tensorflow/core/platform/logging.h"
29
+
30
+ namespace Eigen {
31
+ struct ThreadPoolDevice;
32
+ #ifdef TENSORFLOW_USE_SYCL
33
+ struct SyclDevice;
34
+ #endif
35
+ } // end namespace Eigen
36
+
37
+ namespace perftools {
38
+ namespace gputools {
39
+ class Stream;
40
+ } // namespace gputools
41
+ } // namespace perftools
42
+
43
+ namespace tensorflow {
44
+
45
+ class Device;
46
+ class DeviceAttributes;
47
+ class Env;
48
+ class EventMgr;
49
+ class OpKernelContext;
50
+ class ResourceMgr;
51
+ class TensorProto;
52
+
53
+ namespace thread {
54
+ class ThreadPool;
55
+ }
56
+
57
+ // A wrapper for an Eigen Gpu Device that includes per-op state. The
58
+ // class is defined even for non-GPU devices since the
59
+ // OpKernelContext::Params structure wants to fill it in.
60
+ class PerOpGpuDevice {
61
+ public:
62
+ virtual ~PerOpGpuDevice() {}
63
+ virtual const Eigen::GpuDevice& device() const = 0;
64
+ };
65
+
66
+ // A class that devices can subclass to pass around
67
+ // Device-specific context to OpKernels.
68
+ class DeviceContext : public core::RefCounted {
69
+ public:
70
+ ~DeviceContext() override {}
71
+ virtual perftools::gputools::Stream* stream() const { return nullptr; }
72
+ virtual void MaintainLifetimeOnStream(
73
+ const Tensor* t, perftools::gputools::Stream* stream) const {}
74
+
75
+ // "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into
76
+ // "device_tensor" which is on a GPU device "device". "device_tensor"
77
+ // must be allocated to be of the same size as "cpu_tensor".
78
+ virtual void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
79
+ Tensor* device_tensor,
80
+ StatusCallback done) const {
81
+ done(errors::Internal("Unrecognized device type in CPU-to-device Copy"));
82
+ }
83
+
84
+ // "device_tensor" is a tensor on a non-CPU device. Copies
85
+ // device_tensor into "cpu_tensor". "cpu_tensor" must be allocated
86
+ // to be of the same size as "device_tensor".
87
+ virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor,
88
+ StringPiece tensor_name, Device* device,
89
+ Tensor* cpu_tensor, StatusCallback done) {
90
+ done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
91
+ }
92
+ };
93
+
94
+ // map[i] is the DeviceContext* for the node with id i, if i < map.size().
95
+ typedef std::vector<DeviceContext*> DeviceContextMap;
96
+
97
+ class DeviceBase {
98
+ public:
99
+ explicit DeviceBase(Env* env) : env_(env) {}
100
+ virtual ~DeviceBase();
101
+
102
+ Env* env() const { return env_; }
103
+
104
+ // Override this to return true for devices that require an Op's
105
+ // compute method to save references to the temporary tensors it
106
+ // allocates until the Op execution completes
107
+ virtual bool RequiresRecordingAccessedTensors() const { return false; }
108
+
109
+ struct CpuWorkerThreads {
110
+ int num_threads = 0;
111
+ thread::ThreadPool* workers = nullptr;
112
+ };
113
+
114
+ // Does not take ownership.
115
+ void set_tensorflow_cpu_worker_threads(CpuWorkerThreads* t) {
116
+ cpu_worker_threads_ = t;
117
+ }
118
+
119
+ virtual const CpuWorkerThreads* tensorflow_cpu_worker_threads() const {
120
+ CHECK(cpu_worker_threads_ != nullptr);
121
+ return cpu_worker_threads_;
122
+ }
123
+
124
+ // "stream" is used in special circumstances (such as the
125
+ // constructors of Ops) where there is no available OpKernelContext.
126
+ // "default_context" is used by OpKernelContext whenever a device does not
127
+ // supply a DeviceContext for an op in FillContextMap (e.g. when only
128
+ // using a single stream.)
129
+ // "event_mgr" is used to delay deallocation of temporary GPU buffers.
130
+ // TODO(pbar) Work out how to move this out of DeviceBase.
131
+ struct GpuDeviceInfo {
132
+ // Make sure all the defaults are NULL, so we can spot missing assignments.
133
+ perftools::gputools::Stream* stream = nullptr;
134
+ DeviceContext* default_context = nullptr;
135
+ EventMgr* event_mgr = nullptr;
136
+ int gpu_id = -1;
137
+ };
138
+
139
+ // Does not take ownership.
140
+ void set_tensorflow_gpu_device_info(GpuDeviceInfo* g) {
141
+ gpu_device_info_ = g;
142
+ }
143
+
144
+ virtual const GpuDeviceInfo* tensorflow_gpu_device_info() const {
145
+ return gpu_device_info_;
146
+ }
147
+
148
+ // The preferred thread pool for this device. If it is nullptr, the system
149
+ // automatically assigns a thread pool for execution.
150
+ virtual thread::ThreadPool* tensorflow_device_thread_pool() {
151
+ return device_thread_pool_;
152
+ }
153
+
154
+ // Does not take ownership.
155
+ void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) {
156
+ eigen_cpu_device_ = d;
157
+ }
158
+
159
+ #ifdef TENSORFLOW_USE_SYCL
160
+ void set_eigen_sycl_device(Eigen::SyclDevice* d) { eigen_sycl_device_ = d; }
161
+ #endif
162
+
163
+ // Return the Allocator implementation to use based on the allocator
164
+ // attributes requested. See allocator.h for more details.
165
+ virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
166
+ LOG(FATAL) << "GetAllocator() is not implemented.";
167
+ return nullptr;
168
+ }
169
+
170
+ // Return the Allocator implementation to use based on the allocator
171
+ // attributes requested and the supplied resource manager. By
172
+ // default this ignores the resource manager and calls the base
173
+ // implementation but devices can override if they want to consult
174
+ // the resource manager when choosing the allocator.
175
+ virtual Allocator* GetStepAllocator(AllocatorAttributes attr,
176
+ ResourceMgr* /*step_resource_manager*/) {
177
+ return GetAllocator(attr);
178
+ }
179
+
180
+ virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() {
181
+ CHECK(eigen_cpu_device_ != nullptr);
182
+ return eigen_cpu_device_;
183
+ }
184
+
185
+ #ifdef TENSORFLOW_USE_SYCL
186
+ virtual const Eigen::SyclDevice* eigen_sycl_device() const {
187
+ CHECK(eigen_sycl_device_ != nullptr);
188
+ return eigen_sycl_device_;
189
+ }
190
+ #endif
191
+
192
+ // Caller owns the return value. The OpKernelContext calls this even
193
+ // for devices that do not implement an eigen_gpu_device. Overridden
194
+ // by GPU devices to return a derived type.
195
+ virtual PerOpGpuDevice* MakeGpuDevice() { return nullptr; }
196
+
197
+ virtual DeviceBase* UnderlyingDevice() { return this; }
198
+ virtual const DeviceBase* UnderlyingDevice() const { return this; }
199
+
200
+ // This is overridden by GPU devices to reinitialize the derived
201
+ // type returned by MakeGpuDevice.
202
+ virtual void ReinitializeGpuDevice(OpKernelContext* /*context*/,
203
+ PerOpGpuDevice* /*device*/,
204
+ DeviceContext* /*dc*/,
205
+ Allocator* /*allocator*/) {}
206
+
207
+ // Unimplemented by default
208
+ virtual const DeviceAttributes& attributes() const;
209
+ virtual const string& name() const;
210
+
211
+ // Materializes the given TensorProto into 'tensor' stored in Device
212
+ // memory. Most devices will want to override this.
213
+ //
214
+ // TODO(vrv): We should be able to put this function into
215
+ // OpKernelContext and handle the copies from device memory via send
216
+ // and receive nodes, instead of requiring that each device handle
217
+ // the copies here as well as in copy ops.
218
+ virtual Status MakeTensorFromProto(const TensorProto& tensor_proto,
219
+ const AllocatorAttributes alloc_attrs,
220
+ Tensor* tensor) {
221
+ return errors::Internal("Device does not implement MakeTensorFromProto()");
222
+ }
223
+
224
+ protected:
225
+ // Does not take ownership.
226
+ void set_tensorflow_device_thread_pool(thread::ThreadPool* thread_pool) {
227
+ device_thread_pool_ = thread_pool;
228
+ }
229
+
230
+ private:
231
+ Env* const env_;
232
+ CpuWorkerThreads* cpu_worker_threads_ = nullptr;
233
+ GpuDeviceInfo* gpu_device_info_ = nullptr;
234
+ thread::ThreadPool* device_thread_pool_ = nullptr;
235
+ Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr;
236
+ #ifdef TENSORFLOW_USE_SYCL
237
+ Eigen::SyclDevice* eigen_sycl_device_ = nullptr;
238
+ #endif
239
+ };
240
+
241
+ } // namespace tensorflow
242
+
243
+ #endif // TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
fake_input.cc ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/fake_input.h"
17
+
18
+ #include <vector>
19
+ #include "tensorflow/core/framework/attr_value.pb.h"
20
+ #include "tensorflow/core/framework/node_def_util.h"
21
+ #include "tensorflow/core/framework/op_def.pb.h"
22
+ #include "tensorflow/core/framework/op_def_util.h"
23
+ #include "tensorflow/core/lib/core/errors.h"
24
+ #include "tensorflow/core/lib/core/status.h"
25
+
26
+ namespace tensorflow {
27
+ namespace {
28
+
29
+ class FakeInputImpl {
30
+ public:
31
+ FakeInputImpl(const OpDef* op_def, int in_index, const NodeDef* node_def,
32
+ NodeDefBuilder* builder);
33
+ void SetN(int n);
34
+ void SetDataType(DataType dt);
35
+ void SetTypeList(DataTypeSlice dts);
36
+ Status AddInputToBuilder();
37
+
38
+ private:
39
+ static string FakeNodeName(int in_index);
40
+ Status GetN(int* n) const;
41
+ Status GetDataType(DataType* dt) const;
42
+ void NSources(int n, DataType dt) const;
43
+ void SourceList(DataTypeSlice dts) const;
44
+
45
+ const OpDef* const op_def_;
46
+ const OpDef::ArgDef* const arg_;
47
+ const string in_node_;
48
+ const NodeDef* const node_def_;
49
+ NodeDefBuilder* const builder_;
50
+
51
+ bool n_specified_;
52
+ int n_;
53
+ bool dt_specified_;
54
+ DataType dt_;
55
+ bool dts_specified_;
56
+ DataTypeSlice dts_;
57
+ };
58
+
59
+ FakeInputImpl::FakeInputImpl(const OpDef* op_def, int in_index,
60
+ const NodeDef* node_def, NodeDefBuilder* builder)
61
+ : op_def_(op_def),
62
+ arg_(&op_def->input_arg(in_index)),
63
+ in_node_(FakeNodeName(in_index)),
64
+ node_def_(node_def),
65
+ builder_(builder),
66
+ n_specified_(false),
67
+ dt_specified_(false),
68
+ dts_specified_(false) {}
69
+
70
+ void FakeInputImpl::SetN(int n) {
71
+ n_specified_ = true;
72
+ n_ = n;
73
+ }
74
+
75
+ void FakeInputImpl::SetDataType(DataType dt) {
76
+ dt_specified_ = true;
77
+ dt_ = dt;
78
+ }
79
+
80
+ void FakeInputImpl::SetTypeList(DataTypeSlice dts) {
81
+ dts_specified_ = true;
82
+ dts_ = dts;
83
+ }
84
+
85
+ Status FakeInputImpl::AddInputToBuilder() {
86
+ if (dts_specified_) {
87
+ SourceList(dts_);
88
+
89
+ } else if (n_specified_ || !arg_->number_attr().empty()) {
90
+ int n;
91
+ TF_RETURN_IF_ERROR(GetN(&n));
92
+
93
+ DataType dt;
94
+ if (n > 0) {
95
+ TF_RETURN_IF_ERROR(GetDataType(&dt));
96
+ } else {
97
+ dt = DT_FLOAT;
98
+ }
99
+
100
+ NSources(n, dt);
101
+ } else {
102
+ if (!dt_specified_ && !arg_->type_list_attr().empty()) {
103
+ DataTypeVector dts;
104
+ Status status = GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts);
105
+ if (!status.ok()) {
106
+ return errors::InvalidArgument(
107
+ "Could not infer list of types for input '", arg_->name(), "': ",
108
+ status.error_message());
109
+ }
110
+ SourceList(dts);
111
+ return Status::OK();
112
+ }
113
+
114
+ DataType dt;
115
+ TF_RETURN_IF_ERROR(GetDataType(&dt));
116
+ builder_->Input(in_node_, 0, dt);
117
+ }
118
+ return Status::OK();
119
+ }
120
+
121
+ // static
122
+ string FakeInputImpl::FakeNodeName(int in_index) {
123
+ char c = 'a' + (in_index % 26);
124
+ return string(&c, 1);
125
+ }
126
+
127
+ Status FakeInputImpl::GetN(int* n) const {
128
+ if (n_specified_) {
129
+ *n = n_;
130
+ } else {
131
+ Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n);
132
+ if (!status.ok()) {
133
+ return errors::InvalidArgument("Could not infer length of input '",
134
+ arg_->name(), "': ",
135
+ status.error_message());
136
+ }
137
+ }
138
+ return Status::OK();
139
+ }
140
+
141
+ Status FakeInputImpl::GetDataType(DataType* dt) const {
142
+ if (dt_specified_) {
143
+ *dt = dt_;
144
+ return Status::OK(); // Ignore is_ref field of arg_.
145
+ } else if (arg_->type() != DT_INVALID) {
146
+ *dt = arg_->type();
147
+ } else if (!arg_->type_attr().empty()) {
148
+ Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt);
149
+ if (!status.ok()) {
150
+ // Check if the type attr has a default
151
+ const OpDef::AttrDef* attr = FindAttr(arg_->type_attr(), *op_def_);
152
+ if (attr && attr->has_default_value()) {
153
+ *dt = attr->default_value().type();
154
+ } else {
155
+ return errors::InvalidArgument("Could not infer type for input '",
156
+ arg_->name(), "': ",
157
+ status.error_message());
158
+ }
159
+ }
160
+ } else {
161
+ return errors::InvalidArgument("No type or type_attr field in arg '",
162
+ arg_->name(), "'");
163
+ }
164
+ if (arg_->is_ref()) {
165
+ *dt = MakeRefType(*dt);
166
+ }
167
+ return Status::OK();
168
+ }
169
+
170
+ void FakeInputImpl::NSources(int n, DataType dt) const {
171
+ std::vector<NodeDefBuilder::NodeOut> srcs;
172
+ srcs.reserve(n);
173
+ for (int i = 0; i < n; ++i) {
174
+ srcs.emplace_back(in_node_, i, dt);
175
+ }
176
+ builder_->Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
177
+ }
178
+
179
+ void FakeInputImpl::SourceList(DataTypeSlice dts) const {
180
+ std::vector<NodeDefBuilder::NodeOut> srcs;
181
+ srcs.reserve(dts.size());
182
+ for (size_t i = 0; i < dts.size(); ++i) {
183
+ srcs.emplace_back(in_node_, i, dts[i]);
184
+ }
185
+ builder_->Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
186
+ }
187
+
188
+ } // namespace
189
+
190
+ // Public interface ------------------------------------------------------------
191
+
192
+ FakeInputFunctor FakeInput() {
193
+ return [](const OpDef& op_def, int in_index, const NodeDef& node_def,
194
+ NodeDefBuilder* builder) {
195
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
196
+ return impl.AddInputToBuilder();
197
+ };
198
+ }
199
+
200
+ FakeInputFunctor FakeInput(DataType dt) {
201
+ return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
202
+ NodeDefBuilder* builder) {
203
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
204
+ impl.SetDataType(dt);
205
+ return impl.AddInputToBuilder();
206
+ };
207
+ }
208
+
209
+ FakeInputFunctor FakeInput(int n) {
210
+ return [n](const OpDef& op_def, int in_index, const NodeDef& node_def,
211
+ NodeDefBuilder* builder) {
212
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
213
+ impl.SetN(n);
214
+ return impl.AddInputToBuilder();
215
+ };
216
+ }
217
+
218
+ FakeInputFunctor FakeInput(int n, DataType dt) {
219
+ return [n, dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
220
+ NodeDefBuilder* builder) {
221
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
222
+ impl.SetN(n);
223
+ impl.SetDataType(dt);
224
+ return impl.AddInputToBuilder();
225
+ };
226
+ }
227
+
228
+ FakeInputFunctor FakeInput(DataTypeSlice dts) {
229
+ // Make a copy to ensure the data will still be around when the lambda is
230
+ // called.
231
+ DataTypeVector dtv(dts.begin(), dts.end());
232
+ return [dtv](const OpDef& op_def, int in_index, const NodeDef& node_def,
233
+ NodeDefBuilder* builder) {
234
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
235
+ impl.SetTypeList(dtv);
236
+ return impl.AddInputToBuilder();
237
+ };
238
+ }
239
+
240
+ } // namespace tensorflow
fake_input.h ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
17
+ #define TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
18
+
19
+ #include "tensorflow/core/framework/node_def_builder.h"
20
+ #include "tensorflow/core/framework/types.h"
21
+
22
+ namespace tensorflow {
23
+
24
+ // These functions return values that may be passed to
25
+ // NodeDefBuilder::Input() to add an input for a test. Use them when
26
+ // you don't care about the node names/output indices providing the
27
+ // input. They also allow you to omit the input types and/or
28
+ // list length when they may be inferred.
29
+ FakeInputFunctor FakeInput(); // Infer everything
30
+ FakeInputFunctor FakeInput(DataType dt);
31
+ FakeInputFunctor FakeInput(int n); // List of length n
32
+ FakeInputFunctor FakeInput(int n, DataType dt);
33
+ FakeInputFunctor FakeInput(DataTypeSlice dts);
34
+ inline FakeInputFunctor FakeInput(std::initializer_list<DataType> dts) {
35
+ return FakeInput(DataTypeSlice(dts));
36
+ }
37
+
38
+ } // namespace tensorflow
39
+
40
+ #endif // TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
function.cc ADDED
@@ -0,0 +1,1322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/function.h"
17
+
18
+ #include <map>
19
+ #include <unordered_map>
20
+ #include <utility>
21
+ #include <vector>
22
+
23
+ #include "tensorflow/core/framework/common_shape_fns.h"
24
+ #include "tensorflow/core/framework/function.pb_text.h"
25
+ #include "tensorflow/core/framework/graph.pb.h"
26
+ #include "tensorflow/core/framework/node_def.pb.h"
27
+ #include "tensorflow/core/framework/node_def_util.h"
28
+ #include "tensorflow/core/framework/op.h"
29
+ #include "tensorflow/core/graph/graph.h"
30
+ #include "tensorflow/core/lib/core/errors.h"
31
+ #include "tensorflow/core/lib/gtl/inlined_vector.h"
32
+ #include "tensorflow/core/lib/gtl/map_util.h"
33
+ #include "tensorflow/core/util/equal_graph_def.h"
34
+
35
+ namespace tensorflow {
36
+
37
+ // Extracts the actual type from "attr_values" based on its definition
38
+ // "arg_def".
39
+ //
40
+ // If "arg_def" is a N*T type, *is_type_list is set to false, and
41
+ // *dtypes is set to be a vector of size N and each element is T.
42
+ //
43
+ // If "arg_def" is a list(type), *is_type_list is set to true, and
44
+ // *dtypes is set to be a vector of types specified in attrs for
45
+ // arg_def.
46
+ //
47
+ // Otherwise (arg_def is a simple type T), *is_type_list is set to
48
+ // false, and *dtypes is set to a single element vector, whose only
49
+ // element is T.
50
+ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
51
+ bool* is_type_list, DataTypeVector* dtypes) {
52
+ dtypes->clear();
53
+ if (!arg_def.type_list_attr().empty()) {
54
+ const AttrValue* v = attrs.Find(arg_def.type_list_attr());
55
+ if (v == nullptr) {
56
+ return errors::NotFound("type attr not found: ",
57
+ arg_def.type_list_attr());
58
+ }
59
+ *is_type_list = true;
60
+ for (int i = 0; i < v->list().type_size(); ++i) {
61
+ dtypes->push_back(v->list().type(i));
62
+ }
63
+ return Status::OK();
64
+ }
65
+
66
+ *is_type_list = false;
67
+ int num = 1;
68
+ if (!arg_def.number_attr().empty()) {
69
+ const AttrValue* v = attrs.Find(arg_def.number_attr());
70
+ if (v == nullptr) {
71
+ return errors::NotFound("type attr not found: ", arg_def.type_attr());
72
+ }
73
+ num = v->i();
74
+ }
75
+
76
+ DataType dtype;
77
+ if (arg_def.type() != DT_INVALID) {
78
+ dtype = arg_def.type();
79
+ } else if (arg_def.type_attr().empty()) {
80
+ dtype = DT_INVALID;
81
+ } else {
82
+ const AttrValue* v = attrs.Find(arg_def.type_attr());
83
+ if (v == nullptr) {
84
+ return errors::NotFound("type attr not found: ", arg_def.type_attr());
85
+ }
86
+ dtype = v->type();
87
+ }
88
+ dtypes->resize(num, dtype);
89
+ return Status::OK();
90
+ }
91
+
92
+ namespace {
93
+
94
+ template <typename T>
95
+ void AddAttr(const string& name, const T& val, NodeDef* ndef) {
96
+ SetAttrValue(val, &((*ndef->mutable_attr())[name]));
97
+ }
98
+
99
+ Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
100
+ // attr_values should specify all attrs defined in fdef.
101
+ for (const auto& a : sig.attr()) {
102
+ const AttrValue* v = attr_values.Find(a.name());
103
+ if (!v) {
104
+ return errors::NotFound("Attr ", a.name(), " is not found from ",
105
+ SummarizeOpDef(sig));
106
+ }
107
+ Status status = AttrValueHasType(*v, a.type());
108
+ if (!status.ok()) {
109
+ errors::AppendToMessage(&status, "for attr '", a.name(), "'");
110
+ return status;
111
+ }
112
+ }
113
+
114
+ // TODO(josh11b): Enable this code once it works with function gradients.
115
+ // Right now the C++ function gradient code assumes it can pass
116
+ // all the attrs of the function to the gradient, and any attrs that
117
+ // the gradient doesn't care about will be ignored.
118
+ #if 0
119
+ if (attr_values.size() != sig.attr_size()) {
120
+ for (const auto& a : attr_values) {
121
+ // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
122
+ bool found = false;
123
+ for (const auto& s : sig.attr()) {
124
+ if (a.first == s.name()) {
125
+ found = true;
126
+ break;
127
+ }
128
+ }
129
+ if (!found) {
130
+ return errors::NotFound("Attr ", a.first, " is not found in ",
131
+ SummarizeOpDef(sig));
132
+ }
133
+ }
134
+ }
135
+ #endif
136
+
137
+ return Status::OK();
138
+ }
139
+
140
+ // A helper class for instantiating functions. This contains shared information
141
+ // like the resulting graph and node name index.
142
+ class FunctionInstantiationHelper {
143
+ public:
144
+ FunctionInstantiationHelper(GetFunctionSignature get_function,
145
+ InstantiationResult* result)
146
+ : get_function_(std ::move(get_function)), result_(*result) {
147
+ result_.nodes.clear();
148
+ }
149
+
150
+ // Builds index for nodes that can be used as node's input arguments.
151
+ Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
152
+ AttrSlice attr_values) {
153
+ bool is_type_list;
154
+ DataTypeVector dtypes;
155
+ TF_RETURN_IF_ERROR(
156
+ ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
157
+ CHECK_GE(dtypes.size(), size_t{1});
158
+ int arg_index = result_.nodes.size();
159
+ TF_RETURN_IF_ERROR(
160
+ AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
161
+ // Creates dtypes.size() nodes in the graph.
162
+ for (size_t i = 0; i < dtypes.size(); ++i) {
163
+ TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
164
+ {true, arg_index, 0, false, {dtypes[i]}}));
165
+ DCHECK_EQ(arg_index, result_.nodes.size());
166
+ string name = arg_def.name();
167
+ if (dtypes.size() > 1) {
168
+ strings::StrAppend(&name, "_", i);
169
+ }
170
+ NodeDef* gnode = AddNode(name);
171
+ gnode->set_op("_Arg");
172
+ AddAttr("T", dtypes[i], gnode);
173
+ AddAttr("index", arg_index, gnode);
174
+ result_.arg_types.push_back(dtypes[i]);
175
+ ++arg_index;
176
+ }
177
+ return Status::OK();
178
+ }
179
+
180
+ Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
181
+ const int arg_index) {
182
+ const OpDef* node_sig = nullptr;
183
+ TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
184
+ if (node_sig->output_arg_size() == 0) {
185
+ return AddItem(node.name(), {false, arg_index, 0, false, {}});
186
+ }
187
+ const int num_retval = node_sig->output_arg_size();
188
+ int start = 0;
189
+ bool is_type_list;
190
+ DataTypeVector dtypes;
191
+ for (int i = 0; i < num_retval; ++i) {
192
+ TF_RETURN_IF_ERROR(
193
+ ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
194
+ // Note that we rely on the backwards-compatibility test enforcing
195
+ // that output_arg(*).name() doesn't change here.
196
+ const string base_name =
197
+ strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
198
+ TF_RETURN_IF_ERROR(
199
+ AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
200
+ for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
201
+ TF_RETURN_IF_ERROR(
202
+ AddItem(strings::StrCat(base_name, ":", j),
203
+ {false, arg_index, start + j, false, {dtypes[j]}}));
204
+ }
205
+ start += dtypes.size();
206
+ }
207
+ return Status::OK();
208
+ }
209
+
210
+ Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
211
+ const OpDef* fnode_sig = nullptr;
212
+ TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
213
+ NodeDef* gnode = AddNode(fnode.name());
214
+ gnode->set_op(fnode.op());
215
+ gnode->set_device(fnode.device());
216
+ int gnode_idx = nodes_.size() - 1;
217
+
218
+ // Input
219
+ const int num_args = fnode_sig->input_arg_size();
220
+ bool is_type_list; // ignored
221
+ DataTypeVector dtypes;
222
+ int fnode_arg_index = 0;
223
+ for (int i = 0; i < num_args; ++i) {
224
+ TF_RETURN_IF_ERROR(
225
+ ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
226
+ // Consume inputs (indexed by fnode_arg_index) until we have
227
+ // matched each element of dtypes (indexed by j).
228
+ for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
229
+ if (fnode_arg_index >= fnode.input_size()) {
230
+ // Should never happen if we computed dtypes correctly.
231
+ return errors::InvalidArgument(
232
+ "Attempt to access beyond input size: ", fnode_arg_index,
233
+ " >= ", fnode.input_size());
234
+ }
235
+ // Look up the next input.
236
+ const string& input_name = fnode.input(fnode_arg_index);
237
+ const auto* item = GetItemOrNull(input_name);
238
+ if (item == nullptr) {
239
+ return errors::InvalidArgument(
240
+ "input ", input_name, " is not found: ", SummarizeNodeDef(fnode));
241
+ }
242
+ if (item->dtypes.size() > dtypes.size() - j) {
243
+ return errors::InvalidArgument("Input ", input_name, " too long for ",
244
+ fnode_sig->input_arg(i).name());
245
+ }
246
+ // Match up all the elements of this input (indexed by k) with
247
+ // elements of dtypes (advancing j).
248
+ for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
249
+ if (item->dtypes[k] != dtypes[j]) {
250
+ return errors::InvalidArgument(
251
+ "input ", fnode_sig->input_arg(i).name(), "[", j,
252
+ "] expected type ", DataTypeString(dtypes[j]),
253
+ " != ", DataTypeString(item->dtypes[k]), ", the type of ",
254
+ input_name, "[", k, "]");
255
+ }
256
+ if (item->is_func_arg) {
257
+ AddInput(gnode_idx, item->nid + k, 0);
258
+ } else {
259
+ AddInput(gnode_idx, item->nid, item->idx + k);
260
+ }
261
+ }
262
+ }
263
+ }
264
+
265
+ // Control deps.
266
+ for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
267
+ const string& input = fnode.input(i);
268
+ if (input.empty() || input[0] != '^') {
269
+ return errors::InvalidArgument("Expected input[", i, "] == '", input,
270
+ "' to be a control input.");
271
+ }
272
+ int nid = -1;
273
+ const string node_name = input.substr(1);
274
+ const string node_colon = node_name + ":";
275
+ const string node_colon_bound = node_name + ";";
276
+ // index_ is a map sorted lexicographically, so the key we are looking for
277
+ // must lie in the range [node_name, node_colon_bound).
278
+ auto it = index_.lower_bound(node_name);
279
+ while (it != index_.end() && it->first <= node_colon_bound) {
280
+ if (it->first == node_name ||
281
+ tensorflow::StringPiece(it->first).starts_with(node_colon)) {
282
+ nid = it->second.nid;
283
+ break;
284
+ }
285
+ ++it;
286
+ }
287
+ if (nid == -1) {
288
+ return errors::InvalidArgument("input[", i, "] == '", input,
289
+ "', is not found.");
290
+ }
291
+ AddDep(gnode_idx, nid);
292
+ }
293
+
294
+ // Attrs.
295
+ for (const auto& p : attrs) {
296
+ (*gnode->mutable_attr())[p.first] = p.second;
297
+ }
298
+
299
+ return Status::OK();
300
+ }
301
+
302
+ Status AddReturnNode(
303
+ const OpDef::ArgDef& ret_def, AttrSlice attrs,
304
+ const ::tensorflow::protobuf::Map<string, string>& ret_map,
305
+ int* ret_index) {
306
+ auto ret_iter = ret_map.find(ret_def.name());
307
+ if (ret_iter == ret_map.end()) {
308
+ return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
309
+ }
310
+ bool is_type_list;
311
+ DataTypeVector dtypes;
312
+ TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
313
+ CHECK_GE(dtypes.size(), size_t{1});
314
+ const auto* item = GetItemOrNull(ret_iter->second);
315
+ if (item == nullptr) {
316
+ return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
317
+ ret_iter->second, " is not found.");
318
+ }
319
+ if (dtypes != item->dtypes) {
320
+ return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
321
+ " : ", DataTypeVectorString(dtypes),
322
+ " vs. ",
323
+ DataTypeVectorString(item->dtypes));
324
+ }
325
+ for (size_t i = 0; i < dtypes.size(); ++i) {
326
+ string name = strings::StrCat(ret_def.name(), "_RetVal");
327
+ if (dtypes.size() > 1) {
328
+ strings::StrAppend(&name, "_", i);
329
+ }
330
+ NodeDef* gnode = AddNode(name);
331
+ gnode->set_op("_Retval");
332
+ AddInput(nodes_.size() - 1, item->nid, item->idx + i);
333
+ AddAttr("T", dtypes[i], gnode);
334
+ AddAttr("index", (*ret_index)++, gnode);
335
+ result_.ret_types.push_back(dtypes[i]);
336
+ }
337
+ return Status::OK();
338
+ }
339
+
340
+ // Adds the actual node inputs to the result graph by converting indexes to
341
+ // the node names.
342
+ void AddNodeInputs() {
343
+ for (int i = 0; i < result_.nodes.size(); i++) {
344
+ NodeInfo& node_info = nodes_[i];
345
+ for (const auto& p : node_info.data_inputs) {
346
+ result_.nodes[i].add_input(Name(p.first, p.second));
347
+ }
348
+ for (int index : node_info.control_inputs) {
349
+ result_.nodes[i].add_input(Dep(index));
350
+ }
351
+ }
352
+ }
353
+
354
+ private:
355
+ // This is used to build a small index for all names that can be used as a
356
+ // node's input arguments.
357
+ //
358
+ // If is_func_arg is true, the name is a function's argument. In
359
+ // this case, the produced graph def has node[nid:nid + dtype.size()].
360
+ //
361
+ // Otherwise, the name is a function body's node return value. In
362
+ // this case, the produced graph def has one node node[nid] and
363
+ // the node's output index [idx ... idx + num) corresponds to the
364
+ // named outputs.
365
+ //
366
+ // In all cases, "dtype" specifies the data type.
367
+ struct NameInfoItem {
368
+ bool is_func_arg;
369
+ int nid;
370
+ int idx;
371
+ bool is_type_list;
372
+ DataTypeVector dtypes;
373
+ };
374
+
375
+ // Adds an item into the input name index.
376
+ Status AddItem(const string& name, const NameInfoItem& item) {
377
+ if (!index_.insert({name, item}).second) {
378
+ return errors::InvalidArgument(
379
+ strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
380
+ " name: "),
381
+ name);
382
+ }
383
+ return Status::OK();
384
+ }
385
+
386
+ const NameInfoItem* GetItemOrNull(const string& name) const {
387
+ return gtl::FindOrNull(index_, name);
388
+ }
389
+
390
+ string Dep(int node_index) const {
391
+ return strings::StrCat("^", Name(node_index));
392
+ }
393
+
394
+ string Name(int node_index) const {
395
+ CHECK_LT(node_index, nodes_.size());
396
+ return nodes_[node_index].name;
397
+ }
398
+
399
+ string Name(int node_index, int output_index) const {
400
+ if (output_index == 0) {
401
+ return Name(node_index);
402
+ } else {
403
+ return strings::StrCat(Name(node_index), ":", output_index);
404
+ }
405
+ }
406
+
407
+ NodeDef* AddNode(const string& name) {
408
+ result_.nodes.emplace_back();
409
+ NodeDef* gnode = &result_.nodes.back();
410
+ gnode->set_name(name);
411
+ nodes_.push_back({name, {}, {}});
412
+ CHECK_EQ(result_.nodes.size(), nodes_.size());
413
+ return gnode;
414
+ }
415
+
416
+ void AddInput(int node_index, int output_node, int output_index) {
417
+ CHECK_LT(node_index, nodes_.size());
418
+ nodes_[node_index].data_inputs.push_back(
419
+ std::make_pair(output_node, output_index));
420
+ }
421
+
422
+ void AddDep(int node_index, int dep_index) {
423
+ CHECK_LT(node_index, nodes_.size());
424
+ nodes_[node_index].control_inputs.push_back(dep_index);
425
+ }
426
+
427
+ GetFunctionSignature get_function_;
428
+ InstantiationResult& result_;
429
+ // A small index for all names that can be used as a node's input arguments.
430
+ std::map<string, NameInfoItem> index_;
431
+ // This contains information about a node in the new graph including the node
432
+ // names and input nodes' indexes.
433
+ struct NodeInfo {
434
+ string name;
435
+ // Data inputs where <n, k> means arg k of node n.
436
+ std::vector<std::pair<int, int>> data_inputs;
437
+ // Control inputs (dependencies).
438
+ std::vector<int> control_inputs;
439
+ };
440
+ // nodes_[i] is the information about result_.nodes[i].
441
+ std::vector<NodeInfo> nodes_;
442
+ };
443
+
444
+ // Various helpers Print(proto) to print relevant protos to ascii.
445
+ string Print(const OpDef::ArgDef& arg) {
446
+ string out;
447
+ strings::StrAppend(&out, arg.name(), ":");
448
+ if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
449
+ if (!arg.number_attr().empty()) {
450
+ strings::StrAppend(&out, arg.number_attr(), "*");
451
+ }
452
+ if (arg.type() != DT_INVALID) {
453
+ strings::StrAppend(&out, DataTypeString(arg.type()));
454
+ } else {
455
+ strings::StrAppend(&out, arg.type_attr());
456
+ }
457
+ if (arg.is_ref()) strings::StrAppend(&out, ")");
458
+ return out;
459
+ }
460
+
461
+ // TODO(josh11b): Merge this with SummarizeAttrValue().
462
+ string Print(const AttrValue& attr_value) {
463
+ if (attr_value.value_case() == AttrValue::kType) {
464
+ return DataTypeString(attr_value.type());
465
+ } else if ((attr_value.value_case() == AttrValue::kList) &&
466
+ (attr_value.list().type_size() > 0)) {
467
+ string ret = "{";
468
+ for (int i = 0; i < attr_value.list().type_size(); ++i) {
469
+ if (i > 0) strings::StrAppend(&ret, ", ");
470
+ strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
471
+ }
472
+ strings::StrAppend(&ret, "}");
473
+ return ret;
474
+ } else if (attr_value.value_case() == AttrValue::kFunc) {
475
+ if (attr_value.func().attr_size() == 0) {
476
+ return attr_value.func().name();
477
+ }
478
+ std::vector<string> entries;
479
+ for (auto p : attr_value.func().attr()) {
480
+ entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
481
+ }
482
+ std::sort(entries.begin(), entries.end());
483
+ return strings::StrCat(attr_value.func().name(), "[",
484
+ str_util::Join(entries, ", "), "]");
485
+ }
486
+ return SummarizeAttrValue(attr_value);
487
+ }
488
+
489
+ // TODO(josh11b): Merge this with SummarizeNodeDef().
490
+ string Print(const NodeDef& n) {
491
+ string out;
492
+ strings::StrAppend(&out, n.name(), " = ", n.op());
493
+ if (n.attr_size() > 0) {
494
+ std::vector<string> entries;
495
+ for (auto& a : n.attr()) {
496
+ entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
497
+ }
498
+ std::sort(entries.begin(), entries.end());
499
+ strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
500
+ }
501
+ strings::StrAppend(&out, "(");
502
+ std::vector<StringPiece> dat;
503
+ std::vector<string> dep;
504
+ for (StringPiece s : n.input()) {
505
+ if (s.Consume("^")) {
506
+ dep.push_back(s.ToString());
507
+ } else {
508
+ dat.push_back(s);
509
+ }
510
+ }
511
+ strings::StrAppend(&out, str_util::Join(dat, ", "), ")");
512
+ if (!dep.empty()) {
513
+ strings::StrAppend(&out, " @ ", str_util::Join(dep, ", "));
514
+ }
515
+ return out;
516
+ }
517
+
518
+ string Print(const FunctionDef& fdef) {
519
+ string out;
520
+ const OpDef& sig = fdef.signature();
521
+ strings::StrAppend(&out, "\n", sig.name());
522
+ if (sig.attr_size() > 0) {
523
+ strings::StrAppend(&out, "[");
524
+ for (int i = 0; i < sig.attr_size(); ++i) {
525
+ const auto& a = sig.attr(i);
526
+ if (i > 0) strings::StrAppend(&out, ", ");
527
+ if (a.type() == "type") {
528
+ strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
529
+ } else {
530
+ strings::StrAppend(&out, a.name(), ":", a.type());
531
+ }
532
+ }
533
+ strings::StrAppend(&out, "]");
534
+ }
535
+ strings::StrAppend(&out, "(");
536
+ for (int i = 0; i < sig.input_arg_size(); ++i) {
537
+ if (i > 0) strings::StrAppend(&out, ", ");
538
+ strings::StrAppend(&out, Print(sig.input_arg(i)));
539
+ }
540
+ strings::StrAppend(&out, ") -> (");
541
+ for (int i = 0; i < sig.output_arg_size(); ++i) {
542
+ if (i > 0) strings::StrAppend(&out, ", ");
543
+ strings::StrAppend(&out, Print(sig.output_arg(i)));
544
+ }
545
+ strings::StrAppend(&out, ") {\n");
546
+ for (const auto& n : fdef.node_def()) {
547
+ strings::StrAppend(&out, " ", Print(n), "\n");
548
+ }
549
+ for (const auto& r : fdef.ret()) {
550
+ strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
551
+ }
552
+ strings::StrAppend(&out, "}\n");
553
+ return out;
554
+ }
555
+
556
+ string Print(gtl::ArraySlice<const NodeDef*> nodes) {
557
+ std::vector<const NodeDef*> arg;
558
+ std::vector<const NodeDef*> ret;
559
+ std::vector<const NodeDef*> body;
560
+ for (const NodeDef* n : nodes) {
561
+ if (n->op() == "_Arg") {
562
+ arg.push_back(n);
563
+ } else if (n->op() == "_Retval") {
564
+ ret.push_back(n);
565
+ } else {
566
+ body.push_back(n);
567
+ }
568
+ }
569
+ auto comp = [](const NodeDef* x, const NodeDef* y) {
570
+ int xi;
571
+ TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
572
+ int yi;
573
+ TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
574
+ return xi < yi;
575
+ };
576
+ std::sort(arg.begin(), arg.end(), comp);
577
+ std::sort(ret.begin(), ret.end(), comp);
578
+ string out;
579
+ strings::StrAppend(&out, "\n(");
580
+ auto get_type = [](const NodeDef& n) {
581
+ DataType dt;
582
+ if (!GetNodeAttr(n, "T", &dt).ok()) {
583
+ dt = DT_INVALID;
584
+ }
585
+ return DataTypeString(dt);
586
+ };
587
+ for (size_t i = 0; i < arg.size(); ++i) {
588
+ const NodeDef* n = arg[i];
589
+ if (i > 0) strings::StrAppend(&out, ", ");
590
+ CHECK_GE(n->attr_size(), 2);
591
+ strings::StrAppend(&out, n->name(), ":", get_type(*n));
592
+ }
593
+ strings::StrAppend(&out, ") -> (");
594
+ for (size_t i = 0; i < ret.size(); ++i) {
595
+ const NodeDef* n = ret[i];
596
+ if (i > 0) strings::StrAppend(&out, ", ");
597
+ CHECK_LE(2, n->attr_size());
598
+ CHECK_EQ(1, n->input_size());
599
+ strings::StrAppend(&out, n->input(0), ":", get_type(*n));
600
+ }
601
+ strings::StrAppend(&out, ") {\n");
602
+ for (size_t i = 0; i < body.size(); ++i) {
603
+ strings::StrAppend(&out, " ", Print(*body[i]), "\n");
604
+ }
605
+ strings::StrAppend(&out, "}\n");
606
+ return out;
607
+ }
608
+
609
+ Status AddDefaultAttrs(const string& op,
610
+ const GetFunctionSignature& get_function,
611
+ AttrValueMap* attrs) {
612
+ const OpDef* op_def = nullptr;
613
+ TF_RETURN_IF_ERROR(get_function(op, &op_def));
614
+ AttrSlice attr_slice(attrs);
615
+ for (const auto& attr_def : op_def->attr()) {
616
+ if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
617
+ if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
618
+ return errors::Internal("Somehow duplicated: ", attr_def.name());
619
+ }
620
+ }
621
+ }
622
+ return Status::OK();
623
+ }
624
+
625
+ } // end namespace
626
+
627
+ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
628
+ GetFunctionSignature get_function,
629
+ InstantiationResult* result) {
630
+ VLOG(3) << "Instantiation Function: " << Print(fdef);
631
+
632
+ const OpDef& sig = fdef.signature();
633
+ TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
634
+
635
+ FunctionInstantiationHelper helper(get_function, result);
636
+ Status s;
637
+ for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
638
+ s = helper.BuildInputArgIndex(arg_def, attr_values);
639
+ if (!s.ok()) {
640
+ errors::AppendToMessage(&s, "In ", Print(arg_def));
641
+ return s;
642
+ }
643
+ }
644
+
645
+ auto substitute = [attr_values](StringPiece name, AttrValue* val) {
646
+ if (const AttrValue* v = attr_values.Find(name)) {
647
+ *val = *v;
648
+ return true;
649
+ }
650
+ return false;
651
+ };
652
+
653
+ // Makes a copy of all attrs in fdef and substitutes placeholders.
654
+ // After this step, every attr is bound to a concrete value.
655
+ std::vector<AttrValueMap> node_attrs;
656
+ node_attrs.resize(fdef.node_def_size());
657
+ for (int i = 0; i < fdef.node_def_size(); ++i) {
658
+ for (auto attr : fdef.node_def(i).attr()) {
659
+ if (!SubstitutePlaceholders(substitute, &attr.second)) {
660
+ return errors::InvalidArgument("Failed to bind all placeholders in ",
661
+ SummarizeAttrValue(attr.second));
662
+ }
663
+ if (!node_attrs[i].insert(attr).second) {
664
+ return errors::Internal("Somehow duplicated: ", attr.first);
665
+ }
666
+ }
667
+ TF_RETURN_IF_ERROR(
668
+ AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
669
+ }
670
+
671
+ for (int i = 0; i < fdef.node_def_size(); ++i) {
672
+ s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
673
+ result->nodes.size() + i);
674
+ if (!s.ok()) {
675
+ errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
676
+ return s;
677
+ }
678
+ }
679
+ // Emits one node for each fdef.node_def.
680
+ for (int i = 0; i < fdef.node_def_size(); ++i) {
681
+ s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
682
+ if (!s.ok()) {
683
+ errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
684
+ return s;
685
+ }
686
+ }
687
+
688
+ // Emits nodes for the function's return values.
689
+ int ret_index = 0;
690
+ for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
691
+ s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), &ret_index);
692
+ if (!s.ok()) {
693
+ errors::AppendToMessage(&s, "In function output ", Print(ret_def));
694
+ return s;
695
+ }
696
+ }
697
+
698
+ // Adds the actual node inputs using the input indexes.
699
+ helper.AddNodeInputs();
700
+
701
+ return Status::OK();
702
+ }
703
+
704
+ string DebugString(const FunctionDef& func_def) { return Print(func_def); }
705
+
706
+ string DebugString(const GraphDef& instantiated_func_def) {
707
+ std::vector<const NodeDef*> ptrs;
708
+ for (const NodeDef& n : instantiated_func_def.node()) {
709
+ ptrs.push_back(&n);
710
+ }
711
+ return Print(ptrs);
712
+ }
713
+
714
+ string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
715
+ std::vector<const NodeDef*> ptrs;
716
+ for (const NodeDef& n : instantiated_func_nodes) {
717
+ ptrs.push_back(&n);
718
+ }
719
+ return Print(ptrs);
720
+ }
721
+
722
+ string DebugStringWhole(const GraphDef& gdef) {
723
+ string ret;
724
+ for (const auto& fdef : gdef.library().function()) {
725
+ strings::StrAppend(&ret, Print(fdef));
726
+ }
727
+ strings::StrAppend(&ret, "\n");
728
+ for (const auto& ndef : gdef.node()) {
729
+ strings::StrAppend(&ret, Print(ndef), "\n");
730
+ }
731
+ return ret;
732
+ }
733
+
734
+ namespace {
735
+
736
+ // Returns the name -> attr mapping of fdef's attrs that have a value set. In
737
+ // Python, it's possible to access unset attrs, which returns a default value
738
+ // and adds an unset attr to the map.
739
+ std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
740
+ std::map<string, AttrValue> set_attrs;
741
+ for (auto pair : fdef.attr()) {
742
+ if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
743
+ set_attrs[pair.first] = pair.second;
744
+ }
745
+ }
746
+ return set_attrs;
747
+ }
748
+
749
+ } // end namespace
750
+
751
+ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
752
+ if (!OpDefEqual(f1.signature(), f2.signature())) return false;
753
+
754
+ std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
755
+ std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
756
+ if (f1_attrs.size() != f2_attrs.size()) return false;
757
+ for (auto iter1 : f1_attrs) {
758
+ auto iter2 = f2_attrs.find(iter1.first);
759
+ if (iter2 == f2_attrs.end()) return false;
760
+ if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
761
+ }
762
+
763
+ if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
764
+ return false;
765
+ }
766
+
767
+ std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
768
+ std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
769
+ if (ret1 != ret2) return false;
770
+
771
+ return true;
772
+ }
773
+
774
+ uint64 FunctionDefHash(const FunctionDef& fdef) {
775
+ // signature
776
+ uint64 h = OpDefHash(fdef.signature());
777
+
778
+ // attrs
779
+ std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
780
+ for (const auto& p : attrs) {
781
+ h = Hash64(p.first.data(), p.first.size(), h);
782
+ h = Hash64Combine(AttrValueHash(p.second), h);
783
+ }
784
+
785
+ // node defs
786
+ h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
787
+
788
+ // output names
789
+ std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
790
+ for (const auto& p : ret) {
791
+ h = Hash64(p.first.data(), p.first.size(), h);
792
+ h = Hash64(p.second.data(), p.second.size(), h);
793
+ }
794
+
795
+ return h;
796
+ }
797
+
798
+ string Canonicalize(const string& funcname, AttrSlice attrs) {
799
+ std::vector<string> entries;
800
+ entries.reserve(attrs.size());
801
+ for (auto p : attrs) {
802
+ entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
803
+ }
804
+ std::sort(entries.begin(), entries.end());
805
+ return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
806
+ }
807
+
808
+ FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
809
+ DataTypeSlice ret_types)
810
+ : arg_types_(arg_types.begin(), arg_types.end()),
811
+ ret_types_(ret_types.begin(), ret_types.end()) {
812
+ args_.resize(arg_types_.size());
813
+ rets_.resize(ret_types_.size());
814
+ }
815
+
816
+ FunctionCallFrame::~FunctionCallFrame() {}
817
+
818
+ Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
819
+ // Input type checks.
820
+ if (args.size() != arg_types_.size()) {
821
+ return errors::InvalidArgument("Expects ", arg_types_.size(),
822
+ " arguments, but ", args.size(),
823
+ " is provided");
824
+ }
825
+ for (size_t i = 0; i < args.size(); ++i) {
826
+ if (arg_types_[i] != args[i].dtype()) {
827
+ return errors::InvalidArgument(
828
+ "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
829
+ DataTypeString(args[i].dtype()), " is provided");
830
+ }
831
+ args_[i] = args[i];
832
+ }
833
+ return Status::OK();
834
+ }
835
+
836
+ Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
837
+ rets->clear();
838
+ rets->reserve(rets_.size());
839
+ for (size_t i = 0; i < rets_.size(); ++i) {
840
+ const auto& item = rets_[i];
841
+ if (item.has_val) {
842
+ rets->push_back(item.val);
843
+ } else {
844
+ return errors::Internal("Retval[", i, "] does not have value");
845
+ }
846
+ }
847
+ return Status::OK();
848
+ }
849
+
850
+ Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets) {
851
+ rets->clear();
852
+ rets->reserve(rets_.size());
853
+ for (size_t i = 0; i < rets_.size(); ++i) {
854
+ if (rets_[i].has_val) {
855
+ rets->emplace_back(std::move(rets_[i].val));
856
+ } else {
857
+ return errors::Internal("Retval[", i, "] does not have value");
858
+ }
859
+ }
860
+ return Status::OK();
861
+ }
862
+
863
+ Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
864
+ if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
865
+ return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
866
+ args_.size(), ")");
867
+ }
868
+ *val = args_[index];
869
+ return Status::OK();
870
+ }
871
+
872
+ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
873
+ if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
874
+ return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
875
+ rets_.size(), ")");
876
+ }
877
+ if (val.dtype() != ret_types_[index]) {
878
+ return errors::InvalidArgument(
879
+ "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
880
+ ", but ", DataTypeString(val.dtype()), " is provided.");
881
+ }
882
+ Retval* item = &rets_[index];
883
+ if (!item->has_val) {
884
+ item->has_val = true;
885
+ item->val = val;
886
+ } else {
887
+ return errors::Internal("Retval[", index, "] has already been set.");
888
+ }
889
+ return Status::OK();
890
+ }
891
+
892
+ FunctionLibraryDefinition::FunctionDefAndOpRegistration::
893
+ FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
894
+ : fdef(fdef_in),
895
+ // Exact shape inference for functions is handled by ShapeRefiner.
896
+ // Here we pass a dummy shape inference function for legacy code paths.
897
+ op_registration_data(fdef.signature(), shape_inference::UnknownShape,
898
+ true /* is_function */) {}
899
+
900
+ FunctionLibraryDefinition::FunctionLibraryDefinition(
901
+ const FunctionLibraryDefinition& other)
902
+ : default_registry_(other.default_registry_), func_grad_(other.func_grad_) {
903
+ for (const auto& it : other.function_defs_) {
904
+ TF_CHECK_OK(AddFunctionDef(it.second->fdef));
905
+ }
906
+ }
907
+
908
+ FunctionLibraryDefinition::FunctionLibraryDefinition(
909
+ const OpRegistryInterface* default_registry,
910
+ const FunctionDefLibrary& def_lib)
911
+ : default_registry_(default_registry),
912
+ function_defs_(def_lib.function_size()) {
913
+ for (const auto& fdef : def_lib.function()) {
914
+ // The latter function definition wins.
915
+ auto& ptr = function_defs_[fdef.signature().name()];
916
+ ptr.reset(new FunctionDefAndOpRegistration(fdef));
917
+ }
918
+ for (const auto& grad : def_lib.gradient()) {
919
+ func_grad_[grad.function_name()] = grad.gradient_func();
920
+ }
921
+ }
922
+
923
+ FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
924
+
925
+ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
926
+ auto iter = function_defs_.find(name);
927
+ if (iter == function_defs_.end()) {
928
+ return nullptr;
929
+ } else {
930
+ return &iter->second->fdef;
931
+ }
932
+ }
933
+
934
+ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
935
+ bool added;
936
+ return AddFunctionDefHelper(fdef, &added);
937
+ }
938
+
939
+ Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
940
+ bool* added) {
941
+ *added = false;
942
+ std::unique_ptr<FunctionDefAndOpRegistration>* entry =
943
+ &function_defs_[fdef.signature().name()];
944
+ if (*entry != nullptr) {
945
+ if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
946
+ return errors::InvalidArgument(
947
+ "Cannot add function '", fdef.signature().name(),
948
+ "' because a different function with the same name already "
949
+ "exists.");
950
+ }
951
+ // Ignore duplicate FunctionDefs
952
+ return Status::OK();
953
+ }
954
+ const OpDef* op_def;
955
+ if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
956
+ return errors::InvalidArgument(
957
+ "Cannot add function '", fdef.signature().name(),
958
+ "' because an op with the same name already exists.");
959
+ }
960
+ entry->reset(new FunctionDefAndOpRegistration(fdef));
961
+ *added = true;
962
+ return Status::OK();
963
+ }
964
+
965
+ Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
966
+ bool added;
967
+ return AddGradientDefHelper(grad, &added);
968
+ }
969
+
970
+ Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
971
+ bool* added) {
972
+ *added = false;
973
+ string* entry = &func_grad_[grad.function_name()];
974
+ if (!entry->empty()) {
975
+ if (*entry != grad.gradient_func()) {
976
+ return errors::InvalidArgument(
977
+ "Cannot assign gradient function '", grad.gradient_func(), "' to '",
978
+ grad.function_name(), "' because it already has gradient function ",
979
+ "'", *entry, "'");
980
+ }
981
+ // Ignore duplicate GradientDefs
982
+ return Status::OK();
983
+ }
984
+ *entry = grad.gradient_func();
985
+ *added = true;
986
+ return Status::OK();
987
+ }
988
+
989
+ Status FunctionLibraryDefinition::AddLibrary(
990
+ const FunctionLibraryDefinition& other) {
991
+ // Remember the funcs and grads that we added successfully so that
992
+ // we can roll them back on error.
993
+ std::vector<string> funcs;
994
+ std::vector<string> funcs_with_grads;
995
+ Status s;
996
+ bool added;
997
+ for (auto iter : other.function_defs_) {
998
+ s = AddFunctionDefHelper(iter.second->fdef, &added);
999
+ if (!s.ok()) {
1000
+ Remove(funcs, funcs_with_grads);
1001
+ return s;
1002
+ }
1003
+ if (added) {
1004
+ funcs.push_back(iter.second->fdef.signature().name());
1005
+ }
1006
+ }
1007
+ for (auto iter : other.func_grad_) {
1008
+ GradientDef grad;
1009
+ grad.set_function_name(iter.first);
1010
+ grad.set_gradient_func(iter.second);
1011
+ s = AddGradientDefHelper(grad, &added);
1012
+ if (!s.ok()) {
1013
+ Remove(funcs, funcs_with_grads);
1014
+ return s;
1015
+ }
1016
+ if (added) {
1017
+ funcs_with_grads.push_back(grad.function_name());
1018
+ }
1019
+ }
1020
+ return Status::OK();
1021
+ }
1022
+
1023
+ Status FunctionLibraryDefinition::AddLibrary(
1024
+ const FunctionDefLibrary& lib_def) {
1025
+ // Remember the funcs and grads that we added successfully so that
1026
+ // we can roll them back on error.
1027
+ std::vector<string> funcs;
1028
+ std::vector<string> funcs_with_grads;
1029
+ Status s;
1030
+ bool added;
1031
+ for (const FunctionDef& fdef : lib_def.function()) {
1032
+ s = AddFunctionDefHelper(fdef, &added);
1033
+ if (!s.ok()) {
1034
+ Remove(funcs, funcs_with_grads);
1035
+ return s;
1036
+ }
1037
+ if (added) {
1038
+ funcs.push_back(fdef.signature().name());
1039
+ }
1040
+ }
1041
+ for (const GradientDef& grad : lib_def.gradient()) {
1042
+ s = AddGradientDefHelper(grad, &added);
1043
+ if (!s.ok()) {
1044
+ Remove(funcs, funcs_with_grads);
1045
+ return s;
1046
+ }
1047
+ if (added) {
1048
+ funcs_with_grads.push_back(grad.function_name());
1049
+ }
1050
+ }
1051
+ return Status::OK();
1052
+ }
1053
+
1054
+ void FunctionLibraryDefinition::RemoveFunction(const string& func) {
1055
+ const auto& i = function_defs_.find(func);
1056
+ DCHECK(i != function_defs_.end());
1057
+ function_defs_.erase(i);
1058
+ }
1059
+
1060
+ void FunctionLibraryDefinition::RemoveGradient(const string& func) {
1061
+ const auto& i = func_grad_.find(func);
1062
+ DCHECK(i != func_grad_.end());
1063
+ func_grad_.erase(i);
1064
+ }
1065
+
1066
+ void FunctionLibraryDefinition::Remove(
1067
+ const std::vector<string>& funcs,
1068
+ const std::vector<string>& funcs_with_grads) {
1069
+ for (const string& f : funcs) {
1070
+ RemoveFunction(f);
1071
+ }
1072
+ for (const string& f : funcs_with_grads) {
1073
+ RemoveGradient(f);
1074
+ }
1075
+ }
1076
+
1077
+ string FunctionLibraryDefinition::FindGradient(const string& func) const {
1078
+ return gtl::FindWithDefault(func_grad_, func, "");
1079
+ }
1080
+
1081
+ Status FunctionLibraryDefinition::LookUp(
1082
+ const string& op, const OpRegistrationData** op_reg_data) const {
1083
+ auto iter = function_defs_.find(op);
1084
+ if (iter != function_defs_.end()) {
1085
+ *op_reg_data = &iter->second->op_registration_data;
1086
+ return Status::OK();
1087
+ }
1088
+ return default_registry_->LookUp(op, op_reg_data);
1089
+ }
1090
+
1091
+ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1092
+ const NodeDef& ndef) const {
1093
+ if (ndef.op() != kGradientOp) {
1094
+ // If 'ndef' calls a function and the function's def has the attr,
1095
+ // returns it.
1096
+ return Find(ndef.op());
1097
+ }
1098
+
1099
+ // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1100
+ // Foo's attributes.
1101
+ const NameAttrList* forward_func_attrs;
1102
+ if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
1103
+ return nullptr;
1104
+ }
1105
+ const string& func_name = forward_func_attrs->name();
1106
+ const string& grad_name = FindGradient(func_name);
1107
+ // If 'func' has a user-defined gradient function, uses the grad
1108
+ // function's attrs to see if noinline is specified. Otherwise,
1109
+ // uses func's attrs.
1110
+ if (!grad_name.empty()) {
1111
+ return Find(grad_name);
1112
+ }
1113
+ return Find(func_name);
1114
+ }
1115
+
1116
+ FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1117
+ FunctionDefLibrary lib;
1118
+ for (const auto& f : function_defs_) {
1119
+ *lib.add_function() = f.second->fdef;
1120
+ }
1121
+ for (const auto& g : func_grad_) {
1122
+ GradientDef* gd = lib.add_gradient();
1123
+ gd->set_function_name(g.first);
1124
+ gd->set_gradient_func(g.second);
1125
+ }
1126
+ return lib;
1127
+ }
1128
+
1129
+ template <typename T>
1130
+ Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1131
+ const string& attr, T* value) const {
1132
+ const FunctionDef* fdef = GetAttrImpl(ndef);
1133
+ if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
1134
+ return Status::OK();
1135
+ }
1136
+ return errors::InvalidArgument("Attr ", attr, " is not defined.");
1137
+ }
1138
+
1139
+ template <typename T>
1140
+ Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1141
+ T* value) const {
1142
+ return GetAttr(node.def(), attr, value);
1143
+ }
1144
+
1145
+ #define GET_ATTR(T) \
1146
+ template Status FunctionLibraryDefinition::GetAttr(const Node&, \
1147
+ const string&, T*) const; \
1148
+ template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \
1149
+ const string&, T*) const;
1150
+ GET_ATTR(string)
1151
+ GET_ATTR(bool)
1152
+ #undef GET_ATTR
1153
+
1154
+ void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1155
+ if (val.size() >= 2 && val[0] == '$') {
1156
+ proto.set_placeholder(val.data() + 1, val.size() - 1);
1157
+ } else {
1158
+ SetAttrValue(val, &proto);
1159
+ }
1160
+ }
1161
+
1162
+ FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1163
+ const string& name,
1164
+ gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1165
+ AttrValueWrapper ret;
1166
+ ret.proto.mutable_func()->set_name(name);
1167
+ for (const auto& a : attrs) {
1168
+ ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1169
+ }
1170
+ return ret;
1171
+ }
1172
+
1173
+ NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1174
+ NodeDef n;
1175
+ n.set_op(this->op);
1176
+ n.set_name(this->ret[0]);
1177
+ for (const auto& a : this->attr) {
1178
+ n.mutable_attr()->insert({a.first, a.second.proto});
1179
+ }
1180
+ for (const string& a : this->arg) {
1181
+ n.add_input(a);
1182
+ }
1183
+ for (const string& d : this->dep) {
1184
+ n.add_input(strings::StrCat("^", d));
1185
+ }
1186
+ return n;
1187
+ }
1188
+
1189
+ /* static */
1190
+ FunctionDef FunctionDefHelper::Create(
1191
+ const string& function_name, gtl::ArraySlice<string> in_def,
1192
+ gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1193
+ gtl::ArraySlice<Node> node_def,
1194
+ gtl::ArraySlice<std::pair<string, string>> ret_def) {
1195
+ FunctionDef fdef;
1196
+
1197
+ // Signature
1198
+ OpDefBuilder b(function_name);
1199
+ for (const auto& i : in_def) b.Input(i);
1200
+ for (const auto& o : out_def) b.Output(o);
1201
+ for (const auto& a : attr_def) b.Attr(a);
1202
+
1203
+ OpRegistrationData op_reg_data;
1204
+ TF_CHECK_OK(b.Finalize(&op_reg_data));
1205
+ fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1206
+
1207
+ // Function body
1208
+ for (const auto& n : node_def) {
1209
+ *(fdef.add_node_def()) = n.ToNodeDef();
1210
+ }
1211
+
1212
+ // Returns
1213
+ for (const auto& r : ret_def) {
1214
+ fdef.mutable_ret()->insert({r.first, r.second});
1215
+ }
1216
+ return fdef;
1217
+ }
1218
+
1219
+ /* static */
1220
+ FunctionDef FunctionDefHelper::Define(const string& name,
1221
+ gtl::ArraySlice<string> arg_def,
1222
+ gtl::ArraySlice<string> ret_def,
1223
+ gtl::ArraySlice<string> attr_def,
1224
+ gtl::ArraySlice<Node> node_def) {
1225
+ FunctionDef fdef;
1226
+ OpDefBuilder b(name);
1227
+ for (const auto& a : arg_def) b.Input(a);
1228
+ for (const auto& r : ret_def) b.Output(r);
1229
+ for (const auto& a : attr_def) b.Attr(a);
1230
+
1231
+ OpRegistrationData op_reg_data;
1232
+ TF_CHECK_OK(b.Finalize(&op_reg_data));
1233
+ fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1234
+
1235
+ // Mapping from legacy output names to NodeDef outputs.
1236
+ std::unordered_map<string, string> ret_index;
1237
+ for (const auto& a : fdef.signature().input_arg()) {
1238
+ ret_index[a.name()] = a.name();
1239
+ }
1240
+
1241
+ // For looking up OpDefs
1242
+ auto* op_def_registry = OpRegistry::Global();
1243
+
1244
+ // Function body
1245
+ for (const auto& src : node_def) {
1246
+ NodeDef* n = fdef.add_node_def();
1247
+ n->set_op(src.op);
1248
+ n->set_name(src.ret[0]);
1249
+ for (const auto& a : src.attr) {
1250
+ n->mutable_attr()->insert({a.first, a.second.proto});
1251
+ }
1252
+ for (const string& a : src.arg) {
1253
+ const auto iter = ret_index.find(a);
1254
+ CHECK(iter != ret_index.end()) << "Node input '" << a << "' in '"
1255
+ << src.ret[0] << "' of " << name;
1256
+ n->add_input(iter->second);
1257
+ }
1258
+ for (const string& d : src.dep) {
1259
+ n->add_input(strings::StrCat("^", d));
1260
+ }
1261
+
1262
+ // Add the outputs of this node to ret_index.
1263
+ const OpDef* op_def = nullptr;
1264
+ TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1265
+ CHECK(op_def != nullptr) << n->op();
1266
+ NameRangeMap output_names;
1267
+ TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1268
+ for (const auto& o : output_names) {
1269
+ CHECK_LE(o.second.second, src.ret.size())
1270
+ << "Missing ret for output '" << o.first << "' in '" << src.ret[0]
1271
+ << "' of " << name;
1272
+ for (int i = o.second.first; i < o.second.second; ++i) {
1273
+ ret_index[src.ret[i]] =
1274
+ strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
1275
+ }
1276
+ }
1277
+ }
1278
+
1279
+ // Returns
1280
+ for (const auto& r : fdef.signature().output_arg()) {
1281
+ const auto iter = ret_index.find(r.name());
1282
+ CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
1283
+ fdef.mutable_ret()->insert({r.name(), iter->second});
1284
+ }
1285
+ return fdef;
1286
+ }
1287
+
1288
+ FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
1289
+ gtl::ArraySlice<string> ret_def,
1290
+ gtl::ArraySlice<string> attr_def,
1291
+ gtl::ArraySlice<Node> node_def) {
1292
+ return Define("_", arg_def, ret_def, attr_def, node_def);
1293
+ }
1294
+
1295
+ namespace gradient {
1296
+
1297
+ typedef std::unordered_map<string, Creator> OpGradFactory;
1298
+
1299
+ OpGradFactory* GetOpGradFactory() {
1300
+ static OpGradFactory* factory = new OpGradFactory;
1301
+ return factory;
1302
+ }
1303
+
1304
+ bool RegisterOp(const string& op, Creator func) {
1305
+ CHECK(GetOpGradFactory()->insert({op, func}).second)
1306
+ << "Duplicated gradient for " << op;
1307
+ return true;
1308
+ }
1309
+
1310
+ Status GetOpGradientCreator(const string& op, Creator* creator) {
1311
+ auto fac = GetOpGradFactory();
1312
+ auto iter = fac->find(op);
1313
+ if (iter == fac->end()) {
1314
+ return errors::NotFound("No gradient defined for op: ", op);
1315
+ }
1316
+ *creator = iter->second;
1317
+ return Status::OK();
1318
+ }
1319
+
1320
+ } // end namespace gradient
1321
+
1322
+ } // end namespace tensorflow
function.h ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_
17
+ #define TENSORFLOW_FRAMEWORK_FUNCTION_H_
18
+
19
+ #include <vector>
20
+ #include "tensorflow/core/framework/attr_value.pb.h"
21
+ #include "tensorflow/core/framework/attr_value_util.h"
22
+ #include "tensorflow/core/framework/function.pb.h"
23
+ #include "tensorflow/core/framework/node_def_util.h"
24
+ #include "tensorflow/core/framework/op.h"
25
+ #include "tensorflow/core/framework/selective_registration.h"
26
+ #include "tensorflow/core/framework/types.h"
27
+ #include "tensorflow/core/lib/gtl/flatmap.h"
28
+ #include "tensorflow/core/lib/hash/hash.h"
29
+ #include "tensorflow/core/platform/env.h"
30
+ #include "tensorflow/core/platform/macros.h"
31
+ #include "tensorflow/core/platform/protobuf.h"
32
+
33
+ namespace tensorflow {
34
+
35
+ class CancellationManager;
36
+ class GraphDef;
37
+ class OpKernel;
38
+ class ResourceMgr;
39
+ class Rendezvous;
40
+ class ScopedStepContainer;
41
+ class StepStatsCollector;
42
+ class Node;
43
+
44
+ // FunctionDefHelper::Create is a convenient helper to construct a
45
+ // FunctionDef proto.
46
+ // E.g.,
47
+ // FunctionDef my_func = FunctionDefHelper::Create(
48
+ // "my_func_name",
49
+ // {"x:T", "y:T" /* one string per argument */},
50
+ // {"z:T" /* one string per return value */},
51
+ // {"T: {float, double}" /* one string per attribute */},
52
+ // {
53
+ // {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
54
+ // /* one entry per function node */
55
+ // },
56
+ // /* Mapping between function returns and function node outputs. */
57
+ // {{"z", "o:z"}});
58
+ //
59
+ // For the old Function::Node approach, use FunctionDefHelper::Define()
60
+ // E.g.,
61
+ // FunctionDef my_func = FunctionDefHelper::Define(
62
+ // "my_func_name",
63
+ // {"x:T", "y:T" /* one string per argument */},
64
+ // {"z:T" /* one string per return value */},
65
+ // {"T: {float, double}" /* one string per attribute */},
66
+ // {
67
+ // {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
68
+ // /* one entry per function node */
69
+ // });
70
+ class FunctionDefHelper {
71
+ public:
72
+ // AttrValueWrapper has copy constructors for the type T so that
73
+ // it's easy to construct a simple AttrValue proto.
74
+ //
75
+ // If T is a string type (const char*, string, or StringPiece), and
76
+ // it starts with "$", we construct a AttrValue of "placeholder".
77
+ //
78
+ // E.g.,
79
+ // std::<string, AttrValueWrapper> x = {"T", "$T"}
80
+ // is a named attr value placeholder.
81
+ struct AttrValueWrapper {
82
+ AttrValue proto;
83
+
84
+ AttrValueWrapper() {}
85
+
86
+ template <typename T>
87
+ AttrValueWrapper(T val) { // NOLINT(runtime/explicit)
88
+ SetAttrValue(val, &proto);
89
+ }
90
+
91
+ private:
92
+ void InitFromString(StringPiece val);
93
+ };
94
+
95
+ // Constructs an AttrValue.func given the "name" and "attrs".
96
+ static AttrValueWrapper FunctionRef(
97
+ const string& name,
98
+ gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
99
+ static AttrValueWrapper FunctionRef(const string& name) {
100
+ return FunctionRef(name, {});
101
+ }
102
+
103
+ // Node is used to construct FunctionDef.Node using initialization
104
+ // lists. E.g.,
105
+ // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y
106
+ struct Node {
107
+ // When constructing a NodeDef, the first entry in ret is used as
108
+ // the node name, the remaining values are ignored.
109
+ std::vector<string> ret;
110
+ string op;
111
+ std::vector<string> arg;
112
+ std::vector<std::pair<string, AttrValueWrapper>> attr;
113
+ std::vector<string> dep;
114
+
115
+ NodeDef ToNodeDef() const;
116
+ };
117
+
118
+ // The Create() function uses the new NodeDef field. `ret_def`
119
+ // holds a mapping from the function output names from `out_def` to
120
+ // the node outputs from `node_def`.
121
+ static FunctionDef Create(const string& function_name,
122
+ gtl::ArraySlice<string> in_def,
123
+ gtl::ArraySlice<string> out_def,
124
+ gtl::ArraySlice<string> attr_def,
125
+ gtl::ArraySlice<Node> node_def,
126
+ gtl::ArraySlice<std::pair<string, string>> ret_def);
127
+
128
+ // The two Define() functions use the old FunctionDef::Node field.
129
+ // TODO(josh11b): Get rid of these and transition to the one above.
130
+ static FunctionDef Define(const string& function_name,
131
+ gtl::ArraySlice<string> arg_def,
132
+ gtl::ArraySlice<string> ret_def,
133
+ gtl::ArraySlice<string> attr_def,
134
+ gtl::ArraySlice<Node> node_def);
135
+
136
+ // Defines an anonymous function. I.e., its name is not relevant.
137
+ static FunctionDef Define(gtl::ArraySlice<string> arg_def,
138
+ gtl::ArraySlice<string> ret_def,
139
+ gtl::ArraySlice<string> attr_def,
140
+ gtl::ArraySlice<Node> node_def);
141
+
142
+ // Helpers to construct a constant scalar.
143
+ template <typename T>
144
+ static Node Const(const string& name, const T& val) {
145
+ Node n = {{name}, "Const"};
146
+ const DataType dtype = DataTypeToEnum<T>::value;
147
+ n.attr.push_back({"dtype", dtype});
148
+ Tensor t(dtype, TensorShape({}));
149
+ t.scalar<T>()() = val;
150
+ n.attr.push_back({"value", t});
151
+ return n;
152
+ }
153
+
154
+ template <typename T>
155
+ static Node Const(const string& name, gtl::ArraySlice<T> vals) {
156
+ Node n = {{name}, "Const"};
157
+ const DataType dtype = DataTypeToEnum<T>::value;
158
+ n.attr.push_back({"dtype", dtype});
159
+ int64 num = vals.size();
160
+ Tensor t(dtype, TensorShape({num}));
161
+ for (size_t i = 0; i < vals.size(); ++i) {
162
+ t.flat<T>()(i) = vals[i];
163
+ }
164
+ n.attr.push_back({"value", t});
165
+ return n;
166
+ }
167
+ };
168
+
169
+ template <>
170
+ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
171
+ InitFromString(val);
172
+ }
173
+
174
+ template <>
175
+ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
176
+ const string& val) {
177
+ InitFromString(val);
178
+ }
179
+
180
+ template <>
181
+ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
182
+ InitFromString(val);
183
+ }
184
+
185
+ // Instantiate a function.
186
+ //
187
+ // "fdef" encodes a TF function with some attrs in fdef.signature.attr
188
+ // containing placeholders. InstantiateFunction binds these
189
+ // placeholders and produces an instantiated function encoded in
190
+ // "result.gdef". The value to substitute a placeholder is given by
191
+ // "attr_values", which is a map from a placeholder name to an attr
192
+ // value.
193
+ //
194
+ // InstantiateFunction calls "get_function" to find signatures of other
195
+ // functions and primitive ops.
196
+
197
+ // GetFunctionSignature(func name, opdef) returns OK if the func name is found
198
+ // and opdef is filled with a pointer to the corresponding signature
199
+ // (a OpDef proto). Otherwise, returns an error.
200
+ typedef std::function<Status(const string&, const OpDef**)>
201
+ GetFunctionSignature;
202
+
203
+ struct InstantiationResult {
204
+ DataTypeVector arg_types;
205
+ DataTypeVector ret_types;
206
+ std::vector<NodeDef> nodes;
207
+ };
208
+ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
209
+ GetFunctionSignature get_function,
210
+ InstantiationResult* result);
211
+
212
+ // Returns a debug string for a function definition.
213
+ //
214
+ // The returned text is multiple-line. It is intended to be
215
+ // human-readable rather than being friendly to parsers. It is _NOT_
216
+ // intended to be the canonical string representation of "func_def".
217
+ // Particularly, it may not include all information presented in
218
+ // "func_def" (e.g., comments, description of the function arguments,
219
+ // etc.)
220
+ string DebugString(const FunctionDef& func_def);
221
+ string DebugString(const GraphDef& instantiated_func_def);
222
+ string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
223
+
224
+ // Returns a debug string for a top level graph (the main program and
225
+ // its supporting functions defined in its library).
226
+ string DebugStringWhole(const GraphDef& gdef);
227
+
228
+ // Returns true if f1 == f2. Compares all fields, including descriptions. Order
229
+ // of NodeDefs doesn't matter.
230
+ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
231
+
232
+ // Return a hash of `fdef` that is consistent with FunctionDefsEqual method.
233
+ // In other words, if two fdefs compare equal, their hash values will be the
234
+ // same.
235
+ uint64 FunctionDefHash(const FunctionDef& fdef);
236
+
237
+ // Returns a canonicalized string for the instantiation of the
238
+ // function of the given "name" and attributes "attrs".
239
+ //
240
+ // The returned string is guaranteed to be stable within one address
241
+ // space. But it may be change as the implementation
242
+ // evolves. Therefore, it should not be persisted or compared across
243
+ // address spaces.
244
+ string Canonicalize(const string& funcname, AttrSlice attrs);
245
+
246
+ class CallFrameInterface {
247
+ public:
248
+ virtual ~CallFrameInterface() {}
249
+
250
+ virtual size_t num_args() const = 0;
251
+ virtual size_t num_retvals() const = 0;
252
+
253
+ virtual Status GetArg(int index, Tensor* val) const = 0;
254
+ virtual Status SetRetval(int index, const Tensor& val) = 0;
255
+ };
256
+
257
+ // Represents a function call frame. I.e., the data structure used to
258
+ // pass arguments to a function and retrieve its results.
259
+ //
260
+ // Runtime must arrange accesses to one FunctionCallFrame s.t.
261
+ // 1. SetArgs() happens before any GetArg();
262
+ // 2. GetRetvals happens after all SetRetval();
263
+ class FunctionCallFrame : public CallFrameInterface {
264
+ public:
265
+ FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
266
+ ~FunctionCallFrame();
267
+
268
+ // Caller methods.
269
+ Status SetArgs(gtl::ArraySlice<Tensor> args);
270
+ Status GetRetvals(std::vector<Tensor>* rets) const;
271
+ Status ConsumeRetvals(std::vector<Tensor>* rets);
272
+
273
+ size_t num_args() const override { return arg_types_.size(); }
274
+ size_t num_retvals() const override { return ret_types_.size(); }
275
+
276
+ // Callee methods.
277
+ Status GetArg(int index, Tensor* val) const override;
278
+ Status SetRetval(int index, const Tensor& val) override;
279
+
280
+ private:
281
+ DataTypeVector arg_types_;
282
+ DataTypeVector ret_types_;
283
+ gtl::InlinedVector<Tensor, 4> args_;
284
+ struct Retval {
285
+ bool has_val = false;
286
+ Tensor val;
287
+ };
288
+ gtl::InlinedVector<Retval, 4> rets_;
289
+
290
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
291
+ };
292
+
293
+ // Helper to maintain a map between function names in a given
294
+ // FunctionDefLibrary and function definitions.
295
+ class FunctionLibraryDefinition : public OpRegistryInterface {
296
+ public:
297
+ explicit FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def);
298
+ FunctionLibraryDefinition(const OpRegistryInterface* default_registry,
299
+ const FunctionDefLibrary& lib_def);
300
+ ~FunctionLibraryDefinition() override;
301
+
302
+ FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) =
303
+ delete;
304
+
305
+ // Returns nullptr if "func" is not defined in "lib_def". Otherwise,
306
+ // returns its definition proto.
307
+ const FunctionDef* Find(const string& func) const;
308
+
309
+ // Adds function definition 'fdef' to this function library.
310
+ // Returns status 'ok' on success, or error otherwise. This is a no-op if
311
+ // 'fdef' already exists in this function library.
312
+ // If 'fdef' is successfully added to the library, it will be accessible
313
+ // from 'LookUp' and included in the proto returned by 'ToProto'.
314
+ // This operation is atomic.
315
+ Status AddFunctionDef(const FunctionDef& fdef);
316
+
317
+ // Adds gradient definition 'grad' to this function library.
318
+ // This is a no-op if 'grad' already exists in this function library.
319
+ // If 'grad' is successfully added, it will be accessible via 'FindGradient'
320
+ // and included in the proto returned by 'ToProto'.
321
+ // This operation is atomic.
322
+ Status AddGradientDef(const GradientDef& grad);
323
+
324
+ // Adds the functions and gradients in 'other' to this function library.
325
+ // Duplicate functions and gradients are ignored.
326
+ // This operation is atomic.
327
+ Status AddLibrary(const FunctionLibraryDefinition& other);
328
+
329
+ // Adds the functions and gradients in 'lib_def' to this function library.
330
+ // Duplicate functions and gradients are ignored.
331
+ // This operation is atomic.
332
+ Status AddLibrary(const FunctionDefLibrary& lib_def);
333
+
334
+ // If the gradient function for 'func' is specified explicitly in
335
+ // the library, returns the gradient function name. Otherwise,
336
+ // returns an empty string.
337
+ string FindGradient(const string& func) const;
338
+
339
+ // OpRegistryInterface method. Useful for constructing a Graph.
340
+ //
341
+ // If "op" is defined in the library, returns its signature.
342
+ // Otherwise, assume "op" is a primitive op and returns its op
343
+ // signature and shape inference function.
344
+ Status LookUp(const string& op_type_name,
345
+ const OpRegistrationData** op_reg_data) const override;
346
+
347
+ static constexpr const char* const kGradientOp = "SymbolicGradient";
348
+ static constexpr const char* const kFuncAttr = "f";
349
+
350
+ // Given a node def 'ndef', inspects attributes of the callee
351
+ // function to derive the attribute 'value' for 'attr'. Returns OK
352
+ // iff the attribute is given by the function's definition.
353
+ // TODO(irving): Remove; keep only the const Node& version.
354
+ template <typename T>
355
+ Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const;
356
+
357
+ // Given a node, inspects attributes of the callee function to derive the
358
+ // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
359
+ // function's definition.
360
+ template <typename T>
361
+ Status GetAttr(const Node& node, const string& attr, T* value) const;
362
+
363
+ // Returns a proto representation of the state of this function library.
364
+ FunctionDefLibrary ToProto() const;
365
+
366
+ const OpRegistryInterface* default_registry() const {
367
+ return default_registry_;
368
+ }
369
+
370
+ private:
371
+ // Shape inference for functions is handled separately by ShapeRefiner.
372
+
373
+ struct FunctionDefAndOpRegistration {
374
+ FunctionDefAndOpRegistration(const FunctionDef& fdef_in);
375
+
376
+ FunctionDef fdef;
377
+ OpRegistrationData op_registration_data;
378
+ };
379
+
380
+ // Same as AddFunctionDef/AddGradientDef except these methods set
381
+ // `added` to true if the `fdef`/`grad` were actually added to this.
382
+ Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added);
383
+ Status AddGradientDefHelper(const GradientDef& grad, bool* added);
384
+
385
+ const OpRegistryInterface* const default_registry_;
386
+ gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>>
387
+ function_defs_;
388
+ gtl::FlatMap<string, string> func_grad_;
389
+
390
+ // Helper function for GetAttr. Returns the FunctionDef* to get the
391
+ // attr from.
392
+ const FunctionDef* GetAttrImpl(const NodeDef& ndef) const;
393
+
394
+ // Remove function `func` from the library. `func` must be in the library.
395
+ void RemoveFunction(const string& func);
396
+
397
+ // Remove gradient of function `func` from the library. `func` must have
398
+ // a gradient.
399
+ void RemoveGradient(const string& func);
400
+
401
+ // Remove all functions in `funcs` and all gradients of
402
+ // functions in `funcs_with_grads` from this library.
403
+ void Remove(const std::vector<string>& funcs,
404
+ const std::vector<string>& funcs_with_grads);
405
+ };
406
+
407
+ // Forward declare. Defined in common_runtime/function.h
408
+ struct FunctionBody;
409
+
410
+ // Forward declare. Defined in common_runtime/device.h
411
+ class Device;
412
+
413
+ class FunctionLibraryRuntime {
414
+ public:
415
+ virtual ~FunctionLibraryRuntime() {}
416
+
417
+ // Instantiate a function with the given "attrs".
418
+ //
419
+ // Returns OK and fills in "handle" if the instantiation succeeds.
420
+ // Otherwise returns an error and "handle" is undefined.
421
+ typedef uint64 Handle;
422
+ virtual Status Instantiate(const string& function_name, AttrSlice attrs,
423
+ Handle* handle) = 0;
424
+
425
+ // Releases state associated with the handle.
426
+ virtual Status ReleaseHandle(Handle handle) = 0;
427
+
428
+ // Returns the function body for the instantiated function given its
429
+ // handle 'h'. Returns nullptr if "h" is not found.
430
+ //
431
+ // *this keeps the ownership of the returned object, which remains alive
432
+ // as long as *this.
433
+ virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
434
+
435
+ // Asynchronously invokes the instantiated function identified by
436
+ // "handle".
437
+ //
438
+ // If function execution succeeds, "done" is called with OK and
439
+ // "*rets" is filled with the function's return values. Otheriwse,
440
+ // "done" is called with an error status.
441
+ //
442
+ // Does not take ownership of "rets".
443
+ // In the cross-process scenario, runner isn't used for making the Async
444
+ // RPC calls.
445
+ struct Options {
446
+ // The id of the step that is calling this function.
447
+ int64 step_id = 0;
448
+ Rendezvous* rendezvous = nullptr;
449
+ CancellationManager* cancellation_manager = nullptr;
450
+ ScopedStepContainer* step_container = nullptr;
451
+ StepStatsCollector* stats_collector = nullptr;
452
+
453
+ std::function<void(std::function<void()>)>* runner = nullptr;
454
+
455
+ // Parameters for remote function execution.
456
+ bool remote_execution = false;
457
+ string source_device = ""; // Fully specified device name.
458
+
459
+ // Allocator attributes specifying where the args are / rets should be put.
460
+ // These should either be {} or match the length of args / retvals. If {},
461
+ // the default allocator attributes will be assumed for all args / retvals.
462
+ std::vector<AllocatorAttributes> args_alloc_attrs;
463
+ std::vector<AllocatorAttributes> rets_alloc_attrs;
464
+
465
+ // If true, we create a new IntraProcessRendezvous, else use the existing
466
+ // one.
467
+ bool create_rendezvous = false;
468
+ };
469
+ typedef std::function<void(const Status&)> DoneCallback;
470
+ virtual void Run(const Options& opts, Handle handle,
471
+ gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
472
+ DoneCallback done) = 0;
473
+ virtual void Run(const Options& opts, Handle handle,
474
+ CallFrameInterface* call_frame, DoneCallback done) = 0;
475
+
476
+ // Creates a "kernel" for the given node def "ndef".
477
+ //
478
+ // If succeeds, returns OK and the caller takes the ownership of the
479
+ // returned "*kernel". Otherwise, returns an error.
480
+ virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0;
481
+
482
+ // Returns true iff 'function' is stateful.
483
+ virtual bool IsStateful(const string& function_name) = 0;
484
+
485
+ // Returns the device on which the function executes.
486
+ virtual Device* device() = 0;
487
+
488
+ // Returns the function library definition that backs this runtime.
489
+ virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
490
+ const = 0;
491
+
492
+ // Returns the environment on which the function executes.
493
+ virtual Env* env() = 0;
494
+
495
+ // Returns a debug string showing the definition of the function of
496
+ // 'handle'.
497
+ virtual string DebugString(Handle handle) = 0;
498
+
499
+ // Returns the graph version number.
500
+ virtual int graph_def_version() = 0;
501
+
502
+ typedef uint64 LocalHandle;
503
+ };
504
+
505
+ const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
506
+ const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
507
+ typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&,
508
+ std::unique_ptr<OpKernel>*)>
509
+ CustomKernelCreator;
510
+
511
+ // Used to instantiate and run functions in a distributed system.
512
+ class DistributedFunctionLibraryRuntime {
513
+ public:
514
+ virtual ~DistributedFunctionLibraryRuntime() {}
515
+
516
+ // The _target attr in attrs determines where the function is instantiated.
517
+ virtual Status Instantiate(const string& function_name,
518
+ const FunctionLibraryDefinition& lib_def,
519
+ AttrSlice attrs,
520
+ FunctionLibraryRuntime::LocalHandle* handle) = 0;
521
+
522
+ // opts.runner isn't used for execution.
523
+ virtual void Run(const FunctionLibraryRuntime::Options& opts,
524
+ FunctionLibraryRuntime::LocalHandle handle,
525
+ gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
526
+ FunctionLibraryRuntime::DoneCallback done) = 0;
527
+ };
528
+
529
+ // Extracts the actual type from "attr_values" based on its definition
530
+ // "arg_def".
531
+ //
532
+ // If "arg_def" is a N*T type, *is_type_list is set to false, and
533
+ // *dtypes is set to be a vector of size N and each element is T.
534
+ //
535
+ // If "arg_def" is a list(type), *is_type_list is set to true, and
536
+ // *dtypes is set to be a vector of types specified in attrs for
537
+ // arg_def.
538
+ //
539
+ // Otherwise (arg_def is a simple type T), *is_type_list is set to
540
+ // false, and *dtypes is set to a single element vector, whose only
541
+ // element is T.
542
+ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
543
+ bool* is_type_list, DataTypeVector* dtypes);
544
+
545
+ // To register a gradient function for a builtin op, one should use
546
+ // REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
547
+ //
548
+ // Typically, the c++ grad factory is a plan function that can be
549
+ // converted into ::tensorflow::gradient::Creator, which is
550
+ // std::function<Status(const AttrSlice&, FunctionDef*)>.
551
+ //
552
+ // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
553
+ // definition of a brain function which compute the gradient for the
554
+ // <op_name> when the <op_name> is instantiated with the given attrs.
555
+ //
556
+ // E.g.,
557
+ //
558
+ // Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
559
+ // bool transpose_a;
560
+ // TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
561
+ // bool transpose_b;
562
+ // TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
563
+ // DataType dtype;
564
+ // TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
565
+ // if (!transpose_a && !transpose_b) {
566
+ // *g = FunctionDefHelper::Define(
567
+ // "MatMulGrad",
568
+ // {"x:T ", "y:T", "dz:T"}, // Inputs to this function
569
+ // {"dx:T", "dy:T"}, // Outputs from this function
570
+ // {"T: {float, double}"}, // Attributes needed by this function
571
+ // {
572
+ // {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
573
+ // {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
574
+ // {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
575
+ // {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
576
+ // });
577
+ // } else {
578
+ // ... ...
579
+ // }
580
+ // return Status::OK();
581
+ // }
582
+ //
583
+ // NOTE: $T is substituted with the type variable "T" when the
584
+ // gradient function MatMul is instantiated.
585
+ //
586
+ // TODO(zhifengc): Better documentation somewhere.
587
+
588
+ // Macros to define a gradient function factory for a primitive
589
+ // operation.
590
+ #define REGISTER_OP_GRADIENT(name, fn) \
591
+ REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
592
+
593
+ #define REGISTER_OP_NO_GRADIENT(name) \
594
+ REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
595
+
596
+ #define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
597
+ REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
598
+
599
+ #define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \
600
+ static bool unused_grad_##ctr = SHOULD_REGISTER_OP_GRADIENT && \
601
+ ::tensorflow::gradient::RegisterOp(name, fn)
602
+
603
+ namespace gradient {
604
+ // Register a gradient creator for the "op".
605
+ typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
606
+ bool RegisterOp(const string& op, Creator func);
607
+
608
+ // Returns OK the gradient creator for the "op" is found (may be
609
+ // nullptr if REGISTER_OP_NO_GRADIENT is used.
610
+ Status GetOpGradientCreator(const string& op, Creator* creator);
611
+ };
612
+
613
+ // Declare explicit instantiations of GetAttr
614
+ #define GET_ATTR(T) \
615
+ extern template Status FunctionLibraryDefinition::GetAttr( \
616
+ const Node&, const string&, T*) const; \
617
+ extern template Status FunctionLibraryDefinition::GetAttr( \
618
+ const NodeDef&, const string&, T*) const;
619
+ GET_ATTR(string)
620
+ GET_ATTR(bool)
621
+ #undef GET_ATTR
622
+
623
+ } // end namespace tensorflow
624
+
625
+ #endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_
function.proto ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package tensorflow;
4
+ option cc_enable_arenas = true;
5
+ option java_outer_classname = "FunctionProtos";
6
+ option java_multiple_files = true;
7
+ option java_package = "org.tensorflow.framework";
8
+
9
+ import "tensorflow/core/framework/attr_value.proto";
10
+ import "tensorflow/core/framework/node_def.proto";
11
+ import "tensorflow/core/framework/op_def.proto";
12
+
13
+ // A library is a set of named functions.
14
+ message FunctionDefLibrary {
15
+ repeated FunctionDef function = 1;
16
+ repeated GradientDef gradient = 2;
17
+ }
18
+
19
+ // A function can be instantiated when the runtime can bind every attr
20
+ // with a value. When a GraphDef has a call to a function, it must
21
+ // have binding for every attr defined in the signature.
22
+ //
23
+ // TODO(zhifengc):
24
+ // * device spec, etc.
25
+ message FunctionDef {
26
+ // The definition of the function's name, arguments, return values,
27
+ // attrs etc.
28
+ OpDef signature = 1;
29
+
30
+ // Attributes specific to this function definition.
31
+ map<string, AttrValue> attr = 5;
32
+
33
+ // NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21.
34
+
35
+ // In both of the following fields, there is the need to specify an
36
+ // output that is used as either the input to another node (in
37
+ // `node_def`) or as a return value of the function (in `ret`).
38
+ // Unlike the NodeDefs in GraphDef, we need to be able to specify a
39
+ // list in some cases (instead of just single outputs). Also, we
40
+ // need to be able to deal with lists of unknown length (so the
41
+ // output index may not be known at function definition time). So
42
+ // we use the following format instead:
43
+ // * "fun_in" where "fun_in" is the name of a function input arg in
44
+ // the `signature` field above. This represents that input, whether
45
+ // it is a single tensor or a list.
46
+ // * "fun_in:0" gives the first element of a function input arg (a
47
+ // non-list input is considered a list of length 1 for these
48
+ // purposes).
49
+ // * "node:out" where "node" is the name of a node in `node_def` and
50
+ // "out" is the name one of its op's output arguments (the name
51
+ // comes from the OpDef of the node's op). This represents that
52
+ // node's output, whether it is a single tensor or a list.
53
+ // Note: We enforce that an op's output arguments are never
54
+ // renamed in the backwards-compatibility test.
55
+ // * "node:out:0" gives the first element of a node output arg (a
56
+ // non-list output is considered a list of length 1 for these
57
+ // purposes).
58
+ //
59
+ // NOT CURRENTLY SUPPORTED (but may be in the future):
60
+ // * "node:out:-1" gives last element in a node output list
61
+ // * "node:out:1:" gives a list with all but the first element in a
62
+ // node output list
63
+ // * "node:out::-1" gives a list with all but the last element in a
64
+ // node output list
65
+
66
+ // The body of the function. Unlike the NodeDefs in a GraphDef, attrs
67
+ // may have values of type `placeholder` and the `input` field uses
68
+ // the "output" format above.
69
+
70
+ // By convention, "op" in node_def is resolved by consulting with a
71
+ // user-defined library first. If not resolved, "func" is assumed to
72
+ // be a builtin op.
73
+ repeated NodeDef node_def = 3;
74
+
75
+ // A mapping from the output arg names from `signature` to the
76
+ // outputs from `node_def` that should be returned by the function.
77
+ map<string, string> ret = 4;
78
+ }
79
+
80
+ // GradientDef defines the gradient function of a function defined in
81
+ // a function library.
82
+ //
83
+ // A gradient function g (specified by gradient_func) for a function f
84
+ // (specified by function_name) must follow the following:
85
+ //
86
+ // The function 'f' must be a numerical function which takes N inputs
87
+ // and produces M outputs. Its gradient function 'g', which is a
88
+ // function taking N + M inputs and produces N outputs.
89
+ //
90
+ // I.e. if we have
91
+ // (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
92
+ // then, g is
93
+ // (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
94
+ // dL/dy1, dL/dy2, ..., dL/dy_M),
95
+ // where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
96
+ // loss function). dL/dx_i is the partial derivative of L with respect
97
+ // to x_i.
98
+ message GradientDef {
99
+ string function_name = 1; // The function name.
100
+ string gradient_func = 2; // The gradient function's name.
101
+ }
function_test.cc ADDED
@@ -0,0 +1,1339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/function.h"
17
+ #include <vector>
18
+ #include "tensorflow/core/framework/function.pb.h"
19
+ #include "tensorflow/core/framework/function_testlib.h"
20
+ #include "tensorflow/core/framework/op.h"
21
+ #include "tensorflow/core/framework/tensor_testutil.h"
22
+ #include "tensorflow/core/kernels/ops_util.h"
23
+ #include "tensorflow/core/lib/core/status_test_util.h"
24
+ #include "tensorflow/core/lib/gtl/array_slice.h"
25
+ #include "tensorflow/core/lib/strings/str_util.h"
26
+ #include "tensorflow/core/lib/strings/strcat.h"
27
+ #include "tensorflow/core/platform/test.h"
28
+ #include "tensorflow/core/platform/types.h"
29
+
30
+ namespace tensorflow {
31
+ namespace {
32
+
33
+ // A helper class to make AttrSlice from initializer lists
34
+ class Attrs {
35
+ public:
36
+ Attrs(const std::initializer_list< // NOLINT(runtime/explicit)
37
+ std::pair<string, FunctionDefHelper::AttrValueWrapper>>
38
+ attrs) {
39
+ for (const auto& aval : attrs) {
40
+ map_.insert({aval.first, aval.second.proto});
41
+ }
42
+ }
43
+
44
+ operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit)
45
+
46
+ private:
47
+ AttrValueMap map_;
48
+ };
49
+
50
+ typedef FunctionDefHelper FDH;
51
+
52
+ Status GetOpSig(const string& op, const OpDef** sig) {
53
+ return OpRegistry::Global()->LookUpOpDef(op, sig);
54
+ }
55
+
56
+ REGISTER_OP("One")
57
+ .Output("y: T")
58
+ .Attr("T: {float, double, int32, int64}")
59
+ .Doc(R"doc(
60
+ Returns a tensor with a single element (1) of type T.
61
+
62
+ y: A scalar in type T.
63
+
64
+ )doc");
65
+
66
+ TEST(TFunc, SquarePlusOne) {
67
+ auto fdef = FDH::Create(
68
+ // Name
69
+ "SquarePlusOne",
70
+ // Inputs
71
+ {"x: T"},
72
+ // Outputs
73
+ {"y: T"},
74
+ // Attrs
75
+ {"T: {float, double, int32, int64}"},
76
+ // Nodes
77
+ {// a = Square<T>(x)
78
+ {{"a"}, "Square", {"x"}, {{"T", "$T"}}},
79
+ // o = One<T>()
80
+ // NOTE: We can also have a Cast<Tin, Tout>(x) instead.
81
+ {{"o"}, "One", {}, {{"T", "$T"}}},
82
+ // y = Add<T>(a, o)
83
+ {{"y"}, "Add", {"a:y", "o:y"}, {{"T", "$T"}}}},
84
+ // Returns
85
+ {{"y", "y:z:0"}});
86
+
87
+ const char* e = R"P(
88
+ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
89
+ a = Square[T=$T](x)
90
+ o = One[T=$T]()
91
+ y = Add[T=$T](a:y, o:y)
92
+ return y = y:z:0
93
+ }
94
+ )P";
95
+ EXPECT_EQ(DebugString(fdef), e);
96
+
97
+ // Instantiate one with T=float
98
+ InstantiationResult result;
99
+ TF_ASSERT_OK(
100
+ InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result));
101
+ const char* e2 = R"P(
102
+ (x:float) -> (y:float) {
103
+ a = Square[T=float](x)
104
+ o = One[T=float]()
105
+ y = Add[T=float](a, o)
106
+ }
107
+ )P";
108
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
109
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
110
+ EXPECT_EQ(DebugString(result.nodes), e2);
111
+ }
112
+
113
+ TEST(TFunc, ControlDep) {
114
+ auto fdef = FDH::Create(
115
+ // Name
116
+ "ControlDep",
117
+ // Inputs
118
+ {"x: int32"},
119
+ // Outputs
120
+ {"y: int32"},
121
+ // Attrs
122
+ {},
123
+ // Nodes
124
+ {// a = Identity<int32>(x)
125
+ {{"a"}, "Identity", {"x"}, {{"T", DT_INT32}}},
126
+ // o = NoOp(^a)
127
+ {{"o"}, "NoOp", {"^a"}, {}},
128
+ // y = Identity<int32>(a, ^o)
129
+ {{"y"}, "Identity", {"a:output:0", "^o"}, {{"T", DT_INT32}}}},
130
+ // Returns
131
+ {{"y", "y:output:0"}});
132
+
133
+ const char* e = R"P(
134
+ ControlDep(x:int32) -> (y:int32) {
135
+ a = Identity[T=int32](x)
136
+ o = NoOp() @ a
137
+ y = Identity[T=int32](a:output:0) @ o
138
+ return y = y:output:0
139
+ }
140
+ )P";
141
+ EXPECT_EQ(DebugString(fdef), e);
142
+
143
+ // Instantiate one with T=float
144
+ InstantiationResult result;
145
+ TF_ASSERT_OK(
146
+ InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result));
147
+ const char* e2 = R"P(
148
+ (x:int32) -> (y:int32) {
149
+ a = Identity[T=int32](x)
150
+ o = NoOp() @ a
151
+ y = Identity[T=int32](a) @ o
152
+ }
153
+ )P";
154
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32}));
155
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32}));
156
+ EXPECT_EQ(DebugString(result.nodes), e2);
157
+ }
158
+
159
+ REGISTER_OP("HasDefaultType")
160
+ .Output("out: T")
161
+ .Attr("T: {float, double, int32, int64} = DT_FLOAT");
162
+
163
+ // This verifies that a function using an op before a type attr (with
164
+ // a default) is added, still works. This is important for backwards
165
+ // compatibility.
166
+ TEST(TFunc, MissingTypeAttr) {
167
+ auto fdef = FDH::Create(
168
+ // Name
169
+ "BackCompat",
170
+ // Args
171
+ {},
172
+ // Return values
173
+ {"y: float"},
174
+ // Attrs
175
+ {},
176
+ // Nodes
177
+ {// y = HasDefaultType(x), T missing, defaults to float
178
+ {{"a"}, "HasDefaultType", {}, {}}},
179
+ // Returns
180
+ {{"y", "a:out:0"}});
181
+
182
+ const char* e = R"P(
183
+ BackCompat() -> (y:float) {
184
+ a = HasDefaultType()
185
+ return y = a:out:0
186
+ }
187
+ )P";
188
+ EXPECT_EQ(DebugString(fdef), e);
189
+
190
+ InstantiationResult result;
191
+ TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
192
+ // Should get T=float from Op's default.
193
+ const char* e2 = R"P(
194
+ () -> (a:float) {
195
+ a = HasDefaultType[T=float]()
196
+ }
197
+ )P";
198
+ EXPECT_EQ(result.arg_types, DataTypeVector());
199
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
200
+ EXPECT_EQ(DebugString(result.nodes), e2);
201
+ }
202
+
203
+ TEST(TFunc, NTimesT) {
204
+ auto fdef = FDH::Create(
205
+ // Name
206
+ "NTimesT",
207
+ // Inputs
208
+ {"x: float", "y: float"},
209
+ // Outputs
210
+ {"z: float"},
211
+ // Attrs
212
+ {},
213
+ // Nodes
214
+ {// a = AddN<N=2>(x, y)
215
+ {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 2}}}},
216
+ // Returns
217
+ {{"z", "a:sum:0"}});
218
+
219
+ const char* e = R"P(
220
+ NTimesT(x:float, y:float) -> (z:float) {
221
+ a = AddN[N=2, T=float](x, y)
222
+ return z = a:sum:0
223
+ }
224
+ )P";
225
+ EXPECT_EQ(DebugString(fdef), e);
226
+
227
+ InstantiationResult result;
228
+ TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
229
+ const char* e2 = R"P(
230
+ (x:float, y:float) -> (a:float) {
231
+ a = AddN[N=2, T=float](x, y)
232
+ }
233
+ )P";
234
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT}));
235
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
236
+ EXPECT_EQ(DebugString(result.nodes), e2);
237
+ }
238
+
239
+ // NOTE: This is the simplest Map op. It takes a f:T->U.
240
+ REGISTER_OP("Map")
241
+ .Input("x: N * T")
242
+ .Output("y: N * U")
243
+ .Attr("T: type")
244
+ .Attr("U: type")
245
+ .Attr("N: int >= 1")
246
+ // .Attr("func: func_name_with_attr")
247
+ .Doc(R"doc(
248
+ Applies the 'func' on every input. I.e.,
249
+
250
+ y[i] = func<...>(x[i])
251
+
252
+ x: N tensors, each of type T;
253
+ y: N tensors, each of type U;
254
+
255
+ )doc");
256
+
257
+ TEST(TFunc, AddSquared) {
258
+ auto fdef = FDH::Create(
259
+ // Name
260
+ "AddSquared",
261
+ // Args
262
+ {"x: N*T"},
263
+ // Return values
264
+ {"y: T"},
265
+ // Attrs
266
+ {"N:int", "T:{float, double, int32, int64}"},
267
+ // Nodes
268
+ {// a = Map<func=Square<$T>,T=$T,U=$T,N=$N>(x)
269
+ {{"a"},
270
+ "Map",
271
+ {"x"},
272
+ {{"func", FDH::FunctionRef("Square", {{"T", "$T"}})},
273
+ {"T", "$T"},
274
+ {"U", "$T"},
275
+ {"N", "$N"}}},
276
+ // y = AddN<N=$N,T=$T>(a)
277
+ {{"y"}, "AddN", {"a:y"}, {{"N", "$N"}, {"T", "$T"}}}},
278
+ {{"y", "y:sum"}});
279
+
280
+ const char* e = R"P(
281
+ AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
282
+ a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x)
283
+ y = AddN[N=$N, T=$T](a:y)
284
+ return y = y:sum
285
+ }
286
+ )P";
287
+ EXPECT_EQ(DebugString(fdef), e);
288
+
289
+ // Instantiate one with T=float
290
+ InstantiationResult result;
291
+ TF_ASSERT_OK(InstantiateFunction(fdef, Attrs({{"N", 3}, {"T", DT_FLOAT}}),
292
+ GetOpSig, &result));
293
+ const char* e2 = R"P(
294
+ (x_0:float, x_1:float, x_2:float) -> (y:float) {
295
+ a = Map[N=3, T=float, U=float, func=Square[T=float]](x_0, x_1, x_2)
296
+ y = AddN[N=3, T=float](a, a:1, a:2)
297
+ }
298
+ )P";
299
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT}));
300
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
301
+ EXPECT_EQ(DebugString(result.nodes), e2);
302
+ }
303
+
304
+ TEST(TFunc, ControlDeps) {
305
+ auto fdef = FDH::Define(
306
+ // Name
307
+ "ControlDeps",
308
+ // Args
309
+ {"x: float"},
310
+ // Return values
311
+ {},
312
+ // Attrs
313
+ {},
314
+ // Nodes
315
+ {
316
+ {{"a"}, "One", {}, {{"T", DT_FLOAT}}, {"x"}},
317
+ {{"u"}, "NoOp", {}, {}, {"a"}},
318
+ {{"b"}, "One", {}, {{"T", DT_FLOAT}}, {"u"}},
319
+ {{"v"}, "NoOp", {}, {}, {"b"}},
320
+ {{"c"}, "One", {}, {{"T", DT_FLOAT}}, {"a", "v"}},
321
+ });
322
+ const char* e = R"P(
323
+ ControlDeps(x:float) -> () {
324
+ a = One[T=float]() @ x
325
+ u = NoOp() @ a
326
+ b = One[T=float]() @ u
327
+ v = NoOp() @ b
328
+ c = One[T=float]() @ a, v
329
+ }
330
+ )P";
331
+ EXPECT_EQ(DebugString(fdef), e);
332
+
333
+ InstantiationResult result;
334
+ TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
335
+ const char* e2 = R"P(
336
+ (x:float) -> () {
337
+ a = One[T=float]() @ x
338
+ u = NoOp() @ a
339
+ b = One[T=float]() @ u
340
+ v = NoOp() @ b
341
+ c = One[T=float]() @ a, v
342
+ }
343
+ )P";
344
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
345
+ EXPECT_EQ(result.ret_types, DataTypeVector({}));
346
+ EXPECT_EQ(DebugString(result.nodes), e2);
347
+ }
348
+
349
+ TEST(TFunc, XTimesTwo) {
350
+ auto expect = R"P(
351
+ XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
352
+ two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
353
+ scale = Cast[DstT=$T, SrcT=int64](two:output:0)
354
+ y = Mul[T=$T](x, scale:y:0)
355
+ return y = y:z:0
356
+ }
357
+ )P";
358
+ EXPECT_EQ(expect, DebugString(test::function::XTimesTwo()));
359
+ }
360
+
361
+ TEST(TFunc, WXPlusB) {
362
+ auto expect = R"P(
363
+ WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
364
+ mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
365
+ y = Add[T=$T](mm:product:0, b)
366
+ return y = y:z:0
367
+ }
368
+ )P";
369
+ EXPECT_EQ(expect, DebugString(test::function::WXPlusB()));
370
+ }
371
+
372
+ TEST(TFunc, Body_TypeList) {
373
+ const Tensor kZero = test::AsScalar<int32>(0);
374
+ auto fdef = FDH::Create(
375
+ // Name
376
+ "Test",
377
+ // Args
378
+ {"i:float"},
379
+ // Return values
380
+ {"o:float"},
381
+ // Attrs
382
+ {},
383
+ // Nodes
384
+ {{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}},
385
+ {{"s"},
386
+ "Split",
387
+ {"zero:output:0", "i"},
388
+ {{"num_split", 4}, {"T", DT_FLOAT}}},
389
+ {{"l"}, "Mul", {"s:output:0", "s:output:1"}, {{"T", DT_FLOAT}}},
390
+ {{"r"}, "Mul", {"s:output:2", "s:output:3"}, {{"T", DT_FLOAT}}},
391
+ {{"x"},
392
+ "_ListToArray",
393
+ {"l:z", "r:z"},
394
+ {{"N", 2},
395
+ {"T", DT_FLOAT},
396
+ {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
397
+ {{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
398
+ {{"o", "o:sum:0"}});
399
+
400
+ const char* e = R"P(
401
+ Test(i:float) -> (o:float) {
402
+ zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
403
+ s = Split[T=float, num_split=4](zero:output:0, i)
404
+ l = Mul[T=float](s:output:0, s:output:1)
405
+ r = Mul[T=float](s:output:2, s:output:3)
406
+ x = _ListToArray[N=2, T=float, Tin={float, float}](l:z, r:z)
407
+ o = AddN[N=2, T=float](x:output)
408
+ return o = o:sum:0
409
+ }
410
+ )P";
411
+ EXPECT_EQ(DebugString(fdef), e);
412
+
413
+ InstantiationResult result;
414
+ TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
415
+ const char* e2 = R"P(
416
+ (i:float) -> (o:float) {
417
+ zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
418
+ s = Split[T=float, num_split=4](zero, i)
419
+ l = Mul[T=float](s, s:1)
420
+ r = Mul[T=float](s:2, s:3)
421
+ x = _ListToArray[N=2, T=float, Tin={float, float}](l, r)
422
+ o = AddN[N=2, T=float](x, x:1)
423
+ }
424
+ )P";
425
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
426
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
427
+ EXPECT_EQ(DebugString(result.nodes), e2);
428
+ }
429
+
430
+ REGISTER_OP("Cond")
431
+ .Input("input: Tin")
432
+ .Output("output: out_types")
433
+ .Attr("Tin: list(type)")
434
+ .Attr("out_types: list(type)")
435
+ .Attr("cond: func")
436
+ .Attr("then_branch: func")
437
+ .Attr("else_branch: func")
438
+ .Doc(R"doc(
439
+ output = Cond(input) ? then_branch(input) : else_branch(input)
440
+
441
+ cond: A function takes 'input' and returns a scalar.
442
+ then_branch: A function takes 'input' and returns 'output'.
443
+ else_branch: A function takes 'input' and returns 'output'.
444
+ )doc");
445
+
446
+ TEST(TFunc, Body_Array_List_Converter) {
447
+ auto fdef = FDH::Define(
448
+ // Name
449
+ "MySelect",
450
+ // Args
451
+ {"x:float"},
452
+ // Return values
453
+ {"z:float"},
454
+ // Attrs
455
+ {},
456
+ // Nodes
457
+ {
458
+ {{"y"},
459
+ "Cond",
460
+ {"x"},
461
+ {{"Tin", DataTypeSlice{DT_FLOAT}},
462
+ {"out_types", DataTypeSlice{DT_FLOAT}},
463
+ {"cond", FDH::FunctionRef("MyCond")},
464
+ {"then_branch", FDH::FunctionRef("MyThen")},
465
+ {"else_branch", FDH::FunctionRef("MyElse")}}},
466
+ {{"z"},
467
+ "Cond",
468
+ {"y", "y"},
469
+ {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
470
+ {"out_types", DataTypeSlice{DT_FLOAT}},
471
+ {"cond", FDH::FunctionRef("MyCond2")},
472
+ {"then_branch", FDH::FunctionRef("MyThen2")},
473
+ {"else_branch", FDH::FunctionRef("MyElse2")}}},
474
+ });
475
+
476
+ const char* e = R"P(
477
+ MySelect(x:float) -> (z:float) {
478
+ y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
479
+ z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y:output:0, y:output:0)
480
+ return z = z:output:0
481
+ }
482
+ )P";
483
+ EXPECT_EQ(DebugString(fdef), e);
484
+
485
+ InstantiationResult result;
486
+ TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
487
+ const char* e2 = R"P(
488
+ (x:float) -> (z:float) {
489
+ y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
490
+ z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y)
491
+ }
492
+ )P";
493
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
494
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
495
+ EXPECT_EQ(DebugString(result.nodes), e2);
496
+ }
497
+
498
+ static void HasError(const Status& s, const string& substr) {
499
+ EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
500
+ << ">>" << s << "<<, expected substring >>" << substr << "<<";
501
+ }
502
+
503
+ TEST(InstantiateErrors, Not_Sufficient_Attrs) {
504
+ auto fdef =
505
+ FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
506
+ InstantiationResult result;
507
+ HasError(
508
+ InstantiateFunction(fdef, Attrs({{"U", DT_FLOAT}}), GetOpSig, &result),
509
+ "Attr T is not found from ");
510
+ }
511
+
512
+ #if 0 // TODO(josh11b): Enable this test once having an extra attr is an error.
513
+ TEST(InstantiateErrors, Too_Many_Attrs) {
514
+ auto fdef =
515
+ FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
516
+ InstantiationResult result;
517
+ HasError(InstantiateFunction(fdef, Attrs({{"T", DT_INT32}, {"U", DT_FLOAT}}),
518
+ GetOpSig, &result),
519
+ "Attr U is not found in ");
520
+ }
521
+ #endif
522
+
523
+ TEST(InstantiateErrors, AttrValue_Value_Placeholder) {
524
+ auto fdef =
525
+ FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
526
+ InstantiationResult result;
527
+ HasError(
528
+ InstantiateFunction(fdef, Attrs({{"T", "$bad"}}), GetOpSig, &result),
529
+ "AttrValue had value with unexpected type 'placeholder'\n\tfor attr 'T'");
530
+ }
531
+
532
+ TEST(InstantiateErrors, Unbounded_Attr) {
533
+ auto fdef = FDH::Define("test", {}, {}, {"T:{float, double, int32, int64}"},
534
+ {
535
+ {{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}},
536
+ });
537
+ InstantiationResult result;
538
+ HasError(
539
+ InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result),
540
+ "Failed to bind all placeholders");
541
+ }
542
+
543
+ TEST(InstantiateErrors, DupArgs) {
544
+ auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {});
545
+ InstantiationResult result;
546
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
547
+ "Duplicated arg name");
548
+ }
549
+
550
+ TEST(InstantiateErrors, Dup_Node_Names) {
551
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
552
+ {
553
+ {{"y"}, "One", {}, {{"T", DT_FLOAT}}},
554
+ {{"y"}, "One", {}, {{"T", DT_FLOAT}}},
555
+ });
556
+ InstantiationResult result;
557
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
558
+ "Duplicated ret name");
559
+ }
560
+
561
+ TEST(InstantiateErrors, Node_Arg_Notfound) {
562
+ auto fdef = FDH::Create("test", {"x:float"}, {}, {},
563
+ {
564
+ {{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}},
565
+ },
566
+ {});
567
+ InstantiationResult result;
568
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
569
+ "input z is not found");
570
+ }
571
+
572
+ TEST(InstantiateErrors, Node_Arg_TypeMismatch) {
573
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
574
+ {
575
+ {{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
576
+ });
577
+ InstantiationResult result;
578
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
579
+ "input x[0] expected type int32 != float, the type of x[0]");
580
+ }
581
+
582
+ TEST(InstantiateErrors, Node_Arg_ControlMissing) {
583
+ auto fdef =
584
+ FDH::Define("test", {"x:float"}, {}, {},
585
+ {
586
+ {{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}},
587
+ });
588
+ InstantiationResult result;
589
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
590
+ "input[2] == '^z', is not found.");
591
+ }
592
+
593
+ TEST(InstantiateErrors, FuncRet_Missing) {
594
+ auto fdef = FDH::Create("test", {}, {"y: float"}, {},
595
+ {
596
+ {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
597
+ },
598
+ {});
599
+ InstantiationResult result;
600
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
601
+ "Return y missing");
602
+ }
603
+
604
+ TEST(InstantiateErrors, FuncRet_NotFound) {
605
+ auto fdef = FDH::Create("test", {}, {"y: float"}, {},
606
+ {
607
+ {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
608
+ },
609
+ {{"y", "z"}});
610
+ InstantiationResult result;
611
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
612
+ "Return y -> z is not found");
613
+ }
614
+
615
+ TEST(InstantiateErrors, FuncRet_NameMismatch) {
616
+ auto fdef = FDH::Create("test", {}, {"y: float"}, {},
617
+ {
618
+ {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
619
+ },
620
+ {{"z", "x:y:0"}});
621
+ InstantiationResult result;
622
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
623
+ "Return y missing");
624
+ }
625
+
626
+ // TODO(josh11b): Make this an error.
627
+ // TEST(InstantiateErrors, FuncRet_Extra) {
628
+ // auto fdef = FDH::Create("test", {}, {"y: float"}, {},
629
+ // {
630
+ // {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
631
+ // },
632
+ // {{"y", "x:y:0"}, {"z", "x:y:0"}});
633
+ // InstantiationResult result;
634
+ // HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
635
+ // "ret is not found");
636
+ // }
637
+
638
+ TEST(InstantiateErrors, FuncRet_TypeMismatch) {
639
+ auto fdef = FDH::Define("test", {}, {"y: float"}, {},
640
+ {
641
+ {{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
642
+ });
643
+ InstantiationResult result;
644
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
645
+ "Invalid ret types y : float vs. double\n\tIn function output y");
646
+ }
647
+
648
+ TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
649
+ auto fdef = FDH::Create(
650
+ // Name
651
+ "MySelect",
652
+ // Args
653
+ {"x: float"},
654
+ // Return values
655
+ {"y: float"},
656
+ // Attrs
657
+ {},
658
+ // Nodes
659
+ {
660
+ {{"y"},
661
+ "Cond",
662
+ {"x", "x"},
663
+ {{"tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
664
+ {"cond", FDH::FunctionRef("MyCond2")},
665
+ {"then_branch", FDH::FunctionRef("MyThen2")},
666
+ {"else_branch", FDH::FunctionRef("MyElse2")}}},
667
+ },
668
+ {{"y", "y:output"}});
669
+ InstantiationResult result;
670
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
671
+ "type attr not found: out_types");
672
+ }
673
+
674
+ TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
675
+ auto fdef = FDH::Create(
676
+ // Name
677
+ "MySelect",
678
+ // Args
679
+ {"x: float"},
680
+ // Return values
681
+ {"y: float"},
682
+ // Attrs
683
+ {},
684
+ // Nodes
685
+ {
686
+ {{"y"},
687
+ "Cond",
688
+ {"x", "x"},
689
+ {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
690
+ {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
691
+ {"cond", FDH::FunctionRef("MyCond2")},
692
+ {"then_branch", FDH::FunctionRef("MyThen2")},
693
+ {"else_branch", FDH::FunctionRef("MyElse2")}}},
694
+ },
695
+ {{"y", "y:output"}});
696
+ InstantiationResult result;
697
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
698
+ "Invalid ret types");
699
+ }
700
+
701
+ TEST(InstantiateErrors, TypeList_Missing_Arg) {
702
+ auto fdef = FDH::Create(
703
+ // Name
704
+ "MySelect",
705
+ // Args
706
+ {"x: float"},
707
+ // Return values
708
+ {"y: float"},
709
+ // Attrs
710
+ {},
711
+ // Nodes
712
+ {
713
+ {{"y"},
714
+ "Cond",
715
+ {"x", "unknown"},
716
+ {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
717
+ {"out_types", DataTypeSlice{DT_FLOAT}},
718
+ {"cond", FDH::FunctionRef("MyCond2")},
719
+ {"then_branch", FDH::FunctionRef("MyThen2")},
720
+ {"else_branch", FDH::FunctionRef("MyElse2")}}},
721
+ },
722
+ {{"y", "y:output"}});
723
+ InstantiationResult result;
724
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
725
+ "input unknown is not found");
726
+ }
727
+
728
+ TEST(InstantiateErrors, TooManyInputs) {
729
+ auto fdef = FDH::Create(
730
+ // Name
731
+ "TooManyInputs",
732
+ // Inputs
733
+ {"x: float", "y: float"},
734
+ // Outputs
735
+ {"z: float"},
736
+ // Attrs
737
+ {},
738
+ // Nodes
739
+ {// a = AddN<N=2>(x, y, x)
740
+ {{"a"}, "AddN", {"x", "y", "x"}, {{"T", DT_FLOAT}, {"N", 2}}}},
741
+ // Returns
742
+ {{"z", "a:sum:0"}});
743
+
744
+ InstantiationResult result;
745
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
746
+ "Expected input[2] == 'x' to be a control input.");
747
+ }
748
+
749
+ TEST(InstantiateErrors, TooFewInputs) {
750
+ auto fdef = FDH::Create(
751
+ // Name
752
+ "TooFewInputs",
753
+ // Inputs
754
+ {"x: float", "y: float"},
755
+ // Outputs
756
+ {"z: float"},
757
+ // Attrs
758
+ {},
759
+ // Nodes
760
+ {// a = AddN<N=3>(x, y)
761
+ {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 3}}}},
762
+ // Returns
763
+ {{"z", "a:sum:0"}});
764
+
765
+ InstantiationResult result;
766
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
767
+ "Attempt to access beyond input size: 2 >= 2");
768
+ }
769
+
770
+ TEST(InstantiateErrors, TooManyInputsFromArray1) {
771
+ auto fdef = FDH::Create(
772
+ // Name
773
+ "TooManyInputsFromArray",
774
+ // Inputs
775
+ {"x: float", "y: float"},
776
+ // Outputs
777
+ {"z: float"},
778
+ // Attrs
779
+ {},
780
+ // Nodes
781
+ {// a = _ListToArray(x,y)
782
+ {{"a"},
783
+ "_ListToArray",
784
+ {"x", "y"},
785
+ {{"N", 2},
786
+ {"T", DT_FLOAT},
787
+ {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
788
+ // b = AddN<N=2>(a, y)
789
+ {{"b"}, "AddN", {"a:output", "y"}, {{"T", DT_FLOAT}, {"N", 2}}}},
790
+ // Returns
791
+ {{"z", "a:sum:0"}});
792
+
793
+ InstantiationResult result;
794
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
795
+ "Expected input[1] == 'y' to be a control input.");
796
+ }
797
+
798
+ TEST(InstantiateErrors, TooManyInputsFromArray2) {
799
+ auto fdef = FDH::Create(
800
+ // Name
801
+ "TooManyInputsFromArray",
802
+ // Inputs
803
+ {"x: float", "y: float"},
804
+ // Outputs
805
+ {"z: float"},
806
+ // Attrs
807
+ {},
808
+ // Nodes
809
+ {// a = _ListToArray(x,y)
810
+ {{"a"},
811
+ "_ListToArray",
812
+ {"x", "y"},
813
+ {{"N", 2},
814
+ {"T", DT_FLOAT},
815
+ {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
816
+ // b = AddN<N=2>(x, a)
817
+ {{"b"}, "AddN", {"x", "a:output"}, {{"T", DT_FLOAT}, {"N", 2}}}},
818
+ // Returns
819
+ {{"z", "a:sum:0"}});
820
+
821
+ InstantiationResult result;
822
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
823
+ "Input a:output too long for inputs");
824
+ }
825
+
826
+ TEST(InstantiateErrors, TypeMismatch) {
827
+ auto fdef = FDH::Create(
828
+ // Name
829
+ "TypeMismatch",
830
+ // Inputs
831
+ {"x: float", "y: int32"},
832
+ // Outputs
833
+ {"z: float"},
834
+ // Attrs
835
+ {},
836
+ // Nodes
837
+ {// a = AddN<N=2>(x, y)
838
+ {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 3}}}},
839
+ // Returns
840
+ {{"z", "a:sum:0"}});
841
+
842
+ InstantiationResult result;
843
+ HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
844
+ "input inputs[1] expected type float != int32, the type of y[0]");
845
+ }
846
+
847
+ TEST(FunctionCallFrame, Void_Void) {
848
+ FunctionCallFrame frame({}, {});
849
+ TF_EXPECT_OK(frame.SetArgs({}));
850
+ auto a = test::AsTensor<float>({100});
851
+ HasError(frame.SetArgs({a}), "Invalid argument");
852
+ Tensor v;
853
+ HasError(frame.GetArg(0, &v), "Invalid argument");
854
+ HasError(frame.SetRetval(0, v), "Invalid argument");
855
+ std::vector<Tensor> rets;
856
+ TF_EXPECT_OK(frame.GetRetvals(&rets));
857
+ EXPECT_EQ(rets.size(), 0);
858
+ }
859
+
860
+ TEST(FunctionCallFrame, Float_Float_Float) {
861
+ FunctionCallFrame frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT});
862
+ HasError(frame.SetArgs({}), "Invalid argument: Expects 2 arguments");
863
+ auto a = test::AsTensor<float>({100});
864
+ auto b = test::AsTensor<float>({200});
865
+ auto c = test::AsTensor<int64>({300});
866
+ HasError(frame.SetArgs({a, c}),
867
+ "Invalid argument: Expects arg[1] to be float");
868
+ TF_EXPECT_OK(frame.SetArgs({a, b}));
869
+
870
+ Tensor v;
871
+ HasError(frame.GetArg(-1, &v), "Invalid argument");
872
+ HasError(frame.GetArg(2, &v), "Invalid argument");
873
+ TF_EXPECT_OK(frame.GetArg(0, &v));
874
+ test::ExpectTensorEqual<float>(a, v);
875
+ TF_EXPECT_OK(frame.GetArg(1, &v));
876
+ test::ExpectTensorEqual<float>(b, v);
877
+
878
+ v = test::AsTensor<float>({-100});
879
+ HasError(frame.SetRetval(-1, v), "Invalid argument");
880
+ HasError(frame.SetRetval(1, v), "Invalid argument");
881
+ HasError(frame.SetRetval(0, test::AsTensor<int64>({-100})),
882
+ "Invalid argument: Expects ret[0] to be float");
883
+
884
+ std::vector<Tensor> rets;
885
+ HasError(frame.GetRetvals(&rets), "does not have value");
886
+ TF_EXPECT_OK(frame.SetRetval(0, v));
887
+ HasError(frame.SetRetval(0, v), "has already been set");
888
+
889
+ TF_EXPECT_OK(frame.GetRetvals(&rets));
890
+ EXPECT_EQ(rets.size(), 1);
891
+ test::ExpectTensorEqual<float>(rets[0], v);
892
+ }
893
+
894
+ TEST(Canonicalize, Basic) {
895
+ EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT},
896
+ {"transpose_a", false},
897
+ {"transpose_b", false}})),
898
+ "MatMul[T=float,transpose_a=false,transpose_b=false]");
899
+ EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT},
900
+ {"transpose_b", false},
901
+ {"transpose_a", false}})),
902
+ "MatMul[T=float,transpose_a=false,transpose_b=false]");
903
+ EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_DOUBLE},
904
+ {"transpose_b", true},
905
+ {"transpose_a", false}})),
906
+ "MatMul[T=double,transpose_a=false,transpose_b=true]");
907
+ }
908
+
909
+ TEST(FunctionLibraryDefinitionTest, Find) {
910
+ FunctionDefLibrary proto;
911
+ *proto.add_function() = test::function::XTimesTwo();
912
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
913
+
914
+ EXPECT_EQ(lib_def.Find("XTimes16"), nullptr);
915
+
916
+ auto expect = R"P(
917
+ XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
918
+ two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
919
+ scale = Cast[DstT=$T, SrcT=int64](two:output:0)
920
+ y = Mul[T=$T](x, scale:y:0)
921
+ return y = y:z:0
922
+ }
923
+ )P";
924
+ auto found = lib_def.Find("XTimesTwo");
925
+ ASSERT_NE(found, nullptr);
926
+ EXPECT_EQ(expect, DebugString(*found));
927
+ }
928
+
929
+ TEST(FunctionLibraryDefinitionTest, LookUp) {
930
+ FunctionDefLibrary proto;
931
+ *proto.add_function() = test::function::XTimesTwo();
932
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
933
+
934
+ const OpDef* op_def;
935
+ EXPECT_TRUE(!lib_def.LookUpOpDef("XTimes16", &op_def).ok());
936
+
937
+ TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &op_def));
938
+ ASSERT_NE(op_def, nullptr);
939
+ EXPECT_EQ(op_def->DebugString(),
940
+ test::function::XTimesTwo().signature().DebugString());
941
+
942
+ const OpRegistrationData* op_reg_data;
943
+ TF_EXPECT_OK(lib_def.LookUp("XTimesTwo", &op_reg_data));
944
+ ASSERT_NE(op_reg_data, nullptr);
945
+ // Shape inference function is initialized to UnknownShape.
946
+ ASSERT_NE(op_reg_data->shape_inference_fn, nullptr);
947
+ }
948
+
949
+ TEST(FunctionLibraryDefinitionTest, AddFunctionDef) {
950
+ // Add one function to the proto lib before constructing 'lib_def'.
951
+ FunctionDefLibrary proto;
952
+ *proto.add_function() = test::function::XTimesTwo();
953
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
954
+
955
+ // Add a new function def to the library.
956
+ TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB()));
957
+
958
+ // Test lookup of first function.
959
+ const OpDef* first;
960
+ TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &first));
961
+ ASSERT_NE(first, nullptr);
962
+ EXPECT_EQ(first->DebugString(),
963
+ test::function::XTimesTwo().signature().DebugString());
964
+
965
+ // Test lookup of second function.
966
+ const OpDef* second;
967
+ TF_EXPECT_OK(lib_def.LookUpOpDef("WXPlusB", &second));
968
+ ASSERT_NE(second, nullptr);
969
+ EXPECT_EQ(second->DebugString(),
970
+ test::function::WXPlusB().signature().DebugString());
971
+
972
+ // Can't add function with same name as existing op
973
+ FunctionDef fdef = test::function::XTimesTwo();
974
+ fdef.mutable_signature()->set_name("Add");
975
+ Status s = lib_def.AddFunctionDef(fdef);
976
+ EXPECT_FALSE(s.ok());
977
+ EXPECT_EQ(s.error_message(),
978
+ "Cannot add function 'Add' because an op with the same name "
979
+ "already exists.");
980
+
981
+ // Already-added functions don't produce error
982
+ TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::XTimesTwo()));
983
+ TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB()));
984
+ }
985
+
986
+ TEST(FunctionLibraryDefinitionTest, AddGradientDef) {
987
+ // AddGradientDef() doesn't check that functions referenced exist (yet?)
988
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary());
989
+
990
+ // Test adding a gradient (XTimesFour isn't a valid grad function for
991
+ // XTimesTwo but that's ok for now)
992
+ GradientDef grad;
993
+ grad.set_function_name(test::function::XTimesTwo().signature().name());
994
+ grad.set_gradient_func(test::function::XTimesFour().signature().name());
995
+ TF_EXPECT_OK(lib_def.AddGradientDef(grad));
996
+
997
+ // Already-added gradients don't produce error
998
+ TF_EXPECT_OK(lib_def.AddGradientDef(grad));
999
+
1000
+ // Test that adding a duplicate gradient fails
1001
+ grad.set_gradient_func(test::function::XTimes16().signature().name());
1002
+ Status s = lib_def.AddGradientDef(grad);
1003
+ EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
1004
+ EXPECT_EQ(s.error_message(),
1005
+ "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because "
1006
+ "it already has gradient function 'XTimesFour'");
1007
+ }
1008
+
1009
+ TEST(FunctionLibraryDefinitionTest, AddLibrary) {
1010
+ // Create lib def with single function
1011
+ FunctionDefLibrary proto;
1012
+ *proto.add_function() = test::function::XTimesTwo();
1013
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
1014
+
1015
+ // Add gradient
1016
+ GradientDef grad;
1017
+ grad.set_function_name(test::function::XTimesTwo().signature().name());
1018
+ grad.set_gradient_func(test::function::XTimesFour().signature().name());
1019
+ TF_EXPECT_OK(lib_def.AddGradientDef(grad));
1020
+
1021
+ // Error if you try to add conflicting function
1022
+ proto.Clear();
1023
+ FunctionDef fdef = test::function::XTimesFour();
1024
+ fdef.mutable_signature()->set_name(
1025
+ test::function::XTimesTwo().signature().name());
1026
+ *proto.add_function() = fdef;
1027
+ FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto);
1028
+ Status s = lib_def.AddLibrary(lib_def2);
1029
+ EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
1030
+ EXPECT_EQ(s.error_message(),
1031
+ "Cannot add function 'XTimesTwo' because a different function with "
1032
+ "the same name already exists.");
1033
+
1034
+ // Error if you try to add conflicting gradient
1035
+ proto.Clear();
1036
+ grad.set_gradient_func(test::function::XTimes16().signature().name());
1037
+ *proto.add_gradient() = grad;
1038
+ FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto);
1039
+ s = lib_def.AddLibrary(lib_def3);
1040
+ EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
1041
+ EXPECT_EQ(s.error_message(),
1042
+ "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because "
1043
+ "it already has gradient function 'XTimesFour'");
1044
+
1045
+ // No conflicting functions or gradients OK
1046
+ proto.Clear();
1047
+ *proto.add_function() = test::function::XTimesFour();
1048
+ grad.set_function_name(test::function::XTimes16().signature().name());
1049
+ *proto.add_gradient() = grad;
1050
+ FunctionLibraryDefinition lib_def4(OpRegistry::Global(), proto);
1051
+ TF_EXPECT_OK(lib_def.AddLibrary(lib_def4));
1052
+
1053
+ // OK to add the same functions and gradients twice
1054
+ TF_EXPECT_OK(lib_def.AddLibrary(lib_def));
1055
+ }
1056
+
1057
+ GradientDef MakeGradDef(const string& f, const string& g) {
1058
+ GradientDef grad;
1059
+ grad.set_function_name(f);
1060
+ grad.set_gradient_func(g);
1061
+ return grad;
1062
+ }
1063
+
1064
+ TEST(FunctionLibraryDefinitionTest, AddLibrary_Atomic) {
1065
+ // Create lib def containing two functions with equal names
1066
+ FunctionDefLibrary proto;
1067
+ const string x2_name = test::function::XTimesTwo().signature().name();
1068
+ const string x4_name = test::function::XTimesFour().signature().name();
1069
+ *proto.add_function() = test::function::XTimesTwo();
1070
+ FunctionDef fdef = test::function::XTimesFour();
1071
+ fdef.mutable_signature()->set_name(x2_name);
1072
+ *proto.add_function() = fdef;
1073
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary());
1074
+
1075
+ // Try adding the two functions to lib_def
1076
+ Status s = lib_def.AddLibrary(proto);
1077
+ EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
1078
+ EXPECT_EQ(
1079
+ "Cannot add function 'XTimesTwo' because a different function with "
1080
+ "the same name already exists.",
1081
+ s.error_message());
1082
+
1083
+ // Verify that none of the functions are added
1084
+ EXPECT_TRUE(lib_def.Find(x2_name) == nullptr);
1085
+
1086
+ // Fix the name in proto but add two gradient names for it
1087
+ proto.mutable_function(1)->mutable_signature()->set_name(x4_name);
1088
+ *proto.add_gradient() = MakeGradDef(x2_name, x4_name);
1089
+ *proto.add_gradient() = MakeGradDef(x2_name, "SecondGradName");
1090
+
1091
+ // Try adding the library and check that nothing was added
1092
+ s = lib_def.AddLibrary(proto);
1093
+ EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
1094
+ EXPECT_EQ(s.error_message(),
1095
+ "Cannot assign gradient function 'SecondGradName' to 'XTimesTwo' "
1096
+ "because it already has gradient function 'XTimesFour'");
1097
+ EXPECT_TRUE(lib_def.Find(x2_name) == nullptr);
1098
+ EXPECT_EQ(0, lib_def.ToProto().function_size());
1099
+ EXPECT_EQ(0, lib_def.ToProto().gradient_size());
1100
+ }
1101
+
1102
+ TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_FuncConflict) {
1103
+ const string x2_name = test::function::XTimesTwo().signature().name();
1104
+ const string x4_name = test::function::XTimesFour().signature().name();
1105
+ const string wx_name = test::function::WXPlusB().signature().name();
1106
+
1107
+ // Create FunctionLibraryDefinition with
1108
+ // (func = XTimesTwo, grad = XTimesFour)
1109
+ FunctionDefLibrary proto;
1110
+ *proto.add_function() = test::function::XTimesTwo();
1111
+ *proto.add_gradient() = MakeGradDef(x2_name, x4_name);
1112
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
1113
+ EXPECT_EQ(1, lib_def.ToProto().function_size());
1114
+ EXPECT_EQ(1, lib_def.ToProto().gradient_size());
1115
+
1116
+ // Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo)
1117
+ // and function (name = XTimesTwo, body = XTimeFour)
1118
+ FunctionDefLibrary proto2;
1119
+ *proto2.add_function() = test::function::WXPlusB();
1120
+ *proto2.add_gradient() = MakeGradDef(wx_name, x2_name);
1121
+ *proto2.add_function() = test::function::XTimesFour();
1122
+ proto2.mutable_function(1)->mutable_signature()->set_name(x2_name);
1123
+ FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2);
1124
+
1125
+ // Verify that adding lib_def2 will fail because of function conflict
1126
+ // and WXPlusB is not added.
1127
+ Status s = lib_def.AddLibrary(lib_def2);
1128
+ EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
1129
+ EXPECT_EQ(
1130
+ "Cannot add function 'XTimesTwo' because a different function "
1131
+ "with the same name already exists.",
1132
+ s.error_message());
1133
+ EXPECT_TRUE(lib_def.Find(wx_name) == nullptr);
1134
+ EXPECT_EQ(1, lib_def.ToProto().function_size());
1135
+ EXPECT_EQ(1, lib_def.ToProto().gradient_size());
1136
+ }
1137
+
1138
+ TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_GradConflict) {
1139
+ const string x2_name = test::function::XTimesTwo().signature().name();
1140
+ const string x4_name = test::function::XTimesFour().signature().name();
1141
+ const string wx_name = test::function::WXPlusB().signature().name();
1142
+
1143
+ // Create FunctionLibraryDefinition with
1144
+ // (func = XTimesTwo, grad = XTimesFour)
1145
+ FunctionDefLibrary proto;
1146
+ *proto.add_function() = test::function::XTimesTwo();
1147
+ *proto.add_gradient() = MakeGradDef(x2_name, x4_name);
1148
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
1149
+ EXPECT_EQ(1, lib_def.ToProto().function_size());
1150
+ EXPECT_EQ(1, lib_def.ToProto().gradient_size());
1151
+
1152
+ // Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo)
1153
+ // and (func = XTimesTwo, grad = WXPlusB)
1154
+ FunctionDefLibrary proto2;
1155
+ *proto2.add_function() = test::function::WXPlusB();
1156
+ *proto2.add_gradient() = MakeGradDef(wx_name, x2_name);
1157
+ *proto2.add_function() = test::function::XTimesTwo();
1158
+ *proto2.add_gradient() = MakeGradDef(x2_name, wx_name);
1159
+ FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2);
1160
+
1161
+ // Verify that adding lib_def2 will fail because of gradient conflict
1162
+ // and WXPlusB is not added.
1163
+ Status s = lib_def.AddLibrary(lib_def2);
1164
+ EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code());
1165
+ EXPECT_EQ(
1166
+ "Cannot assign gradient function 'WXPlusB' to 'XTimesTwo'"
1167
+ " because it already has gradient function 'XTimesFour'",
1168
+ s.error_message());
1169
+ EXPECT_TRUE(lib_def.Find(wx_name) == nullptr);
1170
+ EXPECT_EQ(1, lib_def.ToProto().function_size());
1171
+ EXPECT_EQ(1, lib_def.ToProto().gradient_size());
1172
+ }
1173
+
1174
+ TEST(FunctionLibraryDefinitionTest, ToProto) {
1175
+ FunctionDefLibrary proto1;
1176
+ *proto1.add_function() = test::function::XTimesTwo();
1177
+ *proto1.add_function() = test::function::WXPlusB();
1178
+ FunctionLibraryDefinition lib_def1(OpRegistry::Global(), proto1);
1179
+
1180
+ // Call 'ToProto' and make sure both protos have the same function lib size.
1181
+ FunctionDefLibrary proto2 = lib_def1.ToProto();
1182
+ EXPECT_EQ(proto1.function_size(), proto2.function_size());
1183
+
1184
+ // Initialize 'lib_def2' with proto returned by 'ToProto' call.
1185
+ FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2);
1186
+
1187
+ // Test that the first function exists in both libraries.
1188
+ const OpDef *f1, *f2, *f3, *f4;
1189
+ TF_EXPECT_OK(lib_def1.LookUpOpDef("XTimesTwo", &f1));
1190
+ TF_EXPECT_OK(lib_def2.LookUpOpDef("XTimesTwo", &f2));
1191
+ EXPECT_EQ(f1->DebugString(), f2->DebugString());
1192
+
1193
+ // Test that the second function exists in both libraries.
1194
+ TF_EXPECT_OK(lib_def1.LookUpOpDef("WXPlusB", &f3));
1195
+ TF_EXPECT_OK(lib_def2.LookUpOpDef("WXPlusB", &f4));
1196
+ EXPECT_EQ(f3->DebugString(), f4->DebugString());
1197
+ }
1198
+
1199
+ TEST(FunctionLibraryDefinitionTest, GetAttr_FuncNoAttr) {
1200
+ FunctionDefLibrary proto;
1201
+ *proto.add_function() = test::function::XTimesTwo();
1202
+ FunctionLibraryDefinition lib(OpRegistry::Global(), proto);
1203
+
1204
+ NodeDef ndef;
1205
+ bool annotation;
1206
+
1207
+ // Not a function.
1208
+ ndef.set_op("Matmul");
1209
+ EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
1210
+
1211
+ // A function. No attr defined.
1212
+ ndef.set_op("XTimesTwo");
1213
+ EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
1214
+
1215
+ // ndef defines the attr. But we don't care.
1216
+ AddNodeAttr("annotation", true, &ndef);
1217
+ EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
1218
+ }
1219
+
1220
+ template <typename T>
1221
+ void SetAttrValue(FunctionDef* fdef, const string& attr, const T& value) {
1222
+ AttrValue attr_value;
1223
+ SetAttrValue(value, &attr_value);
1224
+ fdef->mutable_attr()->insert({attr, attr_value});
1225
+ }
1226
+
1227
+ TEST(FunctionLibraryDefinitionTest, GetAttr_FuncWithAttr) {
1228
+ FunctionDefLibrary proto;
1229
+ auto fdef = proto.add_function();
1230
+ *fdef = test::function::XTimesTwo();
1231
+ SetAttrValue(fdef, "annotation", true);
1232
+ SetAttrValue(fdef, "options", "some string data");
1233
+ FunctionLibraryDefinition lib(OpRegistry::Global(), proto);
1234
+
1235
+ NodeDef ndef;
1236
+ bool annotation;
1237
+
1238
+ // A function. No attr defined in ndef.
1239
+ ndef.set_op("XTimesTwo");
1240
+ TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation));
1241
+ EXPECT_EQ(annotation, true);
1242
+
1243
+ string str;
1244
+ TF_EXPECT_OK(lib.GetAttr(ndef, "options", &str));
1245
+ EXPECT_EQ(str, "some string data");
1246
+ }
1247
+
1248
+ TEST(FunctionLibraryDefinitionTest, GetAttr_Gradient) {
1249
+ FunctionDefLibrary proto;
1250
+ auto fdef = proto.add_function();
1251
+ *fdef = test::function::XTimesTwo();
1252
+ SetAttrValue(fdef, "annotation", true);
1253
+ *fdef = test::function::WXPlusB();
1254
+ SetAttrValue(fdef, "annotation", false);
1255
+ auto func_grad = proto.add_gradient();
1256
+ func_grad->set_function_name("XTimesTwo");
1257
+ func_grad->set_gradient_func("WXPlusB");
1258
+ FunctionLibraryDefinition lib(OpRegistry::Global(), proto);
1259
+
1260
+ NodeDef ndef;
1261
+ ndef.set_op(FunctionLibraryDefinition::kGradientOp);
1262
+
1263
+ bool annotation;
1264
+ EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
1265
+
1266
+ NameAttrList nal;
1267
+ nal.set_name("XTimesTwo");
1268
+ AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef);
1269
+ TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation));
1270
+ EXPECT_EQ(annotation, false); // XTimesTwo's gradient is WXPlusB.
1271
+
1272
+ nal.set_name("WXPlusB");
1273
+ ndef.clear_attr();
1274
+ AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef);
1275
+ TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation));
1276
+ EXPECT_EQ(annotation, false); // WXPlusB has no custom gradient.
1277
+ }
1278
+
1279
+ // TODO(skyewm): this could be more thorough
1280
+ TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) {
1281
+ // Equal functions
1282
+ const FunctionDef fdef1 = test::function::XTimesTwo();
1283
+ FunctionDef fdef2 = test::function::XTimesTwo();
1284
+ uint64 hash1 = FunctionDefHash(fdef1);
1285
+ EXPECT_TRUE(FunctionDefsEqual(fdef1, fdef2));
1286
+ EXPECT_EQ(hash1, FunctionDefHash(fdef2));
1287
+
1288
+ // Different functions
1289
+ fdef2 = test::function::XTimesFour();
1290
+ EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
1291
+ EXPECT_NE(hash1, FunctionDefHash(fdef2));
1292
+
1293
+ // Different signatures
1294
+ fdef2 = test::function::XTimesTwo();
1295
+ fdef2.mutable_signature()->mutable_input_arg(0)->set_name("foo");
1296
+ EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
1297
+ EXPECT_NE(hash1, FunctionDefHash(fdef2));
1298
+
1299
+ // Descriptions must be equal
1300
+ fdef2 = test::function::XTimesTwo();
1301
+ fdef2.mutable_signature()->mutable_input_arg(0)->set_description("foo");
1302
+ EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
1303
+ EXPECT_NE(hash1, FunctionDefHash(fdef2));
1304
+
1305
+ // Different NodeDefs
1306
+ fdef2 = test::function::XTimesTwo();
1307
+ NodeDef* ndef = fdef2.add_node_def();
1308
+ *ndef = fdef2.node_def(0);
1309
+ ndef->set_name("new_name");
1310
+ EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
1311
+ EXPECT_NE(hash1, FunctionDefHash(fdef2));
1312
+
1313
+ // Different return values
1314
+ fdef2 = test::function::XTimesTwo();
1315
+ (*fdef2.mutable_ret())["y"] = "y:z:1"; // originally is "y:z:0"
1316
+ EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
1317
+ EXPECT_NE(hash1, FunctionDefHash(fdef2));
1318
+
1319
+ // Different attributes
1320
+ fdef2 = test::function::XTimesTwo();
1321
+ SetAttrValue(&fdef2, "ExtraAttr", true);
1322
+ EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
1323
+ EXPECT_NE(hash1, FunctionDefHash(fdef2));
1324
+
1325
+ // Multiple equivalent attributes; the two functions should be equal.
1326
+ fdef2 = test::function::XTimesTwo();
1327
+ FunctionDef fdef3 = test::function::XTimesTwo();
1328
+ SetAttrValue(&fdef2, "Foo", true);
1329
+ SetAttrValue(&fdef3, "Foo", true);
1330
+ SetAttrValue(&fdef2, "Bar", 123);
1331
+ SetAttrValue(&fdef3, "Bar", 123);
1332
+ SetAttrValue(&fdef2, "Baz", "abc");
1333
+ SetAttrValue(&fdef3, "Baz", "abc");
1334
+ EXPECT_TRUE(FunctionDefsEqual(fdef2, fdef3));
1335
+ EXPECT_EQ(FunctionDefHash(fdef2), FunctionDefHash(fdef3));
1336
+ }
1337
+
1338
+ } // end namespace
1339
+ } // end namespace tensorflow
function_testlib.cc ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/function_testlib.h"
17
+
18
+ #include "tensorflow/core/framework/function.h"
19
+ #include "tensorflow/core/framework/node_def.pb.h"
20
+ #include "tensorflow/core/framework/tensor_testutil.h"
21
+ #include "tensorflow/core/framework/versions.pb.h"
22
+ #include "tensorflow/core/lib/core/threadpool.h"
23
+ #include "tensorflow/core/public/version.h"
24
+
25
+ namespace tensorflow {
26
+ namespace test {
27
+ namespace function {
28
+
29
+ typedef FunctionDefHelper FDH;
30
+
31
+ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
32
+ gtl::ArraySlice<FunctionDef> funcs) {
33
+ GraphDef g;
34
+ VersionDef* versions = g.mutable_versions();
35
+ versions->set_producer(TF_GRAPH_DEF_VERSION);
36
+ versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
37
+ for (const auto& n : nodes) {
38
+ *(g.add_node()) = n;
39
+ }
40
+ auto lib = g.mutable_library();
41
+ for (const auto& f : funcs) {
42
+ *(lib->add_function()) = f;
43
+ }
44
+ return g;
45
+ }
46
+
47
+ // Helper to construct a NodeDef.
48
+ NodeDef NDef(const string& name, const string& op,
49
+ gtl::ArraySlice<string> inputs,
50
+ gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs,
51
+ const string& device) {
52
+ NodeDef n;
53
+ n.set_name(name);
54
+ n.set_op(op);
55
+ for (const auto& in : inputs) n.add_input(in);
56
+ n.set_device(device);
57
+ for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
58
+ return n;
59
+ }
60
+
61
+ FunctionDef NonZero() {
62
+ return FDH::Define(
63
+ // Name
64
+ "NonZero",
65
+ // Args
66
+ {"x:T"},
67
+ // Return values
68
+ {"y:T"},
69
+ // Attr def
70
+ {"T:{float, double, int32, int64, string}"},
71
+ // Nodes
72
+ {
73
+ {{"y"}, "Identity", {"x"}, {{"T", "$T"}}},
74
+ });
75
+ }
76
+
77
+ FunctionDef XTimesTwo() {
78
+ const Tensor kTwo = test::AsScalar<int64>(2);
79
+ return FDH::Define(
80
+ // Name
81
+ "XTimesTwo",
82
+ // Args
83
+ {"x: T"},
84
+ // Return values
85
+ {"y: T"},
86
+ // Attr def
87
+ {"T: {float, double, int32, int64}"},
88
+ // Nodes
89
+ {
90
+ {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
91
+ {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
92
+ {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
93
+ });
94
+ }
95
+
96
+ FunctionDef XTimesTwoInt32() {
97
+ const Tensor kTwo = test::AsScalar<int64>(2);
98
+ return FDH::Define(
99
+ // Name
100
+ "XTimesTwoInt32",
101
+ // Args
102
+ {"x: int32"},
103
+ // Return values
104
+ {"y: int32"}, {},
105
+ // Nodes
106
+ {
107
+ {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
108
+ {{"scale"},
109
+ "Cast",
110
+ {"two"},
111
+ {{"SrcT", DT_INT64}, {"DstT", DT_INT32}}},
112
+ {{"y"}, "Mul", {"x", "scale"}, {{"T", DT_INT32}}},
113
+ });
114
+ }
115
+
116
+ FunctionDef XTimesFour() {
117
+ return FDH::Create(
118
+ // Name
119
+ "XTimesFour",
120
+ // Args
121
+ {"x: T"},
122
+ // Return values
123
+ {"y: T"},
124
+ // Attr def
125
+ {"T: {float, double, int32, int64}"},
126
+ // Nodes
127
+ {
128
+ {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
129
+ {{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}},
130
+ },
131
+ {{"y", "y:y:0"}});
132
+ }
133
+
134
+ FunctionDef XTimes16() {
135
+ return FDH::Create(
136
+ // Name
137
+ "XTimes16",
138
+ // Args
139
+ {"x: T"},
140
+ // Return values
141
+ {"y: T"},
142
+ // Attr def
143
+ {"T: {float, double, int32, int64}"},
144
+ // Nodes
145
+ {
146
+ {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
147
+ {{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}},
148
+ },
149
+ {{"y", "y:y:0"}});
150
+ }
151
+
152
+ FunctionDef WXPlusB(){return FDH::Define(
153
+ // Name
154
+ "WXPlusB",
155
+ // Args
156
+ {"w: T", "x: T", "b: T"},
157
+ // Return values
158
+ {"y: T"},
159
+ // Attr def
160
+ {"T: {float, double}"},
161
+ // Nodes
162
+ {
163
+ {{"mm"},
164
+ "MatMul",
165
+ {"w", "x"},
166
+ {
167
+ {"T", "$T"}, {"transpose_a", false}, {"transpose_b", false},
168
+ #ifdef INTEL_MKL
169
+ }},
170
+ #else
171
+ {"_kernel", "eigen"}}},
172
+ #endif
173
+ {
174
+ {"y"}, "Add", {"mm", "b"}, {
175
+ { "T", "$T" }
176
+ }
177
+ }
178
+ });
179
+ }
180
+
181
+ FunctionDef Swap() {
182
+ return FDH::Define(
183
+ // Name
184
+ "Swap",
185
+ // Args
186
+ {"i0: T", "i1: T"},
187
+ // Return values
188
+ {"o0: T", "o1: T"},
189
+ // Attr def
190
+ {"T: {float, double}"},
191
+ // Nodes
192
+ {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}},
193
+ {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}});
194
+ }
195
+
196
+ void FunctionTestSchedClosure(std::function<void()> fn) {
197
+ static thread::ThreadPool* w =
198
+ new thread::ThreadPool(Env::Default(), "Test", 8);
199
+ w->Schedule(std::move(fn));
200
+ }
201
+
202
+ } // end namespace function
203
+ } // end namespace test
204
+ } // end namespace tensorflow
function_testlib.h ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_
17
+ #define TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_
18
+
19
+ #include <string>
20
+
21
+ #include "tensorflow/core/framework/attr_value_util.h"
22
+ #include "tensorflow/core/framework/function.h"
23
+ #include "tensorflow/core/framework/function.pb.h"
24
+ #include "tensorflow/core/framework/graph.pb.h"
25
+ #include "tensorflow/core/framework/node_def.pb.h"
26
+ #include "tensorflow/core/lib/gtl/array_slice.h"
27
+ #include "tensorflow/core/platform/types.h"
28
+
29
+ namespace tensorflow {
30
+ namespace test {
31
+ namespace function {
32
+
33
+ // A helper class to make AttrSlice from initializer lists
34
+ class Attrs {
35
+ public:
36
+ Attrs(const std::initializer_list< // NOLINT(runtime/explicit)
37
+ std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) {
38
+ for (const auto& aval : attrs) {
39
+ map_.insert({aval.first, aval.second.proto});
40
+ }
41
+ }
42
+
43
+ operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit)
44
+
45
+ private:
46
+ AttrValueMap map_;
47
+ };
48
+
49
+ // Helper to construct a NodeDef.
50
+ NodeDef NDef(
51
+ const string& name, const string& op, gtl::ArraySlice<string> inputs,
52
+ gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>>
53
+ attrs = {},
54
+ const string& device = "");
55
+
56
+ // Helper to construct a GraphDef proto.
57
+ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
58
+ gtl::ArraySlice<FunctionDef> funcs = {});
59
+
60
+ // For testing convenience, we provide a few simple functions that can
61
+ // be easily executed and tested.
62
+
63
+ // x:T -> x * 2.
64
+ FunctionDef XTimesTwo();
65
+
66
+ // x:T -> x * 2, where x is int32.
67
+ FunctionDef XTimesTwoInt32();
68
+
69
+ // x:T -> (x * 2) * 2.
70
+ FunctionDef XTimesFour();
71
+
72
+ // x:T -> ((x * 2) * 2) * 2.
73
+ FunctionDef XTimes16();
74
+
75
+ // w:T, x:T, b:T -> MatMul(w, x) + b
76
+ FunctionDef WXPlusB();
77
+
78
+ // x:T -> x:T, T is a type which we automatically converts to a bool.
79
+ FunctionDef NonZero();
80
+
81
+ // x:T, y:T -> y:T, x:T
82
+ FunctionDef Swap();
83
+
84
+ void FunctionTestSchedClosure(std::function<void()> fn);
85
+
86
+ } // end namespace function
87
+ } // end namespace test
88
+ } // end namespace tensorflow
89
+
90
+ #endif // TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_
graph.proto ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package tensorflow;
4
+ option cc_enable_arenas = true;
5
+ option java_outer_classname = "GraphProtos";
6
+ option java_multiple_files = true;
7
+ option java_package = "org.tensorflow.framework";
8
+
9
+ import "tensorflow/core/framework/node_def.proto";
10
+ import "tensorflow/core/framework/function.proto";
11
+ import "tensorflow/core/framework/versions.proto";
12
+
13
+ // Represents the graph of operations
14
+ message GraphDef {
15
+ repeated NodeDef node = 1;
16
+
17
+ // Compatibility versions of the graph. See core/public/version.h for version
18
+ // history. The GraphDef version is distinct from the TensorFlow version, and
19
+ // each release of TensorFlow will support a range of GraphDef versions.
20
+ VersionDef versions = 4;
21
+
22
+ // Deprecated single version field; use versions above instead. Since all
23
+ // GraphDef changes before "versions" was introduced were forward
24
+ // compatible, this field is entirely ignored.
25
+ int32 version = 3 [deprecated = true];
26
+
27
+ // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
28
+ //
29
+ // "library" provides user-defined functions.
30
+ //
31
+ // Naming:
32
+ // * library.function.name are in a flat namespace.
33
+ // NOTE: We may need to change it to be hierarchical to support
34
+ // different orgs. E.g.,
35
+ // { "/google/nn", { ... }},
36
+ // { "/google/vision", { ... }}
37
+ // { "/org_foo/module_bar", { ... }}
38
+ // map<string, FunctionDefLib> named_lib;
39
+ // * If node[i].op is the name of one function in "library",
40
+ // node[i] is deemed as a function call. Otherwise, node[i].op
41
+ // must be a primitive operation supported by the runtime.
42
+ //
43
+ //
44
+ // Function call semantics:
45
+ //
46
+ // * The callee may start execution as soon as some of its inputs
47
+ // are ready. The caller may want to use Tuple() mechanism to
48
+ // ensure all inputs are ready in the same time.
49
+ //
50
+ // * The consumer of return values may start executing as soon as
51
+ // the return values the consumer depends on are ready. The
52
+ // consumer may want to use Tuple() mechanism to ensure the
53
+ // consumer does not start until all return values of the callee
54
+ // function are ready.
55
+ FunctionDefLibrary library = 2;
56
+ };
graph_def_util.cc ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/graph_def_util.h"
17
+
18
+ #include <set>
19
+ #include <unordered_map>
20
+ #include <unordered_set>
21
+ #include <vector>
22
+
23
+ #include "tensorflow/core/framework/attr_value.pb.h"
24
+ #include "tensorflow/core/framework/function.pb.h"
25
+ #include "tensorflow/core/framework/graph.pb.h"
26
+ #include "tensorflow/core/framework/node_def.pb.h"
27
+ #include "tensorflow/core/framework/node_def_util.h"
28
+ #include "tensorflow/core/framework/op_def_util.h"
29
+ #include "tensorflow/core/framework/versions.pb_text.h"
30
+ #include "tensorflow/core/lib/core/errors.h"
31
+ #include "tensorflow/core/lib/core/status.h"
32
+ #include "tensorflow/core/lib/strings/strcat.h"
33
+
34
+ namespace tensorflow {
35
+
36
+ string SummarizeGraphDef(const GraphDef& graph_def) {
37
+ string ret;
38
+ strings::StrAppend(&ret, "versions = ",
39
+ ProtoShortDebugString(graph_def.versions()), ";\n");
40
+ for (const NodeDef& node : graph_def.node()) {
41
+ strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
42
+ }
43
+ return ret;
44
+ }
45
+
46
+ Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
47
+ for (const NodeDef& node : graph_def.node()) {
48
+ TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));
49
+ }
50
+ return Status::OK();
51
+ }
52
+
53
+ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
54
+ const OpRegistryInterface& op_registry,
55
+ int node_offset) {
56
+ if (node_offset > graph_def->node_size()) {
57
+ return errors::InvalidArgument(
58
+ "Tried to add default attrs to GraphDef "
59
+ "starting at offset ",
60
+ node_offset, " with total nodes in graph: ", graph_def->node_size());
61
+ }
62
+
63
+ for (int i = node_offset; i < graph_def->node_size(); ++i) {
64
+ NodeDef* node_def = graph_def->mutable_node(i);
65
+ const OpDef* op_def;
66
+ TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(node_def->op(), &op_def));
67
+ AddDefaultsToNodeDef(*op_def, node_def);
68
+ }
69
+
70
+ return Status::OK();
71
+ }
72
+
73
+ static Status RemoveNewDefaultAttrsFromNodeDef(
74
+ NodeDef* node_def, const OpRegistryInterface& consumer_op_registry,
75
+ const OpRegistryInterface& producer_op_registry,
76
+ std::set<std::pair<string, string>>* op_attr_removed) {
77
+ const OpDef* producer_op_def;
78
+ const OpDef* consumer_op_def;
79
+ TF_RETURN_IF_ERROR(
80
+ producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def));
81
+ TF_RETURN_IF_ERROR(
82
+ consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def));
83
+
84
+ std::vector<string> to_remove;
85
+ for (const auto& attr : node_def->attr()) {
86
+ // If the attr is not in consumer_op_def and doesn't start with '_'...
87
+ if (!StringPiece(attr.first).starts_with("_") &&
88
+ FindAttr(attr.first, *consumer_op_def) == nullptr) {
89
+ const OpDef::AttrDef* producer_attr_def =
90
+ FindAttr(attr.first, *producer_op_def);
91
+ if (producer_attr_def == nullptr) {
92
+ return errors::InvalidArgument(
93
+ "Attr '", attr.first, "' missing in producer's OpDef: ",
94
+ SummarizeOpDef(*producer_op_def), " but found in node: ",
95
+ SummarizeNodeDef(*node_def));
96
+ }
97
+ // ...and it has the same value as the default in producer,
98
+ if (producer_attr_def->has_default_value() &&
99
+ AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) {
100
+ // then we will remove it below.
101
+ to_remove.emplace_back(attr.first);
102
+ }
103
+ }
104
+ }
105
+ // We separate identifying which attrs should be removed from
106
+ // actually removing them to avoid invalidating the loop iterators
107
+ // above.
108
+ for (const string& attr_name : to_remove) {
109
+ node_def->mutable_attr()->erase(attr_name);
110
+ if (op_attr_removed != nullptr) {
111
+ op_attr_removed->insert(std::make_pair(node_def->op(), attr_name));
112
+ }
113
+ }
114
+
115
+ return Status::OK();
116
+ }
117
+
118
+ static bool IsFunction(const GraphDef& graph_def, const string& op_name) {
119
+ for (const auto& func_def : graph_def.library().function()) {
120
+ if (op_name == func_def.signature().name()) return true;
121
+ }
122
+ return false;
123
+ }
124
+
125
+ Status RemoveNewDefaultAttrsFromGraphDef(
126
+ GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry,
127
+ const OpRegistryInterface& producer_op_registry,
128
+ std::set<std::pair<string, string>>* op_attr_removed) {
129
+ // TODO(joshL): Make IsFunction() faster by collecting the names of
130
+ // all functions as a preprocessing step.
131
+ for (int n = 0; n < graph_def->node_size(); ++n) {
132
+ NodeDef* node_def = graph_def->mutable_node(n);
133
+ if (!IsFunction(*graph_def, node_def->op())) {
134
+ TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
135
+ node_def, consumer_op_registry, producer_op_registry,
136
+ op_attr_removed));
137
+ }
138
+ }
139
+ for (int f = 0; f < graph_def->library().function_size(); ++f) {
140
+ FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f);
141
+ for (int n = 0; n < func_def->node_def_size(); ++n) {
142
+ NodeDef* node_def = func_def->mutable_node_def(n);
143
+ if (!IsFunction(*graph_def, node_def->op())) {
144
+ // TODO(josh11b): Better handling of attrs with placeholder values.
145
+ TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
146
+ node_def, consumer_op_registry, producer_op_registry,
147
+ op_attr_removed));
148
+ }
149
+ }
150
+ }
151
+
152
+ return Status::OK();
153
+ }
154
+
155
+ void OpsUsedByGraph(const GraphDef& graph_def,
156
+ std::set<string>* ops_used_in_graph) {
157
+ // Map function names to definitions.
158
+ std::unordered_map<string, const FunctionDef*> name_to_function;
159
+ for (const auto& function : graph_def.library().function()) {
160
+ name_to_function.insert(
161
+ std::make_pair(function.signature().name(), &function));
162
+ }
163
+
164
+ // Collect the sorted list of op names. Since functions can reference
165
+ // functions, we need a recursive traversal.
166
+ std::set<string> used_ops; // Includes both primitive ops and functions
167
+ std::vector<const FunctionDef*> functions_to_process; // A subset of used_ops
168
+ // Collect the logic to mark an op in a lambda; it'll be used twice below.
169
+ const auto mark_op_as_used = [&used_ops, &functions_to_process,
170
+ &name_to_function](const string& op) {
171
+ if (used_ops.insert(op).second) {
172
+ // If it's a function, we'll need to process further
173
+ const auto it = name_to_function.find(op);
174
+ if (it != name_to_function.end()) {
175
+ functions_to_process.push_back(it->second);
176
+ }
177
+ }
178
+ };
179
+ for (const auto& node : graph_def.node()) {
180
+ mark_op_as_used(node.op());
181
+ }
182
+ while (!functions_to_process.empty()) {
183
+ const FunctionDef* fun = functions_to_process.back();
184
+ functions_to_process.pop_back();
185
+ for (const auto& node : fun->node_def()) {
186
+ mark_op_as_used(node.op());
187
+ }
188
+ }
189
+
190
+ // Filter out function names to produce output.
191
+ // TODO(josh11b): Change the above code to produce this directly.
192
+ ops_used_in_graph->clear();
193
+ for (const string& op_name : used_ops) {
194
+ if (name_to_function.find(op_name) == name_to_function.end()) {
195
+ ops_used_in_graph->insert(op_name);
196
+ }
197
+ }
198
+ }
199
+
200
+ Status StrippedOpListForGraph(const GraphDef& graph_def,
201
+ const OpRegistryInterface& op_registry,
202
+ OpList* stripped_op_list) {
203
+ std::set<string> used_ops;
204
+ OpsUsedByGraph(graph_def, &used_ops);
205
+
206
+ // Build the stripped op list in sorted order, ignoring functions.
207
+ stripped_op_list->clear_op();
208
+ for (const string& op_name : used_ops) {
209
+ const OpDef* op_def;
210
+ TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def));
211
+ OpDef* stripped_op = stripped_op_list->add_op();
212
+ stripped_op->CopyFrom(*op_def);
213
+ RemoveDescriptionsFromOpDef(stripped_op);
214
+ }
215
+ return Status::OK();
216
+ }
217
+
218
+ } // namespace tensorflow
graph_def_util.h ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
17
+ #define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
18
+
19
+ #include <set>
20
+ #include "tensorflow/core/framework/op.h"
21
+ #include "tensorflow/core/lib/core/status.h"
22
+
23
+ namespace tensorflow {
24
+
25
+ // Forward declare proto so that it's symbols can be removed from .so exports
26
+ class GraphDef;
27
+
28
+ // Produce a human-readable version of a GraphDef that is more concise
29
+ // than a text-format proto.
30
+ string SummarizeGraphDef(const GraphDef& graph_def);
31
+
32
+ // Validates the syntax of a GraphDef provided externally.
33
+ //
34
+ // The following is an EBNF-style syntax for GraphDef objects. Note that
35
+ // Node objects are actually specified as tensorflow::NodeDef protocol buffers,
36
+ // which contain many other fields that are not (currently) validated.
37
+ //
38
+ // Graph = Node *
39
+ // Node = NodeName, Inputs
40
+ // Inputs = ( DataInput * ), ( ControlInput * )
41
+ // DataInput = NodeName, ( ":", [1-9], [0-9] * ) ?
42
+ // ControlInput = "^", NodeName
43
+ // NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] *
44
+ Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def);
45
+
46
+ // Adds default attributes to NodeDefs in 'graph_def' starting
47
+ // from the 'node_offset' node in 'graph_def'.
48
+ //
49
+ // Default attributes are defined by 'op_registry'.
50
+ //
51
+ // Returns OK on success, an error if 'graph_def' has a NodeDef
52
+ // that cannot be found in 'op_registry'.
53
+ //
54
+ // REQUIRES: 'graph_def' and 'op_registry' are not nullptr.
55
+ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
56
+ const OpRegistryInterface& op_registry,
57
+ int node_offset);
58
+
59
+ // Remove attrs from 'graph_def' that have the default value according
60
+ // to 'producer_op_registry', but don't exist according to
61
+ // 'consumer_op_registry'. This can allow 'graph_def' to run on the
62
+ // consumer even if consumer was built at an earlier CL (before an
63
+ // attr with a default was added). Note that this will not affect
64
+ // attrs with non-default values, so you must run a
65
+ // ValidateGraphDef...() function to see if the result is in fact
66
+ // compatible. If not nullptr, the op/attr pairs that were removed
67
+ // are added to '*op_attr_removed'.
68
+ //
69
+ // Expected usage, for a producer that wants to prepare a graph for
70
+ // a consumer:
71
+ // // For each consumer, update 'graph_def':
72
+ // OpListOpRegistry consumer_op_registry(consumer_server_op_list);
73
+ // std::unordered_set<std::pair<string, string>> op_attr_removed;
74
+ // TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromGraphDef(
75
+ // &graph_def, consumer_op_registry, *OpRegistry::Global(),
76
+ // &op_attr_removed));
77
+ // // Validate that each consumer can understand the resulting 'graph_def'
78
+ // TF_RETURN_IF_ERROR(graph::ValidateGraphDefAgainstOpRegistry(
79
+ // graph_def, consumer_op_registry));
80
+ // // Consumer can use 'graph_def', and 'op_attr_removed' summarizes
81
+ // // what changes had to be made to 'graph_def' for it to work.
82
+ //
83
+ // Expected usage, for a consumer that has a graph and a
84
+ // (optionally-stripped) op_list from a producer (say from a call to
85
+ // StrippedOpListForGraph(), or in the MetaGraphDef):
86
+ // OpListOpRegistry producer_op_registry(producer_stripped_op_list);
87
+ // TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromGraphDef(
88
+ // &graph_def, *OpRegistry::Global(), producer_op_registry, nullptr));
89
+ Status RemoveNewDefaultAttrsFromGraphDef(
90
+ GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry,
91
+ const OpRegistryInterface& producer_op_registry,
92
+ std::set<std::pair<string, string>>* op_attr_removed);
93
+
94
+ // Two functions that collect the ops used by a graph.
95
+ //
96
+ // This returns the ops used as a set of strings.
97
+ void OpsUsedByGraph(const GraphDef& graph_def,
98
+ std::set<string>* ops_used_in_graph);
99
+
100
+ // This function computes the stripped_op_list field of MetaGraphDef
101
+ // and similar protos. The op_registry should contain the ops used to
102
+ // produce graph_def. The resulting stripped_op_list can be
103
+ // communicated from the producer to the consumer, which can use
104
+ // RemoveNewDefaultAttrsFromGraphDef() to improve forwards compatibility
105
+ // (using an OpListOpRegistry as indicated in the example above).
106
+ //
107
+ // Most users will pass *OpRegistry::Global() for op_registry to strip against
108
+ // the list of ops registered in this process.
109
+ Status StrippedOpListForGraph(const GraphDef& graph_def,
110
+ const OpRegistryInterface& op_registry,
111
+ OpList* stripped_op_list);
112
+
113
+ } // namespace tensorflow
114
+
115
+ #endif // TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
graph_def_util_test.cc ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/graph_def_util.h"
17
+
18
+ #include "tensorflow/core/framework/function.h"
19
+ #include "tensorflow/core/framework/graph.pb.h"
20
+ #include "tensorflow/core/framework/node_def_builder.h"
21
+ #include "tensorflow/core/framework/op.h"
22
+ #include "tensorflow/core/framework/op_def.pb.h"
23
+ #include "tensorflow/core/framework/op_def_builder.h"
24
+ #include "tensorflow/core/lib/core/status_test_util.h"
25
+ #include "tensorflow/core/platform/test.h"
26
+ #include "tensorflow/core/util/equal_graph_def.h"
27
+
28
+ namespace tensorflow {
29
+ namespace {
30
+
31
+ Status FinalizeOpDef(const OpDefBuilder& b, OpDef* op_def) {
32
+ OpRegistrationData op_reg_data;
33
+ const Status s = b.Finalize(&op_reg_data);
34
+ *op_def = op_reg_data.op_def;
35
+ return s;
36
+ }
37
+
38
+ // Producer and consumer have default for an attr -> graph unchanged.
39
+ TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) {
40
+ OpList op_list;
41
+ TF_ASSERT_OK(
42
+ FinalizeOpDef(OpDefBuilder("NoChangeWithDefault").Attr("a: int = 12"),
43
+ op_list.add_op()));
44
+ OpListOpRegistry registry(&op_list);
45
+
46
+ GraphDef graph_def;
47
+ TF_ASSERT_OK(NodeDefBuilder("ncwd", "NoChangeWithDefault", &registry)
48
+ .Finalize(graph_def.add_node()));
49
+ GraphDef expected_graph_def = graph_def;
50
+
51
+ std::set<std::pair<string, string>> op_attr_removed;
52
+ TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
53
+ &op_attr_removed));
54
+
55
+ TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
56
+ EXPECT_TRUE(op_attr_removed.empty());
57
+ }
58
+
59
+ // Producer and consumer both have an attr -> graph unchanged.
60
+ TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) {
61
+ OpList op_list;
62
+ TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("NoChangeNoDefault").Attr("a: int"),
63
+ op_list.add_op()));
64
+ OpListOpRegistry registry(&op_list);
65
+
66
+ GraphDef graph_def;
67
+ TF_ASSERT_OK(NodeDefBuilder("ncnd", "NoChangeNoDefault", &registry)
68
+ .Attr("a", 42)
69
+ .Finalize(graph_def.add_node()));
70
+ GraphDef expected_graph_def = graph_def;
71
+
72
+ std::set<std::pair<string, string>> op_attr_removed;
73
+ TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
74
+ &op_attr_removed));
75
+
76
+ TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
77
+ EXPECT_TRUE(op_attr_removed.empty());
78
+ }
79
+
80
+ // Producer has default for an attr that the consumer does not know
81
+ // about, and the produced graph has the default value for the attr ->
82
+ // attr removed from graph (and so able to be consumed).
83
+ TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) {
84
+ OpList consumer_op_list;
85
+ TF_ASSERT_OK(
86
+ FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
87
+ OpListOpRegistry consumer_registry(&consumer_op_list);
88
+
89
+ OpList producer_op_list;
90
+ TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
91
+ producer_op_list.add_op()));
92
+ OpListOpRegistry producer_registry(&producer_op_list);
93
+
94
+ GraphDef produced_graph_def;
95
+ TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &producer_registry)
96
+ .Finalize(produced_graph_def.add_node()));
97
+
98
+ std::set<std::pair<string, string>> op_attr_removed;
99
+ TF_ASSERT_OK(
100
+ RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
101
+ producer_registry, &op_attr_removed));
102
+
103
+ GraphDef expected_graph_def;
104
+ TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &consumer_registry)
105
+ .Finalize(expected_graph_def.add_node()));
106
+ TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
107
+
108
+ std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
109
+ EXPECT_EQ(expected_removed, op_attr_removed);
110
+ }
111
+
112
+ // Producer has default for an attr that the consumer does not know
113
+ // about, graph sets the attr to a value different from the default ->
114
+ // graph unchanged (but not able to be consumed by consumer).
115
+ TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) {
116
+ OpList consumer_op_list;
117
+ TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
118
+ consumer_op_list.add_op()));
119
+ OpListOpRegistry consumer_registry(&consumer_op_list);
120
+
121
+ OpList producer_op_list;
122
+ TF_ASSERT_OK(
123
+ FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
124
+ producer_op_list.add_op()));
125
+ OpListOpRegistry producer_registry(&producer_op_list);
126
+
127
+ GraphDef produced_graph_def;
128
+ TF_ASSERT_OK(NodeDefBuilder("changed_from_default", "ChangedFromDefault",
129
+ &producer_registry)
130
+ .Attr("a", 9)
131
+ .Finalize(produced_graph_def.add_node()));
132
+ GraphDef expected_graph_def = produced_graph_def;
133
+
134
+ std::set<std::pair<string, string>> op_attr_removed;
135
+ TF_ASSERT_OK(
136
+ RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
137
+ producer_registry, &op_attr_removed));
138
+
139
+ TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
140
+ EXPECT_TRUE(op_attr_removed.empty());
141
+ }
142
+
143
+ // Attrs starting with underscores should not be removed.
144
+ TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) {
145
+ OpList consumer_op_list;
146
+ TF_ASSERT_OK(
147
+ FinalizeOpDef(OpDefBuilder("Underscore"), consumer_op_list.add_op()));
148
+ OpListOpRegistry consumer_registry(&consumer_op_list);
149
+
150
+ OpList producer_op_list;
151
+ TF_ASSERT_OK(
152
+ FinalizeOpDef(OpDefBuilder("Underscore"), producer_op_list.add_op()));
153
+ // Add the _underscore attr manually since OpDefBuilder would complain
154
+ OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr();
155
+ attr->set_name("_underscore");
156
+ attr->set_type("int");
157
+ attr->mutable_default_value()->set_i(17);
158
+ OpListOpRegistry producer_registry(&producer_op_list);
159
+
160
+ GraphDef produced_graph_def;
161
+ TF_ASSERT_OK(NodeDefBuilder("node", "Underscore", &producer_registry)
162
+ .Attr("_underscore", 17)
163
+ .Finalize(produced_graph_def.add_node()));
164
+ GraphDef expected_graph_def = produced_graph_def;
165
+
166
+ std::set<std::pair<string, string>> op_attr_removed;
167
+ TF_ASSERT_OK(
168
+ RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
169
+ producer_registry, &op_attr_removed));
170
+
171
+ TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
172
+ EXPECT_EQ(op_attr_removed.size(), 0);
173
+ }
174
+
175
+ TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) {
176
+ OpList consumer_op_list;
177
+ TF_ASSERT_OK(
178
+ FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
179
+ TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
180
+ consumer_op_list.add_op()));
181
+ OpListOpRegistry consumer_registry(&consumer_op_list);
182
+
183
+ OpList producer_op_list;
184
+ TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
185
+ producer_op_list.add_op()));
186
+ TF_ASSERT_OK(
187
+ FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
188
+ producer_op_list.add_op()));
189
+ OpListOpRegistry producer_registry(&producer_op_list);
190
+
191
+ GraphDef produced_graph_def;
192
+ *produced_graph_def.mutable_library()->add_function() =
193
+ FunctionDefHelper::Create(
194
+ "my_func", {}, {}, {},
195
+ {{{"x"}, "UsesDefault", {}, {{"a", 17}}},
196
+ {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
197
+ {});
198
+ OpList function_op_list;
199
+ *function_op_list.add_op() =
200
+ produced_graph_def.library().function(0).signature();
201
+ OpListOpRegistry function_registry(&function_op_list);
202
+ TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
203
+ .Finalize(produced_graph_def.add_node()));
204
+
205
+ std::set<std::pair<string, string>> op_attr_removed;
206
+ TF_ASSERT_OK(
207
+ RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
208
+ producer_registry, &op_attr_removed));
209
+
210
+ GraphDef expected_graph_def;
211
+ *expected_graph_def.mutable_library()->add_function() =
212
+ FunctionDefHelper::Create(
213
+ "my_func", {}, {}, {},
214
+ {{{"x"}, "UsesDefault", {}, {}},
215
+ {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
216
+ {});
217
+ TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
218
+ .Finalize(expected_graph_def.add_node()));
219
+ TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
220
+ EXPECT_EQ(expected_graph_def.library().DebugString(),
221
+ produced_graph_def.library().DebugString());
222
+
223
+ std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
224
+ EXPECT_EQ(expected_removed, op_attr_removed);
225
+ }
226
+
227
+ TEST(StrippedOpListForGraphTest, FlatTest) {
228
+ // Make four ops
229
+ OpList op_list;
230
+ for (const string& op : {"A", "B", "C", "D"}) {
231
+ OpDef* op_def = op_list.add_op();
232
+ op_def->set_name(op);
233
+ op_def->set_summary("summary");
234
+ op_def->set_description("description");
235
+ op_def->set_is_commutative(op == "B");
236
+ }
237
+
238
+ // Make a graph which uses two ops once and twice, respectively.
239
+ // The result should be independent of the ordering.
240
+ const string graph_ops[4][3] = {
241
+ {"C", "B", "B"}, {"B", "C", "B"}, {"B", "B", "C"}, {"C", "C", "B"}};
242
+ for (const bool use_function : {false, true}) {
243
+ for (int order = 0; order < 4; order++) {
244
+ GraphDef graph_def;
245
+ if (use_function) {
246
+ FunctionDef* function_def = graph_def.mutable_library()->add_function();
247
+ function_def->mutable_signature()->set_name("F");
248
+ for (const string& op : graph_ops[order]) {
249
+ function_def->add_node_def()->set_op(op);
250
+ }
251
+ graph_def.add_node()->set_op("F");
252
+ } else {
253
+ for (const string& op : graph_ops[order]) {
254
+ string name = strings::StrCat("name", graph_def.node_size());
255
+ NodeDef* node = graph_def.add_node();
256
+ node->set_name(name);
257
+ node->set_op(op);
258
+ }
259
+ }
260
+
261
+ // Strip the op list
262
+ OpList stripped_op_list;
263
+ TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
264
+ &stripped_op_list));
265
+
266
+ // We should have exactly two ops: B and C.
267
+ ASSERT_EQ(stripped_op_list.op_size(), 2);
268
+ for (int i = 0; i < 2; i++) {
269
+ const OpDef& op = stripped_op_list.op(i);
270
+ EXPECT_EQ(op.name(), i ? "C" : "B");
271
+ EXPECT_EQ(op.summary(), "");
272
+ EXPECT_EQ(op.description(), "");
273
+ EXPECT_EQ(op.is_commutative(), !i);
274
+ }
275
+
276
+ // Should get the same result using OpsUsedByGraph().
277
+ std::set<string> used_ops;
278
+ OpsUsedByGraph(graph_def, &used_ops);
279
+ ASSERT_EQ(std::set<string>({"B", "C"}), used_ops);
280
+ }
281
+ }
282
+ }
283
+
284
+ TEST(StrippedOpListForGraphTest, NestedFunctionTest) {
285
+ // Make a primitive op A.
286
+ OpList op_list;
287
+ op_list.add_op()->set_name("A");
288
+
289
+ for (const bool recursive : {false, true}) {
290
+ // Call A from function B, and B from function C.
291
+ GraphDef graph_def;
292
+ FunctionDef* b = graph_def.mutable_library()->add_function();
293
+ FunctionDef* c = graph_def.mutable_library()->add_function();
294
+ b->mutable_signature()->set_name("B");
295
+ c->mutable_signature()->set_name("C");
296
+ b->add_node_def()->set_op("A");
297
+ c->add_node_def()->set_op("B");
298
+ if (recursive) {
299
+ b->add_node_def()->set_op("B");
300
+ c->add_node_def()->set_op("C");
301
+ }
302
+
303
+ // Use C in the graph.
304
+ graph_def.add_node()->set_op("C");
305
+
306
+ // The stripped op list should contain just A.
307
+ OpList stripped_op_list;
308
+ TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
309
+ &stripped_op_list));
310
+ ASSERT_EQ(stripped_op_list.op_size(), 1);
311
+ ASSERT_EQ(stripped_op_list.op(0).name(), "A");
312
+
313
+ // Should get the same result using OpsUsedByGraph().
314
+ std::set<string> used_ops;
315
+ OpsUsedByGraph(graph_def, &used_ops);
316
+ ASSERT_EQ(std::set<string>({"A"}), used_ops);
317
+ }
318
+ }
319
+
320
+ } // namespace
321
+ } // namespace tensorflow
graph_transfer_info.proto ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package tensorflow;
4
+ option cc_enable_arenas = true;
5
+ option java_outer_classname = "GraphTransferInfoProto";
6
+ option java_multiple_files = true;
7
+ option java_package = "org.tensorflow.framework";
8
+
9
+ import "tensorflow/core/framework/types.proto";
10
+
11
+ // Protocol buffer representing a handle to a tensorflow resource. Handles are
12
+ // not valid across executions, but can be serialized back and forth from within
13
+ // a single run.
14
+ message GraphTransferInfo {
15
+ enum Destination {
16
+ NOP = 0;
17
+ HEXAGON = 1;
18
+ }
19
+ message NodeInput {
20
+ int32 node_id = 1;
21
+ int32 output_port = 2;
22
+ }
23
+ message NodeInfo {
24
+ string name = 1;
25
+ int32 node_id = 2;
26
+ string type_name = 3;
27
+ int32 soc_op_id = 4;
28
+ int32 padding_id = 5;
29
+ int32 input_count = 6;
30
+ int32 output_count = 7;
31
+ };
32
+ message ConstNodeInfo {
33
+ string name = 1;
34
+ int32 node_id = 2;
35
+ repeated int64 shape = 3;
36
+ bytes data = 4;
37
+ DataType dtype = 5;
38
+ };
39
+ message NodeInputInfo {
40
+ int32 node_id = 1;
41
+ repeated NodeInput node_input = 2;
42
+ };
43
+ message NodeOutputInfo {
44
+ int32 node_id = 1;
45
+ repeated int32 max_byte_size = 2;
46
+ };
47
+ message GraphInputNodeInfo {
48
+ string name = 1;
49
+ repeated int64 shape = 2;
50
+ DataType dtype = 3;
51
+ }
52
+
53
+ message GraphOutputNodeInfo {
54
+ string name = 1;
55
+ repeated int64 shape = 2;
56
+ DataType dtype = 3;
57
+ }
58
+
59
+ repeated NodeInfo node_info = 1;
60
+ repeated ConstNodeInfo const_node_info = 2;
61
+ repeated NodeInputInfo node_input_info = 3;
62
+ repeated NodeOutputInfo node_output_info = 4;
63
+ // Input Node parameters of transferred graph
64
+ repeated GraphInputNodeInfo graph_input_node_info = 5;
65
+ repeated GraphOutputNodeInfo graph_output_node_info = 6;
66
+ // Destination of graph transfer
67
+ Destination destination = 7;
68
+ };
iterator.proto ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package tensorflow;
4
+ option cc_enable_arenas = true;
5
+ option java_outer_classname = "IteratorProtos";
6
+ option java_multiple_files = true;
7
+ option java_package = "org.tensorflow.util";
8
+
9
+ // Protocol buffer representing the metadata for an iterator's state stored
10
+ // as a Variant tensor.
11
+ message IteratorStateMetadata {
12
+ // A user-specified version string.
13
+ string version = 1;
14
+
15
+ // Keys for tensors in the VariantTensorDataProto.
16
+ repeated string keys = 2;
17
+ }
kernel_def.proto ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package tensorflow;
4
+ option cc_enable_arenas = true;
5
+ option java_outer_classname = "KernelDefProtos";
6
+ option java_multiple_files = true;
7
+ option java_package = "org.tensorflow.framework";
8
+
9
+ import "tensorflow/core/framework/attr_value.proto";
10
+
11
+ message KernelDef {
12
+ // Must match the name of an Op.
13
+ string op = 1;
14
+
15
+ // Type of device this kernel runs on.
16
+ string device_type = 2;
17
+
18
+ message AttrConstraint {
19
+ // Name of an attr from the Op.
20
+ string name = 1;
21
+
22
+ // A list of values that this kernel supports for this attr.
23
+ // Like OpDef.AttrDef.allowed_values, except for kernels instead of Ops.
24
+ AttrValue allowed_values = 2;
25
+ }
26
+ repeated AttrConstraint constraint = 3;
27
+
28
+ // Names of the Op's input_/output_args that reside in host memory
29
+ // instead of device memory.
30
+ repeated string host_memory_arg = 4;
31
+
32
+ // This allows experimental kernels to be registered for an op that
33
+ // won't be used unless the user specifies a "_kernel" attr with
34
+ // value matching this.
35
+ string label = 5;
36
+ }
kernel_def_builder.cc ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/kernel_def_builder.h"
17
+ #include "tensorflow/core/framework/attr_value.pb.h"
18
+ #include "tensorflow/core/framework/kernel_def.pb_text.h"
19
+ #include "tensorflow/core/framework/kernel_def.pb.h"
20
+
21
+ namespace tensorflow {
22
+
23
+ KernelDefBuilder::KernelDefBuilder(const char* op_name) {
24
+ kernel_def_ = new KernelDef;
25
+ kernel_def_->set_op(op_name);
26
+ }
27
+
28
+ KernelDefBuilder::~KernelDefBuilder() {
29
+ DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
30
+ }
31
+
32
+ KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
33
+ kernel_def_->set_device_type(device_type);
34
+ return *this;
35
+ }
36
+
37
+ KernelDefBuilder& KernelDefBuilder::TypeConstraint(
38
+ const char* attr_name, gtl::ArraySlice<DataType> allowed) {
39
+ auto* constraint = kernel_def_->add_constraint();
40
+ constraint->set_name(attr_name);
41
+ auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
42
+ for (DataType dt : allowed) {
43
+ allowed_values->add_type(dt);
44
+ }
45
+ return *this;
46
+ }
47
+
48
+ KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name,
49
+ DataType allowed) {
50
+ auto* constraint = kernel_def_->add_constraint();
51
+ constraint->set_name(attr_name);
52
+ constraint->mutable_allowed_values()->mutable_list()->add_type(allowed);
53
+ return *this;
54
+ }
55
+
56
+ KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) {
57
+ kernel_def_->add_host_memory_arg(arg_name);
58
+ return *this;
59
+ }
60
+
61
+ KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
62
+ CHECK_EQ(kernel_def_->label(), "")
63
+ << "Trying to set a kernel's label a second time: '" << label
64
+ << "' in: " << ProtoShortDebugString(*kernel_def_);
65
+ kernel_def_->set_label(label);
66
+ return *this;
67
+ }
68
+
69
+ const KernelDef* KernelDefBuilder::Build() {
70
+ KernelDef* r = kernel_def_;
71
+ kernel_def_ = nullptr;
72
+ return r;
73
+ }
74
+
75
+ } // namespace tensorflow
kernel_def_builder.h ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
17
+ #define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
18
+
19
+ #include "tensorflow/core/framework/types.h"
20
+ #include "tensorflow/core/lib/gtl/array_slice.h"
21
+ #include "tensorflow/core/platform/macros.h"
22
+ #include "tensorflow/core/platform/types.h"
23
+
24
+ namespace tensorflow {
25
+
26
+ // Forward declare proto so that kernels don't need to depend on it
27
+ class KernelDef;
28
+
29
+ // Builder class passed to the REGISTER_KERNEL_BUILDER() macro.
30
+ class KernelDefBuilder {
31
+ public:
32
+ // Starts with just the name field set.
33
+ // Caller MUST call Build() and take ownership of the result.
34
+ explicit KernelDefBuilder(const char* op_name);
35
+ ~KernelDefBuilder();
36
+
37
+ // Required: specify the type of device this kernel supports.
38
+ // Returns *this.
39
+ KernelDefBuilder& Device(const char* device_type);
40
+ // KernelDefBuilder& Device(DeviceType device_type);
41
+
42
+ // Specify that this kernel supports a limited set of values for a
43
+ // particular type or list(type) attr (a further restriction than
44
+ // what the Op allows).
45
+ // Returns *this.
46
+ KernelDefBuilder& TypeConstraint(const char* attr_name,
47
+ gtl::ArraySlice<DataType> allowed);
48
+
49
+ // Like TypeConstraint but supports just a single type.
50
+ KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
51
+
52
+ // Like TypeConstraint, but (a) gets the type from a template parameter
53
+ // and (b) only supports a constraint to a single type.
54
+ template <class T>
55
+ KernelDefBuilder& TypeConstraint(const char* attr_name);
56
+ // TODO(josh11b): Support other types of attr constraints as needed.
57
+
58
+ // Specify that this kernel requires/provides an input/output arg
59
+ // in host memory (instead of the default, device memory).
60
+ // Returns *this.
61
+ KernelDefBuilder& HostMemory(const char* arg_name);
62
+
63
+ // Specify that this kernel requires a particular value for the
64
+ // "_kernel" attr. May only be specified once. Returns *this.
65
+ KernelDefBuilder& Label(const char* label);
66
+
67
+ // Returns a pointer to a KernelDef with fields set based on the
68
+ // above calls to this instance.
69
+ // Caller takes ownership of the result.
70
+ const KernelDef* Build();
71
+
72
+ private:
73
+ KernelDef* kernel_def_;
74
+
75
+ TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
76
+ };
77
+
78
+ // IMPLEMENTATION
79
+
80
+ template <class T>
81
+ KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name) {
82
+ return this->TypeConstraint(attr_name, DataTypeToEnum<T>::v());
83
+ }
84
+
85
+ } // namespace tensorflow
86
+
87
+ #endif // TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
kernel_def_builder_test.cc ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/kernel_def_builder.h"
17
+
18
+ #include "tensorflow/core/framework/kernel_def.pb.h"
19
+ #include "tensorflow/core/platform/protobuf.h"
20
+ #include "tensorflow/core/platform/test.h"
21
+
22
+ namespace tensorflow {
23
+ namespace {
24
+
25
+ TEST(KernelDefBuilderTest, Basic) {
26
+ const KernelDef* def = KernelDefBuilder("A").Device(DEVICE_CPU).Build();
27
+ KernelDef expected;
28
+ protobuf::TextFormat::ParseFromString("op: 'A' device_type: 'CPU'",
29
+ &expected);
30
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
31
+ delete def;
32
+ }
33
+
34
+ TEST(KernelDefBuilderTest, TypeConstraint) {
35
+ const KernelDef* def = KernelDefBuilder("B")
36
+ .Device(DEVICE_GPU)
37
+ .TypeConstraint<float>("T")
38
+ .Build();
39
+ KernelDef expected;
40
+ protobuf::TextFormat::ParseFromString(R"proto(
41
+ op: 'B' device_type: 'GPU'
42
+ constraint { name: 'T' allowed_values { list { type: DT_FLOAT } } } )proto",
43
+ &expected);
44
+
45
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
46
+ delete def;
47
+
48
+ def = KernelDefBuilder("C")
49
+ .Device(DEVICE_GPU)
50
+ .TypeConstraint<int32>("U")
51
+ .TypeConstraint<bool>("V")
52
+ .Build();
53
+
54
+ protobuf::TextFormat::ParseFromString(R"proto(
55
+ op: 'C' device_type: 'GPU'
56
+ constraint { name: 'U' allowed_values { list { type: DT_INT32 } } }
57
+ constraint { name: 'V' allowed_values { list { type: DT_BOOL } } } )proto",
58
+ &expected);
59
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
60
+ delete def;
61
+
62
+ def = KernelDefBuilder("D")
63
+ .Device(DEVICE_CPU)
64
+ .TypeConstraint("W", {DT_DOUBLE, DT_STRING})
65
+ .Build();
66
+ protobuf::TextFormat::ParseFromString(R"proto(
67
+ op: 'D' device_type: 'CPU'
68
+ constraint { name: 'W'
69
+ allowed_values { list { type: [DT_DOUBLE, DT_STRING] } } } )proto",
70
+ &expected);
71
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
72
+ delete def;
73
+ }
74
+
75
+ TEST(KernelDefBuilderTest, HostMemory) {
76
+ const KernelDef* def = KernelDefBuilder("E")
77
+ .Device(DEVICE_GPU)
78
+ .HostMemory("in")
79
+ .HostMemory("out")
80
+ .Build();
81
+ KernelDef expected;
82
+ protobuf::TextFormat::ParseFromString(
83
+ "op: 'E' device_type: 'GPU' "
84
+ "host_memory_arg: ['in', 'out']",
85
+ &expected);
86
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
87
+ delete def;
88
+ }
89
+
90
+ } // namespace
91
+ } // namespace tensorflow
load_library.cc ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include <memory>
17
+ #include <unordered_set>
18
+
19
+ #include "tensorflow/core/framework/op.h"
20
+ #include "tensorflow/core/framework/op_kernel.h"
21
+ #include "tensorflow/core/lib/core/errors.h"
22
+ #include "tensorflow/core/platform/env.h"
23
+ #include "tensorflow/core/platform/mem.h"
24
+
25
+ namespace tensorflow {
26
+
27
+ namespace {
28
+
29
+ struct Library {
30
+ void* handle = nullptr;
31
+ OpList op_list;
32
+ };
33
+
34
+ } // namespace
35
+
36
+ // Load a dynamic library.
37
+ // On success, returns the handle to library in result, copies the serialized
38
+ // OpList of OpDefs registered in the library to *buf and the length to *len,
39
+ // and returns OK from the function. Otherwise return nullptr in result
40
+ // and an error status from the function, leaving buf and len untouched.
41
+ //
42
+ // If `library_filename` has already been loaded, we return a cached handle
43
+ // and OpList. Ops and kernels are registered as globals when a library is
44
+ // loaded for the first time. Without caching, every subsequent load would not
45
+ // perform initialization again, so the OpList would be empty.
46
+ Status LoadLibrary(const char* library_filename, void** result,
47
+ const void** buf, size_t* len) {
48
+ static mutex mu(LINKER_INITIALIZED);
49
+ static std::unordered_map<string, Library> loaded_libs;
50
+ Env* env = Env::Default();
51
+ Library library;
52
+ std::unordered_set<string> seen_op_names;
53
+ {
54
+ mutex_lock lock(mu);
55
+ if (loaded_libs.find(library_filename) != loaded_libs.end()) {
56
+ library = loaded_libs[library_filename];
57
+ } else {
58
+ Status s = OpRegistry::Global()->ProcessRegistrations();
59
+ if (!s.ok()) {
60
+ return s;
61
+ }
62
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(
63
+ [&library, &seen_op_names](const Status& s,
64
+ const OpDef& opdef) -> Status {
65
+ if (errors::IsAlreadyExists(s)) {
66
+ if (seen_op_names.find(opdef.name()) == seen_op_names.end()) {
67
+ // Over writing a registration of an op not in this custom op
68
+ // library. Treat this as not an error.
69
+ return Status::OK();
70
+ }
71
+ }
72
+ if (s.ok()) {
73
+ *library.op_list.add_op() = opdef;
74
+ seen_op_names.insert(opdef.name());
75
+ }
76
+ return s;
77
+ }));
78
+ OpRegistry::Global()->DeferRegistrations();
79
+ s = env->LoadLibrary(library_filename, &library.handle);
80
+ if (s.ok()) {
81
+ s = OpRegistry::Global()->ProcessRegistrations();
82
+ }
83
+ if (!s.ok()) {
84
+ OpRegistry::Global()->ClearDeferredRegistrations();
85
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr));
86
+ return s;
87
+ }
88
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr));
89
+
90
+ loaded_libs[library_filename] = library;
91
+ }
92
+ }
93
+ string str;
94
+ library.op_list.SerializeToString(&str);
95
+ char* str_buf = reinterpret_cast<char*>(port::Malloc(str.length()));
96
+ memcpy(str_buf, str.data(), str.length());
97
+ *buf = str_buf;
98
+ *len = str.length();
99
+
100
+ *result = library.handle;
101
+ return Status::OK();
102
+ }
103
+
104
+ } // namespace tensorflow
log_memory.cc ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/log_memory.h"
17
+
18
+ #include "tensorflow/core/framework/log_memory.pb_text.h"
19
+ #include "tensorflow/core/framework/log_memory.pb.h"
20
+
21
+ namespace tensorflow {
22
+
23
+ const string LogMemory::kLogMemoryLabel = "__LOG_MEMORY__";
24
+
25
+ bool LogMemory::IsEnabled() { return VLOG_IS_ON(1); }
26
+
27
+ namespace {
28
+
29
+ // Write the proto entry to LOG(INFO).
30
+ template <typename T>
31
+ void OutputToLog(const T& proto) {
32
+ string type_name = proto.GetTypeName();
33
+ const size_t index = type_name.find_last_of(".");
34
+ if (index != string::npos) type_name = type_name.substr(index + 1);
35
+ LOG(INFO) << LogMemory::kLogMemoryLabel << " " << type_name << " { "
36
+ << ProtoShortDebugString(proto) << " }";
37
+ }
38
+
39
+ } // namespace
40
+
41
+ void LogMemory::RecordStep(const int64 step_id, const string& handle) {
42
+ MemoryLogStep step;
43
+ step.set_step_id(step_id);
44
+ step.set_handle(handle);
45
+ OutputToLog(step);
46
+ }
47
+
48
+ void LogMemory::RecordTensorAllocation(const string& kernel_name,
49
+ const int64 step_id,
50
+ const Tensor& tensor) {
51
+ MemoryLogTensorAllocation allocation;
52
+ allocation.set_step_id(step_id);
53
+ allocation.set_kernel_name(kernel_name);
54
+ tensor.FillDescription(allocation.mutable_tensor());
55
+ OutputToLog(allocation);
56
+ }
57
+
58
+ void LogMemory::RecordTensorDeallocation(const int64 allocation_id,
59
+ const string& allocator_name) {
60
+ MemoryLogTensorDeallocation deallocation;
61
+ deallocation.set_allocation_id(allocation_id);
62
+ deallocation.set_allocator_name(allocator_name);
63
+ OutputToLog(deallocation);
64
+ }
65
+
66
+ void LogMemory::RecordTensorOutput(const string& kernel_name,
67
+ const int64 step_id, const int index,
68
+ const Tensor& tensor) {
69
+ MemoryLogTensorOutput output;
70
+ output.set_step_id(step_id);
71
+ output.set_kernel_name(kernel_name);
72
+ output.set_index(index);
73
+ tensor.FillDescription(output.mutable_tensor());
74
+ OutputToLog(output);
75
+ }
76
+
77
+ void LogMemory::RecordRawAllocation(const string& operation,
78
+ const int64 step_id, size_t num_bytes,
79
+ void* ptr, Allocator* allocator) {
80
+ MemoryLogRawAllocation allocation;
81
+ allocation.set_step_id(step_id);
82
+ allocation.set_operation(operation);
83
+ allocation.set_num_bytes(static_cast<int64>(num_bytes));
84
+ allocation.set_ptr(reinterpret_cast<uintptr_t>(ptr));
85
+ allocation.set_allocation_id(allocator->AllocationId(ptr));
86
+ allocation.set_allocator_name(allocator->Name());
87
+ OutputToLog(allocation);
88
+ }
89
+
90
+ void LogMemory::RecordRawDeallocation(const string& operation,
91
+ const int64 step_id, void* ptr,
92
+ Allocator* allocator, bool deferred) {
93
+ MemoryLogRawDeallocation deallocation;
94
+ deallocation.set_step_id(step_id);
95
+ deallocation.set_operation(operation);
96
+ deallocation.set_allocation_id(allocator->AllocationId(ptr));
97
+ deallocation.set_allocator_name(allocator->Name());
98
+ deallocation.set_deferred(deferred);
99
+ OutputToLog(deallocation);
100
+ }
101
+
102
+ } // namespace tensorflow
log_memory.h ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
17
+ #define TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
18
+
19
+ #include "tensorflow/core/framework/tensor.h"
20
+ #include "tensorflow/core/platform/protobuf.h"
21
+
22
+ namespace tensorflow {
23
+
24
+ // LogMemory contains methods for recording memory allocations and
25
+ // frees, associating each allocation with a step identified by a
26
+ // process-wide id. For now, logging is enabled whenever VLOG_IS_ON(1)
27
+ // for the log_memory module.
28
+ //
29
+ // Limitations: We don't log memory allocations by Eigen on the CPU
30
+ // since that would require major changes to plumb through to the
31
+ // Eigen::{DefaultDevice,ThreadPoolDevice} allocate and deallocate
32
+ // methods. We do log Eigen allocations on GPU since the plumbing was
33
+ // already in place.
34
+ class LogMemory {
35
+ public:
36
+ // Allocations sometimes happen outside any computation step, and
37
+ // SpecialStepIds lists the ids used for those steps.
38
+ enum SpecialStepIds {
39
+ // Used when performing a just-in-time constant folding optimization.
40
+ CONSTANT_FOLDING_STEP_ID = -1,
41
+ // Used when constructing an Op kernel before executing a step.
42
+ OP_KERNEL_CONSTRUCTION_STEP_ID = -2,
43
+ // Used when allocating a tensor buffer from external code, e.g.,
44
+ // the C API.
45
+ EXTERNAL_TENSOR_ALLOCATION_STEP_ID = -3,
46
+ // Used when allocating a buffer for network transfer.
47
+ NETWORK_BUFFER_STEP_ID = -4,
48
+ // Used when allocating a buffer to fill a Proto from the GPU.
49
+ PROTO_BUFFER_STEP_ID = -5,
50
+ // Used when allocating a Tensor where the caller has not indicated
51
+ // the step.
52
+ UNKNOWN_STEP_ID = -6,
53
+ };
54
+
55
+ static const string kLogMemoryLabel;
56
+
57
+ // Test to see if memory logging is enabled. For now, logging is
58
+ // enabled whenever VLOG_IS_ON(1) for the log_memory module.
59
+ static bool IsEnabled();
60
+
61
+ // Log the beginning of a step.
62
+ static void RecordStep(int64 step_id, const string& handle);
63
+
64
+ // Log a tensor buffer allocation. The name indicates which kernel
65
+ // made the allocation. If the allocation is made through an
66
+ // OpKernelContext the step_id indicates which step is executing,
67
+ // otherwise step_id is one of the SpecialStepIds defined in
68
+ // op_kernel.h, e.g. Op Kernel construction or an optimization pass
69
+ // such as constant folding.
70
+ static void RecordTensorAllocation(const string& kernel_name, int64 step_id,
71
+ const Tensor& tensor);
72
+
73
+ // Log a tensor buffer deallocation. The deallocation is triggered
74
+ // when the buffer's refcount falls to zero, and the tracking
75
+ // mechanism does not associate it with a particular step or
76
+ // kernel. The allocation_id/allocator_name should match a
77
+ // corresponding tensor previously passed in to
78
+ // RecordTensorAllocation.
79
+ static void RecordTensorDeallocation(int64 allocation_id,
80
+ const string& allocator_name);
81
+
82
+ // Log the use of a tensor as an output from a kernel.
83
+ static void RecordTensorOutput(const string& kernel_name, int64 step_id,
84
+ int index, const Tensor& tensor);
85
+
86
+ // Log a "raw" allocation, which is just a buffer sized in
87
+ // bytes. The Eigen allocator, and memory copies, record their
88
+ // allocations this way, since they do not allocate TensorFlow
89
+ // tensors. The operation is set to the OpKernel name if this is
90
+ // called from within an Op execution, otherwise it indicates an
91
+ // operation such as memcpy. The step_id if >=0 indicates which step
92
+ // is executing, otherwise step_id is one of the SpecialStepIds
93
+ // defined in op_kernel.h, e.g. Op Kernel construction or an
94
+ // optimization pass such as constant folding.
95
+ static void RecordRawAllocation(const string& operation, int64 step_id,
96
+ size_t num_bytes, void* ptr,
97
+ Allocator* allocator);
98
+
99
+ // Log a "raw" deallocation of a buffer. When deferred is true, the
100
+ // buffer won't be used again, but a GPU kernel may still be
101
+ // enqueued using the buffer. A deferred deallocation should always
102
+ // be followed by a matching non-deferred deallocation when the
103
+ // buffer is actually returned and can be reused.
104
+ static void RecordRawDeallocation(const string& operation, int64 step_id,
105
+ void* ptr, Allocator* allocator,
106
+ bool deferred);
107
+ };
108
+
109
+ } // namespace tensorflow
110
+
111
+ #endif // TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
log_memory.proto ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package tensorflow;
4
+ option cc_enable_arenas = true;
5
+ option java_outer_classname = "LogMemoryProtos";
6
+ option java_multiple_files = true;
7
+ option java_package = "org.tensorflow.framework";
8
+
9
+ import "tensorflow/core/framework/tensor_description.proto";
10
+
11
+ message MemoryLogStep {
12
+ // Process-unique step id.
13
+ int64 step_id = 1;
14
+
15
+ // Handle describing the feeds and fetches of the step.
16
+ string handle = 2;
17
+ };
18
+
19
+ message MemoryLogTensorAllocation {
20
+ // Process-unique step id.
21
+ int64 step_id = 1;
22
+
23
+ // Name of the kernel making the allocation as set in GraphDef,
24
+ // e.g., "affine2/weights/Assign".
25
+ string kernel_name = 2;
26
+
27
+ // Allocated tensor details.
28
+ TensorDescription tensor = 3;
29
+ };
30
+
31
+ message MemoryLogTensorDeallocation {
32
+ // Id of the tensor buffer being deallocated, used to match to a
33
+ // corresponding allocation.
34
+ int64 allocation_id = 1;
35
+
36
+ // Name of the allocator used.
37
+ string allocator_name = 2;
38
+ };
39
+
40
+ message MemoryLogTensorOutput {
41
+ // Process-unique step id.
42
+ int64 step_id = 1;
43
+
44
+ // Name of the kernel producing an output as set in GraphDef, e.g.,
45
+ // "affine2/weights/Assign".
46
+ string kernel_name = 2;
47
+
48
+ // Index of the output being set.
49
+ int32 index = 3;
50
+
51
+ // Output tensor details.
52
+ TensorDescription tensor = 4;
53
+ }
54
+
55
+ message MemoryLogRawAllocation {
56
+ // Process-unique step id.
57
+ int64 step_id = 1;
58
+
59
+ // Name of the operation making the allocation.
60
+ string operation = 2;
61
+
62
+ // Number of bytes in the allocation.
63
+ int64 num_bytes = 3;
64
+
65
+ // Address of the allocation.
66
+ uint64 ptr = 4;
67
+
68
+ // Id of the tensor buffer being allocated, used to match to a
69
+ // corresponding deallocation.
70
+ int64 allocation_id = 5;
71
+
72
+ // Name of the allocator used.
73
+ string allocator_name = 6;
74
+ };
75
+
76
+ message MemoryLogRawDeallocation {
77
+ // Process-unique step id.
78
+ int64 step_id = 1;
79
+
80
+ // Name of the operation making the deallocation.
81
+ string operation = 2;
82
+
83
+ // Id of the tensor buffer being deallocated, used to match to a
84
+ // corresponding allocation.
85
+ int64 allocation_id = 3;
86
+
87
+ // Name of the allocator used.
88
+ string allocator_name = 4;
89
+
90
+ // True if the deallocation is queued and will be performed later,
91
+ // e.g. for GPU lazy freeing of buffers.
92
+ bool deferred = 5;
93
+ };
lookup_interface.cc ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/lookup_interface.h"
17
+
18
+ #include "tensorflow/core/framework/tensor_shape.h"
19
+ #include "tensorflow/core/lib/core/errors.h"
20
+
21
+ namespace tensorflow {
22
+ namespace lookup {
23
+
24
+ Status LookupInterface::CheckKeyShape(const TensorShape& shape) {
25
+ if (!TensorShapeUtils::EndsWith(shape, key_shape())) {
26
+ return errors::InvalidArgument("Input key shape ", shape.DebugString(),
27
+ " must end with the table's key shape ",
28
+ key_shape().DebugString());
29
+ }
30
+ return Status::OK();
31
+ }
32
+
33
+ Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys,
34
+ const Tensor& values) {
35
+ if (keys.dtype() != key_dtype()) {
36
+ return errors::InvalidArgument("Key must be type ", key_dtype(),
37
+ " but got ", keys.dtype());
38
+ }
39
+ if (values.dtype() != value_dtype()) {
40
+ return errors::InvalidArgument("Value must be type ", value_dtype(),
41
+ " but got ", values.dtype());
42
+ }
43
+ return Status::OK();
44
+ }
45
+
46
+ Status LookupInterface::CheckKeyAndValueTensorsHelper(const Tensor& keys,
47
+ const Tensor& values) {
48
+ TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values));
49
+ TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape()));
50
+
51
+ TensorShape expected_value_shape = keys.shape();
52
+ for (int i = 0; i < key_shape().dims(); ++i) {
53
+ expected_value_shape.RemoveDim(expected_value_shape.dims() - 1);
54
+ }
55
+ expected_value_shape.AppendShape(value_shape());
56
+ if (values.shape() != expected_value_shape) {
57
+ return errors::InvalidArgument(
58
+ "Expected shape ", expected_value_shape.DebugString(),
59
+ " for value, got ", values.shape().DebugString());
60
+ }
61
+ return Status::OK();
62
+ }
63
+
64
+ Status LookupInterface::CheckKeyAndValueTensorsForInsert(const Tensor& keys,
65
+ const Tensor& values) {
66
+ return CheckKeyAndValueTensorsHelper(keys, values);
67
+ }
68
+
69
+ Status LookupInterface::CheckKeyAndValueTensorsForImport(const Tensor& keys,
70
+ const Tensor& values) {
71
+ return CheckKeyAndValueTensorsHelper(keys, values);
72
+ }
73
+
74
+ Status LookupInterface::CheckFindArguments(const Tensor& key,
75
+ const Tensor& default_value) {
76
+ TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
77
+ TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
78
+ if (default_value.shape() != value_shape()) {
79
+ return errors::InvalidArgument(
80
+ "Expected shape ", value_shape().DebugString(),
81
+ " for default value, got ", default_value.shape().DebugString());
82
+ }
83
+ return Status::OK();
84
+ }
85
+
86
+ } // namespace lookup
87
+ } // namespace tensorflow
lookup_interface.h ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
17
+ #define TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
18
+
19
+ #include "tensorflow/core/framework/resource_mgr.h"
20
+ #include "tensorflow/core/framework/tensor.h"
21
+ #include "tensorflow/core/lib/core/status.h"
22
+
23
+ namespace tensorflow {
24
+
25
+ class OpKernelContext;
26
+
27
+ namespace lookup {
28
+
29
+ // Forward declaration so we can define GetInitializableLookupTable() in
30
+ // LookupInterface.
31
+ class InitializableLookupTable;
32
+
33
+ // Lookup interface for batch lookups used by table lookup ops.
34
+ class LookupInterface : public ResourceBase {
35
+ public:
36
+ // Performs batch lookups, for every element in the key tensor, Find returns
37
+ // the corresponding value into the values tensor.
38
+ // If an element is not present in the table, the given default value is used.
39
+
40
+ // For tables that require initialization, Find is available once the table
41
+ // is marked as initialized.
42
+
43
+ // Returns the following statuses:
44
+ // - OK: when the find finishes successfully.
45
+ // - FailedPrecondition: if the table is not initialized.
46
+ // - InvalidArgument: if any of the preconditions on the lookup key or value
47
+ // fails.
48
+ // - In addition, other implementations may provide another non-OK status
49
+ // specific to their failure modes.
50
+ virtual Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values,
51
+ const Tensor& default_value) = 0;
52
+
53
+ // Inserts elements into the table. Each element of the key tensor is
54
+ // associated with the corresponding element in the value tensor.
55
+ // This method is only implemented in mutable tables that can be updated over
56
+ // the execution of the graph. It returns Status::NotImplemented for read-only
57
+ // tables that are initialized once before they can be looked up.
58
+
59
+ // Returns the following statuses:
60
+ // - OK: when the insert finishes successfully.
61
+ // - InvalidArgument: if any of the preconditions on the lookup key or value
62
+ // fails.
63
+ // - Unimplemented: if the table does not support insertions.
64
+ virtual Status Insert(OpKernelContext* ctx, const Tensor& keys,
65
+ const Tensor& values) = 0;
66
+
67
+ // Returns the number of elements in the table.
68
+ virtual size_t size() const = 0;
69
+
70
+ // Exports the values of the table to two tensors named keys and values.
71
+ // Note that the shape of the tensors is completely up to the implementation
72
+ // of the table and can be different than the tensors used for the Insert
73
+ // function above.
74
+ virtual Status ExportValues(OpKernelContext* ctx) = 0;
75
+
76
+ // Imports previously exported keys and values.
77
+ // As mentioned above, the shape of the keys and values tensors are determined
78
+ // by the ExportValues function above and can be different than for the
79
+ // Insert function.
80
+ virtual Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
81
+ const Tensor& values) = 0;
82
+
83
+ // Returns the data type of the key.
84
+ virtual DataType key_dtype() const = 0;
85
+
86
+ // Returns the data type of the value.
87
+ virtual DataType value_dtype() const = 0;
88
+
89
+ // Returns the shape of a key in the table.
90
+ virtual TensorShape key_shape() const = 0;
91
+
92
+ // Returns the shape of a value in the table.
93
+ virtual TensorShape value_shape() const = 0;
94
+
95
+ // Check format of the key and value tensors for the Insert function.
96
+ // Returns OK if all the following requirements are satisfied, otherwise it
97
+ // returns InvalidArgument:
98
+ // - DataType of the tensor keys equals to the table key_dtype
99
+ // - DataType of the tensor values equals to the table value_dtype
100
+ // - the values tensor has the required shape given keys and the tables's
101
+ // value shape.
102
+ virtual Status CheckKeyAndValueTensorsForInsert(const Tensor& keys,
103
+ const Tensor& values);
104
+
105
+ // Similar to the function above but instead checks eligibility for the Import
106
+ // function.
107
+ virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys,
108
+ const Tensor& values);
109
+
110
+ // Check the arguments of a find operation. Returns OK if all the following
111
+ // requirements are satisfied, otherwise it returns InvalidArgument:
112
+ // - DataType of the tensor keys equals to the table key_dtype
113
+ // - DataType of the tensor default_value equals to the table value_dtype
114
+ // - the default_value tensor shape matches the table's value shape.
115
+ Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
116
+
117
+ string DebugString() override {
118
+ return strings::StrCat("A lookup table of size: ", size());
119
+ }
120
+
121
+ // Returns an InitializableLookupTable, a subclass of LookupInterface, if the
122
+ // current object is an InitializableLookupTable. Otherwise, returns nullptr.
123
+ virtual InitializableLookupTable* GetInitializableLookupTable() {
124
+ return nullptr;
125
+ }
126
+
127
+ protected:
128
+ virtual ~LookupInterface() = default;
129
+
130
+ // Makes sure that the key and value tensor DataType's match the table
131
+ // key_dtype and value_dtype.
132
+ Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values);
133
+
134
+ // Makes sure that the provided shape is consistent with the table keys shape.
135
+ Status CheckKeyShape(const TensorShape& shape);
136
+
137
+ private:
138
+ Status CheckKeyAndValueTensorsHelper(const Tensor& keys,
139
+ const Tensor& values);
140
+ };
141
+
142
+ } // namespace lookup
143
+ } // namespace tensorflow
144
+
145
+ #endif // TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
memory_types.cc ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "tensorflow/core/framework/memory_types.h"
17
+
18
+ #include <utility>
19
+
20
+ #include "tensorflow/core/framework/kernel_def.pb.h"
21
+ #include "tensorflow/core/framework/node_def.pb.h"
22
+ #include "tensorflow/core/framework/node_def_util.h"
23
+ #include "tensorflow/core/framework/op_kernel.h"
24
+ #include "tensorflow/core/framework/types.h"
25
+ #include "tensorflow/core/lib/core/errors.h"
26
+ #include "tensorflow/core/platform/types.h"
27
+
28
+ namespace tensorflow {
29
+
30
+ namespace {
31
+ // Returns the largest endpoint of anything in the name_map.
32
+ int GetTotal(const NameRangeMap& name_map) {
33
+ int total = 0;
34
+ for (const auto& item : name_map) {
35
+ total = std::max(total, item.second.second);
36
+ }
37
+ return total;
38
+ }
39
+
40
+ // Fills memory_types for either input or output, setting everything
41
+ // to DEVICE_MEMORY except those args in host_memory_args. Removes
42
+ // elements of host_memory_args that were used.
43
+ void MemoryTypesHelper(const NameRangeMap& name_map,
44
+ std::vector<string>* host_memory_args,
45
+ MemoryTypeVector* memory_types) {
46
+ // Update args that have been marked as in "HOST_MEMORY".
47
+ size_t keep = 0;
48
+ for (size_t i = 0; i < host_memory_args->size(); ++i) {
49
+ auto iter = name_map.find((*host_memory_args)[i]);
50
+ if (iter != name_map.end()) {
51
+ for (int j = iter->second.first; j < iter->second.second; ++j) {
52
+ (*memory_types)[j] = HOST_MEMORY;
53
+ }
54
+ } else {
55
+ // (*host_memory_args)[i] not found, save it for the next pass.
56
+ if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i];
57
+ ++keep;
58
+ }
59
+ }
60
+ host_memory_args->resize(keep);
61
+ }
62
+
63
+ MemoryType MTypeFromDType(const DataType dtype) {
64
+ return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY
65
+ : DEVICE_MEMORY;
66
+ }
67
+
68
+ } // namespace
69
+
70
+ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
71
+ const DeviceType& device_type, const NodeDef& ndef,
72
+ MemoryTypeVector* inp_mtypes,
73
+ MemoryTypeVector* out_mtypes) {
74
+ // Look up the Op registered for this op name.
75
+ const OpDef* op_def;
76
+ TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(ndef.op(), &op_def));
77
+
78
+ // Look up the Kernel registered for this node def.
79
+ const KernelDef* kdef = nullptr;
80
+ Status status =
81
+ FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */);
82
+
83
+ DataTypeVector inp_dtypes;
84
+ DataTypeVector out_dtypes;
85
+ TF_RETURN_IF_ERROR(
86
+ InOutTypesForNode(ndef, *op_def, &inp_dtypes, &out_dtypes));
87
+
88
+ inp_mtypes->clear();
89
+ out_mtypes->clear();
90
+
91
+ // For functions (which have no KernelDef) and their gradients, we can only
92
+ // best-effort derive the memory type from the data type. For now, we assume
93
+ // int32 is always on host memory and other types are always on device memory.
94
+ // TODO(zhifengc,phawkins): We should do type inference over function bodies
95
+ // to derive the correct input/output memory types. We should also split
96
+ // host-memory and non host-memory arguments into separate type lists.
97
+ if (!status.ok() || ndef.op() == "SymbolicGradient") {
98
+ for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t));
99
+ for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t));
100
+ return Status::OK();
101
+ }
102
+
103
+ // Gets the input/output names and their corresponding endpoint ranges.
104
+ NameRangeMap inp_names;
105
+ NameRangeMap out_names;
106
+ TF_RETURN_IF_ERROR(NameRangesForNode(ndef, *op_def, &inp_names, &out_names));
107
+
108
+ // Now that we know the size, fill with the default 'DEVICE_MEMORY'.
109
+ inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY);
110
+ out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY);
111
+
112
+ // Fills in host memory types based on the kernel def.
113
+ const auto& from_proto = kdef->host_memory_arg();
114
+ std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
115
+ MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes);
116
+ MemoryTypesHelper(out_names, &host_memory_args, out_mtypes);
117
+ if (!host_memory_args.empty()) {
118
+ return errors::InvalidArgument(
119
+ "HostMemory args '", str_util::Join(host_memory_args, "', '"),
120
+ "' not found in OpDef: ", SummarizeOpDef(*op_def));
121
+ }
122
+ CHECK_LE(inp_mtypes->size(), inp_dtypes.size());
123
+ CHECK_LE(out_mtypes->size(), out_dtypes.size());
124
+
125
+ // Mark e.g. all resource and string types as host memory.
126
+ for (int i = 0; i < inp_mtypes->size(); ++i) {
127
+ if (DataTypeAlwaysOnHost(inp_dtypes[i])) {
128
+ (*inp_mtypes)[i] = HOST_MEMORY;
129
+ }
130
+ }
131
+ for (int i = 0; i < out_mtypes->size(); ++i) {
132
+ if (DataTypeAlwaysOnHost(out_dtypes[i])) {
133
+ (*out_mtypes)[i] = HOST_MEMORY;
134
+ }
135
+ }
136
+
137
+ std::vector<int32> hostmem_attr;
138
+ if (GetNodeAttr(ndef, "_input_hostmem", &hostmem_attr).ok()) {
139
+ for (int32 i : hostmem_attr) {
140
+ if (0 <= i && i < inp_mtypes->size()) {
141
+ (*inp_mtypes)[i] = HOST_MEMORY;
142
+ }
143
+ }
144
+ }
145
+ if (GetNodeAttr(ndef, "_output_hostmem", &hostmem_attr).ok()) {
146
+ for (int32 i : hostmem_attr) {
147
+ if (0 <= i && i < out_mtypes->size()) {
148
+ (*out_mtypes)[i] = HOST_MEMORY;
149
+ }
150
+ }
151
+ }
152
+
153
+ return Status::OK();
154
+ }
155
+
156
+ } // namespace tensorflow