diff --git a/allocation_description.proto b/allocation_description.proto new file mode 100644 index 0000000000000000000000000000000000000000..bb1037c2dfe46a28865485cadbd5d5dcf1974d84 --- /dev/null +++ b/allocation_description.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AllocationDescriptionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +message AllocationDescription { + // Total number of bytes requested + int64 requested_bytes = 1; + + // Total number of bytes allocated if known + int64 allocated_bytes = 2; + + // Name of the allocator used + string allocator_name = 3; + + // Identifier of the allocated buffer if known + int64 allocation_id = 4; + + // Set if this tensor only has one remaining reference + bool has_single_reference = 5; + + // Address of the allocation. + uint64 ptr = 6; +}; diff --git a/allocator.cc b/allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5dadf76daf8d351e509c4ae538b31abf00d9566 --- /dev/null +++ b/allocator.cc @@ -0,0 +1,130 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/allocator.h" + +#include "tensorflow/core/framework/allocator_registry.h" +#include "tensorflow/core/framework/log_memory.h" +#include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +void AllocatorStats::Clear() { + this->num_allocs = 0; + this->bytes_in_use = 0; + this->max_bytes_in_use = 0; + this->max_alloc_size = 0; + this->bytes_limit = 0; +} + +string AllocatorStats::DebugString() const { + return strings::Printf( + "Limit: %20lld\n" + "InUse: %20lld\n" + "MaxInUse: %20lld\n" + "NumAllocs: %20lld\n" + "MaxAllocSize: %20lld\n", + this->bytes_limit, this->bytes_in_use, this->max_bytes_in_use, + this->num_allocs, this->max_alloc_size); +} + +constexpr size_t Allocator::kAllocatorAlignment; + +Allocator::~Allocator() {} + +void RunResourceCtor(ResourceHandle* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) new (p) ResourceHandle(); +} + +void RunResourceDtor(ResourceHandle* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); +} + +// If true, cpu allocator collects more stats. +static bool cpu_allocator_collect_stats = false; +// If true, cpu allocator collects full stats. +static bool cpu_allocator_collect_full_stats = false; + +void EnableCPUAllocatorStats(bool enable) { + cpu_allocator_collect_stats = enable; +} +void EnableCPUAllocatorFullStats(bool enable) { + cpu_allocator_collect_full_stats = enable; +} + +class CPUAllocator : public Allocator { + public: + CPUAllocator() {} + + ~CPUAllocator() override {} + + string Name() override { return "cpu"; } + + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + void* p = port::AlignedMalloc(num_bytes, alignment); + if (cpu_allocator_collect_stats) { + const std::size_t alloc_size = port::MallocExtension_GetAllocatedSize(p); + mutex_lock l(mu_); + ++stats_.num_allocs; + stats_.bytes_in_use += alloc_size; + stats_.max_bytes_in_use = + std::max(stats_.max_bytes_in_use, stats_.bytes_in_use); + stats_.max_alloc_size = + std::max(stats_.max_alloc_size, alloc_size); + } + return p; + } + + void DeallocateRaw(void* ptr) override { + if (cpu_allocator_collect_stats) { + const std::size_t alloc_size = + port::MallocExtension_GetAllocatedSize(ptr); + mutex_lock l(mu_); + stats_.bytes_in_use -= alloc_size; + } + port::AlignedFree(ptr); + } + + void GetStats(AllocatorStats* stats) override { + mutex_lock l(mu_); + *stats = stats_; + } + + size_t AllocatedSizeSlow(void* ptr) override { + return port::MallocExtension_GetAllocatedSize(ptr); + } + + private: + mutex mu_; + AllocatorStats stats_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator); +}; + +Allocator* cpu_allocator() { + static Allocator* cpu_alloc = AllocatorRegistry::Global()->GetAllocator(); + if (cpu_allocator_collect_full_stats && !cpu_alloc->TracksAllocationSizes()) { + cpu_alloc = new TrackingAllocator(cpu_alloc, true); + } + return cpu_alloc; +} + +REGISTER_MEM_ALLOCATOR("DefaultCPUAllocator", 100, CPUAllocator); + +} // namespace tensorflow diff --git a/allocator.h b/allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..5e048a028d2dd9bf60722c3bab6a81330a16d2d8 --- /dev/null +++ b/allocator.h @@ -0,0 +1,394 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_ALLOCATOR_H_ +#define TENSORFLOW_FRAMEWORK_ALLOCATOR_H_ + +#include + +#include + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Attributes for a single allocation call. Different calls to the same +// allocator could potentially have different allocation attributes. +struct AllocationAttributes { + // If the first attempt to allocate the memory fails, the allocation + // should return immediately without retrying. + // An example use case is optional scratch spaces where a failure + // has only performance impact. + bool no_retry_on_failure = false; + // If a Tensor is allocated without the following set to true, then + // it is logged as an unknown allocation. During execution Tensors + // should be allocated through the OpKernelContext which records + // which Op is performing the allocation, and sets this flag to + // true. + bool allocation_will_be_logged = false; +}; + +// Runtime statistics collected by an allocator. +struct AllocatorStats { + int64 num_allocs; // Number of allocations. + int64 bytes_in_use; // Number of bytes in use. + int64 max_bytes_in_use; // The maximum bytes in use. + int64 max_alloc_size; // The max single allocation seen. + + // The upper limit what the allocator can allocate, if such a limit + // is known. Certain allocator may return 0 to indicate the limit is + // unknown. + int64 bytes_limit; + + AllocatorStats() { Clear(); } + + void Clear(); + string DebugString() const; +}; + +// Allocator is an abstract interface for allocating and deallocating +// device memory. +class Allocator { + public: +#ifdef EIGEN_VECTORIZE_AVX512 + // Align to 64 byte boundary. + static constexpr size_t kAllocatorAlignment = 64; +#else + // Align to 32 byte boundary. + static constexpr size_t kAllocatorAlignment = 32; +#endif + + virtual ~Allocator(); + + // Return a string identifying this allocator + virtual string Name() = 0; + + // Return an uninitialized block of memory that is "num_bytes" bytes + // in size. The returned pointer is guaranteed to be aligned to a + // multiple of "alignment" bytes. + // REQUIRES: "alignment" is a power of 2. + virtual void* AllocateRaw(size_t alignment, size_t num_bytes) = 0; + + // Return an uninitialized block of memory that is "num_bytes" bytes + // in size with specified allocation attributes. The returned pointer is + // guaranteed to be aligned to a multiple of "alignment" bytes. + // REQUIRES: "alignment" is a power of 2. + virtual void* AllocateRaw(size_t alignment, size_t num_bytes, + const AllocationAttributes& allocation_attr) { + // The default behavior is to use the implementation without any allocation + // attributes. + return AllocateRaw(alignment, num_bytes); + } + + // Deallocate a block of memory pointer to by "ptr" + // REQUIRES: "ptr" was previously returned by a call to AllocateRaw + virtual void DeallocateRaw(void* ptr) = 0; + + // Convenience functions to do typed allocation. C++ constructors + // and destructors are invoked for complex types if necessary, + // depending on the concrete Allocator implementation. May return + // NULL if the tensor has too many elements to represent in a single + // allocation. + template + T* Allocate(size_t num_elements) { + return Allocate(num_elements, AllocationAttributes()); + } + + template + T* Allocate(size_t num_elements, + const AllocationAttributes& allocation_attr) { + // TODO(jeff): Do we need to allow clients to pass in alignment + // requirements? + + if (num_elements > (std::numeric_limits::max() / sizeof(T))) { + return NULL; + } + + void* p = AllocateRaw(kAllocatorAlignment, sizeof(T) * num_elements, + allocation_attr); + T* typed_p = reinterpret_cast(p); + if (typed_p) RunCtor(typed_p, num_elements); + return typed_p; + } + + template + void Deallocate(T* ptr, size_t num_elements) { + if (ptr) { + RunDtor(ptr, num_elements); + DeallocateRaw(ptr); + } + } + + // Returns true if this allocator tracks the sizes of allocations. + // RequestedSize and AllocatedSize must be overridden if + // TracksAllocationSizes is overridden to return true. + virtual bool TracksAllocationSizes() { return false; } + + // Returns true if this allocator requires tensors with 0 elements + // to allocate buffers. This is false for most allocators, but may + // be used by special-case allocators that want to track tensor + // usage. + virtual bool ShouldAllocateEmptyTensors() { return false; } + + // Returns the user-requested size of the data allocated at + // 'ptr'. Note that the actual buffer allocated might be larger + // than requested, but this function returns the size requested by + // the user. + // + // REQUIRES: TracksAllocationSizes() is true. + // + // REQUIRES: 'ptr!=nullptr' and points to a buffer previously + // allocated by this allocator. + virtual size_t RequestedSize(void* ptr) { + CHECK(false) << "allocator doesn't track sizes"; + return size_t(0); + } + + // Returns the allocated size of the buffer at 'ptr' if known, + // otherwise returns RequestedSize(ptr). AllocatedSize(ptr) is + // guaranteed to be >= RequestedSize(ptr). + // + // REQUIRES: TracksAllocationSizes() is true. + // + // REQUIRES: 'ptr!=nullptr' and points to a buffer previously + // allocated by this allocator. + virtual size_t AllocatedSize(void* ptr) { return RequestedSize(ptr); } + + // Returns either 0 or an identifier assigned to the buffer at 'ptr' + // when the buffer was returned by AllocateRaw. If non-zero, the + // identifier differs from every other ID assigned by this + // allocator. + // + // REQUIRES: TracksAllocationSizes() is true. + // + // REQUIRES: 'ptr!=nullptr' and points to a buffer previously + // allocated by this allocator. + virtual int64 AllocationId(void* ptr) { return 0; } + + // Returns the allocated size of the buffer at 'ptr' if known, + // otherwise returns 0. This method can be called when + // TracksAllocationSizes() is false, but can be extremely slow. + // + // REQUIRES: 'ptr!=nullptr' and points to a buffer previously + // allocated by this allocator. + virtual size_t AllocatedSizeSlow(void* ptr) { + if (TracksAllocationSizes()) { + return AllocatedSize(ptr); + } + return 0; + } + + // Fills in 'stats' with statistics collected by this allocator. + virtual void GetStats(AllocatorStats* stats) { stats->Clear(); } + + private: + // No constructors or destructors are run for simple types + template + void RunCtor(T* p, size_t n) { + static_assert(is_simple_type::value, "T is not a simple type."); + } + + template + void RunDtor(T* p, size_t n) {} + + // custom constructors and destructors that can be overridden for + // non-standard allocators + + // Runs string's default constructor for p[0], p[1], ..., p[n-1]. + virtual void RunStringCtor(string* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) new (p) string(); + } + + // Runs string's default destructor for p[0], p[1], ..., p[n-1]. + virtual void RunStringDtor(string* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) p->~string(); + } + + virtual void RunResourceCtor(ResourceHandle* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) new (p) ResourceHandle(); + } + + // Runs string's default destructor for p[0], p[1], ..., p[n-1]. + virtual void RunResourceDtor(ResourceHandle* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); + } + + virtual void RunVariantCtor(Variant* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) new (p) Variant(); + } + + virtual void RunVariantDtor(Variant* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) p->~Variant(); + } + + // TODO(jeff): Maybe provide some interface to give info about + // current allocation state (total number of bytes available for + // allocation, number of bytes free on device, etc.) +}; + +// Allocator-specific constructors and destructors are used for +// strings +template <> +inline void Allocator::RunCtor(string* p, size_t n) { + RunStringCtor(p, n); +} + +template <> +inline void Allocator::RunDtor(string* p, size_t n) { + RunStringDtor(p, n); +} + +template <> +inline void Allocator::RunCtor(ResourceHandle* p, size_t n) { + RunResourceCtor(p, n); +} + +template <> +inline void Allocator::RunDtor(ResourceHandle* p, size_t n) { + RunResourceDtor(p, n); +} + +template <> +inline void Allocator::RunCtor(Variant* p, size_t n) { + RunVariantCtor(p, n); +} + +template <> +inline void Allocator::RunDtor(Variant* p, size_t n) { + RunVariantDtor(p, n); +} + +// An implementation of Allocator that delegates all calls to another Allocator. +// +// Useful to clients who want to override part of the functionality of another +// allocator. +class AllocatorWrapper : public Allocator { + public: + explicit AllocatorWrapper(Allocator* wrapped) : wrapped_(wrapped) {} + + ~AllocatorWrapper() override {} + + // Returns the wrapped allocator to which all calls are delegated. + Allocator* wrapped() const { return wrapped_; } + + string Name() override { return wrapped_->Name(); } + + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + return wrapped_->AllocateRaw(alignment, num_bytes); + } + + void* AllocateRaw(size_t alignment, size_t num_bytes, + const AllocationAttributes& allocation_attr) override { + return wrapped_->AllocateRaw(alignment, num_bytes, allocation_attr); + } + + void DeallocateRaw(void* ptr) override { wrapped_->DeallocateRaw(ptr); } + + bool TracksAllocationSizes() override { + return wrapped_->TracksAllocationSizes(); + } + + bool ShouldAllocateEmptyTensors() override { + return wrapped_->TracksAllocationSizes(); + } + + size_t RequestedSize(void* ptr) override { + return wrapped_->RequestedSize(ptr); + } + + size_t AllocatedSize(void* ptr) override { + return wrapped_->AllocatedSize(ptr); + } + + int64 AllocationId(void* ptr) override { return wrapped_->AllocationId(ptr); } + + size_t AllocatedSizeSlow(void* ptr) override { + return wrapped_->AllocatedSizeSlow(ptr); + } + + private: + Allocator* const wrapped_; +}; + +// A tensorflow Op may need access to different kinds of memory that +// are not simply a function of the device to which the Op has been +// assigned. For example, an Op executing on a GPU may still need +// to allocate CPU RAM for some purpose. Internal to the tensorflow +// runtime we may choose to allocate CPU ram from special regions +// that have been prepared for higher performance in some use +// contexts, e.g. doing DMA with particular devices. For these +// reasons, the Device interface does not expose just one memory +// Allocator, but instead provides an accessor that takes a +// specification of the desired memory attributes in order to select +// an Allocator. +// +// Example use: +// // Allocator for ordinary device memory: +// Allocator* a = allocator(AllocatorAttributes()); +// ... +// // Allocator for CPU RAM, regardless of where Op is executing: +// AllocatorAttributes attr; +// attr.set_on_host(true); +// Allocator* a = allocator(attr); +struct AllocatorAttributes { + void set_on_host(bool v) { value |= (static_cast(v)); } + bool on_host() const { return value & 0x1; } + void set_nic_compatible(bool v) { value |= (static_cast(v) << 1); } + bool nic_compatible() const { return value & (0x1 << 1); } + void set_gpu_compatible(bool v) { value |= (static_cast(v) << 2); } + bool gpu_compatible() const { return value & (0x1 << 2); } + void Merge(AllocatorAttributes other) { value |= other.value; } + // Returns true if the fields set in *this is a subset of or equal to + // those set in other. + bool IsEqualOrLessRestrictiveThan(const AllocatorAttributes& other) const { + return (value | other.value) == other.value; + } + + // NOTE: The upper 8 bits of the value are reserved for + // device-specific uses. Implementors of a device can interpret these + // upper 8 bits in device-specific ways, and ops implemented for those + // devices are responsible for setting those 8 bits appropriately. + uint32 value = 0; +}; + +// Returns a trivial implementation of Allocator which uses the system +// default malloc. The returned allocator is a process singleton. +Allocator* cpu_allocator(); + +// If 'enable' is true, the process-wide cpu allocator collects +// AllocatorStats. By default, it's disabled. +void EnableCPUAllocatorStats(bool enable); + +// If 'enable' is true, the process-wide cpu allocator collects full +// statistics. By default, it's disabled. +void EnableCPUAllocatorFullStats(bool enable); + +// Abstract interface of an object that does the underlying suballoc/free of +// memory for a higher-level allocator. +class SubAllocator { + public: + virtual ~SubAllocator() {} + virtual void* Alloc(size_t alignment, size_t num_bytes) = 0; + virtual void Free(void* ptr, size_t num_bytes) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_ALLOCATOR_H_ diff --git a/allocator_registry.cc b/allocator_registry.cc new file mode 100644 index 0000000000000000000000000000000000000000..486be39ae31c487560efebc79e0fbab90ddca9db --- /dev/null +++ b/allocator_registry.cc @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/allocator_registry.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// static +AllocatorRegistry* AllocatorRegistry::Global() { + static AllocatorRegistry* global_allocator_registry = new AllocatorRegistry; + return global_allocator_registry; +} + +Allocator* AllocatorRegistry::GetRegisteredAllocator(const string& name, + int priority) { + for (auto entry : allocators_) { + if (!name.compare(entry.name) && priority == entry.priority) { + return entry.allocator; + } + } + return nullptr; +} + +void AllocatorRegistry::Register(const string& name, int priority, + Allocator* allocator) { + CHECK(!name.empty()) << "Need a valid name for Allocator"; + CHECK_GE(priority, 0) << "Priority needs to be non-negative"; + + Allocator* existing = GetRegisteredAllocator(name, priority); + if (existing != nullptr) { + // A duplicate is if the registration name and priority match + // but the Allocator::Name()'s don't match. + CHECK_EQ(existing->Name(), allocator->Name()) + << "Allocator with name: [" << name << "], type [" << existing->Name() + << "], priority: [" << priority + << "] already registered. Choose a different name to register " + << "an allocator of type " << allocator->Name(); + + // The allocator names match, so we can just return. + // It should be safe to delete the allocator since the caller + // gives up ownership of it. + delete allocator; + return; + } + + AllocatorRegistryEntry tmp_entry; + tmp_entry.name = name; + tmp_entry.priority = priority; + tmp_entry.allocator = allocator; + + allocators_.push_back(tmp_entry); + int high_pri = -1; + for (auto entry : allocators_) { + if (high_pri < entry.priority) { + m_curr_allocator_ = entry.allocator; + high_pri = entry.priority; + } + } +} + +Allocator* AllocatorRegistry::GetAllocator() { + return CHECK_NOTNULL(m_curr_allocator_); +} + +} // namespace tensorflow diff --git a/allocator_registry.h b/allocator_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..b26e79ac3b01c7b3fe5099f626c4e35862586282 --- /dev/null +++ b/allocator_registry.h @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Classes to maintain a static registry of memory allocators +#ifndef TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_ +#define TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_ + +#include +#include + +#include "tensorflow/core/framework/allocator.h" + +namespace tensorflow { + +// A global AllocatorRegistry is used to hold allocators for CPU backends +class AllocatorRegistry { + public: + // Add an allocator to the registry. Caller releases ownership of + // 'allocator'. + void Register(const string& name, int priority, Allocator* allocator); + + // Return allocator with highest priority + // If multiple allocators have the same high priority, return one of them + Allocator* GetAllocator(); + + // Returns the global registry of allocators. + static AllocatorRegistry* Global(); + + private: + typedef struct { + string name; + int priority; + Allocator* allocator; // not owned + } AllocatorRegistryEntry; + + // Returns the Allocator registered for 'name' and 'priority', + // or 'nullptr' if not found. + Allocator* GetRegisteredAllocator(const string& name, int priority); + + std::vector allocators_; + Allocator* m_curr_allocator_; // not owned +}; + +namespace allocator_registration { + +class AllocatorRegistration { + public: + AllocatorRegistration(const string& name, int priority, + Allocator* allocator) { + AllocatorRegistry::Global()->Register(name, priority, allocator); + } +}; + +} // namespace allocator_registration + +#define REGISTER_MEM_ALLOCATOR(name, priority, allocator) \ + REGISTER_MEM_ALLOCATOR_UNIQ_HELPER(__COUNTER__, name, priority, allocator) + +#define REGISTER_MEM_ALLOCATOR_UNIQ_HELPER(ctr, name, priority, allocator) \ + REGISTER_MEM_ALLOCATOR_UNIQ(ctr, name, priority, allocator) + +#define REGISTER_MEM_ALLOCATOR_UNIQ(ctr, name, priority, allocator) \ + static allocator_registration::AllocatorRegistration \ + register_allocator_##ctr(name, priority, new allocator) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_ diff --git a/allocator_test.cc b/allocator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..032aeec161bb6978cb942747d3e0f8cff12f8853 --- /dev/null +++ b/allocator_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/allocator.h" + +#include +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +static void CheckStats(Allocator* a, int64 num_allocs, int64 bytes_in_use, + int64 max_bytes_in_use, int64 max_alloc_size) { + AllocatorStats stats; + a->GetStats(&stats); + LOG(INFO) << "Alloc stats: \n" << stats.DebugString(); +#if defined(PLATFORM_GOOGLE) && defined(NDEBUG) + // NOTE: allocator stats expectation depends on the system malloc, + // and can vary as that changes. + static const int64 kSlop = 5 * 1024; + EXPECT_GT(stats.bytes_in_use, bytes_in_use - kSlop); + EXPECT_LT(stats.bytes_in_use, bytes_in_use + kSlop); + EXPECT_GT(stats.max_bytes_in_use, max_bytes_in_use - kSlop); + EXPECT_LT(stats.max_bytes_in_use, max_bytes_in_use + kSlop); + EXPECT_EQ(stats.num_allocs, num_allocs); + EXPECT_EQ(stats.max_alloc_size, max_alloc_size); +#endif +} + +TEST(AllocatorAttributesTest, AllCombos) { + for (bool on_host : {false, true}) { + for (bool nic_compatible : {false, true}) { + for (bool gpu_compatible : {false, true}) { + AllocatorAttributes aa; + aa.set_on_host(on_host); + aa.set_nic_compatible(nic_compatible); + aa.set_gpu_compatible(gpu_compatible); + EXPECT_EQ(on_host, aa.on_host()); + EXPECT_EQ(nic_compatible, aa.nic_compatible()); + EXPECT_EQ(gpu_compatible, aa.gpu_compatible()); + } + } + } +} + +TEST(AllocatorAttributesTest, IsEqualOrLessRestrictiveThan) { + AllocatorAttributes a, b; + EXPECT_TRUE(a.IsEqualOrLessRestrictiveThan(b)); + EXPECT_TRUE(a.IsEqualOrLessRestrictiveThan(a)); + EXPECT_TRUE(b.IsEqualOrLessRestrictiveThan(b)); + + b.set_gpu_compatible(true); + // The set of flags in b is not a subset of those in a. + EXPECT_TRUE(a.IsEqualOrLessRestrictiveThan(b)); + EXPECT_FALSE(b.IsEqualOrLessRestrictiveThan(a)); + EXPECT_TRUE(a.IsEqualOrLessRestrictiveThan(a)); + EXPECT_TRUE(b.IsEqualOrLessRestrictiveThan(b)); + + a.set_nic_compatible(true); + // Neither a nor b is a subset of the other. + EXPECT_FALSE(a.IsEqualOrLessRestrictiveThan(b)); + EXPECT_FALSE(b.IsEqualOrLessRestrictiveThan(a)); + + a.set_gpu_compatible(true); + // The set of flags in b is a proper subset of those in a. + EXPECT_TRUE(b.IsEqualOrLessRestrictiveThan(a)); + EXPECT_FALSE(a.IsEqualOrLessRestrictiveThan(b)); +} + +TEST(CPUAllocatorTest, Simple) { + EnableCPUAllocatorStats(true); + Allocator* a = cpu_allocator(); + std::vector ptrs; + for (int s = 1; s < 1024; s++) { + void* raw = a->AllocateRaw(1, s); + ptrs.push_back(raw); + } + std::sort(ptrs.begin(), ptrs.end()); + CheckStats(a, 1023, 552640, 552640, 1024); + for (size_t i = 0; i < ptrs.size(); i++) { + if (i > 0) { + CHECK_NE(ptrs[i], ptrs[i - 1]); // No dups + } + a->DeallocateRaw(ptrs[i]); + } + CheckStats(a, 1023, 0, 552640, 1024); + float* t1 = a->Allocate(1024); + double* t2 = a->Allocate(1048576); + CheckStats(a, 1025, 1048576 * sizeof(double) + 1024 * sizeof(float), + 1048576 * sizeof(double) + 1024 * sizeof(float), + 1048576 * sizeof(double)); + + a->Deallocate(t1, 1024); + a->Deallocate(t2, 1048576); + + CheckStats(a, 1025, 0, 1048576 * sizeof(double) + 1024 * sizeof(float), + 1048576 * sizeof(double)); + EnableCPUAllocatorStats(false); +} + +// Define a struct that we will use to observe behavior in the unit tests +struct TestStruct { + int x; // not used just want to make sure sizeof(TestStruct) > 1 +}; + +TEST(CPUAllocatorTest, CheckStructSize) { CHECK_GT(sizeof(TestStruct), 1); } + +TEST(CPUAllocatorTest, AllocateOverflowMaxSizeT) { + Allocator* a = cpu_allocator(); + + // The maximum size_t value will definitely overflow. + size_t count_to_allocate = std::numeric_limits::max(); + TestStruct* const test_pointer = a->Allocate(count_to_allocate); + + CHECK_EQ(test_pointer, reinterpret_cast(NULL)); +} + +TEST(CPUAllocatorTest, AllocateOverflowSmallest) { + Allocator* a = cpu_allocator(); + + // count_to_allocate is the smallest count that will cause overflow. + const size_t count_to_allocate = + (std::numeric_limits::max() / sizeof(TestStruct)) + 1; + TestStruct* const test_pointer = a->Allocate(count_to_allocate); + + CHECK_EQ(test_pointer, reinterpret_cast(NULL)); +} + +TEST(CPUAllocatorTest, Sizes) { + Allocator* a = cpu_allocator(); + + EXPECT_EQ(false, a->TracksAllocationSizes()); +} + +namespace { + +AllocatorAttributes DeviceAllocatorAttribute() { + AllocatorAttributes attr; + attr.value |= (0x1 << 24); + return attr; +} + +bool HasDeviceAllocatorAttribute(const AllocatorAttributes& attr) { + return attr.value & (0x1 << 24); +} + +} // namespace + +TEST(CustomAllocatorAttributes, TestSetterAndGetter) { + AllocatorAttributes attr = DeviceAllocatorAttribute(); + EXPECT_TRUE(HasDeviceAllocatorAttribute(attr)); + EXPECT_FALSE(HasDeviceAllocatorAttribute(AllocatorAttributes())); +} + +static void BM_Allocation(int iters, int arg) { + Allocator* a = cpu_allocator(); + // Exercise a few different allocation sizes + std::vector sizes = {256, 4096, 16384, 524288, 512, 1048576}; + int size_index = 0; + + if (arg) EnableCPUAllocatorStats(true); + while (--iters > 0) { + int bytes = sizes[size_index++ % sizes.size()]; + void* p = a->AllocateRaw(1, bytes); + a->DeallocateRaw(p); + } + if (arg) EnableCPUAllocatorStats(false); +} +BENCHMARK(BM_Allocation)->Arg(0)->Arg(1); + +} // namespace tensorflow diff --git a/api_def.proto b/api_def.proto new file mode 100644 index 0000000000000000000000000000000000000000..98c38efc0e9a8e2ca7caf6b666c8930eb7a32733 --- /dev/null +++ b/api_def.proto @@ -0,0 +1,120 @@ +// Defines the text format for including per-op API definition and +// overrides for client language op code generators. + +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ApiDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +import "tensorflow/core/framework/attr_value.proto"; + +// Used to specify and override the default API & behavior in the +// generated code for client languages, from what you would get from +// the OpDef alone. There will be a set of ApiDefs that are common +// to all client languages, and another set per client language. +// The per-client-language ApiDefs will inherit values from the +// common ApiDefs which it can either replace or modify. +// +// We separate the API definition from the OpDef so we can evolve the +// API while remaining backwards compatible when interpretting old +// graphs. Overrides go in an "api_def.pbtxt" file with a text-format +// ApiDefs message. +// +// WARNING: Be *very* careful changing the API for any existing op -- +// you can change the semantics of existing code. These changes may +// need to wait until a major release of TensorFlow to avoid breaking +// our compatibility promises. +message ApiDef { + // Name of the op (in the OpDef) to specify the API for. + string graph_op_name = 1; + + enum Visibility { + // Normally this is "VISIBLE" unless you are inheriting a + // different value from another ApiDef. + DEFAULT_VISIBILITY = 0; + // Publicly visible in the API. + VISIBLE = 1; + // Do not include this op in the generated API. If visibility is + // set to 'SKIP', other fields are ignored for this op. + SKIP = 2; + // Hide this op by putting it into an internal namespace (or whatever + // is appropriate in the target language). + HIDDEN = 3; + } + Visibility visibility = 2; + + // If you specify any endpoint, this will replace all of the + // inherited endpoints. The first endpoint should be the + // "canonical" endpoint, and should not be deprecated (unless all + // endpoints are deprecated). + message Endpoint { + // Name should be either like "CamelCaseName" or + // "Package.CamelCaseName". Client-language-specific ApiDefs may + // use a snake_case convention instead of CamelCase. + string name = 1; + + // First GraphDef version at which the op is disallowed. + int32 deprecation_version = 2; + } + repeated Endpoint endpoint = 3; + + message Arg { + string name = 1; + + // Change the name used to access this arg in the API from what + // is used in the GraphDef. Note that these names in `backticks` + // will also be replaced in the summary & description fields. + string rename_to = 2; + + // Note: this will replace any inherited arg doc. There is no + // current way of modifying arg descriptions (other than replacing + // them entirely) as can be done with op descriptions. + string description = 3; + } + repeated Arg in_arg = 4; + repeated Arg out_arg = 5; + // List of original in_arg names to specify new argument order. + // Length of arg_order should be either empty to keep current order + // or match size of in_arg. + repeated string arg_order = 11; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message Attr { + string name = 1; + + // Change the name used to access this attr in the API from what + // is used in the GraphDef. Note that these names in `backticks` + // will also be replaced in the summary & description fields. + string rename_to = 2; + + // Specify a new default value to use for this attr. This default + // will be used when creating new graphs, as opposed to the + // default in the OpDef, which will be used when interpreting old + // GraphDefs. + AttrValue default_value = 3; + + // Note: this will replace any inherited attr doc, there is no current + // way of modifying attr descriptions as can be done with op descriptions. + string description = 4; + } + repeated Attr attr = 6; + + // One-line human-readable description of what the Op does. + string summary = 7; + + // Additional, longer human-readable description of what the Op does. + string description = 8; + + // Modify an existing/inherited description by adding text to the beginning + // or end. + string description_prefix = 9; + string description_suffix = 10; +} + +message ApiDefs { + repeated ApiDef op = 1; +} diff --git a/attr_value.proto b/attr_value.proto new file mode 100644 index 0000000000000000000000000000000000000000..62f0a9050fb82cda066ca77608c0b278e171c57c --- /dev/null +++ b/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/attr_value_util.cc b/attr_value_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..5aba091840ed0cd32bf85980c7d12dc74e7f3fd9 --- /dev/null +++ b/attr_value_util.cc @@ -0,0 +1,551 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/attr_value_util.h" + +#include +#include + +#include "tensorflow/core/framework/attr_value.pb_text.h" +#include "tensorflow/core/framework/tensor.pb_text.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb_text.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace { + +string SummarizeString(const string& str) { + return strings::StrCat("\"", str_util::CEscape(str), "\""); +} + +string SummarizeTensor(const TensorProto& tensor_proto) { + Tensor t; + if (!t.FromProto(tensor_proto)) { + return strings::StrCat( + ""); + } + return t.DebugString(); +} + +string SummarizeFunc(const NameAttrList& func) { + std::vector entries; + for (auto p : func.attr()) { + entries.push_back( + strings::StrCat(p.first, "=", SummarizeAttrValue(p.second))); + } + std::sort(entries.begin(), entries.end()); + return strings::StrCat(func.name(), "[", str_util::Join(entries, ", "), "]"); +} + +} // namespace + +string SummarizeAttrValue(const AttrValue& attr_value) { + switch (attr_value.value_case()) { + case AttrValue::kS: + return SummarizeString(attr_value.s()); + case AttrValue::kI: + return strings::StrCat(attr_value.i()); + case AttrValue::kF: + return strings::StrCat(attr_value.f()); + case AttrValue::kB: + return attr_value.b() ? "true" : "false"; + case AttrValue::kType: + return EnumName_DataType(attr_value.type()); + case AttrValue::kShape: + return PartialTensorShape::DebugString(attr_value.shape()); + case AttrValue::kTensor: + return SummarizeTensor(attr_value.tensor()); + case AttrValue::kList: { + string ret = "["; + if (attr_value.list().s_size() > 0) { + for (int i = 0; i < attr_value.list().s_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, SummarizeString(attr_value.list().s(i))); + } + } else if (attr_value.list().i_size() > 0) { + for (int i = 0; i < attr_value.list().i_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, attr_value.list().i(i)); + } + } else if (attr_value.list().f_size() > 0) { + for (int i = 0; i < attr_value.list().f_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, attr_value.list().f(i)); + } + } else if (attr_value.list().b_size() > 0) { + for (int i = 0; i < attr_value.list().b_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false"); + } + } else if (attr_value.list().type_size() > 0) { + for (int i = 0; i < attr_value.list().type_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, + EnumName_DataType(attr_value.list().type(i))); + } + } else if (attr_value.list().shape_size() > 0) { + for (int i = 0; i < attr_value.list().shape_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend( + &ret, TensorShape::DebugString(attr_value.list().shape(i))); + } + } else if (attr_value.list().tensor_size() > 0) { + for (int i = 0; i < attr_value.list().tensor_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, + SummarizeTensor(attr_value.list().tensor(i))); + } + } else if (attr_value.list().func_size() > 0) { + for (int i = 0; i < attr_value.list().func_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, SummarizeFunc(attr_value.list().func(i))); + } + } + + strings::StrAppend(&ret, "]"); + return ret; + } + case AttrValue::kFunc: { + return SummarizeFunc(attr_value.func()); + } + case AttrValue::kPlaceholder: + return strings::StrCat("$", attr_value.placeholder()); + case AttrValue::VALUE_NOT_SET: + return ""; + } + return ""; // Prevent missing return warning +} + +Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { + int num_set = 0; + +#define VALIDATE_FIELD(name, type_string, oneof_case) \ + do { \ + if (attr_value.has_list()) { \ + if (attr_value.list().name##_size() > 0) { \ + if (type != "list(" type_string ")") { \ + return errors::InvalidArgument( \ + "AttrValue had value with type 'list(" type_string ")' when '", \ + type, "' expected"); \ + } \ + ++num_set; \ + } \ + } else if (attr_value.value_case() == AttrValue::oneof_case) { \ + if (type != type_string) { \ + return errors::InvalidArgument( \ + "AttrValue had value with type '" type_string "' when '", type, \ + "' expected"); \ + } \ + ++num_set; \ + } \ + } while (false) + + VALIDATE_FIELD(s, "string", kS); + VALIDATE_FIELD(i, "int", kI); + VALIDATE_FIELD(f, "float", kF); + VALIDATE_FIELD(b, "bool", kB); + VALIDATE_FIELD(type, "type", kType); + VALIDATE_FIELD(shape, "shape", kShape); + VALIDATE_FIELD(tensor, "tensor", kTensor); + VALIDATE_FIELD(func, "func", kFunc); + +#undef VALIDATE_FIELD + + if (attr_value.value_case() == AttrValue::kPlaceholder) { + return errors::InvalidArgument( + "AttrValue had value with unexpected type 'placeholder'"); + } + + // If the attr type is 'list', we expect attr_value.has_list() to be + // true. However, proto3's attr_value.has_list() can be false when + // set to an empty list for GraphDef versions <= 4. So we simply + // check if has_list is false and some other field in attr_value is + // set to flag the error. This test can be made more strict once + // support for GraphDef versions <= 4 is dropped. + if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) { + if (num_set) { + return errors::InvalidArgument( + "AttrValue missing value with expected type '", type, "'"); + } else { + // Indicate that we have a list, but an empty one. + ++num_set; + } + } + + // Okay to have an empty list, but not to be missing a non-list value. + if (num_set == 0 && !StringPiece(type).starts_with("list(")) { + return errors::InvalidArgument( + "AttrValue missing value with expected type '", type, "'"); + } + + // Ref types and DT_INVALID are illegal, and DataTypes must + // be a valid enum type. + if (type == "type") { + if (!DataType_IsValid(attr_value.type())) { + return errors::InvalidArgument("AttrValue has invalid DataType enum: ", + attr_value.type()); + } + if (IsRefType(attr_value.type())) { + return errors::InvalidArgument( + "AttrValue must not have reference type value of ", + DataTypeString(attr_value.type())); + } + if (attr_value.type() == DT_INVALID) { + return errors::InvalidArgument("AttrValue has invalid DataType"); + } + } else if (type == "list(type)") { + for (auto as_int : attr_value.list().type()) { + const DataType dtype = static_cast(as_int); + if (!DataType_IsValid(dtype)) { + return errors::InvalidArgument("AttrValue has invalid DataType enum: ", + as_int); + } + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "AttrValue must not have reference type value of ", + DataTypeString(dtype)); + } + if (dtype == DT_INVALID) { + return errors::InvalidArgument("AttrValue contains invalid DataType"); + } + } + } + + return Status::OK(); +} + +bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { + // Parse type. + string field_name; + bool is_list = type.Consume("list("); + if (type.Consume("string")) { + field_name = "s"; + } else if (type.Consume("int")) { + field_name = "i"; + } else if (type.Consume("float")) { + field_name = "f"; + } else if (type.Consume("bool")) { + field_name = "b"; + } else if (type.Consume("type")) { + field_name = "type"; + } else if (type.Consume("shape")) { + field_name = "shape"; + } else if (type.Consume("tensor")) { + field_name = "tensor"; + } else if (type.Consume("func")) { + field_name = "func"; + } else if (type.Consume("placeholder")) { + field_name = "placeholder"; + } else { + return false; + } + if (is_list && !type.Consume(")")) { + return false; + } + + // Construct a valid text proto message to parse. + string to_parse; + if (is_list) { + // TextFormat parser considers "i: 7" to be the same as "i: [7]", + // but we only want to allow list values with []. + StringPiece cleaned = text; + str_util::RemoveLeadingWhitespace(&cleaned); + str_util::RemoveTrailingWhitespace(&cleaned); + if (cleaned.size() < 2 || cleaned[0] != '[' || + cleaned[cleaned.size() - 1] != ']') { + return false; + } + cleaned.remove_prefix(1); + str_util::RemoveLeadingWhitespace(&cleaned); + if (cleaned.size() == 1) { + // User wrote "[]", so return empty list without invoking the TextFormat + // parse which returns an error for "i: []". + out->Clear(); + out->mutable_list(); + return true; + } + to_parse = strings::StrCat("list { ", field_name, ": ", text, " }"); + } else { + to_parse = strings::StrCat(field_name, ": ", text); + } + + return ProtoParseFromString(to_parse, out); +} + +void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; } + +#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ + void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); } + +#define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \ + void SetAttrValue(ARG_TYPE value, AttrValue* out) { \ + out->mutable_list()->Clear(); /* create list() even if value empty */ \ + for (const auto& v : value) { \ + out->mutable_list()->add_##FIELD(v); \ + } \ + } + +#define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \ + DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ + DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice, FIELD) + +DEFINE_SET_ATTR_VALUE_ONE(const string&, s) +DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice, s) +DEFINE_SET_ATTR_VALUE_BOTH(const char*, s) +DEFINE_SET_ATTR_VALUE_BOTH(int64, i) +DEFINE_SET_ATTR_VALUE_BOTH(int32, i) +DEFINE_SET_ATTR_VALUE_BOTH(float, f) +DEFINE_SET_ATTR_VALUE_BOTH(double, f) +DEFINE_SET_ATTR_VALUE_BOTH(bool, b) +DEFINE_SET_ATTR_VALUE_LIST(const std::vector&, b) +DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list, b) +DEFINE_SET_ATTR_VALUE_BOTH(DataType, type) + +void SetAttrValue(StringPiece value, AttrValue* out) { + out->set_s(value.data(), value.size()); +} + +void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { + out->mutable_list()->Clear(); // Create list() even if value empty. + for (const auto& v : value) { + out->mutable_list()->add_s(v.data(), v.size()); + } +} + +void SetAttrValue(const TensorShape& value, AttrValue* out) { + value.AsProto(out->mutable_shape()); +} + +void SetAttrValue(const TensorShapeProto& value, AttrValue* out) { + *out->mutable_shape() = value; +} + +void SetAttrValue(const PartialTensorShape& value, AttrValue* out) { + value.AsProto(out->mutable_shape()); +} + +void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { + out->mutable_list()->Clear(); // Create list() even if value empty. + for (const auto& v : value) { + v.AsProto(out->mutable_list()->add_shape()); + } +} + +void SetAttrValue(gtl::ArraySlice value, AttrValue* out) { + out->mutable_list()->Clear(); // Create list() even if value empty. + for (const auto& v : value) { + *out->mutable_list()->add_shape() = v; + } +} + +void SetAttrValue(const gtl::ArraySlice value, + AttrValue* out) { + out->mutable_list()->Clear(); // Create list() even if value empty. + for (const auto& v : value) { + v.AsProto(out->mutable_list()->add_shape()); + } +} + +void SetAttrValue(const Tensor& value, AttrValue* out) { + if (value.NumElements() > 1) { + value.AsProtoTensorContent(out->mutable_tensor()); + } else { + value.AsProtoField(out->mutable_tensor()); + } +} + +void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { + out->mutable_list()->Clear(); // Create list() even if value empty. + for (const auto& v : value) { + if (v.NumElements() > 1) { + v.AsProtoTensorContent(out->mutable_list()->add_tensor()); + } else { + v.AsProtoField(out->mutable_list()->add_tensor()); + } + } +} + +void SetAttrValue(const TensorProto& value, AttrValue* out) { + *out->mutable_tensor() = value; +} + +void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { + out->mutable_list()->Clear(); // Create list() even if value empty. + for (const auto& v : value) { + *out->mutable_list()->add_tensor() = v; + } +} + +void SetAttrValue(const NameAttrList& value, AttrValue* out) { + *out->mutable_func() = value; +} + +void SetAttrValue(gtl::ArraySlice value, AttrValue* out) { + out->mutable_list()->Clear(); // Create list() even if value empty. + for (const auto& v : value) { + *out->mutable_list()->add_func() = v; + } +} + +bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) { + // There are multiple equivalent representations of attr values containing + // TensorProtos. Compare them by constructing Tensors and serializing them + // back. Comparing Tensor objects is pretty tricky. + if (a.has_tensor() != b.has_tensor()) { + return false; + } else if (a.has_tensor() && b.has_tensor()) { + Tensor at(a.tensor().dtype()); + bool success = at.FromProto(a.tensor()); + DCHECK(success); + + Tensor bt(b.tensor().dtype()); + success = bt.FromProto(b.tensor()); + DCHECK(success); + + TensorProto ap; + at.AsProtoTensorContent(&ap); + + TensorProto bp; + bt.AsProtoTensorContent(&bp); + + string a_str, b_str; + SerializeToStringDeterministic(ap, &a_str); + SerializeToStringDeterministic(bp, &b_str); + return a_str == b_str; + } + + // `func` field contains a nested AttrValue. Compare such AttrValues + // recursively. + if (a.has_func() != b.has_func()) { + return false; + } else if (a.has_func() && b.has_func()) { + const NameAttrList& af = a.func(); + const NameAttrList& bf = b.func(); + if (af.name() != bf.name()) return false; + std::unordered_map am(af.attr().begin(), + af.attr().end()); + for (const auto& bm_pair : bf.attr()) { + const auto& iter = am.find(bm_pair.first); + if (iter == am.end()) return false; + if (!AreAttrValuesEqual(iter->second, bm_pair.second)) return false; + am.erase(iter); + } + if (!am.empty()) return false; + return true; + } + + // All other fields in AttrValue have deterministic representations. + // It is safe to compare their serialized strings. + string a_str, b_str; + SerializeToStringDeterministic(a, &a_str); + SerializeToStringDeterministic(b, &b_str); + return a_str == b_str; +} + +uint64 AttrValueHash(const AttrValue& a) { + if (a.has_tensor()) { + // Deal with multiple representations by parsing TensorProto to + // Tensor and serializing it back. This is slow, but current use case + // don't need high efficiency. + Tensor tensor(a.tensor().dtype()); + bool success = tensor.FromProto(a.tensor()); + DCHECK(success); + TensorProto p; + tensor.AsProtoTensorContent(&p); + string s; + SerializeToStringDeterministic(p, &s); + return Hash64(s); + } + if (a.has_func()) { + const NameAttrList& func = a.func(); + uint64 h = Hash64(func.name()); + std::map map(func.attr().begin(), func.attr().end()); + for (const auto& pair : map) { + h = Hash64(pair.first.data(), pair.first.size(), h); + h = Hash64Combine(AttrValueHash(pair.second), h); + } + return h; + } + + // If `a` is not a tensor or func, get a hash of serialized string. + string s; + SerializeToStringDeterministic(a, &s); + return Hash64(s); +} + +bool HasPlaceHolder(const AttrValue& val) { + switch (val.value_case()) { + case AttrValue::kList: { + for (const NameAttrList& func : val.list().func()) { + for (const auto& p : func.attr()) { + if (HasPlaceHolder(p.second)) { + return true; + } + } + } + break; + } + case AttrValue::kFunc: + for (const auto& p : val.func().attr()) { + if (HasPlaceHolder(p.second)) { + return true; + } + } + break; + case AttrValue::kPlaceholder: + return true; + default: + break; + } + return false; +} + +bool SubstitutePlaceholders(const SubstituteFunc& substitute, + AttrValue* value) { + switch (value->value_case()) { + case AttrValue::kList: { + for (NameAttrList& func : *value->mutable_list()->mutable_func()) { + for (auto& p : *func.mutable_attr()) { + if (!SubstitutePlaceholders(substitute, &p.second)) { + return false; + } + } + } + break; + } + case AttrValue::kFunc: + for (auto& p : *(value->mutable_func()->mutable_attr())) { + if (!SubstitutePlaceholders(substitute, &p.second)) { + return false; + } + } + break; + case AttrValue::kPlaceholder: + return substitute(value->placeholder(), value); + case AttrValue::VALUE_NOT_SET: + return false; + default: + break; + } + return true; +} + +} // namespace tensorflow diff --git a/attr_value_util.h b/attr_value_util.h new file mode 100644 index 0000000000000000000000000000000000000000..29e34c5090ea9116de054ff7d3a9faf4ed2c30fb --- /dev/null +++ b/attr_value_util.h @@ -0,0 +1,116 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Forward declare protos so their symbols can be removed from .so exports +class AttrValue; +class NameAttrList; + +// A human-readable rendering of attr_value, that is more concise than a +// text-format proto. +string SummarizeAttrValue(const AttrValue& attr_value); + +// Generates an error if attr_value doesn't have the indicated attr type. +Status AttrValueHasType(const AttrValue& attr_value, StringPiece type); + +// Converts a text proto value from "text" into the field of *out +// indicated by "type" (e.g. from the type field of an AttrDef). +// Examples: +// * If type:"int" and text:"-14", then *out is set to "i: -14" +// * If type:"list(string)" and text:"['foo', 'bar']", +// then *out is set to "list { s: ['foo', 'bar'] }" +// Returns true on success. +bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out); + +// Sets *out based on the type of value. +void SetAttrValue(const string& value, AttrValue* out); +void SetAttrValue(const char* value, AttrValue* out); +void SetAttrValue(StringPiece value, AttrValue* out); +void SetAttrValue(int64 value, AttrValue* out); +void SetAttrValue(int32 value, AttrValue* out); +void SetAttrValue(float value, AttrValue* out); +void SetAttrValue(double value, AttrValue* out); +void SetAttrValue(bool value, AttrValue* out); +void SetAttrValue(DataType value, AttrValue* out); +void SetAttrValue(const TensorShape& value, AttrValue* out); +void SetAttrValue(const TensorShapeProto& value, AttrValue* out); +void SetAttrValue(const PartialTensorShape& value, AttrValue* out); +void SetAttrValue(const Tensor& value, AttrValue* out); +void SetAttrValue(const TensorProto& value, AttrValue* out); +void SetAttrValue(const NameAttrList& value, AttrValue* out); + +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(const std::vector& value, AttrValue* out); +void SetAttrValue(std::initializer_list value, AttrValue* out); +void SetAttrValue(DataTypeSlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice value, AttrValue* out); + +void SetAttrValue(const AttrValue& value, AttrValue* out); + +// Returns true if a and b have the same value. +bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b); + +// Returns a hash of `a` that is consistent with AreAttrValuesEqual. In other +// words, if two AttrValues compare equal according to AreAttrValuesEqual, +// they will have the same hash value. +// Similarly to protobuf deterministic serialization, hash value is +// guaranteed to be stable only for a given binary. In particular, one should +// probably not persist the returned value. +uint64 AttrValueHash(const AttrValue& a); + +// Returns true if "val" has a placeholder. +bool HasPlaceHolder(const AttrValue& val); + +// SubstitutePlaceholders recursively replaces placeholders in 'value' +// with an attr value by calling SubstituteFunc. Returns true iff all +// placeholders in "value" are replaced with a value. +// +// SubstituteFunc is given a placeholder string. If the placeholder is +// unknown, SubstituteFunc returns false. Otherwise, overwrites the +// attr value and returns true. +using SubstituteFunc = std::function; +bool SubstitutePlaceholders(const SubstituteFunc& substitute, AttrValue* value); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_ diff --git a/attr_value_util_test.cc b/attr_value_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c9a209f05bcab1a0b4304aaddb2d0421e4df45f --- /dev/null +++ b/attr_value_util_test.cc @@ -0,0 +1,195 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/attr_value_util.h" + +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +// A few helpers to construct AttrValue protos. +template +AttrValue V(T value) { + AttrValue ret; + SetAttrValue(value, &ret); + return ret; +} + +AttrValue P(const string& p) { + AttrValue ret; + ret.set_placeholder(p); + return ret; +} + +AttrValue F(const string& name, + std::vector> pairs) { + AttrValue ret; + ret.mutable_func()->set_name(name); + ret.mutable_func()->mutable_attr()->insert(pairs.begin(), pairs.end()); + return ret; +} + +AttrValue Fs( + std::vector>>> + funcs) { + AttrValue ret; + for (const auto& func : funcs) { + NameAttrList* entry = ret.mutable_list()->add_func(); + entry->set_name(func.first); + entry->mutable_attr()->insert(func.second.begin(), func.second.end()); + } + return ret; +} + +TEST(AttrValueUtil, HasType) { + // OK + EXPECT_TRUE(AttrValueHasType(V(123), "int").ok()); + EXPECT_TRUE(AttrValueHasType(V(1.2), "float").ok()); + EXPECT_TRUE(AttrValueHasType(V(DT_FLOAT), "type").ok()); + EXPECT_TRUE(AttrValueHasType(F("f", {}), "func").ok()); + EXPECT_TRUE(AttrValueHasType(Fs({{"f", {}}, {"g", {}}}), "list(func)").ok()); + + // not OK. + EXPECT_FALSE(AttrValueHasType(V(123), "func").ok()); + EXPECT_FALSE(AttrValueHasType(V(1.2), "int").ok()); + EXPECT_FALSE(AttrValueHasType(V(DT_FLOAT), "shape").ok()); + EXPECT_FALSE(AttrValueHasType(F("f", {}), "string").ok()); + EXPECT_FALSE(AttrValueHasType(P("T"), "float").ok()); + EXPECT_FALSE(AttrValueHasType(V(static_cast(1000)), "type").ok()); + std::vector list_type({static_cast(1000)}); + EXPECT_FALSE(AttrValueHasType(V(list_type), "list(type)").ok()); +} + +SubstituteFunc ReplaceTWith(const AttrValue& val) { + return [val](const string& placeholder, AttrValue* target) { + if (placeholder == "T") { + *target = val; + return true; + } else { + return false; + } + }; +} + +TEST(AttrValueUtil, Basic) { + auto v = F("MatMul", {{"dtype", P("T")}, + {"transpose_a", V(false)}, + {"transpose_b", V(true)}, + {"use_cublas", V(true)}}); + TF_EXPECT_OK(AttrValueHasType(v, "func")); + EXPECT_TRUE(HasPlaceHolder(v)); + + EXPECT_EQ( + SummarizeAttrValue(v), + "MatMul[dtype=$T, transpose_a=false, transpose_b=true, use_cublas=true]"); + + SubstitutePlaceholders(ReplaceTWith(V(DT_FLOAT)), &v); + EXPECT_TRUE(!HasPlaceHolder(v)); + EXPECT_EQ(SummarizeAttrValue(v), + "MatMul[dtype=DT_FLOAT, transpose_a=false, transpose_b=true, " + "use_cublas=true]"); +} + +TEST(AttrValueUtil, Shaped) { + auto v = + F("OpRequiresShape", {{"shape_full", V(TensorShape({1, 0}))}, + {"shape_part", V(PartialTensorShape({-1, 1, 0}))}}); + TF_EXPECT_OK(AttrValueHasType(v, "func")); + EXPECT_FALSE(HasPlaceHolder(v)); + + EXPECT_EQ(SummarizeAttrValue(v), + "OpRequiresShape[shape_full=[1,0], shape_part=[?,1,0]]"); +} + +TEST(AttrValueUtil, DeepAttr) { + auto v = Fs({{"f", {{"T", P("T")}}}, {"g", {{"T", P("T")}}}}); + TF_EXPECT_OK(AttrValueHasType(v, "list(func)")); + EXPECT_TRUE(HasPlaceHolder(v)); + + for (int i = 0; i < 3; ++i) { + v = F("f", {{"T", P("T")}, {"F", v}}); + EXPECT_TRUE(HasPlaceHolder(v)); + } + EXPECT_EQ(SummarizeAttrValue(v), + "f[F=f[F=f[F=[f[T=$T], g[T=$T]], T=$T], T=$T], T=$T]"); + + SubstitutePlaceholders(ReplaceTWith(F("x", {})), &v); + EXPECT_TRUE(!HasPlaceHolder(v)); + EXPECT_EQ(SummarizeAttrValue(v), + "f[F=f[F=f[F=[f[T=x[]], g[T=x[]]], T=x[]], T=x[]], T=x[]]"); +} + +AttrValue FromText(const string& text) { + AttrValue attr; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &attr)); + return attr; +} + +void ExpectDifferent(const AttrValue& a1, const AttrValue& a2) { + EXPECT_FALSE(AreAttrValuesEqual(a1, a2)); + EXPECT_FALSE(AreAttrValuesEqual(a2, a1)); + EXPECT_NE(AttrValueHash(a1), AttrValueHash(a2)); +} + +TEST(AttrValueEquality, StringAndFuncTensors) { + AttrValue a = FromText(R"( + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: 'reader_dataset_ops_test/tmphtXHks/text_line.0.txt' + string_val: 'reader_dataset_ops_test/tmphtXHks/text_line.1.txt' + })"); + EXPECT_TRUE(AreAttrValuesEqual(a, a)); + EXPECT_EQ(AttrValueHash(a), AttrValueHash(a)); + + AttrValue b = a; + (*b.mutable_tensor()->mutable_string_val(0))[3] = '1'; + ExpectDifferent(a, b); + + AttrValue c1; + c1.mutable_func()->set_name("func_name"); + (*c1.mutable_func()->mutable_attr())["attr1"] = a; + (*c1.mutable_func()->mutable_attr())["attr2"] = b; + EXPECT_TRUE(AreAttrValuesEqual(c1, c1)); + EXPECT_EQ(AttrValueHash(c1), AttrValueHash(c1)); + + ExpectDifferent(c1, a); + + AttrValue c2 = c1; + c2.mutable_func()->set_name("func_name2"); + ExpectDifferent(c1, c2); + + c2 = c1; + (*c2.mutable_func()->mutable_attr())["attr3"] = b; + ExpectDifferent(c1, c2); + + c2 = c1; + (*c2.mutable_func()->mutable_attr())["attr2"] = a; + ExpectDifferent(c1, c2); + + c2 = c1; + c2.mutable_func()->mutable_attr()->erase("attr2"); + ExpectDifferent(c1, c2); +} + +} // namespace tensorflow diff --git a/bfloat16.cc b/bfloat16.cc new file mode 100644 index 0000000000000000000000000000000000000000..0efe43fde2dadd42aa03d3bf2968d2cbfb113e8d --- /dev/null +++ b/bfloat16.cc @@ -0,0 +1,50 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/bfloat16.h" + +namespace tensorflow { + +void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) { + const uint16_t* p = reinterpret_cast(src); + uint16_t* q = reinterpret_cast(dst); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + for (; size != 0; p += 2, q++, size--) { + *q = p[0]; + } +#else + for (; size != 0; p += 2, q++, size--) { + *q = p[1]; + } +#endif +} + +void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) { + const uint16_t* p = reinterpret_cast(src); + uint16_t* q = reinterpret_cast(dst); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + for (; size != 0; p++, q += 2, size--) { + q[0] = *p; + q[1] = 0; + } +#else + for (; size != 0; p++, q += 2, size--) { + q[0] = 0; + q[1] = *p; + } +#endif +} + +} // end namespace tensorflow diff --git a/bfloat16.h b/bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..968c18bdd2159fee4eb6982c62697951d79b706c --- /dev/null +++ b/bfloat16.h @@ -0,0 +1,62 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_BFLOAT16_H_ +#define TENSORFLOW_FRAMEWORK_BFLOAT16_H_ + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/platform/types.h" + +#if defined(PLATFORM_WINDOWS) +#include "tensorflow/core/platform/windows/cpu_info.h" +#endif + +// Compact 16-bit encoding of floating point numbers. This representation uses +// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. It +// is assumed that floats are in IEEE 754 format so the representation is just +// bits 16-31 of a single precision float. +// +// NOTE: The IEEE floating point standard defines a float16 format that +// is different than this format (it has fewer bits of exponent and more +// bits of mantissa). We don't use that format here because conversion +// to/from 32-bit floats is more complex for that format, and the +// conversion for this format is very simple. +// +// Because of the existing IEEE float16 type, we do not name our representation +// "float16" but just use "uint16". +// +// <-----our 16bits float-------> +// s e e e e e e e e f f f f f f f f f f f f f f f f f f f f f f f +// <------------------------------float--------------------------> +// 3 3 2 2 1 1 0 +// 1 0 3 2 5 4 0 +// +// +// This type only supports conversion back and forth with float. +// +// This file must be compilable by nvcc. +// +// The type is defined in framework/numeric_types.h. + +namespace tensorflow { + +// Conversion routines between an array of float and bfloat16 of +// "size". +void FloatToBFloat16(const float* src, bfloat16* dst, int64 size); +void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_BFLOAT16_H_ diff --git a/bfloat16_test.cc b/bfloat16_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..17e6209f8e5ad5240dfc8ca1def75c178da45c27 --- /dev/null +++ b/bfloat16_test.cc @@ -0,0 +1,158 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/bfloat16.h" + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +TEST(Bfloat16Test, Simple) { + bfloat16 a(12); + // Floating point representation of 12: 0x41400000 + EXPECT_EQ(0x4140, a.value); +} + +float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa, + uint32_t low_mantissa) { + return bit_cast((sign << 31) + (exponent << 23) + + (high_mantissa << 16) + low_mantissa); +} + +struct Bfloat16TestParam { + float input; + float expected; +}; + +class Bfloat16Test : public ::testing::Test, + public ::testing::WithParamInterface {}; + +TEST_P(Bfloat16Test, TruncateTest) { + bfloat16 a(GetParam().input); + if (std::isnan(GetParam().input)) { + EXPECT_TRUE(std::isnan(float(a)) || std::isinf(float(a))); + return; + } + EXPECT_EQ(GetParam().expected, float(a)); +} + +INSTANTIATE_TEST_CASE_P( + Bfloat16Test_Instantiation, Bfloat16Test, + ::testing::Values( + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b1111010111000011), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(1, 0b10000000, 0b1001000, 0b1111010111000011), + BinaryToFloat(1, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000001), + BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b11111111, 0b1111111, 0b1111111111111111), + BinaryToFloat(0, 0b11111111, 0b1111111, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(1, 0b10000000, 0b1001000, 0b1100000000000000), + BinaryToFloat(1, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0100000000000000), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000), + BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000), + BinaryToFloat(0, 0b00000000, 0b1111111, 0b0000000000000000)})); + +TEST(Bfloat16Test, Conversion) { + float a[100]; + for (int i = 0; i < 100; ++i) { + a[i] = i + 1.25; + } + bfloat16 b[100]; + float c[100]; + FloatToBFloat16(a, b, 100); + BFloat16ToFloat(b, c, 100); + for (int i = 0; i < 100; ++i) { + // The relative error should be less than 1/(2^7) since bfloat16 + // has 7 bits mantissa. + EXPECT_LE(fabs(c[i] - a[i]) / a[i], 1.0 / 128); + } +} + +TEST(Bfloat16Test, Epsilon) { + EXPECT_LT(1.0f, static_cast(bfloat16::epsilon() + bfloat16(1.0f))); + EXPECT_EQ(1.0f, static_cast((bfloat16::epsilon() / bfloat16(2.0f)) + + bfloat16(1.0f))); +} + +TEST(Bfloat16Test, Negate) { + EXPECT_EQ(-3.0f, static_cast(-bfloat16(3.0f))); + EXPECT_EQ(4.5f, static_cast(-bfloat16(-4.5f))); +} + +static void BM_FloatToBFloat16(int iters) { + testing::StopTiming(); + static const int N = 32 << 20; + const int64 tot = static_cast(iters) * N; + testing::ItemsProcessed(tot); + testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16))); + + float* inp = new float[N]; + bfloat16* out = new bfloat16[N]; + + testing::StartTiming(); + while (iters--) { + FloatToBFloat16(inp, out, N); + } + delete[] inp; + delete[] out; +} +BENCHMARK(BM_FloatToBFloat16); + +static void BM_BFloat16ToFloat(int iters) { + testing::StopTiming(); + static const int N = 32 << 20; + const int64 tot = static_cast(iters) * N; + testing::ItemsProcessed(tot); + testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16))); + + bfloat16* inp = new bfloat16[N]; + float* out = new float[N]; + + testing::StartTiming(); + while (iters--) { + BFloat16ToFloat(inp, out, N); + } + delete[] inp; + delete[] out; +} +BENCHMARK(BM_BFloat16ToFloat); + +} // namespace +} // namespace tensorflow diff --git a/cancellation.cc b/cancellation.cc new file mode 100644 index 0000000000000000000000000000000000000000..9da4828bbad7b6333336dd1215441f5c5f62151a --- /dev/null +++ b/cancellation.cc @@ -0,0 +1,94 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/cancellation.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +const CancellationToken CancellationManager::kInvalidToken = -1; + +CancellationManager::CancellationManager() + : is_cancelling_(false), + is_cancelled_(false), + next_cancellation_token_(0) {} + +void CancellationManager::StartCancel() { + gtl::FlatMap callbacks_to_run; + { + mutex_lock l(mu_); + if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) { + return; + } + is_cancelling_ = true; + std::swap(callbacks_, callbacks_to_run); + } + // We call these callbacks without holding mu_, so that concurrent + // calls to DeregisterCallback, which can happen asynchronously, do + // not block. The callbacks remain valid because any concurrent call + // to DeregisterCallback will block until the + // cancelled_notification_ is notified. + for (auto key_and_value : callbacks_to_run) { + key_and_value.second(); + } + { + mutex_lock l(mu_); + is_cancelling_ = false; + is_cancelled_.store(true, std::memory_order_release); + } + cancelled_notification_.Notify(); +} + +CancellationToken CancellationManager::get_cancellation_token() { + mutex_lock l(mu_); + return next_cancellation_token_++; +} + +bool CancellationManager::RegisterCallback(CancellationToken token, + CancelCallback callback) { + mutex_lock l(mu_); + CHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token"; + bool should_register = !is_cancelled_ && !is_cancelling_; + if (should_register) { + std::swap(callbacks_[token], callback); + } + return should_register; +} + +bool CancellationManager::DeregisterCallback(CancellationToken token) { + mu_.lock(); + if (is_cancelled_) { + mu_.unlock(); + return false; + } else if (is_cancelling_) { + mu_.unlock(); + // Wait for all of the cancellation callbacks to be called. This + // wait ensures that the caller of DeregisterCallback does not + // return immediately and free objects that may be used in the + // execution of any currently pending callbacks in StartCancel. + cancelled_notification_.WaitForNotification(); + return false; + } else { + callbacks_.erase(token); + mu_.unlock(); + return true; + } +} + +CancellationManager::~CancellationManager() { StartCancel(); } + +} // end namespace tensorflow diff --git a/cancellation.h b/cancellation.h new file mode 100644 index 0000000000000000000000000000000000000000..90074c87b229a82429a561c0a1cfe397c0e04f07 --- /dev/null +++ b/cancellation.h @@ -0,0 +1,137 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_ +#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_ + +#include +#include + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// A token that can be used to register and deregister a +// CancelCallback with a CancellationManager. +// +// CancellationToken values must be created by a call to +// CancellationManager::get_cancellation_token. +typedef int64 CancellationToken; + +// A callback that is invoked when a step is canceled. +// +// NOTE(mrry): See caveats about CancelCallback implementations in the +// comment for CancellationManager::RegisterCallback. +typedef std::function CancelCallback; + +class CancellationManager { + public: + // A value that won't be returned by get_cancellation_token(). + static const CancellationToken kInvalidToken; + + CancellationManager(); + ~CancellationManager(); + + // Run all callbacks associated with this manager. + void StartCancel(); + + // Returns true iff StartCancel() has been called. + bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); } + + // Returns a token that must be used in calls to RegisterCallback + // and DeregisterCallback. + CancellationToken get_cancellation_token(); + + // Attempts to register the given callback to be invoked when this + // manager is cancelled. Returns true if the callback was + // registered; returns false if this manager was already cancelled, + // and the callback was not registered. + // + // If this method returns false, it is the caller's responsibility + // to perform any cancellation cleanup. + // + // This method is tricky to use correctly. The following usage pattern + // is recommended: + // + // class ObjectWithCancellableOperation { + // mutex mu_; + // void CancellableOperation(CancellationManager* cm, + // std::function callback) { + // bool already_cancelled; + // CancellationToken token = cm->get_cancellation_token(); + // { + // mutex_lock(mu_); + // already_cancelled = !cm->RegisterCallback( + // [this, token]() { Cancel(token); }); + // if (!already_cancelled) { + // // Issue asynchronous operation. Associate the pending operation + // // with `token` in some object state, or provide another way for + // // the Cancel method to look up the operation for cancellation. + // // Ensure that `cm->DeregisterCallback(token)` is called without + // // holding `mu_`, before `callback` is invoked. + // // ... + // } + // } + // if (already_cancelled) { + // callback(errors::Cancelled("Operation was cancelled")); + // } + // } + // + // void Cancel(CancellationToken token) { + // mutex_lock(mu_); + // // Take action to cancel the operation with the given cancellation + // // token. + // } + // + // NOTE(mrry): The caller should take care that (i) the calling code + // is robust to `callback` being invoked asynchronously (e.g. from + // another thread), (ii) `callback` is deregistered by a call to + // this->DeregisterCallback(token) when the operation completes + // successfully, and (iii) `callback` does not invoke any method + // on this cancellation manager. Furthermore, it is important that + // the eventual caller of the complementary DeregisterCallback does not + // hold any mutexes that are required by `callback`. + bool RegisterCallback(CancellationToken token, CancelCallback callback); + + // Deregister the callback that, when registered, was associated + // with the given cancellation token. Returns true iff the callback + // was deregistered and will not be invoked; otherwise returns false + // after the callback has been invoked, blocking if necessary. + // + // NOTE(mrry): This method may block if cancellation is in progress. + // The caller of this method must not hold any mutexes that are required + // to invoke any cancellation callback that has been registered with this + // cancellation manager. + bool DeregisterCallback(CancellationToken token); + + private: + bool is_cancelling_; + std::atomic_bool is_cancelled_; + + mutex mu_; + Notification cancelled_notification_; + CancellationToken next_cancellation_token_ GUARDED_BY(mu_); + gtl::FlatMap callbacks_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_CANCELLATION_H_ diff --git a/cancellation_test.cc b/cancellation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3f18240b588876beb59bc14bf9ecf1ff37c11a2 --- /dev/null +++ b/cancellation_test.cc @@ -0,0 +1,118 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/cancellation.h" + +#include +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(Cancellation, SimpleNoCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + bool deregistered = manager->DeregisterCallback(token); + EXPECT_TRUE(deregistered); + delete manager; + EXPECT_FALSE(is_cancelled); +} + +TEST(Cancellation, SimpleCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + manager->StartCancel(); + EXPECT_TRUE(is_cancelled); + delete manager; +} + +TEST(Cancellation, CancelBeforeRegister) { + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + manager->StartCancel(); + bool registered = manager->RegisterCallback(token, nullptr); + EXPECT_FALSE(registered); + delete manager; +} + +TEST(Cancellation, DeregisterAfterCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + manager->StartCancel(); + EXPECT_TRUE(is_cancelled); + bool deregistered = manager->DeregisterCallback(token); + EXPECT_FALSE(deregistered); + delete manager; +} + +TEST(Cancellation, CancelMultiple) { + bool is_cancelled_1 = false, is_cancelled_2 = false, is_cancelled_3 = false; + CancellationManager* manager = new CancellationManager(); + auto token_1 = manager->get_cancellation_token(); + bool registered_1 = manager->RegisterCallback( + token_1, [&is_cancelled_1]() { is_cancelled_1 = true; }); + EXPECT_TRUE(registered_1); + auto token_2 = manager->get_cancellation_token(); + bool registered_2 = manager->RegisterCallback( + token_2, [&is_cancelled_2]() { is_cancelled_2 = true; }); + EXPECT_TRUE(registered_2); + EXPECT_FALSE(is_cancelled_1); + EXPECT_FALSE(is_cancelled_2); + manager->StartCancel(); + EXPECT_TRUE(is_cancelled_1); + EXPECT_TRUE(is_cancelled_2); + EXPECT_FALSE(is_cancelled_3); + auto token_3 = manager->get_cancellation_token(); + bool registered_3 = manager->RegisterCallback( + token_3, [&is_cancelled_3]() { is_cancelled_3 = true; }); + EXPECT_FALSE(registered_3); + EXPECT_FALSE(is_cancelled_3); + delete manager; +} + +TEST(Cancellation, IsCancelled) { + CancellationManager* cm = new CancellationManager(); + thread::ThreadPool w(Env::Default(), "test", 4); + std::vector done(8); + for (size_t i = 0; i < done.size(); ++i) { + Notification* n = &done[i]; + w.Schedule([n, cm]() { + while (!cm->IsCancelled()) { + } + n->Notify(); + }); + } + Env::Default()->SleepForMicroseconds(1000000 /* 1 second */); + cm->StartCancel(); + for (size_t i = 0; i < done.size(); ++i) { + done[i].WaitForNotification(); + } + delete cm; +} + +} // namespace tensorflow diff --git a/common_shape_fns.cc b/common_shape_fns.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ab8e3ec188a223e35b47b6f9517abd9327b23f8 --- /dev/null +++ b/common_shape_fns.cc @@ -0,0 +1,1399 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +namespace tensorflow { + +Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size, + int64 dilation_rate, int64 stride, + Padding padding_type, int64* output_size, + int64* padding_before, + int64* padding_after) { + if (stride <= 0) { + return errors::InvalidArgument("Stride must be > 0, but got ", stride); + } + if (dilation_rate < 1) { + return errors::InvalidArgument("Dilation rate must be >= 1, but got ", + dilation_rate); + } + + // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2. + int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1; + switch (padding_type) { + case Padding::VALID: + *output_size = (input_size - effective_filter_size + stride) / stride; + *padding_before = *padding_after = 0; + break; + case Padding::SAME: + *output_size = (input_size + stride - 1) / stride; + const int64 padding_needed = + std::max(0LL, (*output_size - 1) * stride + effective_filter_size - + input_size); + // For odd values of total padding, add more padding at the 'right' + // side of the given dimension. + *padding_before = padding_needed / 2; + *padding_after = padding_needed - *padding_before; + break; + } + if (*output_size < 0) { + return errors::InvalidArgument("computed output size would be negative"); + } + return Status::OK(); +} + +Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size, + int64 stride, Padding padding_type, + int64* output_size, int64* padding_before, + int64* padding_after) { + return GetWindowedOutputSizeVerboseV2(input_size, filter_size, + /*dilation_rate=*/1, stride, + padding_type, output_size, + padding_before, padding_after); +} + +Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride, + Padding padding_type, int64* output_size, + int64* padding_size) { + int64 padding_after_unused; + return GetWindowedOutputSizeVerbose(input_size, filter_size, stride, + padding_type, output_size, padding_size, + &padding_after_unused); +} + +Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size, + int64 dilation_rate, int64 stride, + Padding padding_type, int64* output_size, + int64* padding_size) { + int64 padding_after_unused; + return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate, + stride, padding_type, output_size, + padding_size, &padding_after_unused); +} + +Status Get3dOutputSize(const std::array& input, + const std::array& window, + const std::array& strides, + Padding padding_type, std::array* output_ptr, + std::array* padding_ptr) { + for (size_t i = 0; i < input.size(); ++i) { + TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i], + padding_type, &(*output_ptr)[i], + &(*padding_ptr)[i])); + } + return Status::OK(); +} + +Status Get3dOutputSizeV2(const std::array& input, + const std::array& window, + const std::array& dilations, + const std::array& strides, + Padding padding_type, std::array* output_ptr, + std::array* padding_ptr) { + for (size_t i = 0; i < input.size(); ++i) { + TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2( + input[i], window[i], dilations[i], strides[i], padding_type, + &(*output_ptr)[i], &(*padding_ptr)[i])); + } + return Status::OK(); +} + +namespace shape_inference { + +// The V2 version computes windowed output size with arbitrary dilation_rate, +// while the original version only handles the cases where dilation_rates equal +// to 1. +Status GetWindowedOutputSizeFromDimsV2( + shape_inference::InferenceContext* c, + shape_inference::DimensionHandle input_size, + shape_inference::DimensionOrConstant filter_size, int64 dilation_rate, + int64 stride, Padding padding_type, + shape_inference::DimensionHandle* output_size) { + if (stride <= 0) { + return errors::InvalidArgument("Stride must be > 0, but got ", stride); + } + + if (dilation_rate < 1) { + return errors::InvalidArgument("Dilation rate must be >= 1, but got ", + dilation_rate); + } + + // See also the parallel implementation in GetWindowedOutputSizeVerbose. + switch (padding_type) { + case Padding::VALID: + if (dilation_rate > 1) { + DimensionHandle window_size; + TF_RETURN_IF_ERROR( + c->Subtract(c->MakeDim(filter_size), 1, &window_size)); + TF_RETURN_IF_ERROR( + c->Multiply(window_size, dilation_rate, &window_size)); + TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size)); + TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size)); + } else { + TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size)); + } + TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size)); + TF_RETURN_IF_ERROR(c->Divide(*output_size, stride, + /*evenly_divisible=*/false, output_size)); + break; + case Padding::SAME: + TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size)); + TF_RETURN_IF_ERROR(c->Divide(*output_size, stride, + /*evenly_divisible=*/false, output_size)); + break; + } + return Status::OK(); +} + +Status GetWindowedOutputSizeFromDims( + shape_inference::InferenceContext* c, + shape_inference::DimensionHandle input_size, + shape_inference::DimensionOrConstant filter_size, int64 stride, + Padding padding_type, shape_inference::DimensionHandle* output_size) { + return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size, + /*dilation_rate=*/1, stride, + padding_type, output_size); +} + +Status UnchangedShape(shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); +} + +Status MatMulShape(shape_inference::InferenceContext* c) { + ShapeHandle a; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a)); + + ShapeHandle b; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b)); + + bool transpose_a, transpose_b; + TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a)); + TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b)); + DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0); + DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1); + + // Validate that the inner shapes are compatible. + DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1); + DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0); + DimensionHandle merged; + TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged)); + + c->set_output(0, c->Matrix(output_rows, output_cols)); + return Status::OK(); +} + +Status BiasAddShape(shape_inference::InferenceContext* c) { + ShapeHandle input_shape; + + // Fetch the data_format attribute, which may not exist. + string data_format; + Status s = c->GetAttr("data_format", &data_format); + + if (s.ok() && data_format == "NCHW") { + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); + } else { + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); + } + + ShapeHandle bias_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape)); + DimensionHandle bias_dim = c->Dim(bias_shape, 0); + + // If rank unknown, return unknown shape. + if (!c->RankKnown(input_shape)) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + } + + // Output has the same shape as the input, and matches the length of + // the bias in its bias dimension. + ShapeHandle output_shape; + if (s.ok() && data_format == "NCHW") { + // Merge the length of bias_shape into the third to last dimension + ShapeHandle first; + TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -3, &first)); + + ShapeHandle last; + TF_RETURN_IF_ERROR(c->Subshape(input_shape, -2, &last)); + + DimensionHandle input_bias_dim = c->Dim(input_shape, -3); + DimensionHandle merged_bias_dim; + TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim)); + ShapeHandle merged_bias = c->Vector(merged_bias_dim); + + ShapeHandle temp; + TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp)); + TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape)); + } else { + ShapeHandle all_but_bias; + TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias)); + + DimensionHandle input_bias_dim = c->Dim(input_shape, -1); + DimensionHandle merged_bias_dim; + TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim)); + + ShapeHandle merged_bias = c->Vector(merged_bias_dim); + TF_RETURN_IF_ERROR( + c->Concatenate(all_but_bias, merged_bias, &output_shape)); + } + + c->set_output(0, output_shape); + return Status::OK(); +} + +Status BiasAddGradShape(shape_inference::InferenceContext* c) { + ShapeHandle input_shape; + // Fetch the data_format attribute, which may not exist. + string data_format; + Status s = c->GetAttr("data_format", &data_format); + + if (s.ok() && data_format == "NCHW") { + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); + c->set_output(0, c->Vector(c->Dim(input_shape, -3))); + } else { + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); + c->set_output(0, c->Vector(c->Dim(input_shape, -1))); + } + + return Status::OK(); +} + +Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, + const ShapeHandle shape_handle, + const string& tensor_name, + shape_inference::InferenceContext* c) { + if (tensor_format == FORMAT_NCHW_VECT_C) { + // Check that the vect dim has size 4. + const int num_dims = c->Rank(shape_handle); + DimensionHandle vect_dim = c->Dim( + shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format)); + DimensionHandle unused_vect_dim; + TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim)); + } + + return Status::OK(); +} + +Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, + const std::vector& spatial, + DimensionOrConstant C, ShapeHandle* out, + shape_inference::InferenceContext* context) { + const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format); + std::vector dims_actual(num_dims); + dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N); + int outer_c_index = GetTensorFeatureDimIndex(num_dims, format); + dims_actual[outer_c_index] = context->MakeDim(C); + if (format == FORMAT_NCHW_VECT_C) { + dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] = + context->MakeDim(4); + } + for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) { + dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] = + context->MakeDim(spatial[spatial_dim]); + } + *out = context->MakeShape(dims_actual); + return Status::OK(); +} + +Status DimensionsFromShape(ShapeHandle shape, TensorFormat format, + DimensionHandle* batch_dim, + gtl::MutableArraySlice spatial_dims, + DimensionHandle* filter_dim, + InferenceContext* context) { + const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format); + // Batch. + *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format)); + // Spatial. + for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size(); + ++spatial_dim_index) { + spatial_dims[spatial_dim_index] = context->Dim( + shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index)); + } + // Channel. + *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format)); + if (format == FORMAT_NCHW_VECT_C) { + TF_RETURN_IF_ERROR(context->Multiply( + *filter_dim, + context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)), + filter_dim)); + } + return Status::OK(); +} + +Status ShapeFromDimensions(DimensionHandle batch_dim, + gtl::ArraySlice spatial_dims, + DimensionHandle filter_dim, TensorFormat format, + InferenceContext* context, ShapeHandle* shape) { + const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format); + std::vector out_dims(rank); + + // Batch. + out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim; + // Spatial. + for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size(); + ++spatial_dim_index) { + out_dims[tensorflow::GetTensorSpatialDimIndex( + rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index]; + } + // Channel. + if (format == tensorflow::FORMAT_NCHW_VECT_C) { + // When format is NCHW_VECT_C, factor the feature map count + // into the outer feature count and the inner feature count (=4). + TF_RETURN_IF_ERROR(context->Divide( + filter_dim, 4, /*evenly_divisible=*/true, + &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)])); + out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4); + } else { + out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim; + } + + *shape = context->MakeShape(out_dims); + return tensorflow::Status::OK(); +} + +Status Conv2DShape(shape_inference::InferenceContext* c) { + string data_format_str, filter_format_str; + if (!c->GetAttr("data_format", &data_format_str).ok()) { + data_format_str = "NHWC"; + } + if (!c->GetAttr("filter_format", &filter_format_str).ok()) { + filter_format_str = "HWIO"; + } + + TensorFormat data_format; + if (!FormatFromString(data_format_str, &data_format)) { + return errors::InvalidArgument("Invalid data format string: ", + data_format_str); + } + FilterTensorFormat filter_format; + if (!FilterFormatFromString(filter_format_str, &filter_format)) { + return errors::InvalidArgument("Invalid filter format string: ", + filter_format_str); + } + + constexpr int num_spatial_dims = 2; + const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); + ShapeHandle conv_input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape)); + TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape( + data_format, conv_input_shape, "conv_input", c)); + + // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C). + ShapeHandle filter_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); + TF_RETURN_IF_ERROR( + CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c)); + + std::vector dilations; + TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); + + if (dilations.size() != 4) { + return errors::InvalidArgument( + "Conv2D requires the dilation attribute to contain 4 values, but got: ", + dilations.size()); + } + + std::vector strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + + // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C). + if (strides.size() != 4) { + return errors::InvalidArgument("Conv2D on data format ", data_format_str, + " requires the stride attribute to contain" + " 4 values, but got: ", + strides.size()); + } + + const int32 stride_rows = GetTensorDim(strides, data_format, 'H'); + const int32 stride_cols = GetTensorDim(strides, data_format, 'W'); + const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H'); + const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W'); + + DimensionHandle batch_size_dim; + DimensionHandle input_depth_dim; + gtl::InlinedVector input_spatial_dims(2); + TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format, + &batch_size_dim, &input_spatial_dims, + &input_depth_dim, c)); + + DimensionHandle output_depth_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'O')); + DimensionHandle filter_rows_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'H')); + DimensionHandle filter_cols_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'W')); + DimensionHandle filter_input_depth_dim; + if (filter_format == FORMAT_OIHW_VECT_I) { + TF_RETURN_IF_ERROR(c->Multiply( + c->Dim(filter_shape, + GetFilterDimIndex(filter_format, 'I')), + c->Dim(filter_shape, + GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)), + &filter_input_depth_dim)); + } else { + filter_input_depth_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'I')); + } + + // Check that the input tensor and the filter tensor agree on the input + // channel count. + DimensionHandle unused; + TF_RETURN_IF_ERROR( + c->Merge(input_depth_dim, filter_input_depth_dim, &unused)); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + DimensionHandle output_rows, output_cols; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( + c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows, + padding, &output_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( + c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols, + padding, &output_cols)); + + ShapeHandle output_shape; + TF_RETURN_IF_ERROR( + ShapeFromDimensions(batch_size_dim, {output_rows, output_cols}, + output_depth_dim, data_format, c, &output_shape)); + c->set_output(0, output_shape); + return Status::OK(); +} + +// TODO(mjanusz): Unify all conv/pooling shape functions. +Status Conv3DShape(shape_inference::InferenceContext* c) { + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); + ShapeHandle filter_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape)); + + string data_format; + Status s = c->GetAttr("data_format", &data_format); + + std::vector strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + if (strides.size() != 5) { + return errors::InvalidArgument( + "Conv3D requires the stride attribute to contain 5 values, but got: ", + strides.size()); + } + + int32 stride_planes, stride_rows, stride_cols; + if (s.ok() && data_format == "NCDHW") { + // Convert input_shape to NDHWC. + auto dim = [&](char dimension) { + return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension)); + }; + input_shape = + c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}}); + stride_planes = strides[2]; + stride_cols = strides[3]; + stride_rows = strides[4]; + } else { + stride_planes = strides[1]; + stride_rows = strides[2]; + stride_cols = strides[3]; + } + + DimensionHandle batch_size_dim = c->Dim(input_shape, 0); + DimensionHandle in_planes_dim = c->Dim(input_shape, 1); + DimensionHandle in_rows_dim = c->Dim(input_shape, 2); + DimensionHandle in_cols_dim = c->Dim(input_shape, 3); + + DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0); + DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1); + DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2); + DimensionHandle output_depth_dim = c->Dim(filter_shape, 4); + + DimensionHandle unused; + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused)); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + DimensionHandle output_planes, output_rows, output_cols; + + TF_RETURN_IF_ERROR( + GetWindowedOutputSizeFromDims(c, in_planes_dim, filter_planes_dim, + stride_planes, padding, &output_planes)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols)); + + ShapeHandle output_shape; + if (data_format == "NCDHW") { + output_shape = c->MakeShape({batch_size_dim, output_depth_dim, + output_planes, output_rows, output_cols}); + } else { + output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows, + output_cols, output_depth_dim}); + } + c->set_output(0, output_shape); + return Status::OK(); +} + +Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); + ShapeHandle filter_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape)); + + std::vector strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + + if (strides.size() != 4) { + return errors::InvalidArgument( + "DepthwiseConv2D requires the stride attribute to contain 4 values, " + "but got: ", + strides.size()); + } + + string data_format; + Status s = c->GetAttr("data_format", &data_format); + int32 stride_rows; + int32 stride_cols; + if (s.ok() && data_format == "NCHW") { + // Canonicalize input shape to NHWC so the shape inference code below can + // process it. + input_shape = + c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2), + c->Dim(input_shape, 3), c->Dim(input_shape, 1)}}); + stride_rows = strides[2]; + stride_cols = strides[3]; + } else { + stride_rows = strides[1]; + stride_cols = strides[2]; + } + + DimensionHandle batch_size_dim = c->Dim(input_shape, 0); + DimensionHandle in_rows_dim = c->Dim(input_shape, 1); + DimensionHandle in_cols_dim = c->Dim(input_shape, 2); + + DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0); + DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1); + DimensionHandle input_depth = c->Dim(filter_shape, 2); + DimensionHandle depth_multiplier = c->Dim(filter_shape, 3); + + // Check that the input depths are compatible. + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth)); + + DimensionHandle output_depth; + TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth)); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + // TODO(mrry,shlens): Raise an error if the stride would cause + // information in the input to be ignored. This will require a change + // in the kernel implementation. + DimensionHandle output_rows, output_cols; + + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols)); + + ShapeHandle output_shape; + if (data_format == "NCHW") { + output_shape = + c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols}); + } else { + output_shape = + c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); + } + c->set_output(0, output_shape); + return Status::OK(); +} + +Status AvgPoolShape(shape_inference::InferenceContext* c) { + string data_format_str; + TensorFormat data_format; + Status s = c->GetAttr("data_format", &data_format_str); + if (s.ok()) { + FormatFromString(data_format_str, &data_format); + } else { + data_format = FORMAT_NHWC; + } + + const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); + + TF_RETURN_IF_ERROR( + CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); + + std::vector strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + if (strides.size() != 4) { + return errors::InvalidArgument( + "AvgPool requires the stride attribute to contain 4 values, but got: ", + strides.size()); + } + + std::vector kernel_sizes; + TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); + if (kernel_sizes.size() != 4) { + return errors::InvalidArgument( + "AvgPool requires the ksize attribute to contain 4 values, but got: ", + kernel_sizes.size()); + } + + int32 stride_rows = GetTensorDim(strides, data_format, 'H'); + int32 stride_cols = GetTensorDim(strides, data_format, 'W'); + int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); + int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); + + constexpr int num_spatial_dims = 2; + DimensionHandle batch_size_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'N')); + DimensionHandle in_rows_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'H')); + DimensionHandle in_cols_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'W')); + DimensionHandle depth_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'C')); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + // TODO(mrry,shlens): Raise an error if the stride would cause + // information in the input to be ignored. This will require a change + // in the kernel implementation. + + DimensionHandle output_rows, output_cols; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); + + ShapeHandle output_shape; + TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, + {output_rows, output_cols}, depth_dim, + &output_shape, c)); + c->set_output(0, output_shape); + return Status::OK(); +} + +Status FusedBatchNormShape(shape_inference::InferenceContext* c) { + ShapeHandle x; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x)); + + bool is_training; + TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); + int number_inputs = (is_training) ? 3 : 5; + string data_format; + TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format)); + DimensionHandle channel_dim = + (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1); + + // covers scale, offset, and if is_training is false, mean, variance + for (int i = 1; i < number_inputs; ++i) { + ShapeHandle vec; + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); + TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); + } + + ShapeHandle y; + if (data_format == "NHWC") { + TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y)); + } else { + TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y)); + } + c->set_output(0, y); + ShapeHandle vector_shape = c->Vector(channel_dim); + c->set_output(1, vector_shape); + c->set_output(2, vector_shape); + c->set_output(3, vector_shape); + c->set_output(4, vector_shape); + return Status::OK(); +} + +Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { + ShapeHandle y_backprop; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop)); + ShapeHandle x; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x)); + + bool is_training; + string data_format; + TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); + TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format)); + DimensionHandle channel_dim = + (data_format == "NHWC") ? c->Dim(y_backprop, 3) : c->Dim(y_backprop, 1); + if (data_format == "NHWC") { + TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim)); + } else { + TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim)); + } + + // covers scale, mean (reserve_space_1), variance (reserve_space_2) + for (int i = 2; i < 5; ++i) { + ShapeHandle vec; + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); + TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); + } + + ShapeHandle x_backprop; + if (data_format == "NHWC") { + TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop)); + } else { + TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop)); + } + c->set_output(0, x_backprop); + c->set_output(1, c->Vector(channel_dim)); + c->set_output(2, c->Vector(channel_dim)); + // Set the correct shapes for reserve_spaces + // so that gradients can be performed when + // the op is in a symbolic condition. + if (is_training) { + c->set_output(3, c->Vector(0)); + c->set_output(4, c->Vector(0)); + } else { + c->set_output(3, c->Vector(channel_dim)); + c->set_output(4, c->Vector(channel_dim)); + } + return Status::OK(); +} + +Status MaxPoolShape(shape_inference::InferenceContext* c) { + string data_format_str; + TensorFormat data_format; + Status s = c->GetAttr("data_format", &data_format_str); + if (s.ok()) { + FormatFromString(data_format_str, &data_format); + } else { + data_format = FORMAT_NHWC; + } + + const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); + + TF_RETURN_IF_ERROR( + CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); + + std::vector strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + if (strides.size() != 4) { + return errors::InvalidArgument( + "MaxPool requires the stride attribute to contain 4 values, but got: ", + strides.size()); + } + + std::vector kernel_sizes; + TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); + if (kernel_sizes.size() != 4) { + return errors::InvalidArgument( + "MaxPool requires the ksize attribute to contain 4 values, but got: ", + kernel_sizes.size()); + } + + int32 stride_depth = GetTensorDim(strides, data_format, 'C'); + int32 stride_rows = GetTensorDim(strides, data_format, 'H'); + int32 stride_cols = GetTensorDim(strides, data_format, 'W'); + int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); + int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); + int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); + + constexpr int num_spatial_dims = 2; + DimensionHandle batch_size_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'N')); + DimensionHandle in_rows_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'H')); + DimensionHandle in_cols_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'W')); + DimensionHandle in_depth_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'C')); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + ShapeHandle output_shape; + DimensionHandle output_rows, output_cols, output_depth; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); + + TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, + {output_rows, output_cols}, + output_depth, &output_shape, c)); + + c->set_output(0, output_shape); + return Status::OK(); +} + +Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { + string data_format_str; + TensorFormat data_format; + Status s = c->GetAttr("data_format", &data_format_str); + if (s.ok()) { + FormatFromString(data_format_str, &data_format); + } else { + data_format = FORMAT_NHWC; + } + + const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); + + TF_RETURN_IF_ERROR( + CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); + + std::vector kernel_sizes; + std::vector strides; + + if (c->num_inputs() + 2 == num_inputs) { + TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); + + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + } else { + // Verify shape of ksize and strides input. + ShapeHandle size; + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused)); + + const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2); + if (kernel_sizes_tensor == nullptr) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + } + kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements()); + auto kernel_sizes_vec = kernel_sizes_tensor->flat(); + std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), + kernel_sizes.begin()); + + const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1); + if (strides_tensor == nullptr) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + } + strides.resize(strides_tensor->shape().num_elements()); + auto strides_vec = strides_tensor->flat(); + std::copy_n(&strides_vec(0), strides.size(), strides.begin()); + } + + if (strides.size() != 4) { + return errors::InvalidArgument( + "MaxPool requires the stride attribute to contain 4 values, but " + "got: ", + strides.size()); + } + if (kernel_sizes.size() != 4) { + return errors::InvalidArgument( + "MaxPool requires the ksize attribute to contain 4 values, but got: ", + kernel_sizes.size()); + } + + int32 stride_depth = GetTensorDim(strides, data_format, 'C'); + int32 stride_rows = GetTensorDim(strides, data_format, 'H'); + int32 stride_cols = GetTensorDim(strides, data_format, 'W'); + int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); + int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); + int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); + + constexpr int num_spatial_dims = 2; + DimensionHandle batch_size_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'N')); + DimensionHandle in_rows_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'H')); + DimensionHandle in_cols_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'W')); + DimensionHandle in_depth_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'C')); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + ShapeHandle output_shape; + DimensionHandle output_rows, output_cols, output_depth; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); + + TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, + {output_rows, output_cols}, + output_depth, &output_shape, c)); + + c->set_output(0, output_shape); + return Status::OK(); +} + +Status Pool3DShape(shape_inference::InferenceContext* c) { + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); + + string data_format; + Status s = c->GetAttr("data_format", &data_format); + + std::vector strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + if (strides.size() != 5) { + return errors::InvalidArgument( + "Pool3D ops require the stride attribute to contain 5 values, but " + "got: ", + strides.size()); + } + + std::vector kernel_sizes; + TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); + if (kernel_sizes.size() != 5) { + return errors::InvalidArgument( + "Pool3D requires the ksize attribute to contain 5 values, but got: ", + kernel_sizes.size()); + } + + int32 stride_planes, stride_rows, stride_cols; + int32 kernel_planes, kernel_rows, kernel_cols; + + if (s.ok() && data_format == "NCDHW") { + // Convert input_shape to NDHWC. + auto dim = [&](char dimension) { + return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension)); + }; + input_shape = + c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}}); + stride_planes = strides[2]; + stride_rows = strides[3]; + stride_cols = strides[4]; + kernel_planes = kernel_sizes[2]; + kernel_rows = kernel_sizes[3]; + kernel_cols = kernel_sizes[4]; + } else { + stride_planes = strides[1]; + stride_rows = strides[2]; + stride_cols = strides[3]; + kernel_planes = kernel_sizes[1]; + kernel_rows = kernel_sizes[2]; + kernel_cols = kernel_sizes[3]; + } + + DimensionHandle batch_size_dim = c->Dim(input_shape, 0); + DimensionHandle in_planes_dim = c->Dim(input_shape, 1); + DimensionHandle in_rows_dim = c->Dim(input_shape, 2); + DimensionHandle in_cols_dim = c->Dim(input_shape, 3); + DimensionHandle output_depth_dim = c->Dim(input_shape, 4); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + // TODO(mrry,shlens): Raise an error if the stride would cause + // information in the input to be ignored. This will require a change + // in the kernel implementation. + DimensionHandle output_planes, output_rows, output_cols; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( + c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); + + ShapeHandle output_shape; + if (data_format == "NCDHW") { + output_shape = c->MakeShape({batch_size_dim, output_depth_dim, + output_planes, output_rows, output_cols}); + } else { + output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows, + output_cols, output_depth_dim}); + } + + c->set_output(0, output_shape); + return Status::OK(); +} + +Status UnknownShape(shape_inference::InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->UnknownShape()); + } + return Status::OK(); +} + +template +Status ReductionShapeHelper(const Tensor* reduction_indices_t, + const int32 input_rank, + std::set& true_indices) { + auto reduction_indices = reduction_indices_t->flat(); + for (int i = 0; i < reduction_indices_t->NumElements(); ++i) { + const T reduction_index = reduction_indices(i); + if (reduction_index < -input_rank || reduction_index >= input_rank) { + return errors::InvalidArgument("Invalid reduction dimension ", + reduction_index, " for input with ", + input_rank, " dimensions."); + } + + auto wrapped_index = reduction_index; + if (wrapped_index < 0) { + wrapped_index += input_rank; + } + + true_indices.insert(wrapped_index); + } + return Status::OK(); +} + +Status ReductionShape(InferenceContext* c) { + ShapeHandle input = c->input(0); + + ShapeHandle indices; + // Older versions of TensorFlow accidentally allowed higher rank tensors like + // [[1,2]] or [[1],[2]] to represent axis=[1,2]. + if (c->graph_def_version() < 21) { + indices = c->input(1); + } else { + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices)); + } + + bool keep_dims; + TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims)); + + const Tensor* reduction_indices_t = c->input_tensor(1); + if (reduction_indices_t == nullptr || !c->RankKnown(input)) { + // If we do not have the reduction values at runtime, or the + // rank of the input, we don't know the output shape. + + if (keep_dims && c->RankKnown(input)) { + // output rank matches input input if . + c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); + return Status::OK(); + } else { + return shape_inference::UnknownShape(c); + } + } + + const int32 input_rank = c->Rank(input); + std::set true_indices; + if (reduction_indices_t->dtype() == DataType::DT_INT32) { + TF_RETURN_IF_ERROR(ReductionShapeHelper(reduction_indices_t, + input_rank, true_indices)); + } else if (reduction_indices_t->dtype() == DataType::DT_INT64) { + TF_RETURN_IF_ERROR(ReductionShapeHelper(reduction_indices_t, + input_rank, true_indices)); + } else { + return errors::InvalidArgument( + "reduction_indices can only be int32 or int64"); + } + + std::vector dims; + for (int i = 0; i < input_rank; ++i) { + if (true_indices.count(i) > 0) { + if (keep_dims) { + dims.emplace_back(c->MakeDim(1)); + } + } else { + dims.emplace_back(c->Dim(input, i)); + } + } + + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); +} + +Status ConcatShapeHelper(InferenceContext* c, int start_value_index, + int end_value_index, int dim_index) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused)); + const Tensor* concat_dim_t = c->input_tensor(dim_index); + if (concat_dim_t == nullptr) { + // Return an unknown shape with same rank as inputs, or an unknown rank + // if no input's rank is known. + + // Find rank. + int32 rank = InferenceContext::kUnknownRank; + for (int i = start_value_index; i < end_value_index; ++i) { + if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i)); + if (rank != InferenceContext::kUnknownRank) { + break; + } + } + if (rank == InferenceContext::kUnknownRank) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + } else if (rank == 0) { + return errors::InvalidArgument( + "Can't concatenate scalars (use tf.stack instead)"); + } else { + for (int i = start_value_index; i < end_value_index; ++i) { + // Check that all the inputs are of the correct rank. + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused)); + } + } + // Build result of different unknown dims. + std::vector dims; + dims.reserve(rank); + for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim()); + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + } + + // Merge all the non-concat dims, and sum the concat dim to make an output + // shape. + const int32 concat_dim = concat_dim_t->scalar()(); + + // Minimum required number of dimensions. + const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1; + + ShapeHandle output_before; + ShapeHandle output_after; + + ShapeHandle input = c->input(end_value_index - 1); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input)); + TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before)); + DimensionHandle output_middle = c->Dim(input, concat_dim); + if (concat_dim == -1) { + output_after = c->Scalar(); // no dimensions. + } else { + TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after)); + } + + for (int i = end_value_index - 2; i >= start_value_index; --i) { + ShapeHandle before; + ShapeHandle after; + input = c->input(i); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input)); + TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before)); + DimensionHandle middle = c->Dim(input, concat_dim); + if (concat_dim == -1) { + after = c->Scalar(); + } else { + TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after)); + } + + TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before)); + TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle)); + TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after)); + } + + ShapeHandle s; + TF_RETURN_IF_ERROR( + c->Concatenate(output_before, c->Vector(output_middle), &s)); + TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s)); + c->set_output(0, s); + return Status::OK(); +} + +Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) { + return ConcatShapeHelper(c, 1 /* start_value_index */, + 1 + num_inputs_to_concat /* end_value_index */, + 0 /* dim_index */); +} + +Status ConcatV2Shape(InferenceContext* c) { + return ConcatShapeHelper(c, 0 /* start_value_index */, + c->num_inputs() - 1 /* end_value_index */, + c->num_inputs() - 1 /* dim_index */); +} + +Status BroadcastBinaryOpShapeFn(InferenceContext* c) { + ShapeHandle shape_x = c->input(0); + ShapeHandle shape_y = c->input(1); + if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + } + const int32 rank_x = c->Rank(shape_x); + const int32 rank_y = c->Rank(shape_y); + const int32 rank_out = std::max(rank_x, rank_y); + + // To compute the broadcast dimensions, we zip together shape_x and shape_y + // and + // pad with 1 to make them the same length. + std::vector dims; + DimensionHandle dim_one; + if (rank_x != rank_y) dim_one = c->MakeDim(1); + for (int i = 0; i < rank_out; ++i) { + const auto dim_x = i < (rank_out - rank_x) + ? dim_one + : c->Dim(shape_x, i - (rank_out - rank_x)); + const bool dim_y_is_one = (i < (rank_out - rank_y)); + const auto dim_y = + dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y)); + if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) { + // One or both dimensions is unknown. + // + // - If either dimension is greater than 1, we assume that the program is + // correct, and the other dimension will be broadcast to match it. + // TODO(cwhipkey): For shape inference, if we eliminate the shape checks + // in C++ op code, we must still assert that the unknown dim is either 1 + // or the same as the known dim. + // - If either dimension is 1, the other dimension is the output. + if (c->Value(dim_x) > 1) { + dims.push_back(dim_x); + } else if (c->Value(dim_y) > 1) { + dims.push_back(dim_y); + } else if (c->Value(dim_x) == 1) { + dims.push_back(dim_y); + } else if (c->Value(dim_y) == 1) { + dims.push_back(dim_x); + } else if (dim_y.SameHandle(dim_x)) { + dims.push_back(dim_x); + } else { + dims.push_back(c->UnknownDim()); + } + } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) { + if (c->Value(dim_x) == 1 && !dim_y_is_one) { + // We will broadcast dim_x to dim_y. + dims.push_back(dim_y); + } else { + DCHECK_EQ(c->Value(dim_y), 1); + // We will broadcast dim_y to dim_x. + dims.push_back(dim_x); + } + } else { + DimensionHandle dim; + TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim)); + dims.push_back(dim); + } + } + + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); +} + +Status RandomShape(shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + c->set_output(0, out); + return Status::OK(); +} + +Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, + ShapeHandle values_shape, ShapeHandle shape_shape) { + // Validate ranks. + ShapeHandle unused_shape; + TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape)); + + // Number of elements in indices and values must match. + DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0); + if (c->ValueKnown(num_index_elements_dim)) { + DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0); + if (c->ValueKnown(num_values_elements_dim)) { + int64 num_index_elements = c->Value(num_index_elements_dim); + int64 num_values_elements = c->Value(num_values_elements_dim); + if (num_index_elements != num_values_elements) { + return errors::InvalidArgument("Number of elements in index (", + num_index_elements, ") and values (", + num_values_elements, ") do not match."); + } + } + } + + // Rank embedded in indices must match shape. + DimensionHandle index_rank_dim = c->Dim(indices_shape, 1); + if (c->ValueKnown(index_rank_dim)) { + DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0); + if (c->ValueKnown(shape_rank_dim)) { + int64 index_rank = c->Value(index_rank_dim); + int32 shape_rank = c->Value(shape_rank_dim); + if (index_rank != shape_rank) { + return errors::InvalidArgument("Index rank (", index_rank, + ") and shape rank (", shape_rank, + ") do not match."); + } + } + } + + return Status::OK(); +} + +Status ScatterNdUpdateShape(InferenceContext* c) { + ShapeHandle input_shape = c->input(0); + if (c->input_handle_shapes_and_types(0) != nullptr) { + input_shape = (*c->input_handle_shapes_and_types(0))[0].shape; + } + ShapeHandle indices_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); + ShapeHandle updates_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); + + if (c->Value(c->NumElements(input_shape)) == 0 && + (c->Value(c->NumElements(indices_shape)) > 0 || + c->Value(c->NumElements(updates_shape)) > 0)) { + return errors::InvalidArgument( + "Indices and updates specified for empty output shape"); + } + + if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { + const int64 num_outer_dims = c->Rank(indices_shape) - 1; + const DimensionHandle index_size = c->Dim(indices_shape, -1); + + // We can only do more validation if the last dimension of indices + // is a known value. + if (c->ValueKnown(index_size)) { + const int64 ix = c->Value(index_size); + ShapeHandle unused; + ShapeHandle prefix_indices; + TF_RETURN_IF_ERROR( + c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices)); + ShapeHandle prefix_updates; + TF_RETURN_IF_ERROR( + c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates)); + + Status s = c->Merge(prefix_indices, prefix_updates, &unused); + if (!s.ok()) { + return errors::InvalidArgument( + "The outer ", num_outer_dims, " dimensions of indices.shape=", + c->DebugString(indices_shape), " must match the outer ", + num_outer_dims, " dimensions of updates.shape=", + c->DebugString(updates_shape), ": ", s.error_message()); + } + + ShapeHandle input_suffix; + TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix)); + ShapeHandle suffix_updates; + TF_RETURN_IF_ERROR( + c->Subshape(updates_shape, num_outer_dims, &suffix_updates)); + s = c->Merge(input_suffix, suffix_updates, &unused); + if (!s.ok()) { + return errors::InvalidArgument( + "The inner ", c->Rank(input_shape) - ix, + " dimensions of input.shape=", c->DebugString(input_shape), + " must match the inner ", c->Rank(updates_shape) - num_outer_dims, + " dimensions of updates.shape=", c->DebugString(updates_shape), + ": ", s.error_message()); + } + } + } + + if (c->input_handle_shapes_and_types(0) == nullptr) { + c->set_output(0, input_shape); + } + return Status::OK(); +} + +Status ExplicitShape(InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + ShapeHandle output_shape; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape)); + c->set_output(0, output_shape); + return Status::OK(); +} + +} // namespace shape_inference + +} // namespace tensorflow diff --git a/common_shape_fns.h b/common_shape_fns.h new file mode 100644 index 0000000000000000000000000000000000000000..c0deb473a25cf19b99ae79903c1a2014b6e378f7 --- /dev/null +++ b/common_shape_fns.h @@ -0,0 +1,290 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_ + +#include + +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +// GetWindowedOutputSize(): Given an input tensor, kernel, stride and padding +// type, the function computes the output and padding dimensions. +// +// For example, ignoring batches or multiple features, a 1D convolution +// takes as input a 1D tensor of shape (H), and convolves it with a filter of +// shape (K). +// +// It also takes in a few additional parameters: +// +// Stride (S): the stride with which we apply the filters. This is the offset +// between locations where we apply the filters. A larger stride +// means that the output will be spatially smaller. +// +// Padding (P): the padding we apply to the input tensor along each +// dimension. This is usually used to make sure that the spatial dimensions +// do not shrink when we progress with convolutions. Two types of padding are +// often used: +// SAME: the pad value is computed so that the output will have size H/S. +// VALID: no padding is carried out. +// The padded area is zero-filled. +// +// The output dimensions for convolution and many other operations, when given +// all the parameters above, are as follows: +// - When Padding = SAME: the output size is (H'), where +// H' = ceil(float(H) / float(S)) +// where ceil is the ceiling function. The number of padded cells +// is computed as: +// Pc = ((H' - 1) * S + K - H) / 2 +// When the stride is 1, the expression simplifies to +// H' = H, Pc = (K-1)/2. +// This is where SAME comes from - the output has the same size as the input +// has. +// +// - When Padding = VALID: the output size is computed as +// H' = ceil(float(H - K + 1) / float(S)) +// and the number of padded cells is always zero. +// When the stride is 1, the expression simplifies to +// H' = H-K+1. +// +// For convolution, mathematically, the output value at location (r') +// is the inner product of two vectors: the chunk of input at +// ((r'*S-Pr) : (r'*S-Pr+K)), +// and the filter. +// +// For 2D and 3D convolutions, the spatial dimensions are orthogonal, so the +// size and padding of each spatial dimension can be computed by calling +// GetWindowedOutputSize separately for each dimension. +// +Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride, + Padding padding_type, int64* output_size, + int64* padding_size); + +// The V2 version computes the same outputs with arbitrary dilation_rate. +// The output dimensions are computed as follows: +// - When adding dilation_rate (D), we compute an effective filter size (K'): +// K' = (K - 1) * D + 1 +// - When Padding = SAME: the output size is (H'), where +// H' = ceil(float(H) / float(S)) +// where ceil is the ceiling function. The number of padded cells +// is computed as: +// Pc = ((H' - 1) * S + K' - H) / 2 +// When the stride is 1, the expression simplifies to +// H' = H, Pc = (K'-1)/2. +// This is where SAME comes from - the output has the same size as the input +// has. +// +// - When Padding = VALID: the output size is computed as +// H' = ceil(float(H - K' + 1) / float(S)) +// and the number of padded cells is always zero. +// When the stride is 1, the expression simplifies to +// H' = H-K'+1. +// +// TODO(b/67112639): Merge V2 versions and the original versions eventually. +Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size, + int64 dilation_rate, int64 stride, + Padding padding_type, int64* output_size, + int64* padding_size); + +// Returns the same output dimensions as in GetWindowedOutputSize, but returns +// verbose padding dimensions (before/after). Any excess padding +// (caused by an odd padding size value) is added to the 'padding_after' +// dimension. +Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size, + int64 stride, Padding padding_type, + int64* output_size, int64* padding_before, + int64* padding_after); + +// The V2 version computes the same outputs with arbitrary dilation_rate. For +// detailed equations, refer to the comments for GetWindowedOutputSizeV2(). +Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size, + int64 dilation_rate, int64 stride, + Padding padding_type, int64* output_size, + int64* padding_before, + int64* padding_after); + +// Given an input tensor, kernel, stride and padding type, populates the 3D size +// of the output tensor and padding to be applied to the input tensor at the +// lower end of every dimension. Use for 3D convolutions, where the input data +// is padded with zeros, as well as for 3D avg/max pooling, where the input data +// is padded with invalid values that are not considered for pooling. +Status Get3dOutputSize(const std::array& input, + const std::array& window, + const std::array& strides, + Padding padding_type, std::array* output_ptr, + std::array* padding_ptr); + +// The V2 version computes the same outputs with arbitrary dilation_rate. For +// detailed equations, refer to the comments for GetWindowedOutputSizeV2(). +Status Get3dOutputSizeV2(const std::array& input, + const std::array& window, + const std::array& dilations, + const std::array& strides, + Padding padding_type, std::array* output_ptr, + std::array* padding_ptr); + +namespace shape_inference { + +// Like GetWindowedOutputSize, but deals with DimensionHandles. +Status GetWindowedOutputSizeFromDims(InferenceContext* c, + DimensionHandle input_size, + DimensionOrConstant filter_size, + int64 stride, Padding padding_type, + DimensionHandle* output_size); + +// The V2 version computes the same outputs with arbitrary dilation_rate. For +// detailed equations, refer to the comments for GetWindowedOutputSizeV2(). +Status GetWindowedOutputSizeFromDimsV2(InferenceContext* c, + DimensionHandle input_size, + DimensionOrConstant filter_size, + int64 dilation_rate, int64 stride, + Padding padding_type, + DimensionHandle* output_size); + +// Transfers shape of input(0) to output(0). +Status UnchangedShape(shape_inference::InferenceContext* c); + +// Transfers shape of input(0) to output(0), after asserting its rank is . +inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c, + int32 rank) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out)); + c->set_output(0, out); + return Status::OK(); +} + +// Transfers shape of input(0) to output(0), after asserting its rank >= . +inline Status UnchangedShapeWithRankAtLeast( + shape_inference::InferenceContext* c, int32 rank) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out)); + c->set_output(0, out); + return Status::OK(); +} + +// Transfers shape of input(0) to output(0), after asserting its rank <= . +inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c, + int32 rank) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out)); + c->set_output(0, out); + return Status::OK(); +} + +// Shape function for use with ops no outputs. +inline Status NoOutputs(shape_inference::InferenceContext* c) { + return Status::OK(); +} + +// Shape function for ops that output a single scalar value. +inline Status ScalarShape(shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); +} + +// Shape function for binary ops where both inputs and the output match. +inline Status MergeBothInputsShapeFn(InferenceContext* c) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out)); + c->set_output(0, out); + return Status::OK(); +} + +// Returns a new shape with the specified dims arranged in the specified +// format. The returned value is owned by this context. +// Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth. +Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, + const std::vector& spatial, + DimensionOrConstant C, ShapeHandle* out, + shape_inference::InferenceContext* context); + +// Shape function for MatMul-like operations. +Status MatMulShape(shape_inference::InferenceContext* c); + +// Shape function for BiasAdd-like operations. +Status BiasAddShape(shape_inference::InferenceContext* c); + +// Shape function for BiasAddGrad-like operations. +Status BiasAddGradShape(shape_inference::InferenceContext* c); + +// Shape function for Conv2D-like operations. +Status Conv2DShape(shape_inference::InferenceContext* c); + +// Shape function for Conv3D-like operations. +Status Conv3DShape(shape_inference::InferenceContext* c); + +// Shape function for DepthwiseConv2D-like operations. +Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c); + +// Shape function for AvgPool-like operations. +Status AvgPoolShape(shape_inference::InferenceContext* c); + +// Shape function for FusedBatchNorm and FusedBatchNormV2 operations. +Status FusedBatchNormShape(shape_inference::InferenceContext* c); + +// Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations. +Status FusedBatchNormGradShape(shape_inference::InferenceContext* c); + +// Shape function for MaxPool-like operations. +Status MaxPoolShape(shape_inference::InferenceContext* c); + +// Shape function for MaxPoolV2-like operations. +Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs); + +// Shape function for 3D Pooling operations. +Status Pool3DShape(shape_inference::InferenceContext* c); + +// Shape function for use with ops whose output shapes are unknown. +Status UnknownShape(shape_inference::InferenceContext* c); + +// Shape function for reduction operations. +Status ReductionShape(shape_inference::InferenceContext* c); + +// Shape function for concat operations. +// is the number of inputs to concatenate and are taken +// from inputs +// [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input. +Status ConcatShape(shape_inference::InferenceContext* c, + int num_inputs_to_concat); + +// Shape function for concat operations. +Status ConcatV2Shape(shape_inference::InferenceContext* c); + +// Shape function for binary operators that broadcast their inputs. +// Tested by ops/math_ops_test.cc. +Status BroadcastBinaryOpShapeFn(InferenceContext* c); + +// Shape function for random operations. +Status RandomShape(shape_inference::InferenceContext* c); + +// Validates the 3 component tensors of a sparse tensor have the proper +// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. +Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, + ShapeHandle values_shape, ShapeHandle shape_shape); + +// Shape function for ScatterNd update/add/sub/... operations. +Status ScatterNdUpdateShape(InferenceContext* c); + +// Shape function for ops with an explicit "shape" attribute. +Status ExplicitShape(InferenceContext* c); + +} // namespace shape_inference + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_ diff --git a/common_shape_fns_test.cc b/common_shape_fns_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f3e5ad45731750bfd73181c41cd029f23aab55f --- /dev/null +++ b/common_shape_fns_test.cc @@ -0,0 +1,1131 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace shape_inference { + +namespace { + +PartialTensorShape S(std::initializer_list dims) { + return PartialTensorShape(dims); +} + +PartialTensorShape Unknown() { return PartialTensorShape(); } + +OpDef MakeOpDef(int num_inputs, int num_outputs) { + OpRegistrationData op_reg_data; + OpDefBuilder b("dummy"); + for (int i = 0; i < num_inputs; ++i) { + b.Input(strings::StrCat("i", i, ": float")); + } + for (int i = 0; i < num_outputs; ++i) { + b.Output(strings::StrCat("o", i, ": float")); + } + CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok()); + return op_reg_data.op_def; +} + +} // namespace + +TEST(CommonShapeFnsTest, NoOutputShapeTest) { + OpRegistrationData op_reg_data; + TF_CHECK_OK(OpDefBuilder("Assert") + .Input("condition: bool") + .Input("data: float") + .Finalize(&op_reg_data)); + OpDef op_def = op_reg_data.op_def; + + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("test", "Assert") + .Input("condition", 0, DT_BOOL) + .Input({{"data", 0, DT_FLOAT}}) + .Finalize(&def)); + + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({}), S({10})}, {}, + {}, {}); + TF_EXPECT_OK(NoOutputs(&c)); + EXPECT_EQ(0, c.num_outputs()); +} + +TEST(CommonShapeFnsTest, ScalarShapeTest) { + OpRegistrationData op_reg_data; + TF_CHECK_OK(OpDefBuilder("L2Loss") + .Input("t: float") + .Output("t: float") + .Finalize(&op_reg_data)); + OpDef op_def = op_reg_data.op_def; + + NodeDef def; + TF_CHECK_OK( + NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def)); + + { + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {}); + TF_EXPECT_OK(ScalarShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ(0, c.Rank(output)); + } + + { + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({1, 23, 4, 4, 2})}, {}, {}, {}); + TF_EXPECT_OK(ScalarShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ(0, c.Rank(output)); + } +} + +TEST(CommonShapeFnsTest, MatMulShapeTest) { + OpRegistrationData op_reg_data; + TF_CHECK_OK(OpDefBuilder("MatMul") + .Input("a: float") + .Input("b: float") + .Output("c: float") + .Attr("transpose_a:bool=false") + .Attr("transpose_b:bool=false") + .Finalize(&op_reg_data)); + OpDef op_def = op_reg_data.op_def; + + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("test", "MatMul") + .Input("a", 0, DT_FLOAT) + .Input("b", 0, DT_FLOAT) + .Attr("transpose_a", false) + .Attr("transpose_b", false) + .Finalize(&def)); + + { + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 3}), S({3, 4})}, {}, {}, {}); + TF_EXPECT_OK(MatMulShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ(2, c.Value(c.Dim(output, 0))); + EXPECT_EQ(4, c.Value(c.Dim(output, 1))); + } + + { + // Unknown inner dimension for one + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, -1}), S({3, 4})}, {}, {}, {}); + TF_EXPECT_OK(MatMulShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ(2, c.Value(c.Dim(output, 0))); + EXPECT_EQ(4, c.Value(c.Dim(output, 1))); + } + + { + // Invalid rank. + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2}), S({3, 4})}, + {}, {}, {}); + auto s = MatMulShape(&c); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE( + StringPiece(s.ToString()) + .contains("Invalid argument: Shape must be rank 2 but is rank 1")); + } + + { + // Unknown outer dimension + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 3}), S({3, -1})}, {}, {}, {}); + TF_EXPECT_OK(MatMulShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ(2, c.Value(c.Dim(output, 0))); + EXPECT_FALSE(c.ValueKnown(c.Dim(output, 1))); + } + + { + // Inner shapes not compatible + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 5}), S({3, 4})}, {}, {}, {}); + auto s = MatMulShape(&c); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE( + StringPiece(s.ToString()) + .contains( + "Invalid argument: Dimensions must be equal, but are 5 and 3")); + } + + { + // Inner shapes not compatible + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {}); + auto s = MatMulShape(&c); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE( + StringPiece(s.ToString()) + .contains("Invalid argument: Shape must be rank 2 but is rank 3")); + } + + { + // transpose_a + TF_CHECK_OK(NodeDefBuilder("test", "MatMul") + .Input("a", 0, DT_FLOAT) + .Input("b", 0, DT_FLOAT) + .Attr("transpose_a", true) + .Attr("transpose_b", false) + .Attr("type", DT_FLOAT) + .Finalize(&def)); + + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({3, 2}), S({3, 4})}, {}, {}, {}); + auto s = MatMulShape(&c); + ShapeHandle output = c.output(0); + EXPECT_EQ(2, c.Value(c.Dim(output, 0))); + EXPECT_EQ(4, c.Value(c.Dim(output, 1))); + } + + { + // transpose_b + TF_CHECK_OK(NodeDefBuilder("test", "MatMul") + .Input("a", 0, DT_FLOAT) + .Input("b", 0, DT_FLOAT) + .Attr("transpose_a", false) + .Attr("transpose_b", true) + .Attr("type", DT_FLOAT) + .Finalize(&def)); + + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 3}), S({4, 3})}, {}, {}, {}); + auto s = MatMulShape(&c); + ShapeHandle output = c.output(0); + EXPECT_EQ(2, c.Value(c.Dim(output, 0))); + EXPECT_EQ(4, c.Value(c.Dim(output, 1))); + } +} + +TEST(CommonShapeFnsTest, BiasAddShapeTest) { + OpRegistrationData op_reg_data; + TF_CHECK_OK(OpDefBuilder("BiasAdd") + .Input("a: float") + .Input("b: float") + .Output("c: float") + .Finalize(&op_reg_data)); + + OpDef op_def = op_reg_data.op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd") + .Input("a", 0, DT_FLOAT) + .Input("b", 0, DT_FLOAT) + .Finalize(&def)); + + { + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 10}), S({10})}, {}, {}, {}); + TF_EXPECT_OK(BiasAddShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ(2, c.Value(c.Dim(output, 0))); + EXPECT_EQ(10, c.Value(c.Dim(output, 1))); + } + + { + // Unknown ranks. + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {Unknown(), Unknown()}, {}, {}, {}); + TF_EXPECT_OK(BiasAddShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_FALSE(c.RankKnown(output)); + } + + { + // Rank > 2 + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {}); + TF_EXPECT_OK(BiasAddShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output)); + } + + { + // NCHW format + TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd") + .Input("a", 0, DT_FLOAT) + .Input("b", 0, DT_FLOAT) + .Attr("data_format", "NCHW") + .Finalize(&def)); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 3, 4, 5}), S({3})}, {}, {}, {}); + TF_EXPECT_OK(BiasAddShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ("[2,3,4,5]", c.DebugString(output)); + } + + { + // NCHW format with high input rank + TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd") + .Input("a", 0, DT_FLOAT) + .Input("b", 0, DT_FLOAT) + .Attr("data_format", "NCHW") + .Finalize(&def)); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {}); + TF_EXPECT_OK(BiasAddShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output)); + } + + { + // NCHW format with input rank 3 + TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd") + .Input("a", 0, DT_FLOAT) + .Input("b", 0, DT_FLOAT) + .Attr("data_format", "NCHW") + .Finalize(&def)); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({10, 11, 12}), S({10})}, {}, {}, {}); + TF_EXPECT_OK(BiasAddShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ("[10,11,12]", c.DebugString(output)); + } + + { + // Input rank not high enough + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3}), S({3})}, {}, + {}, {}); + EXPECT_FALSE(BiasAddShape(&c).ok()); + } + + { + // NCHW rank not high enough + TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd") + .Input("a", 0, DT_FLOAT) + .Input("b", 0, DT_FLOAT) + .Attr("data_format", "NCHW") + .Finalize(&def)); + // NCHW format + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3}), S({3})}, + {}, {}, {}); + EXPECT_FALSE(BiasAddShape(&c).ok()); + } +} + +TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { + OpRegistrationData op_reg_data; + TF_CHECK_OK(OpDefBuilder("BiasAddGrad") + .Input("a: float") + .Output("b: float") + .Finalize(&op_reg_data)); + + OpDef op_def = op_reg_data.op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad") + .Input("a", 0, DT_FLOAT) + .Finalize(&def)); + + { + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 10})}, {}, {}, + {}); + TF_EXPECT_OK(BiasAddGradShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ(10, c.Value(c.Dim(output, 0))); + } + + { + // Rank > 2 + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({5, 7, 2, 10})}, + {}, {}, {}); + TF_EXPECT_OK(BiasAddGradShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ(10, c.Value(c.Dim(output, 0))); + } + + { + // NCHW format + TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad") + .Input("a", 0, DT_FLOAT) + .Attr("data_format", "NCHW") + .Finalize(&def)); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3, 4, 5})}, + {}, {}, {}); + TF_EXPECT_OK(BiasAddGradShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ(3, c.Value(c.Dim(output, 0))); + } + + { + // NCHW format with high input rank + TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad") + .Input("a", 0, DT_FLOAT) + .Attr("data_format", "NCHW") + .Finalize(&def)); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {}); + TF_EXPECT_OK(BiasAddGradShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ(3, c.Value(c.Dim(output, 0))); + } + + { + // NCHW format with input rank 3 + TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad") + .Input("a", 0, DT_FLOAT) + .Attr("data_format", "NCHW") + .Finalize(&def)); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({10, 11, 12})}, + {}, {}, {}); + TF_EXPECT_OK(BiasAddGradShape(&c)); + ShapeHandle output = c.output(0); + EXPECT_EQ(10, c.Value(c.Dim(output, 0))); + } + + { + // Input rank not high enough + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {}, + {}); + EXPECT_FALSE(BiasAddGradShape(&c).ok()); + } + + { + // NCHW rank not high enough + TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad") + .Input("a", 0, DT_FLOAT) + .Attr("data_format", "NCHW") + .Finalize(&def)); + // NCHW format + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3})}, {}, {}, + {}); + EXPECT_FALSE(BiasAddGradShape(&c).ok()); + } +} + +TEST(CommonShapeFnsTest, Conv2DShapeTest) { + ShapeInferenceTestOp op("Conv2D"); + auto set_op = [&op](const std::vector& strides, const string& padding, + const string& data_format, const string& filter_format) { + TF_CHECK_OK(NodeDefBuilder("test", "Conv2D") + .Input("input", 0, DT_FLOAT) + .Input("filter", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("padding", padding) + .Attr("data_format", data_format) + .Attr("filter_format", filter_format) + .Finalize(&op.node_def)); + }; + + // Invalid rank for input + INFER_ERROR("must be rank 4", op, "[4,4];[2,1,1,1]"); + // Invalid rank for filter + INFER_ERROR("must be rank 4", op, "[1,4,4,1];[2,1,1]"); + + // Invalid value for strides + set_op({{1, 1, 0, 1}}, "VALID", "NHWC", "HWIO"); + INFER_ERROR("must be > 0", op, "[1,2,2,1];[1,1,1,1]"); + + // 1x1 filter + set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO"); + INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]"); + + // 2x2 filter + set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO"); + INFER_OK(op, "[1,2,2,1];[2,2,1,1]", "[d0_0,1,1,d1_3]"); + + // 3x3 input, 1x1 filter, 2x2 stride + set_op({{1, 2, 2, 1}}, "VALID", "NHWC", "HWIO"); + INFER_OK(op, "[1,3,3,1];[1,1,1,1]", "[d0_0,2,2,d1_3]"); + + // 3x3 input, 1x1 filter, 2x1 stride + set_op({{1, 2, 1, 1}}, "VALID", "NHWC", "HWIO"); + INFER_OK(op, "[1,3,3,1];[1,1,1,1]", "[d0_0,2,3,d1_3]"); + + // 4x4 input, 2x1 filter, 1x2 stride + set_op({{1, 1, 2, 1}}, "VALID", "NHWC", "HWIO"); + INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]"); + + // Unknown dims in the critical fields lead to partial inference. + INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]"); + INFER_OK(op, "[1,?,4,1];[2,1,1,1]", "[d0_0,?,2,d1_3]"); + INFER_OK(op, "[1,4,?,1];[2,1,1,1]", "[d0_0,3,?,d1_3]"); + INFER_OK(op, "[1,4,4,?];[2,1,1,1]", "[d0_0,3,2,d1_3]"); + INFER_OK(op, "[1,4,4,1];[?,1,1,1]", "[d0_0,?,2,d1_3]"); + INFER_OK(op, "[1,4,4,1];[2,?,1,1]", "[d0_0,3,?,d1_3]"); + + // input depths must match. + INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op, + "[1,2,2,10];[1,1,10000,20]"); + + // Tests for NCHW + // 1x1 filter + set_op({{1, 1, 1, 1}}, "VALID", "NCHW", "HWIO"); + INFER_OK(op, "[1,1,2,2];[1,1,1,1]", "[d0_0,d1_3,2,2]"); + + // 2x2 filter + set_op({{1, 1, 1, 1}}, "VALID", "NCHW", "HWIO"); + INFER_OK(op, "[1,1,2,2];[2,2,1,1]", "[d0_0,d1_3,1,1]"); + + // 3x3 input, 1x1 filter, 2x2 stride + set_op({{1, 1, 2, 2}}, "VALID", "NCHW", "HWIO"); + INFER_OK(op, "[1,1,3,3];[1,1,1,1]", "[d0_0,d1_3,2,2]"); + + // 3x3 input, 1x1 filter, 2x1 stride + set_op({{1, 1, 2, 1}}, "VALID", "NCHW", "HWIO"); + INFER_OK(op, "[1,1,3,3];[1,1,1,1]", "[d0_0,d1_3,2,3]"); + + // 4x4 input, 2x1 filter, 1x2 stride + set_op({{1, 1, 1, 2}}, "VALID", "NCHW", "HWIO"); + INFER_OK(op, "[1,1,4,4];[2,1,1,1]", "[d0_0,d1_3,3,2]"); + + // Tests for NCHW_VECT_C + // 1x1 filter + set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); + INFER_OK(op, "[1,1,2,2,4];[4,1,1,1,4]", "[d0_0,1,2,2,4]"); + + // 2x2 filter + set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); + INFER_OK(op, "[1,1,2,2,4];[4,1,2,2,4]", "[d0_0,1,1,1,4]"); + + // 3x3 input, 1x1 filter, 2x2 stride + set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); + INFER_OK(op, "[1,1,3,3,4];[8,1,1,1,4]", "[d0_0,2,2,2,4]"); + + // 3x3 input, 1x1 filter, 2x1 stride + set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); + INFER_OK(op, "[1,1,3,3,4];[4,1,1,1,4]", "[d0_0,1,2,3,4]"); + + // 4x4 input, 2x1 filter, 1x2 stride + set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); + INFER_OK(op, "[1,1,4,4,4];[4,1,2,1,4]", "[d0_0,1,3,2,4]"); + + // Some tests for "SAME" padding + + // 4x4 input, 1x1 filter, 1x1 stride + set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO"); + INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); + + // 3x3 input, 2x2 filter, 1x1 stride + set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO"); + INFER_OK(op, "[1,3,3,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); + + // 4x4 input, 2x2 filter, 2x2 stride + set_op({{1, 2, 2, 1}}, "SAME", "NHWC", "HWIO"); + INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,2,2,d1_3]"); + + // 4x4 input, 2x2 filter, 1x1 stride + set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO"); + INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); + + // With stride 1x1 and SAME, unknown dims don't matter - filter dims except + // for output channels are ignored for output, so all inputs are carried + // through to output. + set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO"); + INFER_OK(op, "[1,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); + INFER_OK(op, "[1,?,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); + INFER_OK(op, "[1,4,?,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); + INFER_OK(op, "[1,4,4,?];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); + INFER_OK(op, "[?,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); + + // With stride != 1, the input HW dims are divided to produce output dims. + set_op({{1, 2, 2, 1}}, "SAME", "NHWC", "HWIO"); + INFER_OK(op, "[?,4,4,1];[?,?,?,?]", "[d0_0,2,2,d1_3]"); + INFER_OK(op, "[1,?,4,1];[?,?,?,?]", "[d0_0,?,2,d1_3]"); + INFER_OK(op, "[1,4,?,1];[?,?,?,?]", "[d0_0,2,?,d1_3]"); + INFER_OK(op, "[1,4,4,?];[?,?,?,?]", "[d0_0,2,2,d1_3]"); +} + +TEST(CommonShapeFnsTest, Conv2DDilatedShapeTest) { + ShapeInferenceTestOp op("Conv2D"); + auto set_op = [&op](const std::vector& dilations, + const std::vector& strides, const string& padding, + const string& data_format) { + TF_CHECK_OK(NodeDefBuilder("test", "Conv2D") + .Input("input", 0, DT_FLOAT) + .Input("filter", 0, DT_FLOAT) + .Attr("dilations", dilations) + .Attr("strides", strides) + .Attr("padding", padding) + .Attr("data_format", data_format) + .Finalize(&op.node_def)); + }; + + // Invalid rank for dilation + set_op({{1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC"); + INFER_ERROR("contain 4 values", op, "[1,2,2,1];[1,1,1,1]"); + + // Invalid value for dilation + set_op({{1, 0, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC"); + INFER_ERROR("must be >= 1", op, "[1,2,2,1];[1,1,1,1]"); + + // Tests for NHWC + // 1x1 filter, 2x1 dilations, 1x1 strides + set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC"); + INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]"); + + // 1x1 filter, 2x1 dilations, 2x1 strides + set_op({{1, 2, 1, 1}}, {{1, 2, 1, 1}}, "VALID", "NHWC"); + INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,2,4,d1_3]"); + + // 1x1 filter, 2x1 dilations, 2x2 strides + set_op({{1, 2, 1, 1}}, {{1, 2, 2, 1}}, "VALID", "NHWC"); + INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,2,2,d1_3]"); + + // 3x3 filter, 2x1 dilations, 1x1 strides + set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC"); + INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,1,3,d1_3]"); + + // 3x3 filter, 2x1 dilations, 2x1 strides + set_op({{1, 2, 1, 1}}, {{1, 2, 1, 1}}, "VALID", "NHWC"); + INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,1,3,d1_3]"); + + // 3x3 filter, 1x2 dilations, 2x2 strides + set_op({{1, 1, 2, 1}}, {{1, 2, 2, 1}}, "VALID", "NHWC"); + INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,2,1,d1_3]"); + + // Tests for NCHW + // 1x1 filter, 2x1 dilations, 1x1 strides + set_op({{1, 1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,2,2];[1,1,1,1]", "[d0_0,d1_3,2,2]"); + + // 1x1 filter, 2x1 dilations, 2x1 strides + set_op({{1, 1, 2, 1}}, {{1, 1, 2, 1}}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,4,4];[1,1,1,1]", "[d0_0,d1_3,2,4]"); + + // 1x1 filter, 2x1 dilations, 2x2 strides + set_op({{1, 1, 2, 1}}, {{1, 1, 2, 2}}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,4,4];[1,1,1,1]", "[d0_0,d1_3,2,2]"); + + // 3x3 filter, 2x1 dilations, 1x1 strides + set_op({{1, 1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,1,3]"); + + // 3x3 filter, 2x1 dilations, 2x1 strides + set_op({{1, 1, 2, 1}}, {{1, 1, 2, 1}}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,1,3]"); + + // 3x3 filter, 1x2 dilations, 2x2 strides + set_op({{1, 1, 1, 2}}, {{1, 1, 2, 2}}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,2,1]"); + + // Some tests for "SAME" padding + + // 4x4 input, 1x1 filter, 2x1 dilations, 1x1 stride + set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC"); + INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); + + // 3x3 input, 2x2 filter, 2x2 dilations, 1x1 stride + set_op({{1, 2, 2, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC"); + INFER_OK(op, "[1,3,3,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); + + // 4x4 input, 2x2 filter, 1x2 dilations, 2x2 stride + set_op({{1, 1, 2, 1}}, {{1, 2, 2, 1}}, "SAME", "NHWC"); + INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,2,2,d1_3]"); + + // 4x4 input, 2x2 filter, 2x2 dilations, 1x1 stride + set_op({{1, 2, 2, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC"); + INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); +} + +TEST(CommonShapeFnsTest, Conv3DShapeTest) { + ShapeInferenceTestOp op("Conv3D"); + auto set_op = [&op](const std::vector& strides, + const string& padding) { + TF_CHECK_OK(NodeDefBuilder("test", "Conv3D") + .Input("input", 0, DT_FLOAT) + .Input("filter", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("padding", padding) + .Finalize(&op.node_def)); + }; + + // 1x1x1 filter + set_op({{1, 1, 1, 1, 1}}, "VALID"); + INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]"); + + // Invalid rank for input + INFER_ERROR("must be rank 5", op, "[4,4];[2,1,1,1]"); + // Invalid rank for filter + INFER_ERROR("must be rank 5", op, "[1,4,4,1];[2,1,1]"); + + // unknown dims in the critical fields give partial inference. + INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]"); + INFER_OK(op, "[1,?,2,2,1];[1,1,1,1,1]", "[d0_0,?,2,2,d1_4]"); + INFER_OK(op, "[1,2,?,2,1];[1,1,1,1,1]", "[d0_0,2,?,2,d1_4]"); + INFER_OK(op, "[1,2,2,?,1];[1,1,1,1,1]", "[d0_0,2,2,?,d1_4]"); + INFER_OK(op, "[1,2,2,2,1];[?,1,1,1,1]", "[d0_0,?,2,2,d1_4]"); + INFER_OK(op, "[1,2,2,2,1];[1,?,1,1,1]", "[d0_0,2,?,2,d1_4]"); + INFER_OK(op, "[1,2,2,2,1];[1,1,?,1,1]", "[d0_0,2,2,?,d1_4]"); + INFER_OK(op, "[1,2,2,2,1];[1,1,1,?,1]", "[d0_0,2,2,2,d1_4]"); + INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,?]", "[d0_0,2,2,2,d1_4]"); + + // input depths must match. + INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op, + "[1,2,2,2,10];[1,1,1,10000,20]"); + + // 2x2x2 filter + set_op({{1, 1, 1, 1, 1}}, "VALID"); + INFER_OK(op, "[1,2,2,2,1];[2,2,2,1,1]", "[d0_0,1,1,1,d1_4]"); + + // 3x3 input, 1x1 filter, 2x2 stride + set_op({{1, 2, 2, 2, 1}}, "VALID"); + INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]"); + + // 3x3 input, 1x1 filter, 2x1x1 stride + set_op({{1, 2, 1, 1, 1}}, "VALID"); + INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,3,3,d1_4]"); + + // 4x4 input, 2x2 filter, 1x1 stride + set_op({{1, 1, 1, 1, 1}}, "SAME"); + INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + + // with SAME, filter doesn't matter except for last dim. + set_op({{1, 1, 1, 1, 1}}, "SAME"); + INFER_OK(op, "[?,4,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + INFER_OK(op, "[1,?,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + INFER_OK(op, "[1,4,?,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + INFER_OK(op, "[1,4,4,?,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + INFER_OK(op, "[1,4,4,4,?];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + INFER_OK(op, "[1,4,4,4,1];[?,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + INFER_OK(op, "[1,4,4,4,1];[2,?,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + INFER_OK(op, "[1,4,4,4,1];[2,2,?,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + INFER_OK(op, "[1,4,4,4,1];[2,2,2,?,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,?]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + + // with SAME, and stride != 1, division happens to produce output. + set_op({{1, 2, 3, 4, 1}}, "SAME"); + INFER_OK(op, "[1,4,9,4,1];[2,2,2,1,1]", "[d0_0,2,3,1,d1_4]"); + INFER_OK(op, "[?,4,9,4,1];[2,2,2,1,1]", "[d0_0,2,3,1,d1_4]"); + INFER_OK(op, "[1,?,9,4,1];[2,2,2,1,1]", "[d0_0,?,3,1,d1_4]"); + INFER_OK(op, "[1,4,?,4,1];[2,2,2,1,1]", "[d0_0,2,?,1,d1_4]"); + INFER_OK(op, "[1,4,9,?,1];[2,2,2,1,1]", "[d0_0,2,3,?,d1_4]"); + INFER_OK(op, "[1,4,9,4,?];[2,2,2,1,1]", "[d0_0,2,3,1,d1_4]"); + INFER_OK(op, "[1,4,9,4,1];[?,2,2,1,1]", "[d0_0,2,3,1,d1_4]"); + INFER_OK(op, "[1,4,9,4,1];[2,?,2,1,1]", "[d0_0,2,3,1,d1_4]"); + INFER_OK(op, "[1,4,9,4,1];[2,2,?,1,1]", "[d0_0,2,3,1,d1_4]"); + INFER_OK(op, "[1,4,9,4,1];[2,2,2,?,1]", "[d0_0,2,3,1,d1_4]"); + INFER_OK(op, "[1,4,9,4,1];[2,2,2,1,?]", "[d0_0,2,3,1,d1_4]"); +} + +TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) { + ShapeInferenceTestOp op("DepthwiseConv2dNative"); + std::vector strides = {{1, 1, 1, 1}}; + TF_CHECK_OK(NodeDefBuilder("test", "DepthwiseConv2dNative") + .Input("input", 0, DT_FLOAT) + .Input("filter", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("padding", "VALID") + .Attr("data_format", "NHWC") + .Finalize(&op.node_def)); + + // Most of DepthwiseConv2D is implicitly tested by Conv2D, so + // we test only the very-specific differences here. + + // 1x1 filter, depth multiplication + INFER_OK(op, "[1,2,2,3];[1,1,3,4]", "[d0_0,2,2,12]"); + + // Input depths not compatible + INFER_ERROR("Dimensions must be equal, but are 3 and 12", op, + "[1,2,2,3];[1,1,12,4]"); + + // No unknown dims in the critical fields. + INFER_OK(op, "[1,2,2,3];[1,1,3,4]", "[d0_0,2,2,12]"); + INFER_OK(op, "[1,?,2,3];[1,1,3,4]", "[d0_0,?,2,12]"); + INFER_OK(op, "[1,2,?,3];[1,1,3,4]", "[d0_0,2,?,12]"); + INFER_OK(op, "[1,2,2,3];[?,1,3,4]", "[d0_0,?,2,12]"); + INFER_OK(op, "[1,2,2,3];[1,?,3,4]", "[d0_0,2,?,12]"); + INFER_OK(op, "[1,2,2,3];[1,1,?,4]", "[d0_0,2,2,12]"); + INFER_OK(op, "[1,2,2,?];[1,1,?,4]", "[d0_0,2,2,?]"); + INFER_OK(op, "[1,2,2,3];[1,1,3,?]", "[d0_0,2,2,?]"); + + // Test for NCHW format. + TF_CHECK_OK(NodeDefBuilder("test", "DepthwiseConv2dNative") + .Input("input", 0, DT_FLOAT) + .Input("filter", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("padding", "VALID") + .Attr("data_format", "NCHW") + .Finalize(&op.node_def)); + + // 1x1 filter, depth multiplication + INFER_OK(op, "[1,3,2,2];[1,1,3,4]", "[d0_0,12,2,2]"); +} + +TEST(CommonShapeFnsTest, AvgPool2DShapeTest) { + ShapeInferenceTestOp op("AvgPool"); + auto set_op = [&op](const std::vector& strides, + const std::vector& ksizes, const string& padding, + const string& data_format) { + TF_CHECK_OK(NodeDefBuilder("test", "AvgPool") + .Input("input", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("ksize", ksizes) + .Attr("padding", padding) + .Attr("data_format", data_format) + .Finalize(&op.node_def)); + }; + + // Most of the functionality is tested by conv-like shapes, + // so we check the very-specific avgpooling features here. + + // 1x1 filter, 1x1 stride + set_op({1, 1, 1, 1}, {1, 1, 1, 1}, "VALID", "NHWC"); + INFER_OK(op, "[1,2,2,1]", "[d0_0,2,2,d0_3]"); + + // 4x4 input, 2x1 ksize, 1x2 stride + set_op({1, 1, 2, 1}, {1, 2, 1, 1}, "VALID", "NHWC"); + INFER_OK(op, "[1,4,4,1]", "[d0_0,3,2,d0_3]"); + + // 4x4 input, 2x1 ksize, 1x2 stride + // unknown dims in the critical fields lead to partial inference. + // Assumes NHWC format. + INFER_OK(op, "[1,?,4,1]", "[d0_0,?,2,d0_3]"); + INFER_OK(op, "[1,4,?,1]", "[d0_0,3,?,d0_3]"); + + // 4x4 input, 2x1 ksize, 1x2 stride, NCHW format + set_op({{1, 1, 1, 2}}, {1, 1, 2, 1}, "VALID", "NCHW"); + INFER_OK(op, "[1,1,4,4]", "[d0_0,d0_1,3,2]"); + + // 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C test + set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "VALID", "NCHW_VECT_C"); + INFER_OK(op, "[2,3,5,7,4]", "[d0_0,d0_1,4,6,4]"); + INFER_OK(op, "[5,7,?,?,4]", "[d0_0,d0_1,?,?,4]"); + INFER_OK(op, "[?,?,?,?,4]", "[d0_0,d0_1,?,?,4]"); + INFER_ERROR("Dimension must be 4 but is 3", op, "[2,5,7,11,3]"); + + // Invalid rank for input + INFER_ERROR("Shape must be rank", op, "[4,4]"); +} + +TEST(CommonShapeFnsTest, MaxPool2DShapeTest) { + ShapeInferenceTestOp op("MaxPool"); + auto set_op = [&op](const std::vector& strides, + const std::vector& ksizes, const string& padding, + const string& data_format) { + TF_CHECK_OK(NodeDefBuilder("test", "MaxPool") + .Input("input", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("ksize", ksizes) + .Attr("padding", padding) + .Attr("data_format", data_format) + .Finalize(&op.node_def)); + }; + + // Most of the functionality is tested by conv-like shapes, + // so we check the very-specific maxpooling features here, + // namely depthwise kernel and striding. + + // all 1 strides, depth 2 filter + set_op({1, 1, 1, 1}, {1, 1, 1, 2}, "VALID", "NHWC"); + INFER_OK(op, "[1,2,2,2]", "[d0_0,2,2,1]"); + + // depth 3 stride, 1x1x1 filter, NCHW + set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW"); + INFER_OK(op, "[1,7,5,5]", "[d0_0,3,5,5]"); + + // 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C tests + set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "SAME", "NCHW_VECT_C"); + INFER_OK(op, "[2,3,5,7,4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_OK(op, "[5,7,?,?,4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_OK(op, "[?,?,?,?,4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8]"); +} + +TEST(CommonShapeFnsTest, MaxPoolV22DShapeTest) { + ShapeInferenceTestOp op("MaxPoolV2"); + Tensor ksizes_tensor, strides_tensor; + auto set_op = [&op, &ksizes_tensor, &strides_tensor]( + const std::vector& strides, + const std::vector& ksizes, const string& padding, + const string& data_format) { + TF_CHECK_OK(NodeDefBuilder("test", "MaxPoolV2") + .Input("input", 0, DT_FLOAT) + .Input("ksize", 1, DT_INT32) + .Input("strides", 2, DT_INT32) + .Attr("padding", padding) + .Attr("data_format", data_format) + .Finalize(&op.node_def)); + ksizes_tensor = test::AsTensor(ksizes); + op.input_tensors.resize(3); + op.input_tensors[0] = nullptr; + op.input_tensors[1] = &ksizes_tensor; + strides_tensor = test::AsTensor(strides); + op.input_tensors[2] = &strides_tensor; + }; + + // Most of the functionality is tested by conv-like shapes, + // so we check the very-specific maxpooling features here, + // namely depthwise kernel and striding. + + // all 1 strides, depth 2 filter + set_op({1, 1, 1, 1}, {1, 1, 1, 2}, "VALID", "NHWC"); + INFER_OK(op, "[1,2,2,2];[4];[4]", "[d0_0,2,2,1]"); + + // depth 3 stride, 1x1x1 filter, NCHW + set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW"); + INFER_OK(op, "[1,7,5,5];[4];[4]", "[d0_0,3,5,5]"); + + // 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C tests + set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "SAME", "NCHW_VECT_C"); + INFER_OK(op, "[2,3,5,7,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_OK(op, "[5,7,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_OK(op, "[?,?,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8];[4];[4]"); +} + +TEST(CommonShapeFnsTest, Pool3DShapeTest) { + ShapeInferenceTestOp op("MaxPool3D"); + auto set_op = [&op](const std::vector& strides, + const std::vector& ksizes, const string& padding) { + TF_CHECK_OK(NodeDefBuilder("test", "MaxPool3D") + .Input("input", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("ksize", ksizes) + .Attr("padding", padding) + .Finalize(&op.node_def)); + }; + + // Most of the functionality is tested by conv-like shapes, + // so we check that we handle the extra dimension properly. + + // 2x3x4 stride, 1x1x1 filter. + set_op({1, 2, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID"); + INFER_OK(op, "[1,24,24,24,1]", "[d0_0,12,8,6,d0_4]"); + + // Test partially known dimensions + set_op({1, 1, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID"); + INFER_OK(op, "[1,?,24,24,1]", "[d0_0,?,8,6,d0_4]"); +} + +TEST(CommonShapeFnsTest, UnknownShapeTest) { + { + // Single output + ShapeInferenceTestOp op("QueueDequeue"); + TF_CHECK_OK(NodeDefBuilder("test", "QueueDequeue") + .Input("handle", 0, DT_STRING_REF) + .Attr("component_types", {DT_FLOAT}) + .Finalize(&op.node_def)); + INFER_OK(op, "[1]", "?"); + } + + { + // Multiple outputs + ShapeInferenceTestOp op("QueueDequeue"); + TF_CHECK_OK(NodeDefBuilder("test", "QueueDequeue") + .Input("handle", 0, DT_STRING_REF) + .Attr("component_types", {DT_FLOAT, DT_FLOAT, DT_STRING}) + .Finalize(&op.node_def)); + INFER_OK(op, "[1]", "?;?;?"); + } +} + +TEST(CommonShapeFnsTest, Reduce_ShapeFn) { + ShapeInferenceTestOp op("Sum"); + op.input_tensors.resize(2); + + TF_ASSERT_OK(NodeDefBuilder("test", "Sum") + .Input("input", 0, DT_FLOAT) + .Input("reduction_indices", 1, DT_INT32) + .Attr("keep_dims", false) + .Finalize(&op.node_def)); + + // Reduction indices not available, so output is unknown. + INFER_OK(op, "[2,4,5];[2]", "?"); + INFER_OK(op, "?;[2]", "?"); + + Tensor indices = test::AsTensor({1, 2}); + op.input_tensors[1] = &indices; + + // Reduction indices available + INFER_OK(op, "[2,4,5];[2]", "[d0_0]"); + + // Wrapped indices + indices = test::AsTensor({-1, -2}); + op.input_tensors[1] = &indices; + INFER_OK(op, "[2,4,5];[2]", "[d0_0]"); + + // Scalar + indices = test::AsScalar(0); + op.input_tensors[1] = &indices; + INFER_OK(op, "[2,4,5];[]", "[d0_1,d0_2]"); + + indices = test::AsScalar(-4); + op.input_tensors[1] = &indices; + INFER_ERROR("Invalid reduction dimension", op, "[2,4,5];[]"); + + // Empty reduction indices + indices = test::AsTensor({}); + op.input_tensors[1] = &indices; + INFER_OK(op, "[2,4,5];[0]", "[d0_0,d0_1,d0_2]"); + + // Keep dims = true + TF_ASSERT_OK(NodeDefBuilder("test", "Sum") + .Input("input", 0, DT_FLOAT) + .Input("reduction_indices", 1, DT_INT32) + .Attr("keep_dims", true) + .Finalize(&op.node_def)); + indices = test::AsTensor({-1, -2}); + op.input_tensors[1] = &indices; + INFER_OK(op, "[2,4,5];[2]", "[d0_0, 1, 1]"); + + // input rank is known, but reduction indices are not (with keep_dim=true). + // The output rank matches input rank (because of keep_dims=true). + op.input_tensors[1] = nullptr; + INFER_OK(op, "[?,?,?];?", "[?,?,?]"); + INFER_OK(op, "[?,?,?];[2]", "[?,?,?]"); + + // Reduction indices with too many dimensions. + INFER_ERROR("must be at most rank 1 but is rank 2", op, "[?,?,?];[?,?]"); + // With older graph-def version, this is allowed. + op.graph_def_version = 20; + INFER_OK(op, "[?,?,?];[?,?]", "[?,?,?]"); + // And when the tensor is specified, it's still allowed. + op.input_tensors[1] = &indices; + indices = test::AsTensor({-1, -2}, TensorShape({2, 1})); + INFER_OK(op, "[2,4,5];[2,1]", "[d0_0, 1, 1]"); + indices = test::AsTensor({-1, -2}, TensorShape({1, 2})); + INFER_OK(op, "[2,4,5];[1,2]", "[d0_0, 1, 1]"); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) { + NodeDef def; + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {Unknown(), Unknown(), Unknown()}, {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) { + NodeDef def; + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({-1, -1}), S({-1}), S({-1})}, {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) { + NodeDef def; + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({-1}), S({-1}), S({-1})}, {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + EXPECT_EQ(error::INVALID_ARGUMENT, + ValidateSparseTensor(&c, indices, values, shape).code()); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) { + NodeDef def; + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({5, 3}), S({4}), S({3})}, {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + EXPECT_EQ(error::INVALID_ARGUMENT, + ValidateSparseTensor(&c, indices, values, shape).code()); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) { + NodeDef def; + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({5, 3}), S({5}), S({4})}, {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + EXPECT_EQ(error::INVALID_ARGUMENT, + ValidateSparseTensor(&c, indices, values, shape).code()); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) { + NodeDef def; + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({-1, 3}), S({5}), S({3})}, {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) { + NodeDef def; + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({5, 3}), S({-1}), S({3})}, {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) { + NodeDef def; + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({5, -1}), S({5}), S({3})}, {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) { + NodeDef def; + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({5, 3}), S({5}), S({-1})}, {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor) { + NodeDef def; + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({5, 3}), S({5}), S({3})}, {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +} // namespace shape_inference +} // namespace tensorflow diff --git a/control_flow.h b/control_flow.h new file mode 100644 index 0000000000000000000000000000000000000000..4dad0b4fef2d13d6ba583ef55b08f14a12f72d11 --- /dev/null +++ b/control_flow.h @@ -0,0 +1,58 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_ +#define TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_ + +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +const uint64 kIllegalFrameId = ~0uLL; +const int64 kIllegalIterId = -1; + +// For the purpose of control flow, every tensor produced by TensorFlow is +// conceptually tagged by a 'FrameAndIter'. FrameAndIter consists of a +// 'frame_id' and an 'iter_id'. The tensor value it represents is produced +// in the frame with frame_id at the iteration of iter_id. +struct FrameAndIter { + uint64 frame_id = kIllegalFrameId; + int64 iter_id = kIllegalIterId; + + FrameAndIter() {} + + FrameAndIter(uint64 frame, int64 iter) { + frame_id = frame; + iter_id = iter; + } + + bool operator==(const FrameAndIter& other) const { + return (frame_id == other.frame_id && iter_id == other.iter_id); + } +}; + +struct FrameAndIterHash { + size_t operator()(const FrameAndIter& key) const { + // Make sure there are no padding bytes that we don't want + CHECK_EQ(sizeof(uint64) + sizeof(int64), sizeof(FrameAndIter)); + return Hash64(reinterpret_cast(&key), sizeof(FrameAndIter)); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_ diff --git a/cost_graph.proto b/cost_graph.proto new file mode 100644 index 0000000000000000000000000000000000000000..f4837fbfc55dc266bad01c9300e3a8b63c67f1e0 --- /dev/null +++ b/cost_graph.proto @@ -0,0 +1,72 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "CostGraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +message CostGraphDef { + message Node { + // The name of the node. Names are globally unique. + string name = 1; + + // The device of the node. Can be empty if the node is mapped to the + // default partition or partitioning hasn't been run yet. + string device = 2; + + // The id of the node. Node ids are only unique inside a partition. + int32 id = 3; + + // Inputs of this node. They must be executed before this node can be + // executed. An input is a particular output of another node, specified + // by the node id and the output index. + message InputInfo { + int32 preceding_node = 1; + int32 preceding_port = 2; + } + repeated InputInfo input_info = 4; + + // Outputs of this node. + message OutputInfo { + int64 size = 1; + // If >= 0, the output is an alias of an input. Note that an alias input + // may itself be an alias. The algorithm will therefore need to follow + // those pointers. + int64 alias_input_port = 2; + TensorShapeProto shape = 3; + DataType dtype = 4; + } + repeated OutputInfo output_info = 5; + + // Temporary memory used by this node. + int64 temporary_memory_size = 6; + + int64 host_temp_memory_size = 10; + int64 device_temp_memory_size = 11; + int64 host_persistent_memory_size = 12; + int64 device_persistent_memory_size = 16; + + // Estimate of the computational cost of this node, in microseconds. + int64 compute_cost = 9; + + // Analytical estimate of the computational cost of this node, in + // microseconds. + int64 compute_time = 14; + + // Analytical estimate of the memory access cost of this node, in + // microseconds. + int64 memory_time = 15; + + // If true, the output is permanent: it can't be discarded, because this + // node is part of the "final output". Nodes may depend on final nodes. + bool is_final = 7; + + // Ids of the control inputs for this node. + repeated int32 control_input = 8; + } + repeated Node node = 1; +} diff --git a/device_attributes.proto b/device_attributes.proto new file mode 100644 index 0000000000000000000000000000000000000000..9983bcb6bec63602c2e624a183a111622f7f2ace --- /dev/null +++ b/device_attributes.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "DeviceAttributesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +message DeviceLocality { + // Optional bus locality of device. Default value of 0 means + // no specific locality. Specific localities are indexed from 1. + int32 bus_id = 1; +}; + +message DeviceAttributes { + // Fully specified name of the device within a cluster. + string name = 1; + + // String representation of device_type. + string device_type = 2; + + // Memory capacity of device in bytes. + int64 memory_limit = 4; + + // Platform-specific data about device that may be useful + // for supporting efficient data transfers. + DeviceLocality locality = 5; + + // A device is assigned a global unique number each time it is + // initialized. "incarnation" should never be 0. + fixed64 incarnation = 6; + + // String representation of the physical device that this device maps to. + string physical_device_desc = 7; +} diff --git a/device_base.cc b/device_base.cc new file mode 100644 index 0000000000000000000000000000000000000000..e30ee84cc3f9378502d44ee8f1106f092b5d0605 --- /dev/null +++ b/device_base.cc @@ -0,0 +1,30 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/device_base.h" + +namespace tensorflow { + +DeviceBase::~DeviceBase() {} + +const DeviceAttributes& DeviceBase::attributes() const { + LOG(FATAL) << "Device does not implement attributes()"; +} + +const string& DeviceBase::name() const { + LOG(FATAL) << "Device does not implement name()"; +} + +} // namespace tensorflow diff --git a/device_base.h b/device_base.h new file mode 100644 index 0000000000000000000000000000000000000000..1838a8ad02d2bd5522ce3162fea53e3f5afc0309 --- /dev/null +++ b/device_base.h @@ -0,0 +1,243 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_ +#define TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" + +namespace Eigen { +struct ThreadPoolDevice; +#ifdef TENSORFLOW_USE_SYCL +struct SyclDevice; +#endif +} // end namespace Eigen + +namespace perftools { +namespace gputools { +class Stream; +} // namespace gputools +} // namespace perftools + +namespace tensorflow { + +class Device; +class DeviceAttributes; +class Env; +class EventMgr; +class OpKernelContext; +class ResourceMgr; +class TensorProto; + +namespace thread { +class ThreadPool; +} + +// A wrapper for an Eigen Gpu Device that includes per-op state. The +// class is defined even for non-GPU devices since the +// OpKernelContext::Params structure wants to fill it in. +class PerOpGpuDevice { + public: + virtual ~PerOpGpuDevice() {} + virtual const Eigen::GpuDevice& device() const = 0; +}; + +// A class that devices can subclass to pass around +// Device-specific context to OpKernels. +class DeviceContext : public core::RefCounted { + public: + ~DeviceContext() override {} + virtual perftools::gputools::Stream* stream() const { return nullptr; } + virtual void MaintainLifetimeOnStream( + const Tensor* t, perftools::gputools::Stream* stream) const {} + + // "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into + // "device_tensor" which is on a GPU device "device". "device_tensor" + // must be allocated to be of the same size as "cpu_tensor". + virtual void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, + StatusCallback done) const { + done(errors::Internal("Unrecognized device type in CPU-to-device Copy")); + } + + // "device_tensor" is a tensor on a non-CPU device. Copies + // device_tensor into "cpu_tensor". "cpu_tensor" must be allocated + // to be of the same size as "device_tensor". + virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor, + StringPiece tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) { + done(errors::Internal("Unrecognized device type in device-to-CPU Copy")); + } +}; + +// map[i] is the DeviceContext* for the node with id i, if i < map.size(). +typedef std::vector DeviceContextMap; + +class DeviceBase { + public: + explicit DeviceBase(Env* env) : env_(env) {} + virtual ~DeviceBase(); + + Env* env() const { return env_; } + + // Override this to return true for devices that require an Op's + // compute method to save references to the temporary tensors it + // allocates until the Op execution completes + virtual bool RequiresRecordingAccessedTensors() const { return false; } + + struct CpuWorkerThreads { + int num_threads = 0; + thread::ThreadPool* workers = nullptr; + }; + + // Does not take ownership. + void set_tensorflow_cpu_worker_threads(CpuWorkerThreads* t) { + cpu_worker_threads_ = t; + } + + virtual const CpuWorkerThreads* tensorflow_cpu_worker_threads() const { + CHECK(cpu_worker_threads_ != nullptr); + return cpu_worker_threads_; + } + + // "stream" is used in special circumstances (such as the + // constructors of Ops) where there is no available OpKernelContext. + // "default_context" is used by OpKernelContext whenever a device does not + // supply a DeviceContext for an op in FillContextMap (e.g. when only + // using a single stream.) + // "event_mgr" is used to delay deallocation of temporary GPU buffers. + // TODO(pbar) Work out how to move this out of DeviceBase. + struct GpuDeviceInfo { + // Make sure all the defaults are NULL, so we can spot missing assignments. + perftools::gputools::Stream* stream = nullptr; + DeviceContext* default_context = nullptr; + EventMgr* event_mgr = nullptr; + int gpu_id = -1; + }; + + // Does not take ownership. + void set_tensorflow_gpu_device_info(GpuDeviceInfo* g) { + gpu_device_info_ = g; + } + + virtual const GpuDeviceInfo* tensorflow_gpu_device_info() const { + return gpu_device_info_; + } + + // The preferred thread pool for this device. If it is nullptr, the system + // automatically assigns a thread pool for execution. + virtual thread::ThreadPool* tensorflow_device_thread_pool() { + return device_thread_pool_; + } + + // Does not take ownership. + void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) { + eigen_cpu_device_ = d; + } + +#ifdef TENSORFLOW_USE_SYCL + void set_eigen_sycl_device(Eigen::SyclDevice* d) { eigen_sycl_device_ = d; } +#endif + + // Return the Allocator implementation to use based on the allocator + // attributes requested. See allocator.h for more details. + virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) { + LOG(FATAL) << "GetAllocator() is not implemented."; + return nullptr; + } + + // Return the Allocator implementation to use based on the allocator + // attributes requested and the supplied resource manager. By + // default this ignores the resource manager and calls the base + // implementation but devices can override if they want to consult + // the resource manager when choosing the allocator. + virtual Allocator* GetStepAllocator(AllocatorAttributes attr, + ResourceMgr* /*step_resource_manager*/) { + return GetAllocator(attr); + } + + virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() { + CHECK(eigen_cpu_device_ != nullptr); + return eigen_cpu_device_; + } + +#ifdef TENSORFLOW_USE_SYCL + virtual const Eigen::SyclDevice* eigen_sycl_device() const { + CHECK(eigen_sycl_device_ != nullptr); + return eigen_sycl_device_; + } +#endif + + // Caller owns the return value. The OpKernelContext calls this even + // for devices that do not implement an eigen_gpu_device. Overridden + // by GPU devices to return a derived type. + virtual PerOpGpuDevice* MakeGpuDevice() { return nullptr; } + + virtual DeviceBase* UnderlyingDevice() { return this; } + virtual const DeviceBase* UnderlyingDevice() const { return this; } + + // This is overridden by GPU devices to reinitialize the derived + // type returned by MakeGpuDevice. + virtual void ReinitializeGpuDevice(OpKernelContext* /*context*/, + PerOpGpuDevice* /*device*/, + DeviceContext* /*dc*/, + Allocator* /*allocator*/) {} + + // Unimplemented by default + virtual const DeviceAttributes& attributes() const; + virtual const string& name() const; + + // Materializes the given TensorProto into 'tensor' stored in Device + // memory. Most devices will want to override this. + // + // TODO(vrv): We should be able to put this function into + // OpKernelContext and handle the copies from device memory via send + // and receive nodes, instead of requiring that each device handle + // the copies here as well as in copy ops. + virtual Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) { + return errors::Internal("Device does not implement MakeTensorFromProto()"); + } + + protected: + // Does not take ownership. + void set_tensorflow_device_thread_pool(thread::ThreadPool* thread_pool) { + device_thread_pool_ = thread_pool; + } + + private: + Env* const env_; + CpuWorkerThreads* cpu_worker_threads_ = nullptr; + GpuDeviceInfo* gpu_device_info_ = nullptr; + thread::ThreadPool* device_thread_pool_ = nullptr; + Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr; +#ifdef TENSORFLOW_USE_SYCL + Eigen::SyclDevice* eigen_sycl_device_ = nullptr; +#endif +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_ diff --git a/fake_input.cc b/fake_input.cc new file mode 100644 index 0000000000000000000000000000000000000000..ad301a8aa4ba4be5b7031d00984d8e6febf1583e --- /dev/null +++ b/fake_input.cc @@ -0,0 +1,240 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" + +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace { + +class FakeInputImpl { + public: + FakeInputImpl(const OpDef* op_def, int in_index, const NodeDef* node_def, + NodeDefBuilder* builder); + void SetN(int n); + void SetDataType(DataType dt); + void SetTypeList(DataTypeSlice dts); + Status AddInputToBuilder(); + + private: + static string FakeNodeName(int in_index); + Status GetN(int* n) const; + Status GetDataType(DataType* dt) const; + void NSources(int n, DataType dt) const; + void SourceList(DataTypeSlice dts) const; + + const OpDef* const op_def_; + const OpDef::ArgDef* const arg_; + const string in_node_; + const NodeDef* const node_def_; + NodeDefBuilder* const builder_; + + bool n_specified_; + int n_; + bool dt_specified_; + DataType dt_; + bool dts_specified_; + DataTypeSlice dts_; +}; + +FakeInputImpl::FakeInputImpl(const OpDef* op_def, int in_index, + const NodeDef* node_def, NodeDefBuilder* builder) + : op_def_(op_def), + arg_(&op_def->input_arg(in_index)), + in_node_(FakeNodeName(in_index)), + node_def_(node_def), + builder_(builder), + n_specified_(false), + dt_specified_(false), + dts_specified_(false) {} + +void FakeInputImpl::SetN(int n) { + n_specified_ = true; + n_ = n; +} + +void FakeInputImpl::SetDataType(DataType dt) { + dt_specified_ = true; + dt_ = dt; +} + +void FakeInputImpl::SetTypeList(DataTypeSlice dts) { + dts_specified_ = true; + dts_ = dts; +} + +Status FakeInputImpl::AddInputToBuilder() { + if (dts_specified_) { + SourceList(dts_); + + } else if (n_specified_ || !arg_->number_attr().empty()) { + int n; + TF_RETURN_IF_ERROR(GetN(&n)); + + DataType dt; + if (n > 0) { + TF_RETURN_IF_ERROR(GetDataType(&dt)); + } else { + dt = DT_FLOAT; + } + + NSources(n, dt); + } else { + if (!dt_specified_ && !arg_->type_list_attr().empty()) { + DataTypeVector dts; + Status status = GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts); + if (!status.ok()) { + return errors::InvalidArgument( + "Could not infer list of types for input '", arg_->name(), "': ", + status.error_message()); + } + SourceList(dts); + return Status::OK(); + } + + DataType dt; + TF_RETURN_IF_ERROR(GetDataType(&dt)); + builder_->Input(in_node_, 0, dt); + } + return Status::OK(); +} + +// static +string FakeInputImpl::FakeNodeName(int in_index) { + char c = 'a' + (in_index % 26); + return string(&c, 1); +} + +Status FakeInputImpl::GetN(int* n) const { + if (n_specified_) { + *n = n_; + } else { + Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n); + if (!status.ok()) { + return errors::InvalidArgument("Could not infer length of input '", + arg_->name(), "': ", + status.error_message()); + } + } + return Status::OK(); +} + +Status FakeInputImpl::GetDataType(DataType* dt) const { + if (dt_specified_) { + *dt = dt_; + return Status::OK(); // Ignore is_ref field of arg_. + } else if (arg_->type() != DT_INVALID) { + *dt = arg_->type(); + } else if (!arg_->type_attr().empty()) { + Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt); + if (!status.ok()) { + // Check if the type attr has a default + const OpDef::AttrDef* attr = FindAttr(arg_->type_attr(), *op_def_); + if (attr && attr->has_default_value()) { + *dt = attr->default_value().type(); + } else { + return errors::InvalidArgument("Could not infer type for input '", + arg_->name(), "': ", + status.error_message()); + } + } + } else { + return errors::InvalidArgument("No type or type_attr field in arg '", + arg_->name(), "'"); + } + if (arg_->is_ref()) { + *dt = MakeRefType(*dt); + } + return Status::OK(); +} + +void FakeInputImpl::NSources(int n, DataType dt) const { + std::vector srcs; + srcs.reserve(n); + for (int i = 0; i < n; ++i) { + srcs.emplace_back(in_node_, i, dt); + } + builder_->Input(gtl::ArraySlice(srcs)); +} + +void FakeInputImpl::SourceList(DataTypeSlice dts) const { + std::vector srcs; + srcs.reserve(dts.size()); + for (size_t i = 0; i < dts.size(); ++i) { + srcs.emplace_back(in_node_, i, dts[i]); + } + builder_->Input(gtl::ArraySlice(srcs)); +} + +} // namespace + +// Public interface ------------------------------------------------------------ + +FakeInputFunctor FakeInput() { + return [](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + return impl.AddInputToBuilder(); + }; +} + +FakeInputFunctor FakeInput(DataType dt) { + return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + impl.SetDataType(dt); + return impl.AddInputToBuilder(); + }; +} + +FakeInputFunctor FakeInput(int n) { + return [n](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + impl.SetN(n); + return impl.AddInputToBuilder(); + }; +} + +FakeInputFunctor FakeInput(int n, DataType dt) { + return [n, dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + impl.SetN(n); + impl.SetDataType(dt); + return impl.AddInputToBuilder(); + }; +} + +FakeInputFunctor FakeInput(DataTypeSlice dts) { + // Make a copy to ensure the data will still be around when the lambda is + // called. + DataTypeVector dtv(dts.begin(), dts.end()); + return [dtv](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + impl.SetTypeList(dtv); + return impl.AddInputToBuilder(); + }; +} + +} // namespace tensorflow diff --git a/fake_input.h b/fake_input.h new file mode 100644 index 0000000000000000000000000000000000000000..103db47a9964637fcfb1253e8c60863a0ba7f4cc --- /dev/null +++ b/fake_input.h @@ -0,0 +1,40 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_ +#define TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_ + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +// These functions return values that may be passed to +// NodeDefBuilder::Input() to add an input for a test. Use them when +// you don't care about the node names/output indices providing the +// input. They also allow you to omit the input types and/or +// list length when they may be inferred. +FakeInputFunctor FakeInput(); // Infer everything +FakeInputFunctor FakeInput(DataType dt); +FakeInputFunctor FakeInput(int n); // List of length n +FakeInputFunctor FakeInput(int n, DataType dt); +FakeInputFunctor FakeInput(DataTypeSlice dts); +inline FakeInputFunctor FakeInput(std::initializer_list dts) { + return FakeInput(DataTypeSlice(dts)); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_ diff --git a/function.cc b/function.cc new file mode 100644 index 0000000000000000000000000000000000000000..d757e962e522f801243a35a362f0c6821814d948 --- /dev/null +++ b/function.cc @@ -0,0 +1,1322 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/function.h" + +#include +#include +#include +#include + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/function.pb_text.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/util/equal_graph_def.h" + +namespace tensorflow { + +// Extracts the actual type from "attr_values" based on its definition +// "arg_def". +// +// If "arg_def" is a N*T type, *is_type_list is set to false, and +// *dtypes is set to be a vector of size N and each element is T. +// +// If "arg_def" is a list(type), *is_type_list is set to true, and +// *dtypes is set to be a vector of types specified in attrs for +// arg_def. +// +// Otherwise (arg_def is a simple type T), *is_type_list is set to +// false, and *dtypes is set to a single element vector, whose only +// element is T. +Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, + bool* is_type_list, DataTypeVector* dtypes) { + dtypes->clear(); + if (!arg_def.type_list_attr().empty()) { + const AttrValue* v = attrs.Find(arg_def.type_list_attr()); + if (v == nullptr) { + return errors::NotFound("type attr not found: ", + arg_def.type_list_attr()); + } + *is_type_list = true; + for (int i = 0; i < v->list().type_size(); ++i) { + dtypes->push_back(v->list().type(i)); + } + return Status::OK(); + } + + *is_type_list = false; + int num = 1; + if (!arg_def.number_attr().empty()) { + const AttrValue* v = attrs.Find(arg_def.number_attr()); + if (v == nullptr) { + return errors::NotFound("type attr not found: ", arg_def.type_attr()); + } + num = v->i(); + } + + DataType dtype; + if (arg_def.type() != DT_INVALID) { + dtype = arg_def.type(); + } else if (arg_def.type_attr().empty()) { + dtype = DT_INVALID; + } else { + const AttrValue* v = attrs.Find(arg_def.type_attr()); + if (v == nullptr) { + return errors::NotFound("type attr not found: ", arg_def.type_attr()); + } + dtype = v->type(); + } + dtypes->resize(num, dtype); + return Status::OK(); +} + +namespace { + +template +void AddAttr(const string& name, const T& val, NodeDef* ndef) { + SetAttrValue(val, &((*ndef->mutable_attr())[name])); +} + +Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) { + // attr_values should specify all attrs defined in fdef. + for (const auto& a : sig.attr()) { + const AttrValue* v = attr_values.Find(a.name()); + if (!v) { + return errors::NotFound("Attr ", a.name(), " is not found from ", + SummarizeOpDef(sig)); + } + Status status = AttrValueHasType(*v, a.type()); + if (!status.ok()) { + errors::AppendToMessage(&status, "for attr '", a.name(), "'"); + return status; + } + } + +// TODO(josh11b): Enable this code once it works with function gradients. +// Right now the C++ function gradient code assumes it can pass +// all the attrs of the function to the gradient, and any attrs that +// the gradient doesn't care about will be ignored. +#if 0 + if (attr_values.size() != sig.attr_size()) { + for (const auto& a : attr_values) { + // TODO(josh11b): Possibly should ignore attrs that start with "_" here? + bool found = false; + for (const auto& s : sig.attr()) { + if (a.first == s.name()) { + found = true; + break; + } + } + if (!found) { + return errors::NotFound("Attr ", a.first, " is not found in ", + SummarizeOpDef(sig)); + } + } + } +#endif + + return Status::OK(); +} + +// A helper class for instantiating functions. This contains shared information +// like the resulting graph and node name index. +class FunctionInstantiationHelper { + public: + FunctionInstantiationHelper(GetFunctionSignature get_function, + InstantiationResult* result) + : get_function_(std ::move(get_function)), result_(*result) { + result_.nodes.clear(); + } + + // Builds index for nodes that can be used as node's input arguments. + Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, + AttrSlice attr_values) { + bool is_type_list; + DataTypeVector dtypes; + TF_RETURN_IF_ERROR( + ArgNumType(attr_values, arg_def, &is_type_list, &dtypes)); + CHECK_GE(dtypes.size(), size_t{1}); + int arg_index = result_.nodes.size(); + TF_RETURN_IF_ERROR( + AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes})); + // Creates dtypes.size() nodes in the graph. + for (size_t i = 0; i < dtypes.size(); ++i) { + TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i), + {true, arg_index, 0, false, {dtypes[i]}})); + DCHECK_EQ(arg_index, result_.nodes.size()); + string name = arg_def.name(); + if (dtypes.size() > 1) { + strings::StrAppend(&name, "_", i); + } + NodeDef* gnode = AddNode(name); + gnode->set_op("_Arg"); + AddAttr("T", dtypes[i], gnode); + AddAttr("index", arg_index, gnode); + result_.arg_types.push_back(dtypes[i]); + ++arg_index; + } + return Status::OK(); + } + + Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs, + const int arg_index) { + const OpDef* node_sig = nullptr; + TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig)); + if (node_sig->output_arg_size() == 0) { + return AddItem(node.name(), {false, arg_index, 0, false, {}}); + } + const int num_retval = node_sig->output_arg_size(); + int start = 0; + bool is_type_list; + DataTypeVector dtypes; + for (int i = 0; i < num_retval; ++i) { + TF_RETURN_IF_ERROR( + ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes)); + // Note that we rely on the backwards-compatibility test enforcing + // that output_arg(*).name() doesn't change here. + const string base_name = + strings::StrCat(node.name(), ":", node_sig->output_arg(i).name()); + TF_RETURN_IF_ERROR( + AddItem(base_name, {false, arg_index, start, is_type_list, dtypes})); + for (int j = 0; j < static_cast(dtypes.size()); ++j) { + TF_RETURN_IF_ERROR( + AddItem(strings::StrCat(base_name, ":", j), + {false, arg_index, start + j, false, {dtypes[j]}})); + } + start += dtypes.size(); + } + return Status::OK(); + } + + Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) { + const OpDef* fnode_sig = nullptr; + TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig)); + NodeDef* gnode = AddNode(fnode.name()); + gnode->set_op(fnode.op()); + gnode->set_device(fnode.device()); + int gnode_idx = nodes_.size() - 1; + + // Input + const int num_args = fnode_sig->input_arg_size(); + bool is_type_list; // ignored + DataTypeVector dtypes; + int fnode_arg_index = 0; + for (int i = 0; i < num_args; ++i) { + TF_RETURN_IF_ERROR( + ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes)); + // Consume inputs (indexed by fnode_arg_index) until we have + // matched each element of dtypes (indexed by j). + for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) { + if (fnode_arg_index >= fnode.input_size()) { + // Should never happen if we computed dtypes correctly. + return errors::InvalidArgument( + "Attempt to access beyond input size: ", fnode_arg_index, + " >= ", fnode.input_size()); + } + // Look up the next input. + const string& input_name = fnode.input(fnode_arg_index); + const auto* item = GetItemOrNull(input_name); + if (item == nullptr) { + return errors::InvalidArgument( + "input ", input_name, " is not found: ", SummarizeNodeDef(fnode)); + } + if (item->dtypes.size() > dtypes.size() - j) { + return errors::InvalidArgument("Input ", input_name, " too long for ", + fnode_sig->input_arg(i).name()); + } + // Match up all the elements of this input (indexed by k) with + // elements of dtypes (advancing j). + for (int k = 0; k < item->dtypes.size(); ++k, ++j) { + if (item->dtypes[k] != dtypes[j]) { + return errors::InvalidArgument( + "input ", fnode_sig->input_arg(i).name(), "[", j, + "] expected type ", DataTypeString(dtypes[j]), + " != ", DataTypeString(item->dtypes[k]), ", the type of ", + input_name, "[", k, "]"); + } + if (item->is_func_arg) { + AddInput(gnode_idx, item->nid + k, 0); + } else { + AddInput(gnode_idx, item->nid, item->idx + k); + } + } + } + } + + // Control deps. + for (int i = fnode_arg_index; i < fnode.input_size(); ++i) { + const string& input = fnode.input(i); + if (input.empty() || input[0] != '^') { + return errors::InvalidArgument("Expected input[", i, "] == '", input, + "' to be a control input."); + } + int nid = -1; + const string node_name = input.substr(1); + const string node_colon = node_name + ":"; + const string node_colon_bound = node_name + ";"; + // index_ is a map sorted lexicographically, so the key we are looking for + // must lie in the range [node_name, node_colon_bound). + auto it = index_.lower_bound(node_name); + while (it != index_.end() && it->first <= node_colon_bound) { + if (it->first == node_name || + tensorflow::StringPiece(it->first).starts_with(node_colon)) { + nid = it->second.nid; + break; + } + ++it; + } + if (nid == -1) { + return errors::InvalidArgument("input[", i, "] == '", input, + "', is not found."); + } + AddDep(gnode_idx, nid); + } + + // Attrs. + for (const auto& p : attrs) { + (*gnode->mutable_attr())[p.first] = p.second; + } + + return Status::OK(); + } + + Status AddReturnNode( + const OpDef::ArgDef& ret_def, AttrSlice attrs, + const ::tensorflow::protobuf::Map& ret_map, + int* ret_index) { + auto ret_iter = ret_map.find(ret_def.name()); + if (ret_iter == ret_map.end()) { + return errors::InvalidArgument("Return ", ret_def.name(), " missing."); + } + bool is_type_list; + DataTypeVector dtypes; + TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes)); + CHECK_GE(dtypes.size(), size_t{1}); + const auto* item = GetItemOrNull(ret_iter->second); + if (item == nullptr) { + return errors::InvalidArgument("Return ", ret_def.name(), " -> ", + ret_iter->second, " is not found."); + } + if (dtypes != item->dtypes) { + return errors::InvalidArgument("Invalid ret types ", ret_def.name(), + " : ", DataTypeVectorString(dtypes), + " vs. ", + DataTypeVectorString(item->dtypes)); + } + for (size_t i = 0; i < dtypes.size(); ++i) { + string name = strings::StrCat(ret_def.name(), "_RetVal"); + if (dtypes.size() > 1) { + strings::StrAppend(&name, "_", i); + } + NodeDef* gnode = AddNode(name); + gnode->set_op("_Retval"); + AddInput(nodes_.size() - 1, item->nid, item->idx + i); + AddAttr("T", dtypes[i], gnode); + AddAttr("index", (*ret_index)++, gnode); + result_.ret_types.push_back(dtypes[i]); + } + return Status::OK(); + } + + // Adds the actual node inputs to the result graph by converting indexes to + // the node names. + void AddNodeInputs() { + for (int i = 0; i < result_.nodes.size(); i++) { + NodeInfo& node_info = nodes_[i]; + for (const auto& p : node_info.data_inputs) { + result_.nodes[i].add_input(Name(p.first, p.second)); + } + for (int index : node_info.control_inputs) { + result_.nodes[i].add_input(Dep(index)); + } + } + } + + private: + // This is used to build a small index for all names that can be used as a + // node's input arguments. + // + // If is_func_arg is true, the name is a function's argument. In + // this case, the produced graph def has node[nid:nid + dtype.size()]. + // + // Otherwise, the name is a function body's node return value. In + // this case, the produced graph def has one node node[nid] and + // the node's output index [idx ... idx + num) corresponds to the + // named outputs. + // + // In all cases, "dtype" specifies the data type. + struct NameInfoItem { + bool is_func_arg; + int nid; + int idx; + bool is_type_list; + DataTypeVector dtypes; + }; + + // Adds an item into the input name index. + Status AddItem(const string& name, const NameInfoItem& item) { + if (!index_.insert({name, item}).second) { + return errors::InvalidArgument( + strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret", + " name: "), + name); + } + return Status::OK(); + } + + const NameInfoItem* GetItemOrNull(const string& name) const { + return gtl::FindOrNull(index_, name); + } + + string Dep(int node_index) const { + return strings::StrCat("^", Name(node_index)); + } + + string Name(int node_index) const { + CHECK_LT(node_index, nodes_.size()); + return nodes_[node_index].name; + } + + string Name(int node_index, int output_index) const { + if (output_index == 0) { + return Name(node_index); + } else { + return strings::StrCat(Name(node_index), ":", output_index); + } + } + + NodeDef* AddNode(const string& name) { + result_.nodes.emplace_back(); + NodeDef* gnode = &result_.nodes.back(); + gnode->set_name(name); + nodes_.push_back({name, {}, {}}); + CHECK_EQ(result_.nodes.size(), nodes_.size()); + return gnode; + } + + void AddInput(int node_index, int output_node, int output_index) { + CHECK_LT(node_index, nodes_.size()); + nodes_[node_index].data_inputs.push_back( + std::make_pair(output_node, output_index)); + } + + void AddDep(int node_index, int dep_index) { + CHECK_LT(node_index, nodes_.size()); + nodes_[node_index].control_inputs.push_back(dep_index); + } + + GetFunctionSignature get_function_; + InstantiationResult& result_; + // A small index for all names that can be used as a node's input arguments. + std::map index_; + // This contains information about a node in the new graph including the node + // names and input nodes' indexes. + struct NodeInfo { + string name; + // Data inputs where means arg k of node n. + std::vector> data_inputs; + // Control inputs (dependencies). + std::vector control_inputs; + }; + // nodes_[i] is the information about result_.nodes[i]. + std::vector nodes_; +}; + +// Various helpers Print(proto) to print relevant protos to ascii. +string Print(const OpDef::ArgDef& arg) { + string out; + strings::StrAppend(&out, arg.name(), ":"); + if (arg.is_ref()) strings::StrAppend(&out, "Ref("); + if (!arg.number_attr().empty()) { + strings::StrAppend(&out, arg.number_attr(), "*"); + } + if (arg.type() != DT_INVALID) { + strings::StrAppend(&out, DataTypeString(arg.type())); + } else { + strings::StrAppend(&out, arg.type_attr()); + } + if (arg.is_ref()) strings::StrAppend(&out, ")"); + return out; +} + +// TODO(josh11b): Merge this with SummarizeAttrValue(). +string Print(const AttrValue& attr_value) { + if (attr_value.value_case() == AttrValue::kType) { + return DataTypeString(attr_value.type()); + } else if ((attr_value.value_case() == AttrValue::kList) && + (attr_value.list().type_size() > 0)) { + string ret = "{"; + for (int i = 0; i < attr_value.list().type_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i))); + } + strings::StrAppend(&ret, "}"); + return ret; + } else if (attr_value.value_case() == AttrValue::kFunc) { + if (attr_value.func().attr_size() == 0) { + return attr_value.func().name(); + } + std::vector entries; + for (auto p : attr_value.func().attr()) { + entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + } + std::sort(entries.begin(), entries.end()); + return strings::StrCat(attr_value.func().name(), "[", + str_util::Join(entries, ", "), "]"); + } + return SummarizeAttrValue(attr_value); +} + +// TODO(josh11b): Merge this with SummarizeNodeDef(). +string Print(const NodeDef& n) { + string out; + strings::StrAppend(&out, n.name(), " = ", n.op()); + if (n.attr_size() > 0) { + std::vector entries; + for (auto& a : n.attr()) { + entries.push_back(strings::StrCat(a.first, "=", Print(a.second))); + } + std::sort(entries.begin(), entries.end()); + strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]"); + } + strings::StrAppend(&out, "("); + std::vector dat; + std::vector dep; + for (StringPiece s : n.input()) { + if (s.Consume("^")) { + dep.push_back(s.ToString()); + } else { + dat.push_back(s); + } + } + strings::StrAppend(&out, str_util::Join(dat, ", "), ")"); + if (!dep.empty()) { + strings::StrAppend(&out, " @ ", str_util::Join(dep, ", ")); + } + return out; +} + +string Print(const FunctionDef& fdef) { + string out; + const OpDef& sig = fdef.signature(); + strings::StrAppend(&out, "\n", sig.name()); + if (sig.attr_size() > 0) { + strings::StrAppend(&out, "["); + for (int i = 0; i < sig.attr_size(); ++i) { + const auto& a = sig.attr(i); + if (i > 0) strings::StrAppend(&out, ", "); + if (a.type() == "type") { + strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values())); + } else { + strings::StrAppend(&out, a.name(), ":", a.type()); + } + } + strings::StrAppend(&out, "]"); + } + strings::StrAppend(&out, "("); + for (int i = 0; i < sig.input_arg_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, Print(sig.input_arg(i))); + } + strings::StrAppend(&out, ") -> ("); + for (int i = 0; i < sig.output_arg_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, Print(sig.output_arg(i))); + } + strings::StrAppend(&out, ") {\n"); + for (const auto& n : fdef.node_def()) { + strings::StrAppend(&out, " ", Print(n), "\n"); + } + for (const auto& r : fdef.ret()) { + strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n"); + } + strings::StrAppend(&out, "}\n"); + return out; +} + +string Print(gtl::ArraySlice nodes) { + std::vector arg; + std::vector ret; + std::vector body; + for (const NodeDef* n : nodes) { + if (n->op() == "_Arg") { + arg.push_back(n); + } else if (n->op() == "_Retval") { + ret.push_back(n); + } else { + body.push_back(n); + } + } + auto comp = [](const NodeDef* x, const NodeDef* y) { + int xi; + TF_CHECK_OK(GetNodeAttr(*x, "index", &xi)); + int yi; + TF_CHECK_OK(GetNodeAttr(*y, "index", &yi)); + return xi < yi; + }; + std::sort(arg.begin(), arg.end(), comp); + std::sort(ret.begin(), ret.end(), comp); + string out; + strings::StrAppend(&out, "\n("); + auto get_type = [](const NodeDef& n) { + DataType dt; + if (!GetNodeAttr(n, "T", &dt).ok()) { + dt = DT_INVALID; + } + return DataTypeString(dt); + }; + for (size_t i = 0; i < arg.size(); ++i) { + const NodeDef* n = arg[i]; + if (i > 0) strings::StrAppend(&out, ", "); + CHECK_GE(n->attr_size(), 2); + strings::StrAppend(&out, n->name(), ":", get_type(*n)); + } + strings::StrAppend(&out, ") -> ("); + for (size_t i = 0; i < ret.size(); ++i) { + const NodeDef* n = ret[i]; + if (i > 0) strings::StrAppend(&out, ", "); + CHECK_LE(2, n->attr_size()); + CHECK_EQ(1, n->input_size()); + strings::StrAppend(&out, n->input(0), ":", get_type(*n)); + } + strings::StrAppend(&out, ") {\n"); + for (size_t i = 0; i < body.size(); ++i) { + strings::StrAppend(&out, " ", Print(*body[i]), "\n"); + } + strings::StrAppend(&out, "}\n"); + return out; +} + +Status AddDefaultAttrs(const string& op, + const GetFunctionSignature& get_function, + AttrValueMap* attrs) { + const OpDef* op_def = nullptr; + TF_RETURN_IF_ERROR(get_function(op, &op_def)); + AttrSlice attr_slice(attrs); + for (const auto& attr_def : op_def->attr()) { + if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) { + if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) { + return errors::Internal("Somehow duplicated: ", attr_def.name()); + } + } + } + return Status::OK(); +} + +} // end namespace + +Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, + GetFunctionSignature get_function, + InstantiationResult* result) { + VLOG(3) << "Instantiation Function: " << Print(fdef); + + const OpDef& sig = fdef.signature(); + TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values)); + + FunctionInstantiationHelper helper(get_function, result); + Status s; + for (const OpDef::ArgDef& arg_def : sig.input_arg()) { + s = helper.BuildInputArgIndex(arg_def, attr_values); + if (!s.ok()) { + errors::AppendToMessage(&s, "In ", Print(arg_def)); + return s; + } + } + + auto substitute = [attr_values](StringPiece name, AttrValue* val) { + if (const AttrValue* v = attr_values.Find(name)) { + *val = *v; + return true; + } + return false; + }; + + // Makes a copy of all attrs in fdef and substitutes placeholders. + // After this step, every attr is bound to a concrete value. + std::vector node_attrs; + node_attrs.resize(fdef.node_def_size()); + for (int i = 0; i < fdef.node_def_size(); ++i) { + for (auto attr : fdef.node_def(i).attr()) { + if (!SubstitutePlaceholders(substitute, &attr.second)) { + return errors::InvalidArgument("Failed to bind all placeholders in ", + SummarizeAttrValue(attr.second)); + } + if (!node_attrs[i].insert(attr).second) { + return errors::Internal("Somehow duplicated: ", attr.first); + } + } + TF_RETURN_IF_ERROR( + AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i])); + } + + for (int i = 0; i < fdef.node_def_size(); ++i) { + s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]), + result->nodes.size() + i); + if (!s.ok()) { + errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i))); + return s; + } + } + // Emits one node for each fdef.node_def. + for (int i = 0; i < fdef.node_def_size(); ++i) { + s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i])); + if (!s.ok()) { + errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i))); + return s; + } + } + + // Emits nodes for the function's return values. + int ret_index = 0; + for (const OpDef::ArgDef& ret_def : sig.output_arg()) { + s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), &ret_index); + if (!s.ok()) { + errors::AppendToMessage(&s, "In function output ", Print(ret_def)); + return s; + } + } + + // Adds the actual node inputs using the input indexes. + helper.AddNodeInputs(); + + return Status::OK(); +} + +string DebugString(const FunctionDef& func_def) { return Print(func_def); } + +string DebugString(const GraphDef& instantiated_func_def) { + std::vector ptrs; + for (const NodeDef& n : instantiated_func_def.node()) { + ptrs.push_back(&n); + } + return Print(ptrs); +} + +string DebugString(gtl::ArraySlice instantiated_func_nodes) { + std::vector ptrs; + for (const NodeDef& n : instantiated_func_nodes) { + ptrs.push_back(&n); + } + return Print(ptrs); +} + +string DebugStringWhole(const GraphDef& gdef) { + string ret; + for (const auto& fdef : gdef.library().function()) { + strings::StrAppend(&ret, Print(fdef)); + } + strings::StrAppend(&ret, "\n"); + for (const auto& ndef : gdef.node()) { + strings::StrAppend(&ret, Print(ndef), "\n"); + } + return ret; +} + +namespace { + +// Returns the name -> attr mapping of fdef's attrs that have a value set. In +// Python, it's possible to access unset attrs, which returns a default value +// and adds an unset attr to the map. +std::map GetSetAttrs(const FunctionDef& fdef) { + std::map set_attrs; + for (auto pair : fdef.attr()) { + if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) { + set_attrs[pair.first] = pair.second; + } + } + return set_attrs; +} + +} // end namespace + +bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) { + if (!OpDefEqual(f1.signature(), f2.signature())) return false; + + std::map f1_attrs = GetSetAttrs(f1); + std::map f2_attrs = GetSetAttrs(f2); + if (f1_attrs.size() != f2_attrs.size()) return false; + for (auto iter1 : f1_attrs) { + auto iter2 = f2_attrs.find(iter1.first); + if (iter2 == f2_attrs.end()) return false; + if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false; + } + + if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) { + return false; + } + + std::map ret1(f1.ret().begin(), f1.ret().end()); + std::map ret2(f2.ret().begin(), f2.ret().end()); + if (ret1 != ret2) return false; + + return true; +} + +uint64 FunctionDefHash(const FunctionDef& fdef) { + // signature + uint64 h = OpDefHash(fdef.signature()); + + // attrs + std::map attrs = GetSetAttrs(fdef); + for (const auto& p : attrs) { + h = Hash64(p.first.data(), p.first.size(), h); + h = Hash64Combine(AttrValueHash(p.second), h); + } + + // node defs + h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h); + + // output names + std::map ret(fdef.ret().begin(), fdef.ret().end()); + for (const auto& p : ret) { + h = Hash64(p.first.data(), p.first.size(), h); + h = Hash64(p.second.data(), p.second.size(), h); + } + + return h; +} + +string Canonicalize(const string& funcname, AttrSlice attrs) { + std::vector entries; + entries.reserve(attrs.size()); + for (auto p : attrs) { + entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + } + std::sort(entries.begin(), entries.end()); + return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]"); +} + +FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types, + DataTypeSlice ret_types) + : arg_types_(arg_types.begin(), arg_types.end()), + ret_types_(ret_types.begin(), ret_types.end()) { + args_.resize(arg_types_.size()); + rets_.resize(ret_types_.size()); +} + +FunctionCallFrame::~FunctionCallFrame() {} + +Status FunctionCallFrame::SetArgs(gtl::ArraySlice args) { + // Input type checks. + if (args.size() != arg_types_.size()) { + return errors::InvalidArgument("Expects ", arg_types_.size(), + " arguments, but ", args.size(), + " is provided"); + } + for (size_t i = 0; i < args.size(); ++i) { + if (arg_types_[i] != args[i].dtype()) { + return errors::InvalidArgument( + "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ", + DataTypeString(args[i].dtype()), " is provided"); + } + args_[i] = args[i]; + } + return Status::OK(); +} + +Status FunctionCallFrame::GetRetvals(std::vector* rets) const { + rets->clear(); + rets->reserve(rets_.size()); + for (size_t i = 0; i < rets_.size(); ++i) { + const auto& item = rets_[i]; + if (item.has_val) { + rets->push_back(item.val); + } else { + return errors::Internal("Retval[", i, "] does not have value"); + } + } + return Status::OK(); +} + +Status FunctionCallFrame::ConsumeRetvals(std::vector* rets) { + rets->clear(); + rets->reserve(rets_.size()); + for (size_t i = 0; i < rets_.size(); ++i) { + if (rets_[i].has_val) { + rets->emplace_back(std::move(rets_[i].val)); + } else { + return errors::Internal("Retval[", i, "] does not have value"); + } + } + return Status::OK(); +} + +Status FunctionCallFrame::GetArg(int index, Tensor* val) const { + if (index < 0 || static_cast(index) >= args_.size()) { + return errors::InvalidArgument("GetArg ", index, " is not within [0, ", + args_.size(), ")"); + } + *val = args_[index]; + return Status::OK(); +} + +Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { + if (index < 0 || static_cast(index) >= rets_.size()) { + return errors::InvalidArgument("SetRetval ", index, " is not within [0, ", + rets_.size(), ")"); + } + if (val.dtype() != ret_types_[index]) { + return errors::InvalidArgument( + "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]), + ", but ", DataTypeString(val.dtype()), " is provided."); + } + Retval* item = &rets_[index]; + if (!item->has_val) { + item->has_val = true; + item->val = val; + } else { + return errors::Internal("Retval[", index, "] has already been set."); + } + return Status::OK(); +} + +FunctionLibraryDefinition::FunctionDefAndOpRegistration:: + FunctionDefAndOpRegistration(const FunctionDef& fdef_in) + : fdef(fdef_in), + // Exact shape inference for functions is handled by ShapeRefiner. + // Here we pass a dummy shape inference function for legacy code paths. + op_registration_data(fdef.signature(), shape_inference::UnknownShape, + true /* is_function */) {} + +FunctionLibraryDefinition::FunctionLibraryDefinition( + const FunctionLibraryDefinition& other) + : default_registry_(other.default_registry_), func_grad_(other.func_grad_) { + for (const auto& it : other.function_defs_) { + TF_CHECK_OK(AddFunctionDef(it.second->fdef)); + } +} + +FunctionLibraryDefinition::FunctionLibraryDefinition( + const OpRegistryInterface* default_registry, + const FunctionDefLibrary& def_lib) + : default_registry_(default_registry), + function_defs_(def_lib.function_size()) { + for (const auto& fdef : def_lib.function()) { + // The latter function definition wins. + auto& ptr = function_defs_[fdef.signature().name()]; + ptr.reset(new FunctionDefAndOpRegistration(fdef)); + } + for (const auto& grad : def_lib.gradient()) { + func_grad_[grad.function_name()] = grad.gradient_func(); + } +} + +FunctionLibraryDefinition::~FunctionLibraryDefinition() {} + +const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const { + auto iter = function_defs_.find(name); + if (iter == function_defs_.end()) { + return nullptr; + } else { + return &iter->second->fdef; + } +} + +Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) { + bool added; + return AddFunctionDefHelper(fdef, &added); +} + +Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef, + bool* added) { + *added = false; + std::unique_ptr* entry = + &function_defs_[fdef.signature().name()]; + if (*entry != nullptr) { + if (!FunctionDefsEqual((*entry)->fdef, fdef)) { + return errors::InvalidArgument( + "Cannot add function '", fdef.signature().name(), + "' because a different function with the same name already " + "exists."); + } + // Ignore duplicate FunctionDefs + return Status::OK(); + } + const OpDef* op_def; + if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) { + return errors::InvalidArgument( + "Cannot add function '", fdef.signature().name(), + "' because an op with the same name already exists."); + } + entry->reset(new FunctionDefAndOpRegistration(fdef)); + *added = true; + return Status::OK(); +} + +Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) { + bool added; + return AddGradientDefHelper(grad, &added); +} + +Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad, + bool* added) { + *added = false; + string* entry = &func_grad_[grad.function_name()]; + if (!entry->empty()) { + if (*entry != grad.gradient_func()) { + return errors::InvalidArgument( + "Cannot assign gradient function '", grad.gradient_func(), "' to '", + grad.function_name(), "' because it already has gradient function ", + "'", *entry, "'"); + } + // Ignore duplicate GradientDefs + return Status::OK(); + } + *entry = grad.gradient_func(); + *added = true; + return Status::OK(); +} + +Status FunctionLibraryDefinition::AddLibrary( + const FunctionLibraryDefinition& other) { + // Remember the funcs and grads that we added successfully so that + // we can roll them back on error. + std::vector funcs; + std::vector funcs_with_grads; + Status s; + bool added; + for (auto iter : other.function_defs_) { + s = AddFunctionDefHelper(iter.second->fdef, &added); + if (!s.ok()) { + Remove(funcs, funcs_with_grads); + return s; + } + if (added) { + funcs.push_back(iter.second->fdef.signature().name()); + } + } + for (auto iter : other.func_grad_) { + GradientDef grad; + grad.set_function_name(iter.first); + grad.set_gradient_func(iter.second); + s = AddGradientDefHelper(grad, &added); + if (!s.ok()) { + Remove(funcs, funcs_with_grads); + return s; + } + if (added) { + funcs_with_grads.push_back(grad.function_name()); + } + } + return Status::OK(); +} + +Status FunctionLibraryDefinition::AddLibrary( + const FunctionDefLibrary& lib_def) { + // Remember the funcs and grads that we added successfully so that + // we can roll them back on error. + std::vector funcs; + std::vector funcs_with_grads; + Status s; + bool added; + for (const FunctionDef& fdef : lib_def.function()) { + s = AddFunctionDefHelper(fdef, &added); + if (!s.ok()) { + Remove(funcs, funcs_with_grads); + return s; + } + if (added) { + funcs.push_back(fdef.signature().name()); + } + } + for (const GradientDef& grad : lib_def.gradient()) { + s = AddGradientDefHelper(grad, &added); + if (!s.ok()) { + Remove(funcs, funcs_with_grads); + return s; + } + if (added) { + funcs_with_grads.push_back(grad.function_name()); + } + } + return Status::OK(); +} + +void FunctionLibraryDefinition::RemoveFunction(const string& func) { + const auto& i = function_defs_.find(func); + DCHECK(i != function_defs_.end()); + function_defs_.erase(i); +} + +void FunctionLibraryDefinition::RemoveGradient(const string& func) { + const auto& i = func_grad_.find(func); + DCHECK(i != func_grad_.end()); + func_grad_.erase(i); +} + +void FunctionLibraryDefinition::Remove( + const std::vector& funcs, + const std::vector& funcs_with_grads) { + for (const string& f : funcs) { + RemoveFunction(f); + } + for (const string& f : funcs_with_grads) { + RemoveGradient(f); + } +} + +string FunctionLibraryDefinition::FindGradient(const string& func) const { + return gtl::FindWithDefault(func_grad_, func, ""); +} + +Status FunctionLibraryDefinition::LookUp( + const string& op, const OpRegistrationData** op_reg_data) const { + auto iter = function_defs_.find(op); + if (iter != function_defs_.end()) { + *op_reg_data = &iter->second->op_registration_data; + return Status::OK(); + } + return default_registry_->LookUp(op, op_reg_data); +} + +const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( + const NodeDef& ndef) const { + if (ndef.op() != kGradientOp) { + // If 'ndef' calls a function and the function's def has the attr, + // returns it. + return Find(ndef.op()); + } + + // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or + // Foo's attributes. + const NameAttrList* forward_func_attrs; + if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) { + return nullptr; + } + const string& func_name = forward_func_attrs->name(); + const string& grad_name = FindGradient(func_name); + // If 'func' has a user-defined gradient function, uses the grad + // function's attrs to see if noinline is specified. Otherwise, + // uses func's attrs. + if (!grad_name.empty()) { + return Find(grad_name); + } + return Find(func_name); +} + +FunctionDefLibrary FunctionLibraryDefinition::ToProto() const { + FunctionDefLibrary lib; + for (const auto& f : function_defs_) { + *lib.add_function() = f.second->fdef; + } + for (const auto& g : func_grad_) { + GradientDef* gd = lib.add_gradient(); + gd->set_function_name(g.first); + gd->set_gradient_func(g.second); + } + return lib; +} + +template +Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, + const string& attr, T* value) const { + const FunctionDef* fdef = GetAttrImpl(ndef); + if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) { + return Status::OK(); + } + return errors::InvalidArgument("Attr ", attr, " is not defined."); +} + +template +Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr, + T* value) const { + return GetAttr(node.def(), attr, value); +} + +#define GET_ATTR(T) \ + template Status FunctionLibraryDefinition::GetAttr(const Node&, \ + const string&, T*) const; \ + template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \ + const string&, T*) const; +GET_ATTR(string) +GET_ATTR(bool) +#undef GET_ATTR + +void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) { + if (val.size() >= 2 && val[0] == '$') { + proto.set_placeholder(val.data() + 1, val.size() - 1); + } else { + SetAttrValue(val, &proto); + } +} + +FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef( + const string& name, + gtl::ArraySlice> attrs) { + AttrValueWrapper ret; + ret.proto.mutable_func()->set_name(name); + for (const auto& a : attrs) { + ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto}); + } + return ret; +} + +NodeDef FunctionDefHelper::Node::ToNodeDef() const { + NodeDef n; + n.set_op(this->op); + n.set_name(this->ret[0]); + for (const auto& a : this->attr) { + n.mutable_attr()->insert({a.first, a.second.proto}); + } + for (const string& a : this->arg) { + n.add_input(a); + } + for (const string& d : this->dep) { + n.add_input(strings::StrCat("^", d)); + } + return n; +} + +/* static */ +FunctionDef FunctionDefHelper::Create( + const string& function_name, gtl::ArraySlice in_def, + gtl::ArraySlice out_def, gtl::ArraySlice attr_def, + gtl::ArraySlice node_def, + gtl::ArraySlice> ret_def) { + FunctionDef fdef; + + // Signature + OpDefBuilder b(function_name); + for (const auto& i : in_def) b.Input(i); + for (const auto& o : out_def) b.Output(o); + for (const auto& a : attr_def) b.Attr(a); + + OpRegistrationData op_reg_data; + TF_CHECK_OK(b.Finalize(&op_reg_data)); + fdef.mutable_signature()->Swap(&op_reg_data.op_def); + + // Function body + for (const auto& n : node_def) { + *(fdef.add_node_def()) = n.ToNodeDef(); + } + + // Returns + for (const auto& r : ret_def) { + fdef.mutable_ret()->insert({r.first, r.second}); + } + return fdef; +} + +/* static */ +FunctionDef FunctionDefHelper::Define(const string& name, + gtl::ArraySlice arg_def, + gtl::ArraySlice ret_def, + gtl::ArraySlice attr_def, + gtl::ArraySlice node_def) { + FunctionDef fdef; + OpDefBuilder b(name); + for (const auto& a : arg_def) b.Input(a); + for (const auto& r : ret_def) b.Output(r); + for (const auto& a : attr_def) b.Attr(a); + + OpRegistrationData op_reg_data; + TF_CHECK_OK(b.Finalize(&op_reg_data)); + fdef.mutable_signature()->Swap(&op_reg_data.op_def); + + // Mapping from legacy output names to NodeDef outputs. + std::unordered_map ret_index; + for (const auto& a : fdef.signature().input_arg()) { + ret_index[a.name()] = a.name(); + } + + // For looking up OpDefs + auto* op_def_registry = OpRegistry::Global(); + + // Function body + for (const auto& src : node_def) { + NodeDef* n = fdef.add_node_def(); + n->set_op(src.op); + n->set_name(src.ret[0]); + for (const auto& a : src.attr) { + n->mutable_attr()->insert({a.first, a.second.proto}); + } + for (const string& a : src.arg) { + const auto iter = ret_index.find(a); + CHECK(iter != ret_index.end()) << "Node input '" << a << "' in '" + << src.ret[0] << "' of " << name; + n->add_input(iter->second); + } + for (const string& d : src.dep) { + n->add_input(strings::StrCat("^", d)); + } + + // Add the outputs of this node to ret_index. + const OpDef* op_def = nullptr; + TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op(); + CHECK(op_def != nullptr) << n->op(); + NameRangeMap output_names; + TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names)); + for (const auto& o : output_names) { + CHECK_LE(o.second.second, src.ret.size()) + << "Missing ret for output '" << o.first << "' in '" << src.ret[0] + << "' of " << name; + for (int i = o.second.first; i < o.second.second; ++i) { + ret_index[src.ret[i]] = + strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first); + } + } + } + + // Returns + for (const auto& r : fdef.signature().output_arg()) { + const auto iter = ret_index.find(r.name()); + CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name; + fdef.mutable_ret()->insert({r.name(), iter->second}); + } + return fdef; +} + +FunctionDef FunctionDefHelper::Define(gtl::ArraySlice arg_def, + gtl::ArraySlice ret_def, + gtl::ArraySlice attr_def, + gtl::ArraySlice node_def) { + return Define("_", arg_def, ret_def, attr_def, node_def); +} + +namespace gradient { + +typedef std::unordered_map OpGradFactory; + +OpGradFactory* GetOpGradFactory() { + static OpGradFactory* factory = new OpGradFactory; + return factory; +} + +bool RegisterOp(const string& op, Creator func) { + CHECK(GetOpGradFactory()->insert({op, func}).second) + << "Duplicated gradient for " << op; + return true; +} + +Status GetOpGradientCreator(const string& op, Creator* creator) { + auto fac = GetOpGradFactory(); + auto iter = fac->find(op); + if (iter == fac->end()) { + return errors::NotFound("No gradient defined for op: ", op); + } + *creator = iter->second; + return Status::OK(); +} + +} // end namespace gradient + +} // end namespace tensorflow diff --git a/function.h b/function.h new file mode 100644 index 0000000000000000000000000000000000000000..1a579ab63125ff5abc2f76d06187482234a54b9c --- /dev/null +++ b/function.h @@ -0,0 +1,625 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_ +#define TENSORFLOW_FRAMEWORK_FUNCTION_H_ + +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/selective_registration.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +class CancellationManager; +class GraphDef; +class OpKernel; +class ResourceMgr; +class Rendezvous; +class ScopedStepContainer; +class StepStatsCollector; +class Node; + +// FunctionDefHelper::Create is a convenient helper to construct a +// FunctionDef proto. +// E.g., +// FunctionDef my_func = FunctionDefHelper::Create( +// "my_func_name", +// {"x:T", "y:T" /* one string per argument */}, +// {"z:T" /* one string per return value */}, +// {"T: {float, double}" /* one string per attribute */}, +// { +// {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}} +// /* one entry per function node */ +// }, +// /* Mapping between function returns and function node outputs. */ +// {{"z", "o:z"}}); +// +// For the old Function::Node approach, use FunctionDefHelper::Define() +// E.g., +// FunctionDef my_func = FunctionDefHelper::Define( +// "my_func_name", +// {"x:T", "y:T" /* one string per argument */}, +// {"z:T" /* one string per return value */}, +// {"T: {float, double}" /* one string per attribute */}, +// { +// {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}} +// /* one entry per function node */ +// }); +class FunctionDefHelper { + public: + // AttrValueWrapper has copy constructors for the type T so that + // it's easy to construct a simple AttrValue proto. + // + // If T is a string type (const char*, string, or StringPiece), and + // it starts with "$", we construct a AttrValue of "placeholder". + // + // E.g., + // std:: x = {"T", "$T"} + // is a named attr value placeholder. + struct AttrValueWrapper { + AttrValue proto; + + AttrValueWrapper() {} + + template + AttrValueWrapper(T val) { // NOLINT(runtime/explicit) + SetAttrValue(val, &proto); + } + + private: + void InitFromString(StringPiece val); + }; + + // Constructs an AttrValue.func given the "name" and "attrs". + static AttrValueWrapper FunctionRef( + const string& name, + gtl::ArraySlice> attrs); + static AttrValueWrapper FunctionRef(const string& name) { + return FunctionRef(name, {}); + } + + // Node is used to construct FunctionDef.Node using initialization + // lists. E.g., + // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y + struct Node { + // When constructing a NodeDef, the first entry in ret is used as + // the node name, the remaining values are ignored. + std::vector ret; + string op; + std::vector arg; + std::vector> attr; + std::vector dep; + + NodeDef ToNodeDef() const; + }; + + // The Create() function uses the new NodeDef field. `ret_def` + // holds a mapping from the function output names from `out_def` to + // the node outputs from `node_def`. + static FunctionDef Create(const string& function_name, + gtl::ArraySlice in_def, + gtl::ArraySlice out_def, + gtl::ArraySlice attr_def, + gtl::ArraySlice node_def, + gtl::ArraySlice> ret_def); + + // The two Define() functions use the old FunctionDef::Node field. + // TODO(josh11b): Get rid of these and transition to the one above. + static FunctionDef Define(const string& function_name, + gtl::ArraySlice arg_def, + gtl::ArraySlice ret_def, + gtl::ArraySlice attr_def, + gtl::ArraySlice node_def); + + // Defines an anonymous function. I.e., its name is not relevant. + static FunctionDef Define(gtl::ArraySlice arg_def, + gtl::ArraySlice ret_def, + gtl::ArraySlice attr_def, + gtl::ArraySlice node_def); + + // Helpers to construct a constant scalar. + template + static Node Const(const string& name, const T& val) { + Node n = {{name}, "Const"}; + const DataType dtype = DataTypeToEnum::value; + n.attr.push_back({"dtype", dtype}); + Tensor t(dtype, TensorShape({})); + t.scalar()() = val; + n.attr.push_back({"value", t}); + return n; + } + + template + static Node Const(const string& name, gtl::ArraySlice vals) { + Node n = {{name}, "Const"}; + const DataType dtype = DataTypeToEnum::value; + n.attr.push_back({"dtype", dtype}); + int64 num = vals.size(); + Tensor t(dtype, TensorShape({num})); + for (size_t i = 0; i < vals.size(); ++i) { + t.flat()(i) = vals[i]; + } + n.attr.push_back({"value", t}); + return n; + } +}; + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) { + InitFromString(val); +} + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper( + const string& val) { + InitFromString(val); +} + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { + InitFromString(val); +} + +// Instantiate a function. +// +// "fdef" encodes a TF function with some attrs in fdef.signature.attr +// containing placeholders. InstantiateFunction binds these +// placeholders and produces an instantiated function encoded in +// "result.gdef". The value to substitute a placeholder is given by +// "attr_values", which is a map from a placeholder name to an attr +// value. +// +// InstantiateFunction calls "get_function" to find signatures of other +// functions and primitive ops. + +// GetFunctionSignature(func name, opdef) returns OK if the func name is found +// and opdef is filled with a pointer to the corresponding signature +// (a OpDef proto). Otherwise, returns an error. +typedef std::function + GetFunctionSignature; + +struct InstantiationResult { + DataTypeVector arg_types; + DataTypeVector ret_types; + std::vector nodes; +}; +Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, + GetFunctionSignature get_function, + InstantiationResult* result); + +// Returns a debug string for a function definition. +// +// The returned text is multiple-line. It is intended to be +// human-readable rather than being friendly to parsers. It is _NOT_ +// intended to be the canonical string representation of "func_def". +// Particularly, it may not include all information presented in +// "func_def" (e.g., comments, description of the function arguments, +// etc.) +string DebugString(const FunctionDef& func_def); +string DebugString(const GraphDef& instantiated_func_def); +string DebugString(gtl::ArraySlice instantiated_func_nodes); + +// Returns a debug string for a top level graph (the main program and +// its supporting functions defined in its library). +string DebugStringWhole(const GraphDef& gdef); + +// Returns true if f1 == f2. Compares all fields, including descriptions. Order +// of NodeDefs doesn't matter. +bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2); + +// Return a hash of `fdef` that is consistent with FunctionDefsEqual method. +// In other words, if two fdefs compare equal, their hash values will be the +// same. +uint64 FunctionDefHash(const FunctionDef& fdef); + +// Returns a canonicalized string for the instantiation of the +// function of the given "name" and attributes "attrs". +// +// The returned string is guaranteed to be stable within one address +// space. But it may be change as the implementation +// evolves. Therefore, it should not be persisted or compared across +// address spaces. +string Canonicalize(const string& funcname, AttrSlice attrs); + +class CallFrameInterface { + public: + virtual ~CallFrameInterface() {} + + virtual size_t num_args() const = 0; + virtual size_t num_retvals() const = 0; + + virtual Status GetArg(int index, Tensor* val) const = 0; + virtual Status SetRetval(int index, const Tensor& val) = 0; +}; + +// Represents a function call frame. I.e., the data structure used to +// pass arguments to a function and retrieve its results. +// +// Runtime must arrange accesses to one FunctionCallFrame s.t. +// 1. SetArgs() happens before any GetArg(); +// 2. GetRetvals happens after all SetRetval(); +class FunctionCallFrame : public CallFrameInterface { + public: + FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types); + ~FunctionCallFrame(); + + // Caller methods. + Status SetArgs(gtl::ArraySlice args); + Status GetRetvals(std::vector* rets) const; + Status ConsumeRetvals(std::vector* rets); + + size_t num_args() const override { return arg_types_.size(); } + size_t num_retvals() const override { return ret_types_.size(); } + + // Callee methods. + Status GetArg(int index, Tensor* val) const override; + Status SetRetval(int index, const Tensor& val) override; + + private: + DataTypeVector arg_types_; + DataTypeVector ret_types_; + gtl::InlinedVector args_; + struct Retval { + bool has_val = false; + Tensor val; + }; + gtl::InlinedVector rets_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame); +}; + +// Helper to maintain a map between function names in a given +// FunctionDefLibrary and function definitions. +class FunctionLibraryDefinition : public OpRegistryInterface { + public: + explicit FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def); + FunctionLibraryDefinition(const OpRegistryInterface* default_registry, + const FunctionDefLibrary& lib_def); + ~FunctionLibraryDefinition() override; + + FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) = + delete; + + // Returns nullptr if "func" is not defined in "lib_def". Otherwise, + // returns its definition proto. + const FunctionDef* Find(const string& func) const; + + // Adds function definition 'fdef' to this function library. + // Returns status 'ok' on success, or error otherwise. This is a no-op if + // 'fdef' already exists in this function library. + // If 'fdef' is successfully added to the library, it will be accessible + // from 'LookUp' and included in the proto returned by 'ToProto'. + // This operation is atomic. + Status AddFunctionDef(const FunctionDef& fdef); + + // Adds gradient definition 'grad' to this function library. + // This is a no-op if 'grad' already exists in this function library. + // If 'grad' is successfully added, it will be accessible via 'FindGradient' + // and included in the proto returned by 'ToProto'. + // This operation is atomic. + Status AddGradientDef(const GradientDef& grad); + + // Adds the functions and gradients in 'other' to this function library. + // Duplicate functions and gradients are ignored. + // This operation is atomic. + Status AddLibrary(const FunctionLibraryDefinition& other); + + // Adds the functions and gradients in 'lib_def' to this function library. + // Duplicate functions and gradients are ignored. + // This operation is atomic. + Status AddLibrary(const FunctionDefLibrary& lib_def); + + // If the gradient function for 'func' is specified explicitly in + // the library, returns the gradient function name. Otherwise, + // returns an empty string. + string FindGradient(const string& func) const; + + // OpRegistryInterface method. Useful for constructing a Graph. + // + // If "op" is defined in the library, returns its signature. + // Otherwise, assume "op" is a primitive op and returns its op + // signature and shape inference function. + Status LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const override; + + static constexpr const char* const kGradientOp = "SymbolicGradient"; + static constexpr const char* const kFuncAttr = "f"; + + // Given a node def 'ndef', inspects attributes of the callee + // function to derive the attribute 'value' for 'attr'. Returns OK + // iff the attribute is given by the function's definition. + // TODO(irving): Remove; keep only the const Node& version. + template + Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const; + + // Given a node, inspects attributes of the callee function to derive the + // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the + // function's definition. + template + Status GetAttr(const Node& node, const string& attr, T* value) const; + + // Returns a proto representation of the state of this function library. + FunctionDefLibrary ToProto() const; + + const OpRegistryInterface* default_registry() const { + return default_registry_; + } + + private: + // Shape inference for functions is handled separately by ShapeRefiner. + + struct FunctionDefAndOpRegistration { + FunctionDefAndOpRegistration(const FunctionDef& fdef_in); + + FunctionDef fdef; + OpRegistrationData op_registration_data; + }; + + // Same as AddFunctionDef/AddGradientDef except these methods set + // `added` to true if the `fdef`/`grad` were actually added to this. + Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added); + Status AddGradientDefHelper(const GradientDef& grad, bool* added); + + const OpRegistryInterface* const default_registry_; + gtl::FlatMap> + function_defs_; + gtl::FlatMap func_grad_; + + // Helper function for GetAttr. Returns the FunctionDef* to get the + // attr from. + const FunctionDef* GetAttrImpl(const NodeDef& ndef) const; + + // Remove function `func` from the library. `func` must be in the library. + void RemoveFunction(const string& func); + + // Remove gradient of function `func` from the library. `func` must have + // a gradient. + void RemoveGradient(const string& func); + + // Remove all functions in `funcs` and all gradients of + // functions in `funcs_with_grads` from this library. + void Remove(const std::vector& funcs, + const std::vector& funcs_with_grads); +}; + +// Forward declare. Defined in common_runtime/function.h +struct FunctionBody; + +// Forward declare. Defined in common_runtime/device.h +class Device; + +class FunctionLibraryRuntime { + public: + virtual ~FunctionLibraryRuntime() {} + + // Instantiate a function with the given "attrs". + // + // Returns OK and fills in "handle" if the instantiation succeeds. + // Otherwise returns an error and "handle" is undefined. + typedef uint64 Handle; + virtual Status Instantiate(const string& function_name, AttrSlice attrs, + Handle* handle) = 0; + + // Releases state associated with the handle. + virtual Status ReleaseHandle(Handle handle) = 0; + + // Returns the function body for the instantiated function given its + // handle 'h'. Returns nullptr if "h" is not found. + // + // *this keeps the ownership of the returned object, which remains alive + // as long as *this. + virtual const FunctionBody* GetFunctionBody(Handle h) = 0; + + // Asynchronously invokes the instantiated function identified by + // "handle". + // + // If function execution succeeds, "done" is called with OK and + // "*rets" is filled with the function's return values. Otheriwse, + // "done" is called with an error status. + // + // Does not take ownership of "rets". + // In the cross-process scenario, runner isn't used for making the Async + // RPC calls. + struct Options { + // The id of the step that is calling this function. + int64 step_id = 0; + Rendezvous* rendezvous = nullptr; + CancellationManager* cancellation_manager = nullptr; + ScopedStepContainer* step_container = nullptr; + StepStatsCollector* stats_collector = nullptr; + + std::function)>* runner = nullptr; + + // Parameters for remote function execution. + bool remote_execution = false; + string source_device = ""; // Fully specified device name. + + // Allocator attributes specifying where the args are / rets should be put. + // These should either be {} or match the length of args / retvals. If {}, + // the default allocator attributes will be assumed for all args / retvals. + std::vector args_alloc_attrs; + std::vector rets_alloc_attrs; + + // If true, we create a new IntraProcessRendezvous, else use the existing + // one. + bool create_rendezvous = false; + }; + typedef std::function DoneCallback; + virtual void Run(const Options& opts, Handle handle, + gtl::ArraySlice args, std::vector* rets, + DoneCallback done) = 0; + virtual void Run(const Options& opts, Handle handle, + CallFrameInterface* call_frame, DoneCallback done) = 0; + + // Creates a "kernel" for the given node def "ndef". + // + // If succeeds, returns OK and the caller takes the ownership of the + // returned "*kernel". Otherwise, returns an error. + virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0; + + // Returns true iff 'function' is stateful. + virtual bool IsStateful(const string& function_name) = 0; + + // Returns the device on which the function executes. + virtual Device* device() = 0; + + // Returns the function library definition that backs this runtime. + virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition() + const = 0; + + // Returns the environment on which the function executes. + virtual Env* env() = 0; + + // Returns a debug string showing the definition of the function of + // 'handle'. + virtual string DebugString(Handle handle) = 0; + + // Returns the graph version number. + virtual int graph_def_version() = 0; + + typedef uint64 LocalHandle; +}; + +const FunctionLibraryRuntime::Handle kInvalidHandle = -1; +const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1; +typedef std::function*)> + CustomKernelCreator; + +// Used to instantiate and run functions in a distributed system. +class DistributedFunctionLibraryRuntime { + public: + virtual ~DistributedFunctionLibraryRuntime() {} + + // The _target attr in attrs determines where the function is instantiated. + virtual Status Instantiate(const string& function_name, + const FunctionLibraryDefinition& lib_def, + AttrSlice attrs, + FunctionLibraryRuntime::LocalHandle* handle) = 0; + + // opts.runner isn't used for execution. + virtual void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, + gtl::ArraySlice args, std::vector* rets, + FunctionLibraryRuntime::DoneCallback done) = 0; +}; + +// Extracts the actual type from "attr_values" based on its definition +// "arg_def". +// +// If "arg_def" is a N*T type, *is_type_list is set to false, and +// *dtypes is set to be a vector of size N and each element is T. +// +// If "arg_def" is a list(type), *is_type_list is set to true, and +// *dtypes is set to be a vector of types specified in attrs for +// arg_def. +// +// Otherwise (arg_def is a simple type T), *is_type_list is set to +// false, and *dtypes is set to a single element vector, whose only +// element is T. +Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, + bool* is_type_list, DataTypeVector* dtypes); + +// To register a gradient function for a builtin op, one should use +// REGISTER_OP_GRADIENT(, ); +// +// Typically, the c++ grad factory is a plan function that can be +// converted into ::tensorflow::gradient::Creator, which is +// std::function. +// +// A ::tensorflow::gradient::Creator should populate in FunctionDef* with a +// definition of a brain function which compute the gradient for the +// when the is instantiated with the given attrs. +// +// E.g., +// +// Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { +// bool transpose_a; +// TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a)); +// bool transpose_b; +// TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b)); +// DataType dtype; +// TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype)); +// if (!transpose_a && !transpose_b) { +// *g = FunctionDefHelper::Define( +// "MatMulGrad", +// {"x:T ", "y:T", "dz:T"}, // Inputs to this function +// {"dx:T", "dy:T"}, // Outputs from this function +// {"T: {float, double}"}, // Attributes needed by this function +// { +// {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}}, +// {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}}, +// {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}}, +// {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}}, +// }); +// } else { +// ... ... +// } +// return Status::OK(); +// } +// +// NOTE: $T is substituted with the type variable "T" when the +// gradient function MatMul is instantiated. +// +// TODO(zhifengc): Better documentation somewhere. + +// Macros to define a gradient function factory for a primitive +// operation. +#define REGISTER_OP_GRADIENT(name, fn) \ + REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn) + +#define REGISTER_OP_NO_GRADIENT(name) \ + REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr) + +#define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \ + REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) + +#define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \ + static bool unused_grad_##ctr = SHOULD_REGISTER_OP_GRADIENT && \ + ::tensorflow::gradient::RegisterOp(name, fn) + +namespace gradient { +// Register a gradient creator for the "op". +typedef std::function Creator; +bool RegisterOp(const string& op, Creator func); + +// Returns OK the gradient creator for the "op" is found (may be +// nullptr if REGISTER_OP_NO_GRADIENT is used. +Status GetOpGradientCreator(const string& op, Creator* creator); +}; + +// Declare explicit instantiations of GetAttr +#define GET_ATTR(T) \ + extern template Status FunctionLibraryDefinition::GetAttr( \ + const Node&, const string&, T*) const; \ + extern template Status FunctionLibraryDefinition::GetAttr( \ + const NodeDef&, const string&, T*) const; +GET_ATTR(string) +GET_ATTR(bool) +#undef GET_ATTR + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_ diff --git a/function.proto b/function.proto new file mode 100644 index 0000000000000000000000000000000000000000..bd01e86da3a646b7e6331301a03e80e90d2ce6ee --- /dev/null +++ b/function.proto @@ -0,0 +1,101 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/attr_value.proto"; +import "tensorflow/core/framework/node_def.proto"; +import "tensorflow/core/framework/op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// +// TODO(zhifengc): +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21. + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/function_test.cc b/function_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..23685e9c536a67ca33fbabdab438e7192c8a47fc --- /dev/null +++ b/function_test.cc @@ -0,0 +1,1339 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/function.h" +#include +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +// A helper class to make AttrSlice from initializer lists +class Attrs { + public: + Attrs(const std::initializer_list< // NOLINT(runtime/explicit) + std::pair> + attrs) { + for (const auto& aval : attrs) { + map_.insert({aval.first, aval.second.proto}); + } + } + + operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) + + private: + AttrValueMap map_; +}; + +typedef FunctionDefHelper FDH; + +Status GetOpSig(const string& op, const OpDef** sig) { + return OpRegistry::Global()->LookUpOpDef(op, sig); +} + +REGISTER_OP("One") + .Output("y: T") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Returns a tensor with a single element (1) of type T. + +y: A scalar in type T. + +)doc"); + +TEST(TFunc, SquarePlusOne) { + auto fdef = FDH::Create( + // Name + "SquarePlusOne", + // Inputs + {"x: T"}, + // Outputs + {"y: T"}, + // Attrs + {"T: {float, double, int32, int64}"}, + // Nodes + {// a = Square(x) + {{"a"}, "Square", {"x"}, {{"T", "$T"}}}, + // o = One() + // NOTE: We can also have a Cast(x) instead. + {{"o"}, "One", {}, {{"T", "$T"}}}, + // y = Add(a, o) + {{"y"}, "Add", {"a:y", "o:y"}, {{"T", "$T"}}}}, + // Returns + {{"y", "y:z:0"}}); + + const char* e = R"P( +SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) { + a = Square[T=$T](x) + o = One[T=$T]() + y = Add[T=$T](a:y, o:y) + return y = y:z:0 +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + // Instantiate one with T=float + InstantiationResult result; + TF_ASSERT_OK( + InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result)); + const char* e2 = R"P( +(x:float) -> (y:float) { + a = Square[T=float](x) + o = One[T=float]() + y = Add[T=float](a, o) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.nodes), e2); +} + +TEST(TFunc, ControlDep) { + auto fdef = FDH::Create( + // Name + "ControlDep", + // Inputs + {"x: int32"}, + // Outputs + {"y: int32"}, + // Attrs + {}, + // Nodes + {// a = Identity(x) + {{"a"}, "Identity", {"x"}, {{"T", DT_INT32}}}, + // o = NoOp(^a) + {{"o"}, "NoOp", {"^a"}, {}}, + // y = Identity(a, ^o) + {{"y"}, "Identity", {"a:output:0", "^o"}, {{"T", DT_INT32}}}}, + // Returns + {{"y", "y:output:0"}}); + + const char* e = R"P( +ControlDep(x:int32) -> (y:int32) { + a = Identity[T=int32](x) + o = NoOp() @ a + y = Identity[T=int32](a:output:0) @ o + return y = y:output:0 +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + // Instantiate one with T=float + InstantiationResult result; + TF_ASSERT_OK( + InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result)); + const char* e2 = R"P( +(x:int32) -> (y:int32) { + a = Identity[T=int32](x) + o = NoOp() @ a + y = Identity[T=int32](a) @ o +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32})); + EXPECT_EQ(DebugString(result.nodes), e2); +} + +REGISTER_OP("HasDefaultType") + .Output("out: T") + .Attr("T: {float, double, int32, int64} = DT_FLOAT"); + +// This verifies that a function using an op before a type attr (with +// a default) is added, still works. This is important for backwards +// compatibility. +TEST(TFunc, MissingTypeAttr) { + auto fdef = FDH::Create( + // Name + "BackCompat", + // Args + {}, + // Return values + {"y: float"}, + // Attrs + {}, + // Nodes + {// y = HasDefaultType(x), T missing, defaults to float + {{"a"}, "HasDefaultType", {}, {}}}, + // Returns + {{"y", "a:out:0"}}); + + const char* e = R"P( +BackCompat() -> (y:float) { + a = HasDefaultType() + return y = a:out:0 +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + InstantiationResult result; + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); + // Should get T=float from Op's default. + const char* e2 = R"P( +() -> (a:float) { + a = HasDefaultType[T=float]() +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector()); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.nodes), e2); +} + +TEST(TFunc, NTimesT) { + auto fdef = FDH::Create( + // Name + "NTimesT", + // Inputs + {"x: float", "y: float"}, + // Outputs + {"z: float"}, + // Attrs + {}, + // Nodes + {// a = AddN(x, y) + {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 2}}}}, + // Returns + {{"z", "a:sum:0"}}); + + const char* e = R"P( +NTimesT(x:float, y:float) -> (z:float) { + a = AddN[N=2, T=float](x, y) + return z = a:sum:0 +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + InstantiationResult result; + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); + const char* e2 = R"P( +(x:float, y:float) -> (a:float) { + a = AddN[N=2, T=float](x, y) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.nodes), e2); +} + +// NOTE: This is the simplest Map op. It takes a f:T->U. +REGISTER_OP("Map") + .Input("x: N * T") + .Output("y: N * U") + .Attr("T: type") + .Attr("U: type") + .Attr("N: int >= 1") + // .Attr("func: func_name_with_attr") + .Doc(R"doc( +Applies the 'func' on every input. I.e., + +y[i] = func<...>(x[i]) + +x: N tensors, each of type T; +y: N tensors, each of type U; + +)doc"); + +TEST(TFunc, AddSquared) { + auto fdef = FDH::Create( + // Name + "AddSquared", + // Args + {"x: N*T"}, + // Return values + {"y: T"}, + // Attrs + {"N:int", "T:{float, double, int32, int64}"}, + // Nodes + {// a = Map,T=$T,U=$T,N=$N>(x) + {{"a"}, + "Map", + {"x"}, + {{"func", FDH::FunctionRef("Square", {{"T", "$T"}})}, + {"T", "$T"}, + {"U", "$T"}, + {"N", "$N"}}}, + // y = AddN(a) + {{"y"}, "AddN", {"a:y"}, {{"N", "$N"}, {"T", "$T"}}}}, + {{"y", "y:sum"}}); + + const char* e = R"P( +AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) { + a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x) + y = AddN[N=$N, T=$T](a:y) + return y = y:sum +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + // Instantiate one with T=float + InstantiationResult result; + TF_ASSERT_OK(InstantiateFunction(fdef, Attrs({{"N", 3}, {"T", DT_FLOAT}}), + GetOpSig, &result)); + const char* e2 = R"P( +(x_0:float, x_1:float, x_2:float) -> (y:float) { + a = Map[N=3, T=float, U=float, func=Square[T=float]](x_0, x_1, x_2) + y = AddN[N=3, T=float](a, a:1, a:2) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.nodes), e2); +} + +TEST(TFunc, ControlDeps) { + auto fdef = FDH::Define( + // Name + "ControlDeps", + // Args + {"x: float"}, + // Return values + {}, + // Attrs + {}, + // Nodes + { + {{"a"}, "One", {}, {{"T", DT_FLOAT}}, {"x"}}, + {{"u"}, "NoOp", {}, {}, {"a"}}, + {{"b"}, "One", {}, {{"T", DT_FLOAT}}, {"u"}}, + {{"v"}, "NoOp", {}, {}, {"b"}}, + {{"c"}, "One", {}, {{"T", DT_FLOAT}}, {"a", "v"}}, + }); + const char* e = R"P( +ControlDeps(x:float) -> () { + a = One[T=float]() @ x + u = NoOp() @ a + b = One[T=float]() @ u + v = NoOp() @ b + c = One[T=float]() @ a, v +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + InstantiationResult result; + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); + const char* e2 = R"P( +(x:float) -> () { + a = One[T=float]() @ x + u = NoOp() @ a + b = One[T=float]() @ u + v = NoOp() @ b + c = One[T=float]() @ a, v +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({})); + EXPECT_EQ(DebugString(result.nodes), e2); +} + +TEST(TFunc, XTimesTwo) { + auto expect = R"P( +XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { + two = Const[dtype=int64, value=Tensor]() + scale = Cast[DstT=$T, SrcT=int64](two:output:0) + y = Mul[T=$T](x, scale:y:0) + return y = y:z:0 +} +)P"; + EXPECT_EQ(expect, DebugString(test::function::XTimesTwo())); +} + +TEST(TFunc, WXPlusB) { + auto expect = R"P( +WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) { + mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x) + y = Add[T=$T](mm:product:0, b) + return y = y:z:0 +} +)P"; + EXPECT_EQ(expect, DebugString(test::function::WXPlusB())); +} + +TEST(TFunc, Body_TypeList) { + const Tensor kZero = test::AsScalar(0); + auto fdef = FDH::Create( + // Name + "Test", + // Args + {"i:float"}, + // Return values + {"o:float"}, + // Attrs + {}, + // Nodes + {{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}}, + {{"s"}, + "Split", + {"zero:output:0", "i"}, + {{"num_split", 4}, {"T", DT_FLOAT}}}, + {{"l"}, "Mul", {"s:output:0", "s:output:1"}, {{"T", DT_FLOAT}}}, + {{"r"}, "Mul", {"s:output:2", "s:output:3"}, {{"T", DT_FLOAT}}}, + {{"x"}, + "_ListToArray", + {"l:z", "r:z"}, + {{"N", 2}, + {"T", DT_FLOAT}, + {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}}, + {{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}}, + {{"o", "o:sum:0"}}); + + const char* e = R"P( +Test(i:float) -> (o:float) { + zero = Const[dtype=int32, value=Tensor]() + s = Split[T=float, num_split=4](zero:output:0, i) + l = Mul[T=float](s:output:0, s:output:1) + r = Mul[T=float](s:output:2, s:output:3) + x = _ListToArray[N=2, T=float, Tin={float, float}](l:z, r:z) + o = AddN[N=2, T=float](x:output) + return o = o:sum:0 +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + InstantiationResult result; + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); + const char* e2 = R"P( +(i:float) -> (o:float) { + zero = Const[dtype=int32, value=Tensor]() + s = Split[T=float, num_split=4](zero, i) + l = Mul[T=float](s, s:1) + r = Mul[T=float](s:2, s:3) + x = _ListToArray[N=2, T=float, Tin={float, float}](l, r) + o = AddN[N=2, T=float](x, x:1) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.nodes), e2); +} + +REGISTER_OP("Cond") + .Input("input: Tin") + .Output("output: out_types") + .Attr("Tin: list(type)") + .Attr("out_types: list(type)") + .Attr("cond: func") + .Attr("then_branch: func") + .Attr("else_branch: func") + .Doc(R"doc( +output = Cond(input) ? then_branch(input) : else_branch(input) + +cond: A function takes 'input' and returns a scalar. +then_branch: A function takes 'input' and returns 'output'. +else_branch: A function takes 'input' and returns 'output'. +)doc"); + +TEST(TFunc, Body_Array_List_Converter) { + auto fdef = FDH::Define( + // Name + "MySelect", + // Args + {"x:float"}, + // Return values + {"z:float"}, + // Attrs + {}, + // Nodes + { + {{"y"}, + "Cond", + {"x"}, + {{"Tin", DataTypeSlice{DT_FLOAT}}, + {"out_types", DataTypeSlice{DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond")}, + {"then_branch", FDH::FunctionRef("MyThen")}, + {"else_branch", FDH::FunctionRef("MyElse")}}}, + {{"z"}, + "Cond", + {"y", "y"}, + {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"out_types", DataTypeSlice{DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond2")}, + {"then_branch", FDH::FunctionRef("MyThen2")}, + {"else_branch", FDH::FunctionRef("MyElse2")}}}, + }); + + const char* e = R"P( +MySelect(x:float) -> (z:float) { + y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x) + z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y:output:0, y:output:0) + return z = z:output:0 +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + InstantiationResult result; + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); + const char* e2 = R"P( +(x:float) -> (z:float) { + y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x) + z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.nodes), e2); +} + +static void HasError(const Status& s, const string& substr) { + EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) + << ">>" << s << "<<, expected substring >>" << substr << "<<"; +} + +TEST(InstantiateErrors, Not_Sufficient_Attrs) { + auto fdef = + FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); + InstantiationResult result; + HasError( + InstantiateFunction(fdef, Attrs({{"U", DT_FLOAT}}), GetOpSig, &result), + "Attr T is not found from "); +} + +#if 0 // TODO(josh11b): Enable this test once having an extra attr is an error. +TEST(InstantiateErrors, Too_Many_Attrs) { + auto fdef = + FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, Attrs({{"T", DT_INT32}, {"U", DT_FLOAT}}), + GetOpSig, &result), + "Attr U is not found in "); +} +#endif + +TEST(InstantiateErrors, AttrValue_Value_Placeholder) { + auto fdef = + FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); + InstantiationResult result; + HasError( + InstantiateFunction(fdef, Attrs({{"T", "$bad"}}), GetOpSig, &result), + "AttrValue had value with unexpected type 'placeholder'\n\tfor attr 'T'"); +} + +TEST(InstantiateErrors, Unbounded_Attr) { + auto fdef = FDH::Define("test", {}, {}, {"T:{float, double, int32, int64}"}, + { + {{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}}, + }); + InstantiationResult result; + HasError( + InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result), + "Failed to bind all placeholders"); +} + +TEST(InstantiateErrors, DupArgs) { + auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "Duplicated arg name"); +} + +TEST(InstantiateErrors, Dup_Node_Names) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y"}, "One", {}, {{"T", DT_FLOAT}}}, + {{"y"}, "One", {}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "Duplicated ret name"); +} + +TEST(InstantiateErrors, Node_Arg_Notfound) { + auto fdef = FDH::Create("test", {"x:float"}, {}, {}, + { + {{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}}, + }, + {}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "input z is not found"); +} + +TEST(InstantiateErrors, Node_Arg_TypeMismatch) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "input x[0] expected type int32 != float, the type of x[0]"); +} + +TEST(InstantiateErrors, Node_Arg_ControlMissing) { + auto fdef = + FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "input[2] == '^z', is not found."); +} + +TEST(InstantiateErrors, FuncRet_Missing) { + auto fdef = FDH::Create("test", {}, {"y: float"}, {}, + { + {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, + }, + {}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "Return y missing"); +} + +TEST(InstantiateErrors, FuncRet_NotFound) { + auto fdef = FDH::Create("test", {}, {"y: float"}, {}, + { + {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, + }, + {{"y", "z"}}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "Return y -> z is not found"); +} + +TEST(InstantiateErrors, FuncRet_NameMismatch) { + auto fdef = FDH::Create("test", {}, {"y: float"}, {}, + { + {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, + }, + {{"z", "x:y:0"}}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "Return y missing"); +} + +// TODO(josh11b): Make this an error. +// TEST(InstantiateErrors, FuncRet_Extra) { +// auto fdef = FDH::Create("test", {}, {"y: float"}, {}, +// { +// {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, +// }, +// {{"y", "x:y:0"}, {"z", "x:y:0"}}); +// InstantiationResult result; +// HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), +// "ret is not found"); +// } + +TEST(InstantiateErrors, FuncRet_TypeMismatch) { + auto fdef = FDH::Define("test", {}, {"y: float"}, {}, + { + {{"y"}, "One", {}, {{"T", DT_DOUBLE}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "Invalid ret types y : float vs. double\n\tIn function output y"); +} + +TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) { + auto fdef = FDH::Create( + // Name + "MySelect", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attrs + {}, + // Nodes + { + {{"y"}, + "Cond", + {"x", "x"}, + {{"tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond2")}, + {"then_branch", FDH::FunctionRef("MyThen2")}, + {"else_branch", FDH::FunctionRef("MyElse2")}}}, + }, + {{"y", "y:output"}}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "type attr not found: out_types"); +} + +TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) { + auto fdef = FDH::Create( + // Name + "MySelect", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attrs + {}, + // Nodes + { + {{"y"}, + "Cond", + {"x", "x"}, + {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond2")}, + {"then_branch", FDH::FunctionRef("MyThen2")}, + {"else_branch", FDH::FunctionRef("MyElse2")}}}, + }, + {{"y", "y:output"}}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "Invalid ret types"); +} + +TEST(InstantiateErrors, TypeList_Missing_Arg) { + auto fdef = FDH::Create( + // Name + "MySelect", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attrs + {}, + // Nodes + { + {{"y"}, + "Cond", + {"x", "unknown"}, + {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"out_types", DataTypeSlice{DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond2")}, + {"then_branch", FDH::FunctionRef("MyThen2")}, + {"else_branch", FDH::FunctionRef("MyElse2")}}}, + }, + {{"y", "y:output"}}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "input unknown is not found"); +} + +TEST(InstantiateErrors, TooManyInputs) { + auto fdef = FDH::Create( + // Name + "TooManyInputs", + // Inputs + {"x: float", "y: float"}, + // Outputs + {"z: float"}, + // Attrs + {}, + // Nodes + {// a = AddN(x, y, x) + {{"a"}, "AddN", {"x", "y", "x"}, {{"T", DT_FLOAT}, {"N", 2}}}}, + // Returns + {{"z", "a:sum:0"}}); + + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "Expected input[2] == 'x' to be a control input."); +} + +TEST(InstantiateErrors, TooFewInputs) { + auto fdef = FDH::Create( + // Name + "TooFewInputs", + // Inputs + {"x: float", "y: float"}, + // Outputs + {"z: float"}, + // Attrs + {}, + // Nodes + {// a = AddN(x, y) + {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 3}}}}, + // Returns + {{"z", "a:sum:0"}}); + + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "Attempt to access beyond input size: 2 >= 2"); +} + +TEST(InstantiateErrors, TooManyInputsFromArray1) { + auto fdef = FDH::Create( + // Name + "TooManyInputsFromArray", + // Inputs + {"x: float", "y: float"}, + // Outputs + {"z: float"}, + // Attrs + {}, + // Nodes + {// a = _ListToArray(x,y) + {{"a"}, + "_ListToArray", + {"x", "y"}, + {{"N", 2}, + {"T", DT_FLOAT}, + {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}}, + // b = AddN(a, y) + {{"b"}, "AddN", {"a:output", "y"}, {{"T", DT_FLOAT}, {"N", 2}}}}, + // Returns + {{"z", "a:sum:0"}}); + + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "Expected input[1] == 'y' to be a control input."); +} + +TEST(InstantiateErrors, TooManyInputsFromArray2) { + auto fdef = FDH::Create( + // Name + "TooManyInputsFromArray", + // Inputs + {"x: float", "y: float"}, + // Outputs + {"z: float"}, + // Attrs + {}, + // Nodes + {// a = _ListToArray(x,y) + {{"a"}, + "_ListToArray", + {"x", "y"}, + {{"N", 2}, + {"T", DT_FLOAT}, + {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}}, + // b = AddN(x, a) + {{"b"}, "AddN", {"x", "a:output"}, {{"T", DT_FLOAT}, {"N", 2}}}}, + // Returns + {{"z", "a:sum:0"}}); + + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "Input a:output too long for inputs"); +} + +TEST(InstantiateErrors, TypeMismatch) { + auto fdef = FDH::Create( + // Name + "TypeMismatch", + // Inputs + {"x: float", "y: int32"}, + // Outputs + {"z: float"}, + // Attrs + {}, + // Nodes + {// a = AddN(x, y) + {{"a"}, "AddN", {"x", "y"}, {{"T", DT_FLOAT}, {"N", 3}}}}, + // Returns + {{"z", "a:sum:0"}}); + + InstantiationResult result; + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), + "input inputs[1] expected type float != int32, the type of y[0]"); +} + +TEST(FunctionCallFrame, Void_Void) { + FunctionCallFrame frame({}, {}); + TF_EXPECT_OK(frame.SetArgs({})); + auto a = test::AsTensor({100}); + HasError(frame.SetArgs({a}), "Invalid argument"); + Tensor v; + HasError(frame.GetArg(0, &v), "Invalid argument"); + HasError(frame.SetRetval(0, v), "Invalid argument"); + std::vector rets; + TF_EXPECT_OK(frame.GetRetvals(&rets)); + EXPECT_EQ(rets.size(), 0); +} + +TEST(FunctionCallFrame, Float_Float_Float) { + FunctionCallFrame frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT}); + HasError(frame.SetArgs({}), "Invalid argument: Expects 2 arguments"); + auto a = test::AsTensor({100}); + auto b = test::AsTensor({200}); + auto c = test::AsTensor({300}); + HasError(frame.SetArgs({a, c}), + "Invalid argument: Expects arg[1] to be float"); + TF_EXPECT_OK(frame.SetArgs({a, b})); + + Tensor v; + HasError(frame.GetArg(-1, &v), "Invalid argument"); + HasError(frame.GetArg(2, &v), "Invalid argument"); + TF_EXPECT_OK(frame.GetArg(0, &v)); + test::ExpectTensorEqual(a, v); + TF_EXPECT_OK(frame.GetArg(1, &v)); + test::ExpectTensorEqual(b, v); + + v = test::AsTensor({-100}); + HasError(frame.SetRetval(-1, v), "Invalid argument"); + HasError(frame.SetRetval(1, v), "Invalid argument"); + HasError(frame.SetRetval(0, test::AsTensor({-100})), + "Invalid argument: Expects ret[0] to be float"); + + std::vector rets; + HasError(frame.GetRetvals(&rets), "does not have value"); + TF_EXPECT_OK(frame.SetRetval(0, v)); + HasError(frame.SetRetval(0, v), "has already been set"); + + TF_EXPECT_OK(frame.GetRetvals(&rets)); + EXPECT_EQ(rets.size(), 1); + test::ExpectTensorEqual(rets[0], v); +} + +TEST(Canonicalize, Basic) { + EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT}, + {"transpose_a", false}, + {"transpose_b", false}})), + "MatMul[T=float,transpose_a=false,transpose_b=false]"); + EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT}, + {"transpose_b", false}, + {"transpose_a", false}})), + "MatMul[T=float,transpose_a=false,transpose_b=false]"); + EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_DOUBLE}, + {"transpose_b", true}, + {"transpose_a", false}})), + "MatMul[T=double,transpose_a=false,transpose_b=true]"); +} + +TEST(FunctionLibraryDefinitionTest, Find) { + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); + + EXPECT_EQ(lib_def.Find("XTimes16"), nullptr); + + auto expect = R"P( +XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { + two = Const[dtype=int64, value=Tensor]() + scale = Cast[DstT=$T, SrcT=int64](two:output:0) + y = Mul[T=$T](x, scale:y:0) + return y = y:z:0 +} +)P"; + auto found = lib_def.Find("XTimesTwo"); + ASSERT_NE(found, nullptr); + EXPECT_EQ(expect, DebugString(*found)); +} + +TEST(FunctionLibraryDefinitionTest, LookUp) { + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); + + const OpDef* op_def; + EXPECT_TRUE(!lib_def.LookUpOpDef("XTimes16", &op_def).ok()); + + TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &op_def)); + ASSERT_NE(op_def, nullptr); + EXPECT_EQ(op_def->DebugString(), + test::function::XTimesTwo().signature().DebugString()); + + const OpRegistrationData* op_reg_data; + TF_EXPECT_OK(lib_def.LookUp("XTimesTwo", &op_reg_data)); + ASSERT_NE(op_reg_data, nullptr); + // Shape inference function is initialized to UnknownShape. + ASSERT_NE(op_reg_data->shape_inference_fn, nullptr); +} + +TEST(FunctionLibraryDefinitionTest, AddFunctionDef) { + // Add one function to the proto lib before constructing 'lib_def'. + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); + + // Add a new function def to the library. + TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB())); + + // Test lookup of first function. + const OpDef* first; + TF_EXPECT_OK(lib_def.LookUpOpDef("XTimesTwo", &first)); + ASSERT_NE(first, nullptr); + EXPECT_EQ(first->DebugString(), + test::function::XTimesTwo().signature().DebugString()); + + // Test lookup of second function. + const OpDef* second; + TF_EXPECT_OK(lib_def.LookUpOpDef("WXPlusB", &second)); + ASSERT_NE(second, nullptr); + EXPECT_EQ(second->DebugString(), + test::function::WXPlusB().signature().DebugString()); + + // Can't add function with same name as existing op + FunctionDef fdef = test::function::XTimesTwo(); + fdef.mutable_signature()->set_name("Add"); + Status s = lib_def.AddFunctionDef(fdef); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + "Cannot add function 'Add' because an op with the same name " + "already exists."); + + // Already-added functions don't produce error + TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::XTimesTwo())); + TF_EXPECT_OK(lib_def.AddFunctionDef(test::function::WXPlusB())); +} + +TEST(FunctionLibraryDefinitionTest, AddGradientDef) { + // AddGradientDef() doesn't check that functions referenced exist (yet?) + FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary()); + + // Test adding a gradient (XTimesFour isn't a valid grad function for + // XTimesTwo but that's ok for now) + GradientDef grad; + grad.set_function_name(test::function::XTimesTwo().signature().name()); + grad.set_gradient_func(test::function::XTimesFour().signature().name()); + TF_EXPECT_OK(lib_def.AddGradientDef(grad)); + + // Already-added gradients don't produce error + TF_EXPECT_OK(lib_def.AddGradientDef(grad)); + + // Test that adding a duplicate gradient fails + grad.set_gradient_func(test::function::XTimes16().signature().name()); + Status s = lib_def.AddGradientDef(grad); + EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); + EXPECT_EQ(s.error_message(), + "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because " + "it already has gradient function 'XTimesFour'"); +} + +TEST(FunctionLibraryDefinitionTest, AddLibrary) { + // Create lib def with single function + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); + + // Add gradient + GradientDef grad; + grad.set_function_name(test::function::XTimesTwo().signature().name()); + grad.set_gradient_func(test::function::XTimesFour().signature().name()); + TF_EXPECT_OK(lib_def.AddGradientDef(grad)); + + // Error if you try to add conflicting function + proto.Clear(); + FunctionDef fdef = test::function::XTimesFour(); + fdef.mutable_signature()->set_name( + test::function::XTimesTwo().signature().name()); + *proto.add_function() = fdef; + FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto); + Status s = lib_def.AddLibrary(lib_def2); + EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); + EXPECT_EQ(s.error_message(), + "Cannot add function 'XTimesTwo' because a different function with " + "the same name already exists."); + + // Error if you try to add conflicting gradient + proto.Clear(); + grad.set_gradient_func(test::function::XTimes16().signature().name()); + *proto.add_gradient() = grad; + FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto); + s = lib_def.AddLibrary(lib_def3); + EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); + EXPECT_EQ(s.error_message(), + "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because " + "it already has gradient function 'XTimesFour'"); + + // No conflicting functions or gradients OK + proto.Clear(); + *proto.add_function() = test::function::XTimesFour(); + grad.set_function_name(test::function::XTimes16().signature().name()); + *proto.add_gradient() = grad; + FunctionLibraryDefinition lib_def4(OpRegistry::Global(), proto); + TF_EXPECT_OK(lib_def.AddLibrary(lib_def4)); + + // OK to add the same functions and gradients twice + TF_EXPECT_OK(lib_def.AddLibrary(lib_def)); +} + +GradientDef MakeGradDef(const string& f, const string& g) { + GradientDef grad; + grad.set_function_name(f); + grad.set_gradient_func(g); + return grad; +} + +TEST(FunctionLibraryDefinitionTest, AddLibrary_Atomic) { + // Create lib def containing two functions with equal names + FunctionDefLibrary proto; + const string x2_name = test::function::XTimesTwo().signature().name(); + const string x4_name = test::function::XTimesFour().signature().name(); + *proto.add_function() = test::function::XTimesTwo(); + FunctionDef fdef = test::function::XTimesFour(); + fdef.mutable_signature()->set_name(x2_name); + *proto.add_function() = fdef; + FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary()); + + // Try adding the two functions to lib_def + Status s = lib_def.AddLibrary(proto); + EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); + EXPECT_EQ( + "Cannot add function 'XTimesTwo' because a different function with " + "the same name already exists.", + s.error_message()); + + // Verify that none of the functions are added + EXPECT_TRUE(lib_def.Find(x2_name) == nullptr); + + // Fix the name in proto but add two gradient names for it + proto.mutable_function(1)->mutable_signature()->set_name(x4_name); + *proto.add_gradient() = MakeGradDef(x2_name, x4_name); + *proto.add_gradient() = MakeGradDef(x2_name, "SecondGradName"); + + // Try adding the library and check that nothing was added + s = lib_def.AddLibrary(proto); + EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); + EXPECT_EQ(s.error_message(), + "Cannot assign gradient function 'SecondGradName' to 'XTimesTwo' " + "because it already has gradient function 'XTimesFour'"); + EXPECT_TRUE(lib_def.Find(x2_name) == nullptr); + EXPECT_EQ(0, lib_def.ToProto().function_size()); + EXPECT_EQ(0, lib_def.ToProto().gradient_size()); +} + +TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_FuncConflict) { + const string x2_name = test::function::XTimesTwo().signature().name(); + const string x4_name = test::function::XTimesFour().signature().name(); + const string wx_name = test::function::WXPlusB().signature().name(); + + // Create FunctionLibraryDefinition with + // (func = XTimesTwo, grad = XTimesFour) + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + *proto.add_gradient() = MakeGradDef(x2_name, x4_name); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); + EXPECT_EQ(1, lib_def.ToProto().function_size()); + EXPECT_EQ(1, lib_def.ToProto().gradient_size()); + + // Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo) + // and function (name = XTimesTwo, body = XTimeFour) + FunctionDefLibrary proto2; + *proto2.add_function() = test::function::WXPlusB(); + *proto2.add_gradient() = MakeGradDef(wx_name, x2_name); + *proto2.add_function() = test::function::XTimesFour(); + proto2.mutable_function(1)->mutable_signature()->set_name(x2_name); + FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2); + + // Verify that adding lib_def2 will fail because of function conflict + // and WXPlusB is not added. + Status s = lib_def.AddLibrary(lib_def2); + EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); + EXPECT_EQ( + "Cannot add function 'XTimesTwo' because a different function " + "with the same name already exists.", + s.error_message()); + EXPECT_TRUE(lib_def.Find(wx_name) == nullptr); + EXPECT_EQ(1, lib_def.ToProto().function_size()); + EXPECT_EQ(1, lib_def.ToProto().gradient_size()); +} + +TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_GradConflict) { + const string x2_name = test::function::XTimesTwo().signature().name(); + const string x4_name = test::function::XTimesFour().signature().name(); + const string wx_name = test::function::WXPlusB().signature().name(); + + // Create FunctionLibraryDefinition with + // (func = XTimesTwo, grad = XTimesFour) + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + *proto.add_gradient() = MakeGradDef(x2_name, x4_name); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); + EXPECT_EQ(1, lib_def.ToProto().function_size()); + EXPECT_EQ(1, lib_def.ToProto().gradient_size()); + + // Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo) + // and (func = XTimesTwo, grad = WXPlusB) + FunctionDefLibrary proto2; + *proto2.add_function() = test::function::WXPlusB(); + *proto2.add_gradient() = MakeGradDef(wx_name, x2_name); + *proto2.add_function() = test::function::XTimesTwo(); + *proto2.add_gradient() = MakeGradDef(x2_name, wx_name); + FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2); + + // Verify that adding lib_def2 will fail because of gradient conflict + // and WXPlusB is not added. + Status s = lib_def.AddLibrary(lib_def2); + EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); + EXPECT_EQ( + "Cannot assign gradient function 'WXPlusB' to 'XTimesTwo'" + " because it already has gradient function 'XTimesFour'", + s.error_message()); + EXPECT_TRUE(lib_def.Find(wx_name) == nullptr); + EXPECT_EQ(1, lib_def.ToProto().function_size()); + EXPECT_EQ(1, lib_def.ToProto().gradient_size()); +} + +TEST(FunctionLibraryDefinitionTest, ToProto) { + FunctionDefLibrary proto1; + *proto1.add_function() = test::function::XTimesTwo(); + *proto1.add_function() = test::function::WXPlusB(); + FunctionLibraryDefinition lib_def1(OpRegistry::Global(), proto1); + + // Call 'ToProto' and make sure both protos have the same function lib size. + FunctionDefLibrary proto2 = lib_def1.ToProto(); + EXPECT_EQ(proto1.function_size(), proto2.function_size()); + + // Initialize 'lib_def2' with proto returned by 'ToProto' call. + FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2); + + // Test that the first function exists in both libraries. + const OpDef *f1, *f2, *f3, *f4; + TF_EXPECT_OK(lib_def1.LookUpOpDef("XTimesTwo", &f1)); + TF_EXPECT_OK(lib_def2.LookUpOpDef("XTimesTwo", &f2)); + EXPECT_EQ(f1->DebugString(), f2->DebugString()); + + // Test that the second function exists in both libraries. + TF_EXPECT_OK(lib_def1.LookUpOpDef("WXPlusB", &f3)); + TF_EXPECT_OK(lib_def2.LookUpOpDef("WXPlusB", &f4)); + EXPECT_EQ(f3->DebugString(), f4->DebugString()); +} + +TEST(FunctionLibraryDefinitionTest, GetAttr_FuncNoAttr) { + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + FunctionLibraryDefinition lib(OpRegistry::Global(), proto); + + NodeDef ndef; + bool annotation; + + // Not a function. + ndef.set_op("Matmul"); + EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok()); + + // A function. No attr defined. + ndef.set_op("XTimesTwo"); + EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok()); + + // ndef defines the attr. But we don't care. + AddNodeAttr("annotation", true, &ndef); + EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok()); +} + +template +void SetAttrValue(FunctionDef* fdef, const string& attr, const T& value) { + AttrValue attr_value; + SetAttrValue(value, &attr_value); + fdef->mutable_attr()->insert({attr, attr_value}); +} + +TEST(FunctionLibraryDefinitionTest, GetAttr_FuncWithAttr) { + FunctionDefLibrary proto; + auto fdef = proto.add_function(); + *fdef = test::function::XTimesTwo(); + SetAttrValue(fdef, "annotation", true); + SetAttrValue(fdef, "options", "some string data"); + FunctionLibraryDefinition lib(OpRegistry::Global(), proto); + + NodeDef ndef; + bool annotation; + + // A function. No attr defined in ndef. + ndef.set_op("XTimesTwo"); + TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation)); + EXPECT_EQ(annotation, true); + + string str; + TF_EXPECT_OK(lib.GetAttr(ndef, "options", &str)); + EXPECT_EQ(str, "some string data"); +} + +TEST(FunctionLibraryDefinitionTest, GetAttr_Gradient) { + FunctionDefLibrary proto; + auto fdef = proto.add_function(); + *fdef = test::function::XTimesTwo(); + SetAttrValue(fdef, "annotation", true); + *fdef = test::function::WXPlusB(); + SetAttrValue(fdef, "annotation", false); + auto func_grad = proto.add_gradient(); + func_grad->set_function_name("XTimesTwo"); + func_grad->set_gradient_func("WXPlusB"); + FunctionLibraryDefinition lib(OpRegistry::Global(), proto); + + NodeDef ndef; + ndef.set_op(FunctionLibraryDefinition::kGradientOp); + + bool annotation; + EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok()); + + NameAttrList nal; + nal.set_name("XTimesTwo"); + AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef); + TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation)); + EXPECT_EQ(annotation, false); // XTimesTwo's gradient is WXPlusB. + + nal.set_name("WXPlusB"); + ndef.clear_attr(); + AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef); + TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation)); + EXPECT_EQ(annotation, false); // WXPlusB has no custom gradient. +} + +// TODO(skyewm): this could be more thorough +TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) { + // Equal functions + const FunctionDef fdef1 = test::function::XTimesTwo(); + FunctionDef fdef2 = test::function::XTimesTwo(); + uint64 hash1 = FunctionDefHash(fdef1); + EXPECT_TRUE(FunctionDefsEqual(fdef1, fdef2)); + EXPECT_EQ(hash1, FunctionDefHash(fdef2)); + + // Different functions + fdef2 = test::function::XTimesFour(); + EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); + EXPECT_NE(hash1, FunctionDefHash(fdef2)); + + // Different signatures + fdef2 = test::function::XTimesTwo(); + fdef2.mutable_signature()->mutable_input_arg(0)->set_name("foo"); + EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); + EXPECT_NE(hash1, FunctionDefHash(fdef2)); + + // Descriptions must be equal + fdef2 = test::function::XTimesTwo(); + fdef2.mutable_signature()->mutable_input_arg(0)->set_description("foo"); + EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); + EXPECT_NE(hash1, FunctionDefHash(fdef2)); + + // Different NodeDefs + fdef2 = test::function::XTimesTwo(); + NodeDef* ndef = fdef2.add_node_def(); + *ndef = fdef2.node_def(0); + ndef->set_name("new_name"); + EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); + EXPECT_NE(hash1, FunctionDefHash(fdef2)); + + // Different return values + fdef2 = test::function::XTimesTwo(); + (*fdef2.mutable_ret())["y"] = "y:z:1"; // originally is "y:z:0" + EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); + EXPECT_NE(hash1, FunctionDefHash(fdef2)); + + // Different attributes + fdef2 = test::function::XTimesTwo(); + SetAttrValue(&fdef2, "ExtraAttr", true); + EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); + EXPECT_NE(hash1, FunctionDefHash(fdef2)); + + // Multiple equivalent attributes; the two functions should be equal. + fdef2 = test::function::XTimesTwo(); + FunctionDef fdef3 = test::function::XTimesTwo(); + SetAttrValue(&fdef2, "Foo", true); + SetAttrValue(&fdef3, "Foo", true); + SetAttrValue(&fdef2, "Bar", 123); + SetAttrValue(&fdef3, "Bar", 123); + SetAttrValue(&fdef2, "Baz", "abc"); + SetAttrValue(&fdef3, "Baz", "abc"); + EXPECT_TRUE(FunctionDefsEqual(fdef2, fdef3)); + EXPECT_EQ(FunctionDefHash(fdef2), FunctionDefHash(fdef3)); +} + +} // end namespace +} // end namespace tensorflow diff --git a/function_testlib.cc b/function_testlib.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8b456051b76241104febd29d55fe82a9146a239 --- /dev/null +++ b/function_testlib.cc @@ -0,0 +1,204 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/function_testlib.h" + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { +namespace test { +namespace function { + +typedef FunctionDefHelper FDH; + +GraphDef GDef(gtl::ArraySlice nodes, + gtl::ArraySlice funcs) { + GraphDef g; + VersionDef* versions = g.mutable_versions(); + versions->set_producer(TF_GRAPH_DEF_VERSION); + versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER); + for (const auto& n : nodes) { + *(g.add_node()) = n; + } + auto lib = g.mutable_library(); + for (const auto& f : funcs) { + *(lib->add_function()) = f; + } + return g; +} + +// Helper to construct a NodeDef. +NodeDef NDef(const string& name, const string& op, + gtl::ArraySlice inputs, + gtl::ArraySlice> attrs, + const string& device) { + NodeDef n; + n.set_name(name); + n.set_op(op); + for (const auto& in : inputs) n.add_input(in); + n.set_device(device); + for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto}); + return n; +} + +FunctionDef NonZero() { + return FDH::Define( + // Name + "NonZero", + // Args + {"x:T"}, + // Return values + {"y:T"}, + // Attr def + {"T:{float, double, int32, int64, string}"}, + // Nodes + { + {{"y"}, "Identity", {"x"}, {{"T", "$T"}}}, + }); +} + +FunctionDef XTimesTwo() { + const Tensor kTwo = test::AsScalar(2); + return FDH::Define( + // Name + "XTimesTwo", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, + {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}}, + }); +} + +FunctionDef XTimesTwoInt32() { + const Tensor kTwo = test::AsScalar(2); + return FDH::Define( + // Name + "XTimesTwoInt32", + // Args + {"x: int32"}, + // Return values + {"y: int32"}, {}, + // Nodes + { + {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, + {{"scale"}, + "Cast", + {"two"}, + {{"SrcT", DT_INT64}, {"DstT", DT_INT32}}}, + {{"y"}, "Mul", {"x", "scale"}, {{"T", DT_INT32}}}, + }); +} + +FunctionDef XTimesFour() { + return FDH::Create( + // Name + "XTimesFour", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}}, + {{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}}, + }, + {{"y", "y:y:0"}}); +} + +FunctionDef XTimes16() { + return FDH::Create( + // Name + "XTimes16", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}}, + {{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}}, + }, + {{"y", "y:y:0"}}); +} + +FunctionDef WXPlusB(){return FDH::Define( + // Name + "WXPlusB", + // Args + {"w: T", "x: T", "b: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double}"}, + // Nodes + { + {{"mm"}, + "MatMul", + {"w", "x"}, + { + {"T", "$T"}, {"transpose_a", false}, {"transpose_b", false}, +#ifdef INTEL_MKL + }}, +#else + {"_kernel", "eigen"}}}, +#endif + { + {"y"}, "Add", {"mm", "b"}, { + { "T", "$T" } + } + } + }); +} + +FunctionDef Swap() { + return FDH::Define( + // Name + "Swap", + // Args + {"i0: T", "i1: T"}, + // Return values + {"o0: T", "o1: T"}, + // Attr def + {"T: {float, double}"}, + // Nodes + {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}}, + {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}}); +} + +void FunctionTestSchedClosure(std::function fn) { + static thread::ThreadPool* w = + new thread::ThreadPool(Env::Default(), "Test", 8); + w->Schedule(std::move(fn)); +} + +} // end namespace function +} // end namespace test +} // end namespace tensorflow diff --git a/function_testlib.h b/function_testlib.h new file mode 100644 index 0000000000000000000000000000000000000000..fbf273fa015c9326e01f45d1c603d22ab239fe25 --- /dev/null +++ b/function_testlib.h @@ -0,0 +1,90 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_ +#define TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_ + +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace test { +namespace function { + +// A helper class to make AttrSlice from initializer lists +class Attrs { + public: + Attrs(const std::initializer_list< // NOLINT(runtime/explicit) + std::pair>& attrs) { + for (const auto& aval : attrs) { + map_.insert({aval.first, aval.second.proto}); + } + } + + operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) + + private: + AttrValueMap map_; +}; + +// Helper to construct a NodeDef. +NodeDef NDef( + const string& name, const string& op, gtl::ArraySlice inputs, + gtl::ArraySlice> + attrs = {}, + const string& device = ""); + +// Helper to construct a GraphDef proto. +GraphDef GDef(gtl::ArraySlice nodes, + gtl::ArraySlice funcs = {}); + +// For testing convenience, we provide a few simple functions that can +// be easily executed and tested. + +// x:T -> x * 2. +FunctionDef XTimesTwo(); + +// x:T -> x * 2, where x is int32. +FunctionDef XTimesTwoInt32(); + +// x:T -> (x * 2) * 2. +FunctionDef XTimesFour(); + +// x:T -> ((x * 2) * 2) * 2. +FunctionDef XTimes16(); + +// w:T, x:T, b:T -> MatMul(w, x) + b +FunctionDef WXPlusB(); + +// x:T -> x:T, T is a type which we automatically converts to a bool. +FunctionDef NonZero(); + +// x:T, y:T -> y:T, x:T +FunctionDef Swap(); + +void FunctionTestSchedClosure(std::function fn); + +} // end namespace function +} // end namespace test +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_ diff --git a/graph.proto b/graph.proto new file mode 100644 index 0000000000000000000000000000000000000000..7d6e16d5c129a068775fabc474770af929d99620 --- /dev/null +++ b/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/node_def.proto"; +import "tensorflow/core/framework/function.proto"; +import "tensorflow/core/framework/versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/graph_def_util.cc b/graph_def_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd018b7243897a5b45aa35d7fb94ca1ee1b12e75 --- /dev/null +++ b/graph_def_util.cc @@ -0,0 +1,218 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/graph_def_util.h" + +#include +#include +#include +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/versions.pb_text.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +string SummarizeGraphDef(const GraphDef& graph_def) { + string ret; + strings::StrAppend(&ret, "versions = ", + ProtoShortDebugString(graph_def.versions()), ";\n"); + for (const NodeDef& node : graph_def.node()) { + strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n"); + } + return ret; +} + +Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) { + for (const NodeDef& node : graph_def.node()) { + TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node)); + } + return Status::OK(); +} + +Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, + const OpRegistryInterface& op_registry, + int node_offset) { + if (node_offset > graph_def->node_size()) { + return errors::InvalidArgument( + "Tried to add default attrs to GraphDef " + "starting at offset ", + node_offset, " with total nodes in graph: ", graph_def->node_size()); + } + + for (int i = node_offset; i < graph_def->node_size(); ++i) { + NodeDef* node_def = graph_def->mutable_node(i); + const OpDef* op_def; + TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(node_def->op(), &op_def)); + AddDefaultsToNodeDef(*op_def, node_def); + } + + return Status::OK(); +} + +static Status RemoveNewDefaultAttrsFromNodeDef( + NodeDef* node_def, const OpRegistryInterface& consumer_op_registry, + const OpRegistryInterface& producer_op_registry, + std::set>* op_attr_removed) { + const OpDef* producer_op_def; + const OpDef* consumer_op_def; + TF_RETURN_IF_ERROR( + producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def)); + TF_RETURN_IF_ERROR( + consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def)); + + std::vector to_remove; + for (const auto& attr : node_def->attr()) { + // If the attr is not in consumer_op_def and doesn't start with '_'... + if (!StringPiece(attr.first).starts_with("_") && + FindAttr(attr.first, *consumer_op_def) == nullptr) { + const OpDef::AttrDef* producer_attr_def = + FindAttr(attr.first, *producer_op_def); + if (producer_attr_def == nullptr) { + return errors::InvalidArgument( + "Attr '", attr.first, "' missing in producer's OpDef: ", + SummarizeOpDef(*producer_op_def), " but found in node: ", + SummarizeNodeDef(*node_def)); + } + // ...and it has the same value as the default in producer, + if (producer_attr_def->has_default_value() && + AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) { + // then we will remove it below. + to_remove.emplace_back(attr.first); + } + } + } + // We separate identifying which attrs should be removed from + // actually removing them to avoid invalidating the loop iterators + // above. + for (const string& attr_name : to_remove) { + node_def->mutable_attr()->erase(attr_name); + if (op_attr_removed != nullptr) { + op_attr_removed->insert(std::make_pair(node_def->op(), attr_name)); + } + } + + return Status::OK(); +} + +static bool IsFunction(const GraphDef& graph_def, const string& op_name) { + for (const auto& func_def : graph_def.library().function()) { + if (op_name == func_def.signature().name()) return true; + } + return false; +} + +Status RemoveNewDefaultAttrsFromGraphDef( + GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry, + const OpRegistryInterface& producer_op_registry, + std::set>* op_attr_removed) { + // TODO(joshL): Make IsFunction() faster by collecting the names of + // all functions as a preprocessing step. + for (int n = 0; n < graph_def->node_size(); ++n) { + NodeDef* node_def = graph_def->mutable_node(n); + if (!IsFunction(*graph_def, node_def->op())) { + TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef( + node_def, consumer_op_registry, producer_op_registry, + op_attr_removed)); + } + } + for (int f = 0; f < graph_def->library().function_size(); ++f) { + FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f); + for (int n = 0; n < func_def->node_def_size(); ++n) { + NodeDef* node_def = func_def->mutable_node_def(n); + if (!IsFunction(*graph_def, node_def->op())) { + // TODO(josh11b): Better handling of attrs with placeholder values. + TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef( + node_def, consumer_op_registry, producer_op_registry, + op_attr_removed)); + } + } + } + + return Status::OK(); +} + +void OpsUsedByGraph(const GraphDef& graph_def, + std::set* ops_used_in_graph) { + // Map function names to definitions. + std::unordered_map name_to_function; + for (const auto& function : graph_def.library().function()) { + name_to_function.insert( + std::make_pair(function.signature().name(), &function)); + } + + // Collect the sorted list of op names. Since functions can reference + // functions, we need a recursive traversal. + std::set used_ops; // Includes both primitive ops and functions + std::vector functions_to_process; // A subset of used_ops + // Collect the logic to mark an op in a lambda; it'll be used twice below. + const auto mark_op_as_used = [&used_ops, &functions_to_process, + &name_to_function](const string& op) { + if (used_ops.insert(op).second) { + // If it's a function, we'll need to process further + const auto it = name_to_function.find(op); + if (it != name_to_function.end()) { + functions_to_process.push_back(it->second); + } + } + }; + for (const auto& node : graph_def.node()) { + mark_op_as_used(node.op()); + } + while (!functions_to_process.empty()) { + const FunctionDef* fun = functions_to_process.back(); + functions_to_process.pop_back(); + for (const auto& node : fun->node_def()) { + mark_op_as_used(node.op()); + } + } + + // Filter out function names to produce output. + // TODO(josh11b): Change the above code to produce this directly. + ops_used_in_graph->clear(); + for (const string& op_name : used_ops) { + if (name_to_function.find(op_name) == name_to_function.end()) { + ops_used_in_graph->insert(op_name); + } + } +} + +Status StrippedOpListForGraph(const GraphDef& graph_def, + const OpRegistryInterface& op_registry, + OpList* stripped_op_list) { + std::set used_ops; + OpsUsedByGraph(graph_def, &used_ops); + + // Build the stripped op list in sorted order, ignoring functions. + stripped_op_list->clear_op(); + for (const string& op_name : used_ops) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def)); + OpDef* stripped_op = stripped_op_list->add_op(); + stripped_op->CopyFrom(*op_def); + RemoveDescriptionsFromOpDef(stripped_op); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/graph_def_util.h b/graph_def_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0c6542a9f258a28f42a9caa9821ea3faf8010342 --- /dev/null +++ b/graph_def_util.h @@ -0,0 +1,115 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_ + +#include +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Forward declare proto so that it's symbols can be removed from .so exports +class GraphDef; + +// Produce a human-readable version of a GraphDef that is more concise +// than a text-format proto. +string SummarizeGraphDef(const GraphDef& graph_def); + +// Validates the syntax of a GraphDef provided externally. +// +// The following is an EBNF-style syntax for GraphDef objects. Note that +// Node objects are actually specified as tensorflow::NodeDef protocol buffers, +// which contain many other fields that are not (currently) validated. +// +// Graph = Node * +// Node = NodeName, Inputs +// Inputs = ( DataInput * ), ( ControlInput * ) +// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ? +// ControlInput = "^", NodeName +// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] * +Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def); + +// Adds default attributes to NodeDefs in 'graph_def' starting +// from the 'node_offset' node in 'graph_def'. +// +// Default attributes are defined by 'op_registry'. +// +// Returns OK on success, an error if 'graph_def' has a NodeDef +// that cannot be found in 'op_registry'. +// +// REQUIRES: 'graph_def' and 'op_registry' are not nullptr. +Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, + const OpRegistryInterface& op_registry, + int node_offset); + +// Remove attrs from 'graph_def' that have the default value according +// to 'producer_op_registry', but don't exist according to +// 'consumer_op_registry'. This can allow 'graph_def' to run on the +// consumer even if consumer was built at an earlier CL (before an +// attr with a default was added). Note that this will not affect +// attrs with non-default values, so you must run a +// ValidateGraphDef...() function to see if the result is in fact +// compatible. If not nullptr, the op/attr pairs that were removed +// are added to '*op_attr_removed'. +// +// Expected usage, for a producer that wants to prepare a graph for +// a consumer: +// // For each consumer, update 'graph_def': +// OpListOpRegistry consumer_op_registry(consumer_server_op_list); +// std::unordered_set> op_attr_removed; +// TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromGraphDef( +// &graph_def, consumer_op_registry, *OpRegistry::Global(), +// &op_attr_removed)); +// // Validate that each consumer can understand the resulting 'graph_def' +// TF_RETURN_IF_ERROR(graph::ValidateGraphDefAgainstOpRegistry( +// graph_def, consumer_op_registry)); +// // Consumer can use 'graph_def', and 'op_attr_removed' summarizes +// // what changes had to be made to 'graph_def' for it to work. +// +// Expected usage, for a consumer that has a graph and a +// (optionally-stripped) op_list from a producer (say from a call to +// StrippedOpListForGraph(), or in the MetaGraphDef): +// OpListOpRegistry producer_op_registry(producer_stripped_op_list); +// TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromGraphDef( +// &graph_def, *OpRegistry::Global(), producer_op_registry, nullptr)); +Status RemoveNewDefaultAttrsFromGraphDef( + GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry, + const OpRegistryInterface& producer_op_registry, + std::set>* op_attr_removed); + +// Two functions that collect the ops used by a graph. +// +// This returns the ops used as a set of strings. +void OpsUsedByGraph(const GraphDef& graph_def, + std::set* ops_used_in_graph); + +// This function computes the stripped_op_list field of MetaGraphDef +// and similar protos. The op_registry should contain the ops used to +// produce graph_def. The resulting stripped_op_list can be +// communicated from the producer to the consumer, which can use +// RemoveNewDefaultAttrsFromGraphDef() to improve forwards compatibility +// (using an OpListOpRegistry as indicated in the example above). +// +// Most users will pass *OpRegistry::Global() for op_registry to strip against +// the list of ops registered in this process. +Status StrippedOpListForGraph(const GraphDef& graph_def, + const OpRegistryInterface& op_registry, + OpList* stripped_op_list); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_ diff --git a/graph_def_util_test.cc b/graph_def_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1ac322e48e2e6a9a572d8e85b01e166fc7e36f74 --- /dev/null +++ b/graph_def_util_test.cc @@ -0,0 +1,321 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/graph_def_util.h" + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" + +namespace tensorflow { +namespace { + +Status FinalizeOpDef(const OpDefBuilder& b, OpDef* op_def) { + OpRegistrationData op_reg_data; + const Status s = b.Finalize(&op_reg_data); + *op_def = op_reg_data.op_def; + return s; +} + +// Producer and consumer have default for an attr -> graph unchanged. +TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) { + OpList op_list; + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("NoChangeWithDefault").Attr("a: int = 12"), + op_list.add_op())); + OpListOpRegistry registry(&op_list); + + GraphDef graph_def; + TF_ASSERT_OK(NodeDefBuilder("ncwd", "NoChangeWithDefault", ®istry) + .Finalize(graph_def.add_node())); + GraphDef expected_graph_def = graph_def; + + std::set> op_attr_removed; + TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry, + &op_attr_removed)); + + TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def); + EXPECT_TRUE(op_attr_removed.empty()); +} + +// Producer and consumer both have an attr -> graph unchanged. +TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) { + OpList op_list; + TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("NoChangeNoDefault").Attr("a: int"), + op_list.add_op())); + OpListOpRegistry registry(&op_list); + + GraphDef graph_def; + TF_ASSERT_OK(NodeDefBuilder("ncnd", "NoChangeNoDefault", ®istry) + .Attr("a", 42) + .Finalize(graph_def.add_node())); + GraphDef expected_graph_def = graph_def; + + std::set> op_attr_removed; + TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry, + &op_attr_removed)); + + TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def); + EXPECT_TRUE(op_attr_removed.empty()); +} + +// Producer has default for an attr that the consumer does not know +// about, and the produced graph has the default value for the attr -> +// attr removed from graph (and so able to be consumed). +TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) { + OpList consumer_op_list; + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op())); + OpListOpRegistry consumer_registry(&consumer_op_list); + + OpList producer_op_list; + TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"), + producer_op_list.add_op())); + OpListOpRegistry producer_registry(&producer_op_list); + + GraphDef produced_graph_def; + TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &producer_registry) + .Finalize(produced_graph_def.add_node())); + + std::set> op_attr_removed; + TF_ASSERT_OK( + RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, + producer_registry, &op_attr_removed)); + + GraphDef expected_graph_def; + TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &consumer_registry) + .Finalize(expected_graph_def.add_node())); + TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def); + + std::set> expected_removed({{"UsesDefault", "a"}}); + EXPECT_EQ(expected_removed, op_attr_removed); +} + +// Producer has default for an attr that the consumer does not know +// about, graph sets the attr to a value different from the default -> +// graph unchanged (but not able to be consumed by consumer). +TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) { + OpList consumer_op_list; + TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"), + consumer_op_list.add_op())); + OpListOpRegistry consumer_registry(&consumer_op_list); + + OpList producer_op_list; + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"), + producer_op_list.add_op())); + OpListOpRegistry producer_registry(&producer_op_list); + + GraphDef produced_graph_def; + TF_ASSERT_OK(NodeDefBuilder("changed_from_default", "ChangedFromDefault", + &producer_registry) + .Attr("a", 9) + .Finalize(produced_graph_def.add_node())); + GraphDef expected_graph_def = produced_graph_def; + + std::set> op_attr_removed; + TF_ASSERT_OK( + RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, + producer_registry, &op_attr_removed)); + + TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def); + EXPECT_TRUE(op_attr_removed.empty()); +} + +// Attrs starting with underscores should not be removed. +TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) { + OpList consumer_op_list; + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("Underscore"), consumer_op_list.add_op())); + OpListOpRegistry consumer_registry(&consumer_op_list); + + OpList producer_op_list; + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("Underscore"), producer_op_list.add_op())); + // Add the _underscore attr manually since OpDefBuilder would complain + OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr(); + attr->set_name("_underscore"); + attr->set_type("int"); + attr->mutable_default_value()->set_i(17); + OpListOpRegistry producer_registry(&producer_op_list); + + GraphDef produced_graph_def; + TF_ASSERT_OK(NodeDefBuilder("node", "Underscore", &producer_registry) + .Attr("_underscore", 17) + .Finalize(produced_graph_def.add_node())); + GraphDef expected_graph_def = produced_graph_def; + + std::set> op_attr_removed; + TF_ASSERT_OK( + RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, + producer_registry, &op_attr_removed)); + + TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def); + EXPECT_EQ(op_attr_removed.size(), 0); +} + +TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) { + OpList consumer_op_list; + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op())); + TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"), + consumer_op_list.add_op())); + OpListOpRegistry consumer_registry(&consumer_op_list); + + OpList producer_op_list; + TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"), + producer_op_list.add_op())); + TF_ASSERT_OK( + FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"), + producer_op_list.add_op())); + OpListOpRegistry producer_registry(&producer_op_list); + + GraphDef produced_graph_def; + *produced_graph_def.mutable_library()->add_function() = + FunctionDefHelper::Create( + "my_func", {}, {}, {}, + {{{"x"}, "UsesDefault", {}, {{"a", 17}}}, + {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}}, + {}); + OpList function_op_list; + *function_op_list.add_op() = + produced_graph_def.library().function(0).signature(); + OpListOpRegistry function_registry(&function_op_list); + TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry) + .Finalize(produced_graph_def.add_node())); + + std::set> op_attr_removed; + TF_ASSERT_OK( + RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, + producer_registry, &op_attr_removed)); + + GraphDef expected_graph_def; + *expected_graph_def.mutable_library()->add_function() = + FunctionDefHelper::Create( + "my_func", {}, {}, {}, + {{{"x"}, "UsesDefault", {}, {}}, + {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}}, + {}); + TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry) + .Finalize(expected_graph_def.add_node())); + TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def); + EXPECT_EQ(expected_graph_def.library().DebugString(), + produced_graph_def.library().DebugString()); + + std::set> expected_removed({{"UsesDefault", "a"}}); + EXPECT_EQ(expected_removed, op_attr_removed); +} + +TEST(StrippedOpListForGraphTest, FlatTest) { + // Make four ops + OpList op_list; + for (const string& op : {"A", "B", "C", "D"}) { + OpDef* op_def = op_list.add_op(); + op_def->set_name(op); + op_def->set_summary("summary"); + op_def->set_description("description"); + op_def->set_is_commutative(op == "B"); + } + + // Make a graph which uses two ops once and twice, respectively. + // The result should be independent of the ordering. + const string graph_ops[4][3] = { + {"C", "B", "B"}, {"B", "C", "B"}, {"B", "B", "C"}, {"C", "C", "B"}}; + for (const bool use_function : {false, true}) { + for (int order = 0; order < 4; order++) { + GraphDef graph_def; + if (use_function) { + FunctionDef* function_def = graph_def.mutable_library()->add_function(); + function_def->mutable_signature()->set_name("F"); + for (const string& op : graph_ops[order]) { + function_def->add_node_def()->set_op(op); + } + graph_def.add_node()->set_op("F"); + } else { + for (const string& op : graph_ops[order]) { + string name = strings::StrCat("name", graph_def.node_size()); + NodeDef* node = graph_def.add_node(); + node->set_name(name); + node->set_op(op); + } + } + + // Strip the op list + OpList stripped_op_list; + TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list), + &stripped_op_list)); + + // We should have exactly two ops: B and C. + ASSERT_EQ(stripped_op_list.op_size(), 2); + for (int i = 0; i < 2; i++) { + const OpDef& op = stripped_op_list.op(i); + EXPECT_EQ(op.name(), i ? "C" : "B"); + EXPECT_EQ(op.summary(), ""); + EXPECT_EQ(op.description(), ""); + EXPECT_EQ(op.is_commutative(), !i); + } + + // Should get the same result using OpsUsedByGraph(). + std::set used_ops; + OpsUsedByGraph(graph_def, &used_ops); + ASSERT_EQ(std::set({"B", "C"}), used_ops); + } + } +} + +TEST(StrippedOpListForGraphTest, NestedFunctionTest) { + // Make a primitive op A. + OpList op_list; + op_list.add_op()->set_name("A"); + + for (const bool recursive : {false, true}) { + // Call A from function B, and B from function C. + GraphDef graph_def; + FunctionDef* b = graph_def.mutable_library()->add_function(); + FunctionDef* c = graph_def.mutable_library()->add_function(); + b->mutable_signature()->set_name("B"); + c->mutable_signature()->set_name("C"); + b->add_node_def()->set_op("A"); + c->add_node_def()->set_op("B"); + if (recursive) { + b->add_node_def()->set_op("B"); + c->add_node_def()->set_op("C"); + } + + // Use C in the graph. + graph_def.add_node()->set_op("C"); + + // The stripped op list should contain just A. + OpList stripped_op_list; + TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list), + &stripped_op_list)); + ASSERT_EQ(stripped_op_list.op_size(), 1); + ASSERT_EQ(stripped_op_list.op(0).name(), "A"); + + // Should get the same result using OpsUsedByGraph(). + std::set used_ops; + OpsUsedByGraph(graph_def, &used_ops); + ASSERT_EQ(std::set({"A"}), used_ops); + } +} + +} // namespace +} // namespace tensorflow diff --git a/graph_transfer_info.proto b/graph_transfer_info.proto new file mode 100644 index 0000000000000000000000000000000000000000..016259ddbf5254a96086d4813a312b27593f10d9 --- /dev/null +++ b/graph_transfer_info.proto @@ -0,0 +1,68 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphTransferInfoProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/types.proto"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message GraphTransferInfo { + enum Destination { + NOP = 0; + HEXAGON = 1; + } + message NodeInput { + int32 node_id = 1; + int32 output_port = 2; + } + message NodeInfo { + string name = 1; + int32 node_id = 2; + string type_name = 3; + int32 soc_op_id = 4; + int32 padding_id = 5; + int32 input_count = 6; + int32 output_count = 7; + }; + message ConstNodeInfo { + string name = 1; + int32 node_id = 2; + repeated int64 shape = 3; + bytes data = 4; + DataType dtype = 5; + }; + message NodeInputInfo { + int32 node_id = 1; + repeated NodeInput node_input = 2; + }; + message NodeOutputInfo { + int32 node_id = 1; + repeated int32 max_byte_size = 2; + }; + message GraphInputNodeInfo { + string name = 1; + repeated int64 shape = 2; + DataType dtype = 3; + } + + message GraphOutputNodeInfo { + string name = 1; + repeated int64 shape = 2; + DataType dtype = 3; + } + + repeated NodeInfo node_info = 1; + repeated ConstNodeInfo const_node_info = 2; + repeated NodeInputInfo node_input_info = 3; + repeated NodeOutputInfo node_output_info = 4; + // Input Node parameters of transferred graph + repeated GraphInputNodeInfo graph_input_node_info = 5; + repeated GraphOutputNodeInfo graph_output_node_info = 6; + // Destination of graph transfer + Destination destination = 7; +}; diff --git a/iterator.proto b/iterator.proto new file mode 100644 index 0000000000000000000000000000000000000000..7e5f5ea2e0c2f976855813d2f5e53de0f190872e --- /dev/null +++ b/iterator.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "IteratorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.util"; + +// Protocol buffer representing the metadata for an iterator's state stored +// as a Variant tensor. +message IteratorStateMetadata { + // A user-specified version string. + string version = 1; + + // Keys for tensors in the VariantTensorDataProto. + repeated string keys = 2; +} diff --git a/kernel_def.proto b/kernel_def.proto new file mode 100644 index 0000000000000000000000000000000000000000..65e9ef04a06651a7b230008e34b5c4d15e0572ca --- /dev/null +++ b/kernel_def.proto @@ -0,0 +1,36 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "KernelDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/attr_value.proto"; + +message KernelDef { + // Must match the name of an Op. + string op = 1; + + // Type of device this kernel runs on. + string device_type = 2; + + message AttrConstraint { + // Name of an attr from the Op. + string name = 1; + + // A list of values that this kernel supports for this attr. + // Like OpDef.AttrDef.allowed_values, except for kernels instead of Ops. + AttrValue allowed_values = 2; + } + repeated AttrConstraint constraint = 3; + + // Names of the Op's input_/output_args that reside in host memory + // instead of device memory. + repeated string host_memory_arg = 4; + + // This allows experimental kernels to be registered for an op that + // won't be used unless the user specifies a "_kernel" attr with + // value matching this. + string label = 5; +} diff --git a/kernel_def_builder.cc b/kernel_def_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb86f18ff06c38860e0c24e60b42326317ddecfb --- /dev/null +++ b/kernel_def_builder.cc @@ -0,0 +1,75 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/kernel_def.pb_text.h" +#include "tensorflow/core/framework/kernel_def.pb.h" + +namespace tensorflow { + +KernelDefBuilder::KernelDefBuilder(const char* op_name) { + kernel_def_ = new KernelDef; + kernel_def_->set_op(op_name); +} + +KernelDefBuilder::~KernelDefBuilder() { + DCHECK(kernel_def_ == nullptr) << "Did not call Build()"; +} + +KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) { + kernel_def_->set_device_type(device_type); + return *this; +} + +KernelDefBuilder& KernelDefBuilder::TypeConstraint( + const char* attr_name, gtl::ArraySlice allowed) { + auto* constraint = kernel_def_->add_constraint(); + constraint->set_name(attr_name); + auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); + for (DataType dt : allowed) { + allowed_values->add_type(dt); + } + return *this; +} + +KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name, + DataType allowed) { + auto* constraint = kernel_def_->add_constraint(); + constraint->set_name(attr_name); + constraint->mutable_allowed_values()->mutable_list()->add_type(allowed); + return *this; +} + +KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) { + kernel_def_->add_host_memory_arg(arg_name); + return *this; +} + +KernelDefBuilder& KernelDefBuilder::Label(const char* label) { + CHECK_EQ(kernel_def_->label(), "") + << "Trying to set a kernel's label a second time: '" << label + << "' in: " << ProtoShortDebugString(*kernel_def_); + kernel_def_->set_label(label); + return *this; +} + +const KernelDef* KernelDefBuilder::Build() { + KernelDef* r = kernel_def_; + kernel_def_ = nullptr; + return r; +} + +} // namespace tensorflow diff --git a/kernel_def_builder.h b/kernel_def_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..2966aa58de45a93d1629096a4a54a53d75c80670 --- /dev/null +++ b/kernel_def_builder.h @@ -0,0 +1,87 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_ +#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_ + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Forward declare proto so that kernels don't need to depend on it +class KernelDef; + +// Builder class passed to the REGISTER_KERNEL_BUILDER() macro. +class KernelDefBuilder { + public: + // Starts with just the name field set. + // Caller MUST call Build() and take ownership of the result. + explicit KernelDefBuilder(const char* op_name); + ~KernelDefBuilder(); + + // Required: specify the type of device this kernel supports. + // Returns *this. + KernelDefBuilder& Device(const char* device_type); + // KernelDefBuilder& Device(DeviceType device_type); + + // Specify that this kernel supports a limited set of values for a + // particular type or list(type) attr (a further restriction than + // what the Op allows). + // Returns *this. + KernelDefBuilder& TypeConstraint(const char* attr_name, + gtl::ArraySlice allowed); + + // Like TypeConstraint but supports just a single type. + KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed); + + // Like TypeConstraint, but (a) gets the type from a template parameter + // and (b) only supports a constraint to a single type. + template + KernelDefBuilder& TypeConstraint(const char* attr_name); + // TODO(josh11b): Support other types of attr constraints as needed. + + // Specify that this kernel requires/provides an input/output arg + // in host memory (instead of the default, device memory). + // Returns *this. + KernelDefBuilder& HostMemory(const char* arg_name); + + // Specify that this kernel requires a particular value for the + // "_kernel" attr. May only be specified once. Returns *this. + KernelDefBuilder& Label(const char* label); + + // Returns a pointer to a KernelDef with fields set based on the + // above calls to this instance. + // Caller takes ownership of the result. + const KernelDef* Build(); + + private: + KernelDef* kernel_def_; + + TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder); +}; + +// IMPLEMENTATION + +template +KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name) { + return this->TypeConstraint(attr_name, DataTypeToEnum::v()); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_ diff --git a/kernel_def_builder_test.cc b/kernel_def_builder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..31656c918bb43ada2d5d75aebc321902dd97c8a3 --- /dev/null +++ b/kernel_def_builder_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/kernel_def_builder.h" + +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(KernelDefBuilderTest, Basic) { + const KernelDef* def = KernelDefBuilder("A").Device(DEVICE_CPU).Build(); + KernelDef expected; + protobuf::TextFormat::ParseFromString("op: 'A' device_type: 'CPU'", + &expected); + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; +} + +TEST(KernelDefBuilderTest, TypeConstraint) { + const KernelDef* def = KernelDefBuilder("B") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .Build(); + KernelDef expected; + protobuf::TextFormat::ParseFromString(R"proto( + op: 'B' device_type: 'GPU' + constraint { name: 'T' allowed_values { list { type: DT_FLOAT } } } )proto", + &expected); + + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; + + def = KernelDefBuilder("C") + .Device(DEVICE_GPU) + .TypeConstraint("U") + .TypeConstraint("V") + .Build(); + + protobuf::TextFormat::ParseFromString(R"proto( + op: 'C' device_type: 'GPU' + constraint { name: 'U' allowed_values { list { type: DT_INT32 } } } + constraint { name: 'V' allowed_values { list { type: DT_BOOL } } } )proto", + &expected); + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; + + def = KernelDefBuilder("D") + .Device(DEVICE_CPU) + .TypeConstraint("W", {DT_DOUBLE, DT_STRING}) + .Build(); + protobuf::TextFormat::ParseFromString(R"proto( + op: 'D' device_type: 'CPU' + constraint { name: 'W' + allowed_values { list { type: [DT_DOUBLE, DT_STRING] } } } )proto", + &expected); + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; +} + +TEST(KernelDefBuilderTest, HostMemory) { + const KernelDef* def = KernelDefBuilder("E") + .Device(DEVICE_GPU) + .HostMemory("in") + .HostMemory("out") + .Build(); + KernelDef expected; + protobuf::TextFormat::ParseFromString( + "op: 'E' device_type: 'GPU' " + "host_memory_arg: ['in', 'out']", + &expected); + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; +} + +} // namespace +} // namespace tensorflow diff --git a/load_library.cc b/load_library.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9e33b148f71cd6b1856cf55436a7e73df9df059 --- /dev/null +++ b/load_library.cc @@ -0,0 +1,104 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mem.h" + +namespace tensorflow { + +namespace { + +struct Library { + void* handle = nullptr; + OpList op_list; +}; + +} // namespace + +// Load a dynamic library. +// On success, returns the handle to library in result, copies the serialized +// OpList of OpDefs registered in the library to *buf and the length to *len, +// and returns OK from the function. Otherwise return nullptr in result +// and an error status from the function, leaving buf and len untouched. +// +// If `library_filename` has already been loaded, we return a cached handle +// and OpList. Ops and kernels are registered as globals when a library is +// loaded for the first time. Without caching, every subsequent load would not +// perform initialization again, so the OpList would be empty. +Status LoadLibrary(const char* library_filename, void** result, + const void** buf, size_t* len) { + static mutex mu(LINKER_INITIALIZED); + static std::unordered_map loaded_libs; + Env* env = Env::Default(); + Library library; + std::unordered_set seen_op_names; + { + mutex_lock lock(mu); + if (loaded_libs.find(library_filename) != loaded_libs.end()) { + library = loaded_libs[library_filename]; + } else { + Status s = OpRegistry::Global()->ProcessRegistrations(); + if (!s.ok()) { + return s; + } + TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher( + [&library, &seen_op_names](const Status& s, + const OpDef& opdef) -> Status { + if (errors::IsAlreadyExists(s)) { + if (seen_op_names.find(opdef.name()) == seen_op_names.end()) { + // Over writing a registration of an op not in this custom op + // library. Treat this as not an error. + return Status::OK(); + } + } + if (s.ok()) { + *library.op_list.add_op() = opdef; + seen_op_names.insert(opdef.name()); + } + return s; + })); + OpRegistry::Global()->DeferRegistrations(); + s = env->LoadLibrary(library_filename, &library.handle); + if (s.ok()) { + s = OpRegistry::Global()->ProcessRegistrations(); + } + if (!s.ok()) { + OpRegistry::Global()->ClearDeferredRegistrations(); + TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr)); + return s; + } + TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr)); + + loaded_libs[library_filename] = library; + } + } + string str; + library.op_list.SerializeToString(&str); + char* str_buf = reinterpret_cast(port::Malloc(str.length())); + memcpy(str_buf, str.data(), str.length()); + *buf = str_buf; + *len = str.length(); + + *result = library.handle; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/log_memory.cc b/log_memory.cc new file mode 100644 index 0000000000000000000000000000000000000000..5b525412d570992464fa363e6971682615f48a79 --- /dev/null +++ b/log_memory.cc @@ -0,0 +1,102 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/log_memory.h" + +#include "tensorflow/core/framework/log_memory.pb_text.h" +#include "tensorflow/core/framework/log_memory.pb.h" + +namespace tensorflow { + +const string LogMemory::kLogMemoryLabel = "__LOG_MEMORY__"; + +bool LogMemory::IsEnabled() { return VLOG_IS_ON(1); } + +namespace { + +// Write the proto entry to LOG(INFO). +template +void OutputToLog(const T& proto) { + string type_name = proto.GetTypeName(); + const size_t index = type_name.find_last_of("."); + if (index != string::npos) type_name = type_name.substr(index + 1); + LOG(INFO) << LogMemory::kLogMemoryLabel << " " << type_name << " { " + << ProtoShortDebugString(proto) << " }"; +} + +} // namespace + +void LogMemory::RecordStep(const int64 step_id, const string& handle) { + MemoryLogStep step; + step.set_step_id(step_id); + step.set_handle(handle); + OutputToLog(step); +} + +void LogMemory::RecordTensorAllocation(const string& kernel_name, + const int64 step_id, + const Tensor& tensor) { + MemoryLogTensorAllocation allocation; + allocation.set_step_id(step_id); + allocation.set_kernel_name(kernel_name); + tensor.FillDescription(allocation.mutable_tensor()); + OutputToLog(allocation); +} + +void LogMemory::RecordTensorDeallocation(const int64 allocation_id, + const string& allocator_name) { + MemoryLogTensorDeallocation deallocation; + deallocation.set_allocation_id(allocation_id); + deallocation.set_allocator_name(allocator_name); + OutputToLog(deallocation); +} + +void LogMemory::RecordTensorOutput(const string& kernel_name, + const int64 step_id, const int index, + const Tensor& tensor) { + MemoryLogTensorOutput output; + output.set_step_id(step_id); + output.set_kernel_name(kernel_name); + output.set_index(index); + tensor.FillDescription(output.mutable_tensor()); + OutputToLog(output); +} + +void LogMemory::RecordRawAllocation(const string& operation, + const int64 step_id, size_t num_bytes, + void* ptr, Allocator* allocator) { + MemoryLogRawAllocation allocation; + allocation.set_step_id(step_id); + allocation.set_operation(operation); + allocation.set_num_bytes(static_cast(num_bytes)); + allocation.set_ptr(reinterpret_cast(ptr)); + allocation.set_allocation_id(allocator->AllocationId(ptr)); + allocation.set_allocator_name(allocator->Name()); + OutputToLog(allocation); +} + +void LogMemory::RecordRawDeallocation(const string& operation, + const int64 step_id, void* ptr, + Allocator* allocator, bool deferred) { + MemoryLogRawDeallocation deallocation; + deallocation.set_step_id(step_id); + deallocation.set_operation(operation); + deallocation.set_allocation_id(allocator->AllocationId(ptr)); + deallocation.set_allocator_name(allocator->Name()); + deallocation.set_deferred(deferred); + OutputToLog(deallocation); +} + +} // namespace tensorflow diff --git a/log_memory.h b/log_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..faef7b8e98dd78e75eb93bcf1aaa73d630fd3b33 --- /dev/null +++ b/log_memory.h @@ -0,0 +1,111 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_ +#define TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// LogMemory contains methods for recording memory allocations and +// frees, associating each allocation with a step identified by a +// process-wide id. For now, logging is enabled whenever VLOG_IS_ON(1) +// for the log_memory module. +// +// Limitations: We don't log memory allocations by Eigen on the CPU +// since that would require major changes to plumb through to the +// Eigen::{DefaultDevice,ThreadPoolDevice} allocate and deallocate +// methods. We do log Eigen allocations on GPU since the plumbing was +// already in place. +class LogMemory { + public: + // Allocations sometimes happen outside any computation step, and + // SpecialStepIds lists the ids used for those steps. + enum SpecialStepIds { + // Used when performing a just-in-time constant folding optimization. + CONSTANT_FOLDING_STEP_ID = -1, + // Used when constructing an Op kernel before executing a step. + OP_KERNEL_CONSTRUCTION_STEP_ID = -2, + // Used when allocating a tensor buffer from external code, e.g., + // the C API. + EXTERNAL_TENSOR_ALLOCATION_STEP_ID = -3, + // Used when allocating a buffer for network transfer. + NETWORK_BUFFER_STEP_ID = -4, + // Used when allocating a buffer to fill a Proto from the GPU. + PROTO_BUFFER_STEP_ID = -5, + // Used when allocating a Tensor where the caller has not indicated + // the step. + UNKNOWN_STEP_ID = -6, + }; + + static const string kLogMemoryLabel; + + // Test to see if memory logging is enabled. For now, logging is + // enabled whenever VLOG_IS_ON(1) for the log_memory module. + static bool IsEnabled(); + + // Log the beginning of a step. + static void RecordStep(int64 step_id, const string& handle); + + // Log a tensor buffer allocation. The name indicates which kernel + // made the allocation. If the allocation is made through an + // OpKernelContext the step_id indicates which step is executing, + // otherwise step_id is one of the SpecialStepIds defined in + // op_kernel.h, e.g. Op Kernel construction or an optimization pass + // such as constant folding. + static void RecordTensorAllocation(const string& kernel_name, int64 step_id, + const Tensor& tensor); + + // Log a tensor buffer deallocation. The deallocation is triggered + // when the buffer's refcount falls to zero, and the tracking + // mechanism does not associate it with a particular step or + // kernel. The allocation_id/allocator_name should match a + // corresponding tensor previously passed in to + // RecordTensorAllocation. + static void RecordTensorDeallocation(int64 allocation_id, + const string& allocator_name); + + // Log the use of a tensor as an output from a kernel. + static void RecordTensorOutput(const string& kernel_name, int64 step_id, + int index, const Tensor& tensor); + + // Log a "raw" allocation, which is just a buffer sized in + // bytes. The Eigen allocator, and memory copies, record their + // allocations this way, since they do not allocate TensorFlow + // tensors. The operation is set to the OpKernel name if this is + // called from within an Op execution, otherwise it indicates an + // operation such as memcpy. The step_id if >=0 indicates which step + // is executing, otherwise step_id is one of the SpecialStepIds + // defined in op_kernel.h, e.g. Op Kernel construction or an + // optimization pass such as constant folding. + static void RecordRawAllocation(const string& operation, int64 step_id, + size_t num_bytes, void* ptr, + Allocator* allocator); + + // Log a "raw" deallocation of a buffer. When deferred is true, the + // buffer won't be used again, but a GPU kernel may still be + // enqueued using the buffer. A deferred deallocation should always + // be followed by a matching non-deferred deallocation when the + // buffer is actually returned and can be reused. + static void RecordRawDeallocation(const string& operation, int64 step_id, + void* ptr, Allocator* allocator, + bool deferred); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_ diff --git a/log_memory.proto b/log_memory.proto new file mode 100644 index 0000000000000000000000000000000000000000..d1e126330d20b634c4ed482e704a08b7aaf3fa5f --- /dev/null +++ b/log_memory.proto @@ -0,0 +1,93 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "LogMemoryProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/tensor_description.proto"; + +message MemoryLogStep { + // Process-unique step id. + int64 step_id = 1; + + // Handle describing the feeds and fetches of the step. + string handle = 2; +}; + +message MemoryLogTensorAllocation { + // Process-unique step id. + int64 step_id = 1; + + // Name of the kernel making the allocation as set in GraphDef, + // e.g., "affine2/weights/Assign". + string kernel_name = 2; + + // Allocated tensor details. + TensorDescription tensor = 3; +}; + +message MemoryLogTensorDeallocation { + // Id of the tensor buffer being deallocated, used to match to a + // corresponding allocation. + int64 allocation_id = 1; + + // Name of the allocator used. + string allocator_name = 2; +}; + +message MemoryLogTensorOutput { + // Process-unique step id. + int64 step_id = 1; + + // Name of the kernel producing an output as set in GraphDef, e.g., + // "affine2/weights/Assign". + string kernel_name = 2; + + // Index of the output being set. + int32 index = 3; + + // Output tensor details. + TensorDescription tensor = 4; +} + +message MemoryLogRawAllocation { + // Process-unique step id. + int64 step_id = 1; + + // Name of the operation making the allocation. + string operation = 2; + + // Number of bytes in the allocation. + int64 num_bytes = 3; + + // Address of the allocation. + uint64 ptr = 4; + + // Id of the tensor buffer being allocated, used to match to a + // corresponding deallocation. + int64 allocation_id = 5; + + // Name of the allocator used. + string allocator_name = 6; +}; + +message MemoryLogRawDeallocation { + // Process-unique step id. + int64 step_id = 1; + + // Name of the operation making the deallocation. + string operation = 2; + + // Id of the tensor buffer being deallocated, used to match to a + // corresponding allocation. + int64 allocation_id = 3; + + // Name of the allocator used. + string allocator_name = 4; + + // True if the deallocation is queued and will be performed later, + // e.g. for GPU lazy freeing of buffers. + bool deferred = 5; +}; diff --git a/lookup_interface.cc b/lookup_interface.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf3204ea6e283d90390be46d5753a830705619ee --- /dev/null +++ b/lookup_interface.cc @@ -0,0 +1,87 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/lookup_interface.h" + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace lookup { + +Status LookupInterface::CheckKeyShape(const TensorShape& shape) { + if (!TensorShapeUtils::EndsWith(shape, key_shape())) { + return errors::InvalidArgument("Input key shape ", shape.DebugString(), + " must end with the table's key shape ", + key_shape().DebugString()); + } + return Status::OK(); +} + +Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys, + const Tensor& values) { + if (keys.dtype() != key_dtype()) { + return errors::InvalidArgument("Key must be type ", key_dtype(), + " but got ", keys.dtype()); + } + if (values.dtype() != value_dtype()) { + return errors::InvalidArgument("Value must be type ", value_dtype(), + " but got ", values.dtype()); + } + return Status::OK(); +} + +Status LookupInterface::CheckKeyAndValueTensorsHelper(const Tensor& keys, + const Tensor& values) { + TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values)); + TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape())); + + TensorShape expected_value_shape = keys.shape(); + for (int i = 0; i < key_shape().dims(); ++i) { + expected_value_shape.RemoveDim(expected_value_shape.dims() - 1); + } + expected_value_shape.AppendShape(value_shape()); + if (values.shape() != expected_value_shape) { + return errors::InvalidArgument( + "Expected shape ", expected_value_shape.DebugString(), + " for value, got ", values.shape().DebugString()); + } + return Status::OK(); +} + +Status LookupInterface::CheckKeyAndValueTensorsForInsert(const Tensor& keys, + const Tensor& values) { + return CheckKeyAndValueTensorsHelper(keys, values); +} + +Status LookupInterface::CheckKeyAndValueTensorsForImport(const Tensor& keys, + const Tensor& values) { + return CheckKeyAndValueTensorsHelper(keys, values); +} + +Status LookupInterface::CheckFindArguments(const Tensor& key, + const Tensor& default_value) { + TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value)); + TF_RETURN_IF_ERROR(CheckKeyShape(key.shape())); + if (default_value.shape() != value_shape()) { + return errors::InvalidArgument( + "Expected shape ", value_shape().DebugString(), + " for default value, got ", default_value.shape().DebugString()); + } + return Status::OK(); +} + +} // namespace lookup +} // namespace tensorflow diff --git a/lookup_interface.h b/lookup_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..1381dd66a56c7eb5d2a0f0aab760608a50b9b1b0 --- /dev/null +++ b/lookup_interface.h @@ -0,0 +1,145 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_ +#define TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class OpKernelContext; + +namespace lookup { + +// Forward declaration so we can define GetInitializableLookupTable() in +// LookupInterface. +class InitializableLookupTable; + +// Lookup interface for batch lookups used by table lookup ops. +class LookupInterface : public ResourceBase { + public: + // Performs batch lookups, for every element in the key tensor, Find returns + // the corresponding value into the values tensor. + // If an element is not present in the table, the given default value is used. + + // For tables that require initialization, Find is available once the table + // is marked as initialized. + + // Returns the following statuses: + // - OK: when the find finishes successfully. + // - FailedPrecondition: if the table is not initialized. + // - InvalidArgument: if any of the preconditions on the lookup key or value + // fails. + // - In addition, other implementations may provide another non-OK status + // specific to their failure modes. + virtual Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values, + const Tensor& default_value) = 0; + + // Inserts elements into the table. Each element of the key tensor is + // associated with the corresponding element in the value tensor. + // This method is only implemented in mutable tables that can be updated over + // the execution of the graph. It returns Status::NotImplemented for read-only + // tables that are initialized once before they can be looked up. + + // Returns the following statuses: + // - OK: when the insert finishes successfully. + // - InvalidArgument: if any of the preconditions on the lookup key or value + // fails. + // - Unimplemented: if the table does not support insertions. + virtual Status Insert(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) = 0; + + // Returns the number of elements in the table. + virtual size_t size() const = 0; + + // Exports the values of the table to two tensors named keys and values. + // Note that the shape of the tensors is completely up to the implementation + // of the table and can be different than the tensors used for the Insert + // function above. + virtual Status ExportValues(OpKernelContext* ctx) = 0; + + // Imports previously exported keys and values. + // As mentioned above, the shape of the keys and values tensors are determined + // by the ExportValues function above and can be different than for the + // Insert function. + virtual Status ImportValues(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) = 0; + + // Returns the data type of the key. + virtual DataType key_dtype() const = 0; + + // Returns the data type of the value. + virtual DataType value_dtype() const = 0; + + // Returns the shape of a key in the table. + virtual TensorShape key_shape() const = 0; + + // Returns the shape of a value in the table. + virtual TensorShape value_shape() const = 0; + + // Check format of the key and value tensors for the Insert function. + // Returns OK if all the following requirements are satisfied, otherwise it + // returns InvalidArgument: + // - DataType of the tensor keys equals to the table key_dtype + // - DataType of the tensor values equals to the table value_dtype + // - the values tensor has the required shape given keys and the tables's + // value shape. + virtual Status CheckKeyAndValueTensorsForInsert(const Tensor& keys, + const Tensor& values); + + // Similar to the function above but instead checks eligibility for the Import + // function. + virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys, + const Tensor& values); + + // Check the arguments of a find operation. Returns OK if all the following + // requirements are satisfied, otherwise it returns InvalidArgument: + // - DataType of the tensor keys equals to the table key_dtype + // - DataType of the tensor default_value equals to the table value_dtype + // - the default_value tensor shape matches the table's value shape. + Status CheckFindArguments(const Tensor& keys, const Tensor& default_value); + + string DebugString() override { + return strings::StrCat("A lookup table of size: ", size()); + } + + // Returns an InitializableLookupTable, a subclass of LookupInterface, if the + // current object is an InitializableLookupTable. Otherwise, returns nullptr. + virtual InitializableLookupTable* GetInitializableLookupTable() { + return nullptr; + } + + protected: + virtual ~LookupInterface() = default; + + // Makes sure that the key and value tensor DataType's match the table + // key_dtype and value_dtype. + Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values); + + // Makes sure that the provided shape is consistent with the table keys shape. + Status CheckKeyShape(const TensorShape& shape); + + private: + Status CheckKeyAndValueTensorsHelper(const Tensor& keys, + const Tensor& values); +}; + +} // namespace lookup +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_ diff --git a/memory_types.cc b/memory_types.cc new file mode 100644 index 0000000000000000000000000000000000000000..270118bb678e110269be9aa67a3904e36c34c512 --- /dev/null +++ b/memory_types.cc @@ -0,0 +1,156 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/memory_types.h" + +#include + +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { +// Returns the largest endpoint of anything in the name_map. +int GetTotal(const NameRangeMap& name_map) { + int total = 0; + for (const auto& item : name_map) { + total = std::max(total, item.second.second); + } + return total; +} + +// Fills memory_types for either input or output, setting everything +// to DEVICE_MEMORY except those args in host_memory_args. Removes +// elements of host_memory_args that were used. +void MemoryTypesHelper(const NameRangeMap& name_map, + std::vector* host_memory_args, + MemoryTypeVector* memory_types) { + // Update args that have been marked as in "HOST_MEMORY". + size_t keep = 0; + for (size_t i = 0; i < host_memory_args->size(); ++i) { + auto iter = name_map.find((*host_memory_args)[i]); + if (iter != name_map.end()) { + for (int j = iter->second.first; j < iter->second.second; ++j) { + (*memory_types)[j] = HOST_MEMORY; + } + } else { + // (*host_memory_args)[i] not found, save it for the next pass. + if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i]; + ++keep; + } + } + host_memory_args->resize(keep); +} + +MemoryType MTypeFromDType(const DataType dtype) { + return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY + : DEVICE_MEMORY; +} + +} // namespace + +Status MemoryTypesForNode(const OpRegistryInterface* op_registry, + const DeviceType& device_type, const NodeDef& ndef, + MemoryTypeVector* inp_mtypes, + MemoryTypeVector* out_mtypes) { + // Look up the Op registered for this op name. + const OpDef* op_def; + TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(ndef.op(), &op_def)); + + // Look up the Kernel registered for this node def. + const KernelDef* kdef = nullptr; + Status status = + FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */); + + DataTypeVector inp_dtypes; + DataTypeVector out_dtypes; + TF_RETURN_IF_ERROR( + InOutTypesForNode(ndef, *op_def, &inp_dtypes, &out_dtypes)); + + inp_mtypes->clear(); + out_mtypes->clear(); + + // For functions (which have no KernelDef) and their gradients, we can only + // best-effort derive the memory type from the data type. For now, we assume + // int32 is always on host memory and other types are always on device memory. + // TODO(zhifengc,phawkins): We should do type inference over function bodies + // to derive the correct input/output memory types. We should also split + // host-memory and non host-memory arguments into separate type lists. + if (!status.ok() || ndef.op() == "SymbolicGradient") { + for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t)); + for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t)); + return Status::OK(); + } + + // Gets the input/output names and their corresponding endpoint ranges. + NameRangeMap inp_names; + NameRangeMap out_names; + TF_RETURN_IF_ERROR(NameRangesForNode(ndef, *op_def, &inp_names, &out_names)); + + // Now that we know the size, fill with the default 'DEVICE_MEMORY'. + inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY); + out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY); + + // Fills in host memory types based on the kernel def. + const auto& from_proto = kdef->host_memory_arg(); + std::vector host_memory_args(from_proto.begin(), from_proto.end()); + MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes); + MemoryTypesHelper(out_names, &host_memory_args, out_mtypes); + if (!host_memory_args.empty()) { + return errors::InvalidArgument( + "HostMemory args '", str_util::Join(host_memory_args, "', '"), + "' not found in OpDef: ", SummarizeOpDef(*op_def)); + } + CHECK_LE(inp_mtypes->size(), inp_dtypes.size()); + CHECK_LE(out_mtypes->size(), out_dtypes.size()); + + // Mark e.g. all resource and string types as host memory. + for (int i = 0; i < inp_mtypes->size(); ++i) { + if (DataTypeAlwaysOnHost(inp_dtypes[i])) { + (*inp_mtypes)[i] = HOST_MEMORY; + } + } + for (int i = 0; i < out_mtypes->size(); ++i) { + if (DataTypeAlwaysOnHost(out_dtypes[i])) { + (*out_mtypes)[i] = HOST_MEMORY; + } + } + + std::vector hostmem_attr; + if (GetNodeAttr(ndef, "_input_hostmem", &hostmem_attr).ok()) { + for (int32 i : hostmem_attr) { + if (0 <= i && i < inp_mtypes->size()) { + (*inp_mtypes)[i] = HOST_MEMORY; + } + } + } + if (GetNodeAttr(ndef, "_output_hostmem", &hostmem_attr).ok()) { + for (int32 i : hostmem_attr) { + if (0 <= i && i < out_mtypes->size()) { + (*out_mtypes)[i] = HOST_MEMORY; + } + } + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/memory_types.h b/memory_types.h new file mode 100644 index 0000000000000000000000000000000000000000..d3918513d36c09a1e1d4e7e46c49a70c2376c198 --- /dev/null +++ b/memory_types.h @@ -0,0 +1,38 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +class NodeDef; + +// Returns into *{input,output}_memory_types the memory type of each +// {input,output} tensor. +// +// REQUIRES: * '*_memory_types' is not nullptr. +// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +Status MemoryTypesForNode(const OpRegistryInterface* op_registry, + const DeviceType& device_type, const NodeDef& ndef, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_ diff --git a/memory_types_test.cc b/memory_types_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3126ea8e5f8974cb11f88301de613eb5b920830f --- /dev/null +++ b/memory_types_test.cc @@ -0,0 +1,92 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/memory_types.h" + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class DummyKernel : public OpKernel { + public: + explicit DummyKernel(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(tensorflow::OpKernelContext* context) override {} +}; + +REGISTER_OP("HostMemoryTest") + .Input("a: float") + .Input("b: T") + .Input("c: N * string") + .Input("d: Tlist") + .Input("e: Rlist") + .Output("o: N * T") + .Output("p: Tlist") + .Attr("T: type") + .Attr("N: int") + .Attr("Tlist: list(type)") + .Attr("Rlist: list(type)"); +REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("HostMemoryTest") + .Device(DEVICE_GPU) + .HostMemory("a") + .HostMemory("c") + .HostMemory("d") + .HostMemory("o"), + DummyKernel); + +TEST(MemoryTypesForNode, Simple) { + NodeDef node_def; + TF_ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest") + .Input(FakeInput()) + .Input(FakeInput(DT_BOOL)) + .Input(FakeInput(3)) + .Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32})) + .Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE})) + .Finalize(&node_def)); + MemoryTypeVector input, output; + + TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def, + &input, &output)); + // a:float, b:bool, c:3*string, d:(int32, float, int32), + // e:(resource, string, resource) + EXPECT_EQ( + MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, + HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, + DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}), + input); + // o:3*bool, p:(int32, float, int32) + EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, + DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}), + output); + + TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def, + &input, &output)); + EXPECT_EQ( + MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, + HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, + HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}), + input); + EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, + DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}), + output); +} + +} // namespace tensorflow diff --git a/node_def.proto b/node_def.proto new file mode 100644 index 0000000000000000000000000000000000000000..8fcee32e2986611a1406dfc6998c9e2810b01034 --- /dev/null +++ b/node_def.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // TODO(josh11b): Add some examples here showing best practices. + map attr = 5; +}; diff --git a/node_def_builder.cc b/node_def_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..f9cf6ce87359d6e4df03306629b53a73e5673181 --- /dev/null +++ b/node_def_builder.cc @@ -0,0 +1,298 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def_builder.h" + +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt) + : node(n.ToString()), index(i), data_type(dt) {} + +NodeDefBuilder::NodeOut::NodeOut() { + // uninitialized, call Reset() before use. +} + +void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) { + node = n.ToString(); + index = i; + data_type = dt; +} + +NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, + const OpRegistryInterface* op_registry) { + node_def_.set_name(name.ToString()); + const Status status = op_registry->LookUpOpDef(op_name.ToString(), &op_def_); + if (status.ok()) { + Initialize(); + } else { + errors_.push_back(status.error_message()); + inputs_specified_ = 0; + } +} + +NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def) + : op_def_(op_def) { + node_def_.set_name(name.ToString()); + Initialize(); +} + +void NodeDefBuilder::Initialize() { + inputs_specified_ = 0; + node_def_.set_op(op_def_->name()); +} + +const OpDef::ArgDef* NodeDefBuilder::NextArgDef() { + if (!NextArgAvailable()) return nullptr; + return &op_def_->input_arg(inputs_specified_++); +} + +bool NodeDefBuilder::NextArgAvailable() { + if (op_def_ == nullptr) { + return false; + } else if (inputs_specified_ >= op_def_->input_arg_size()) { + errors_.push_back(strings::StrCat("More Input() calls than the ", + op_def_->input_arg_size(), + " input_args")); + return false; + } + return true; +} + +NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) { + if (NextArgAvailable()) { + Status status = fake_input(*op_def_, inputs_specified_, node_def_, this); + if (!status.ok()) errors_.push_back(status.error_message()); + } + return *this; +} + +NodeDefBuilder& NodeDefBuilder::Input(StringPiece src_node, int src_index, + DataType dt) { + const OpDef::ArgDef* arg = NextArgDef(); + if (arg != nullptr) SingleInput(arg, src_node, src_index, dt); + return *this; +} + +NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) { + Input(src.node, src.index, src.data_type); + return *this; +} + +// For inputs that take a list of tensors. +NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice src_list) { + const OpDef::ArgDef* arg = NextArgDef(); + if (arg != nullptr) ListInput(arg, src_list); + return *this; +} + +void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg, + StringPiece src_node, int src_index, + DataType dt) { + AddInput(src_node, src_index); + + if (!input_arg->number_attr().empty() || + !input_arg->type_list_attr().empty()) { + errors_.push_back(strings::StrCat("Single tensor passed to '", + input_arg->name(), "', expected list")); + return; + } + + if (input_arg->type() != DT_INVALID) { + const DataType expected = MaybeAddRef(input_arg, input_arg->type()); + VerifyInputType(input_arg, expected, dt); + } else { + VerifyInputRef(input_arg, dt); + Attr(input_arg->type_attr(), BaseType(dt)); + } +} + +void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg, + gtl::ArraySlice src_list) { + for (const auto& node_out : src_list) { + AddInput(node_out.node, node_out.index); + } + + if (!input_arg->number_attr().empty()) { + Attr(input_arg->number_attr(), static_cast(src_list.size())); + if (input_arg->type() != DT_INVALID) { + const DataType expected = MaybeAddRef(input_arg, input_arg->type()); + for (const auto& node_out : src_list) { + VerifyInputType(input_arg, expected, node_out.data_type); + } + } else if (!src_list.empty()) { + const DataType base = BaseType(src_list[0].data_type); + Attr(input_arg->type_attr(), base); + const DataType expected = MaybeAddRef(input_arg, base); + for (const auto& node_out : src_list) { + VerifyInputType(input_arg, expected, node_out.data_type); + } + } + } else if (!input_arg->type_list_attr().empty()) { + DataTypeVector type_vec; + type_vec.reserve(src_list.size()); + for (const auto& node_out : src_list) { + const DataType dt = node_out.data_type; + VerifyInputRef(input_arg, dt); + type_vec.push_back(BaseType(dt)); + } + Attr(input_arg->type_list_attr(), type_vec); + } else { + errors_.push_back(strings::StrCat("List provided to input '", + input_arg->name(), + "' when single Tensor expected")); + } +} + +void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) { + if (src_node.empty()) { + errors_.push_back("Empty input node name"); + } else if (src_node[0] == '^') { + errors_.push_back( + strings::StrCat("Non-control input starting with ^: ", src_node)); + } else if (src_index > 0) { + node_def_.add_input(strings::StrCat(src_node, ":", src_index)); + } else { + node_def_.add_input(src_node.ToString()); + } +} + +void NodeDefBuilder::VerifyInputType(const OpDef::ArgDef* input_arg, + DataType expected, DataType dt) { + if (!TypesCompatible(expected, dt)) { + errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ", + DataTypeString(dt), " expected ", + DataTypeString(expected))); + } +} + +void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg, + DataType dt) { + if (input_arg->is_ref() && !IsRefType(dt)) { + errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ", + DataTypeString(dt), + " expected ref type")); + } +} + +NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) { + control_inputs_.push_back(src_node.ToString()); + return *this; +} + +NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) { + node_def_.set_device(device_spec.ToString()); + return *this; +} + +Status NodeDefBuilder::Finalize(NodeDef* node_def) const { + const std::vector* errors_ptr = &errors_; + std::vector errors_storage; + if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) { + // Since this is a const method, to add an error, we have to make + // a copy of the existing errors. + errors_storage = errors_; + errors_storage.push_back( + strings::StrCat(inputs_specified_, " inputs specified of ", + op_def_->input_arg_size(), " inputs in Op")); + errors_ptr = &errors_storage; + } + + if (!errors_ptr->empty()) { + if (errors_ptr->size() == 1) { + if (op_def_ == nullptr) { + return errors::InvalidArgument((*errors_ptr)[0], + " while building NodeDef '", + node_def_.name(), "'"); + } + return errors::InvalidArgument( + (*errors_ptr)[0], " while building NodeDef '", node_def_.name(), + "' using ", SummarizeOpDef(*op_def_)); + } else { + return errors::InvalidArgument( + errors_ptr->size(), " errors while building NodeDef '", + node_def_.name(), "' using ", SummarizeOpDef(*op_def_), ":\n", + str_util::Join(*errors_ptr, "\n")); + } + } else { + NodeDef node_def_backup; + if (node_def == nullptr) node_def = &node_def_backup; + *node_def = node_def_; + + // Add control inputs after the regular inputs. + for (const auto& control_input : control_inputs_) { + node_def->add_input(strings::StrCat("^", control_input)); + } + + // Add default values for unspecified attrs. + AddDefaultsToNodeDef(*op_def_, node_def); + + return Status::OK(); + } +} + +NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) { + if (const AttrValue* found = AttrSlice(node_def_).Find(name)) { + if (!AreAttrValuesEqual(*found, value)) { + errors_.push_back(strings::StrCat("Inconsistent values for attr '", name, + "' ", SummarizeAttrValue(*found), + " vs. ", SummarizeAttrValue(value))); + } + } else { + AddNodeAttr(name, value, &node_def_); + } + return *this; +} + +#define ATTR(T) \ + NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \ + AttrValue attr_value; \ + SetAttrValue(value, &attr_value); \ + return Attr(name, attr_value); \ + } +ATTR(StringPiece) +ATTR(const char*) +ATTR(int32) +ATTR(int64) +ATTR(float) +ATTR(double) +ATTR(bool) +ATTR(DataType) +ATTR(const PartialTensorShape&) +ATTR(const Tensor&) +ATTR(const TensorProto&) +ATTR(const NameAttrList&) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(const std::vector&) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +#undef ATTR + +} // namespace tensorflow diff --git a/node_def_builder.h b/node_def_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..c138332bebc9877b74b16bf4576887db513acfc2 --- /dev/null +++ b/node_def_builder.h @@ -0,0 +1,178 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ +#define TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ + +#include +#include +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +class NodeDefBuilder; +typedef std::function + FakeInputFunctor; + +// This is a helper for creating a NodeDef. Automatically sets attrs +// that can be inferred from the inputs, and uses default values +// (where they exist) for unspecified attrs. Example usage: +// +// NodeDef node_def; +// Status status = NodeDefBuilder(node_name, op_name) +// .Input(...) +// .Attr(...) +// .Finalize(&node_def); +// if (!status.ok()) return status; +// // Use node_def here. +class NodeDefBuilder { + public: + // To specify an output to be consumed by one of the Input() methods below. + struct NodeOut { + NodeOut(StringPiece n, int i, DataType dt); + NodeOut(); // uninitialized, call Reset() before use. + void Reset(StringPiece n, int i, DataType dt); + string node; + int index; + DataType data_type; + }; + + // Specify the name and the Op (either via an OpDef or the name of + // the Op plus a registry) for the NodeDef. Other fields are + // specified by calling the methods below. + // REQUIRES: The OpDef must satisfy ValidateOpDef(). + NodeDefBuilder(StringPiece name, StringPiece op_name, + const OpRegistryInterface* op_registry = OpRegistry::Global()); + // REQUIRES: in addition, *op_def must outlive *this. + NodeDefBuilder(StringPiece name, const OpDef* op_def); + + // You must call one Input() function per input_arg in the Op, + // *and in the same order as the input_args appear in the OpDef.* + + // For inputs that take a single tensor. + NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt); + NodeDefBuilder& Input(const NodeOut& src); + + // For inputs that take a list of tensors. + NodeDefBuilder& Input(gtl::ArraySlice src_list); + + // To create inputs in tests, see fake_input.h. + NodeDefBuilder& Input(FakeInputFunctor fake_input); + + // Specify that this node must only run after src_node. + NodeDefBuilder& ControlInput(StringPiece src_node); + + // Constrains what devices this node may be scheduled on. + NodeDefBuilder& Device(StringPiece device_spec); + + // Sets the attr, if not already set. If already set with a different + // value, an error will be returned from Finalize(). + NodeDefBuilder& Attr(StringPiece name, const AttrValue& value); + NodeDefBuilder& Attr(StringPiece name, StringPiece value); + NodeDefBuilder& Attr(StringPiece name, const char* value); + NodeDefBuilder& Attr(StringPiece name, int32 value); + NodeDefBuilder& Attr(StringPiece name, int64 value); + NodeDefBuilder& Attr(StringPiece name, float value); + NodeDefBuilder& Attr(StringPiece name, double value); + NodeDefBuilder& Attr(StringPiece name, bool value); + NodeDefBuilder& Attr(StringPiece name, DataType value); + NodeDefBuilder& Attr(StringPiece name, const PartialTensorShape& value); + NodeDefBuilder& Attr(StringPiece name, const Tensor& value); + NodeDefBuilder& Attr(StringPiece name, const TensorProto& value); + NodeDefBuilder& Attr(StringPiece name, const NameAttrList& value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, const std::vector& value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, + gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, + gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + + template + NodeDefBuilder& Attr(StringPiece name, std::initializer_list value) { + return Attr(name, gtl::ArraySlice(value)); + } + + // Finish building the NodeDef, returning any errors or setting + // *node_def if none. + // WARNING: Not all problems are detected! The resulting NodeDef may + // not be valid! Call ValidateNodeDef() from node_def_utils to be sure. + Status Finalize(NodeDef* node_def) const; + + // Accessors for the values set in the constructor. + const string& node_name() const { return node_def_.name(); } + const OpDef& op_def() const { return *op_def_; } + + private: + // Called in the constructors. + void Initialize(); + + // Get the current ArgDef and advance to the next one. Returns nullptr + // if no more inputs are available. + const OpDef::ArgDef* NextArgDef(); + + // Returns true if there is still an input_arg available in *op_def_, + // otherwise adds to error_ and returns false. + bool NextArgAvailable(); + + // These do the main work of the Input() methods. + void SingleInput(const OpDef::ArgDef* input_arg, StringPiece src_node, + int src_index, DataType dt); + void ListInput(const OpDef::ArgDef* input_arg, + gtl::ArraySlice src_list); + + // Add "src_node:src_index" to the list of inputs in the node_def_. + void AddInput(StringPiece src_node, int src_index); + + // Generate an error if you can't pass dt when expected is expected. + void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected, + DataType dt); + + // If input_arg->is_ref() is true, generate an error if dt is not a ref. + void VerifyInputRef(const OpDef::ArgDef* input_arg, DataType dt); + + // Makes dt a ref type if that is what the input_arg specifies. + DataType MaybeAddRef(const OpDef::ArgDef* input_arg, DataType dt) { + return input_arg->is_ref() ? MakeRefType(dt) : dt; + } + + const OpDef* op_def_; + NodeDef node_def_; + int inputs_specified_; + std::vector control_inputs_; + std::vector errors_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ diff --git a/node_def_builder_test.cc b/node_def_builder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e836873f667a6971b2c12d44860e5436a04cb93c --- /dev/null +++ b/node_def_builder_test.cc @@ -0,0 +1,1059 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def_builder.h" + +#include +#include +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class NodeDefBuilderTest : public ::testing::Test { + protected: + // Specify an OpDef via an OpDefBuilder. + void Op(const OpDefBuilder& op_def_builder) { + OpRegistrationData op_reg_data; + TF_EXPECT_OK(op_def_builder.Finalize(&op_reg_data)); + op_def_ = op_reg_data.op_def; + } + + // Resets builder_ with a new NodeDefBuilder using the Op from the last call + // to Op() above. + NodeDefBuilder& Builder() { + EXPECT_FALSE(op_def_.name().empty()) << "Must call Op() before Builder()"; + builder_.reset(new NodeDefBuilder("n", &op_def_)); + return *builder_; + } + + // Calls Finalize() and verifies it returns success and the result matches + // expectations. + void ExpectSuccess(const NodeDefBuilder& builder, + DataTypeSlice expected_in_types, + DataTypeSlice expected_out_types, StringPiece proto) { + NodeDef node_def; + Status status = builder.Finalize(&node_def); + TF_EXPECT_OK(status); + if (!status.ok()) return; + NodeDef expected; + protobuf::TextFormat::ParseFromString(strings::StrCat("name: 'n' ", proto), + &expected); + EXPECT_EQ(node_def.DebugString(), expected.DebugString()); + + DataTypeVector in_types, out_types; + status = + InOutTypesForNode(node_def, builder.op_def(), &in_types, &out_types); + TF_EXPECT_OK(status); + if (!status.ok()) return; + EXPECT_EQ(DataTypeSliceString(expected_in_types), + DataTypeVectorString(in_types)); + EXPECT_EQ(DataTypeSliceString(expected_out_types), + DataTypeVectorString(out_types)); + + status = ValidateNodeDef(node_def, op_def_); + TF_EXPECT_OK(status); + } + + // Calls Finalize() and verifies it returns an error. + // Each message must appear as a substring of the error. + void ExpectFailures(const NodeDefBuilder& builder, + const std::vector& messages) { + NodeDef node_def; + Status status = builder.Finalize(&node_def); + EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); + if (status.ok()) return; + for (const string& message : messages) { + EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) + << status << ", " << message; + } + } + + // Calls Finalize() and verifies it returns an error. + // Message must appear as a substring of the error. + void ExpectFailure(const NodeDefBuilder& builder, const string& message) { + ExpectFailures(builder, {message}); + } + + // Like ExpectFailure(), except that the error can come from + // ValidateNodeDef(). + void ExpectInvalid(const NodeDefBuilder& builder, const string& message) { + NodeDef node_def; + Status status = builder.Finalize(&node_def); + if (status.ok()) { + status = ValidateNodeDef(node_def, op_def_); + } + EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); + if (status.ok()) return; + EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) + << "Actual error: " << status.error_message() + << "\nDoes not contain: " << message; + } + + OpDef op_def_; + std::unique_ptr builder_; +}; + +TEST_F(NodeDefBuilderTest, Simple) { + Op(OpDefBuilder("Simple").Input("a: int32").Output("out: float")); + + ExpectSuccess(Builder().Input("x", 0, DT_INT32), {DT_INT32}, {DT_FLOAT}, + R"proto( op: "Simple" input: "x" )proto"); + + // Port != 0 + ExpectSuccess(Builder().Input("y", 2, DT_INT32), {DT_INT32}, {DT_FLOAT}, + R"proto( op: "Simple" input: "y:2" )proto"); + + // FakeInput + ExpectSuccess(Builder().Input(FakeInput()), {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: "a" )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_FLOAT}, + R"proto( op: "Simple" input: "a" )proto"); + + // Ref input + ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32}, + {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto"); + + // ControlInput + ExpectSuccess( + Builder().ControlInput("x").Input(FakeInput()).ControlInput("y"), + {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: ["a", "^x", "^y"] )proto"); + + // Device + ExpectSuccess(Builder().Input(FakeInput()).Device("ddd"), {DT_INT32}, + {DT_FLOAT}, R"proto( + op: "Simple" input: "a" device: "ddd" )proto"); + + // Extra input + ExpectFailure(Builder().Input("x", 0, DT_INT32).Input("y", 0, DT_INT32), + "More Input() calls than the 1 input_args while building " + "NodeDef 'n' using Op " + "out:float>"); + + // Missing input + ExpectFailure(Builder(), "0 inputs specified of 1 inputs in Op while"); + + { // Finalize() twice. + NodeDefBuilder& builder = Builder(); + // First call to Finalize() + TF_EXPECT_OK(builder.Input(FakeInput()).Finalize(nullptr)); + // ExpectSuccess() also calls Finalize(). + ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: "a" )proto"); + } + + { // Input() after Finalize() + NodeDefBuilder& builder = Builder(); + // Calling Finalize() before enough inputs -> error. + ExpectFailure(builder, "0 inputs specified of 1 inputs in Op while"); + builder.Input(FakeInput()); + // Calling Finalize() with enough inputs -> success + ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: "a" )proto"); + // Calling Finalize() with too many inputs -> error. + builder.Input(FakeInput(DT_INT32)); + ExpectFailure(builder, "More Input() calls than the 1 input_args while"); + } + + // Wrong input type + ExpectFailure(Builder().Input("x", 0, DT_FLOAT), + "Input 'a' passed float expected int32 "); + + ExpectFailure(Builder().Input("x", 0, DT_FLOAT_REF), + "Input 'a' passed float_ref expected int32 "); + + // List input + ExpectFailure(Builder().Input(FakeInput(3, DT_FLOAT)), + "List provided to input 'a' when single Tensor expected while"); + + ExpectFailure(Builder().Input(FakeInput(3)), + "List provided to input 'a' when single Tensor expected while"); + + // Bad ControlInput + ExpectInvalid(Builder().Input(FakeInput()).ControlInput("z:2"), + "Control input '^z:2' must not have ':' in NodeDef:"); + + // Bad input name + ExpectFailure(Builder().Input("", 0, DT_INT32), + "Empty input node name while"); + + ExpectFailure(Builder().Input("^x", 0, DT_INT32), + "Non-control input starting with ^: ^x while"); +} + +TEST_F(NodeDefBuilderTest, OpDoesNotExist) { + NodeDefBuilder builder("n", "Op Does Not Exist"); + builder.Input(FakeInput()) + .Input(FakeInput(12)) + .ControlInput("y") + .Attr("foo", 12) + .Device("device"); + ExpectFailures(builder, {"Op type not registered 'Op Does Not Exist'", + "while building NodeDef 'n'"}); +} + +TEST_F(NodeDefBuilderTest, Polymorphic) { + Op(OpDefBuilder("Polymorphic") + .Input("v: T") + .Output("out: T") + .Attr("T: type")); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_INT32}, + R"proto( + op: "Polymorphic" input: "a" + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DT_FLOAT)), {DT_FLOAT}, {DT_FLOAT}, + R"proto( + op: "Polymorphic" input: "a" + attr { key: "T" value { type: DT_FLOAT } } )proto"); + + // Redundant Attr() + ExpectSuccess(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_BOOL), + {DT_BOOL}, {DT_BOOL}, R"proto( + op: "Polymorphic" input: "a" + attr { key: "T" value { type: DT_BOOL } } )proto"); + + // Conficting Attr() + ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_STRING), + "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while"); + + ExpectFailure(Builder().Attr("T", DT_STRING).Input(FakeInput(DT_BOOL)), + "Inconsistent values for attr 'T' DT_STRING vs. DT_BOOL while"); + + ExpectFailure(Builder().Attr("T", 12).Input(FakeInput(DT_BOOL)), + "Inconsistent values for attr 'T' 12 vs. DT_BOOL while"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicOut) { + Op(OpDefBuilder("PolymorphicOut").Output("out: T").Attr("T: type")); + + ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32}, R"proto( + op: "PolymorphicOut" + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Attr("T", DT_FLOAT), {}, {DT_FLOAT}, R"proto( + op: "PolymorphicOut" + attr { key: "T" value { type: DT_FLOAT } } )proto"); + + // Redundant attr + ExpectSuccess(Builder().Attr("T", DT_FLOAT).Attr("T", DT_FLOAT), {}, + {DT_FLOAT}, R"proto( + op: "PolymorphicOut" + attr { key: "T" value { type: DT_FLOAT } } )proto"); + + // Conflicting attr + ExpectFailure(Builder().Attr("T", DT_BOOL).Attr("T", DT_FLOAT), + "Inconsistent values for attr 'T' DT_BOOL vs. DT_FLOAT while"); + + // Missing attr + ExpectInvalid(Builder(), "NodeDef missing attr 'T' from"); + + // Attr has the wrong type + ExpectInvalid( + Builder().Attr("T", {DT_INT32, DT_BOOL}), + "AttrValue had value with type 'list(type)' when 'type' expected"); + + ExpectInvalid(Builder().Attr("T", 12), + "AttrValue had value with type 'int' when 'type' expected"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicDefaultOut) { + Op(OpDefBuilder("PolymorphicDefaultOut") + .Output("out: T") + .Attr("T: type = DT_STRING")); + + ExpectSuccess(Builder(), {}, {DT_STRING}, R"proto( + op: "PolymorphicDefaultOut" + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectSuccess(Builder().Attr("T", DT_BOOL), {}, {DT_BOOL}, R"proto( + op: "PolymorphicDefaultOut" + attr { key: "T" value { type: DT_BOOL } } )proto"); +} + +TEST_F(NodeDefBuilderTest, Binary) { + Op(OpDefBuilder("Binary").Input("a: T").Input("b: T").Output("out: T").Attr( + "T: type")); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32)).Input(FakeInput(DT_INT32)), + {DT_INT32, DT_INT32}, {DT_INT32}, R"proto( + op: "Binary" input: "a" input: "b" + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DT_STRING)).Input(FakeInput()), + {DT_STRING, DT_STRING}, {DT_STRING}, R"proto( + op: "Binary" input: "a" input: "b" + attr { key: "T" value { type: DT_STRING } } )proto"); + + // Type mismatch + ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Input(FakeInput(DT_STRING)), + "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while"); +} + +TEST_F(NodeDefBuilderTest, Restrict) { + Op(OpDefBuilder("Restrict") + .Input("a: T") + .Output("out: T") + .Attr("T: {string, bool}")); + ExpectSuccess(Builder().Input(FakeInput(DT_STRING)), {DT_STRING}, {DT_STRING}, + R"proto( + op: "Restrict" input: "a" + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectInvalid(Builder().Input(FakeInput(DT_INT32)), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, TypeList) { + Op(OpDefBuilder("TypeList").Input("a: T").Attr("T: list(type)")); + + ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_INT32})), + {DT_STRING, DT_INT32}, {}, R"proto( + op: "TypeList" input: ["a", "a:1"] + attr { key: "T" value { list { type: [DT_STRING, DT_INT32] } } } + )proto"); + + ExpectSuccess(Builder().Input(FakeInput(3, DT_BOOL)), + {DT_BOOL, DT_BOOL, DT_BOOL}, {}, R"proto( + op: "TypeList" input: ["a", "a:1", "a:2"] + attr { key: "T" value { list { type: [DT_BOOL, DT_BOOL, DT_BOOL] } } } + )proto"); + + ExpectInvalid(Builder().Input(FakeInput(0)), + "Length for attr 'T' of 0 must be at least minimum 1"); + + ExpectInvalid(Builder().Input(FakeInput({})), + "Length for attr 'T' of 0 must be at least minimum 1"); + + ExpectInvalid(Builder().Input(FakeInput(DT_BOOL)), + "Single tensor passed to 'a', expected list while"); + + ExpectFailures(Builder().Input(FakeInput()), + {"2 errors while building NodeDef", + "Could not infer list of types for input 'a': " + "No attr named 'T' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); +} + +TEST_F(NodeDefBuilderTest, TypeListNoMin) { + Op(OpDefBuilder("TypeListNoMin").Input("a: T").Attr("T: list(type) >= 0")); + + ExpectSuccess(Builder().Input(FakeInput(0)), {}, {}, R"proto( + op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DataTypeVector())), {}, {}, R"proto( + op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput({})), {}, {}, R"proto( + op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput({DT_BOOL})), {DT_BOOL}, {}, R"proto( + op: "TypeListNoMin" input: "a" + attr { key: "T" value { list { type: DT_BOOL } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, TypeListTwice) { + Op(OpDefBuilder("TypeListTwice") + .Input("a: T") + .Input("b: T") + .Attr("T: list(type) >= 0")); + + ExpectSuccess(Builder() + .Input(FakeInput({DT_INT32, DT_BOOL})) + .Input(FakeInput({DT_INT32, DT_BOOL})), + {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto( + op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"] + attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto"); + + ExpectSuccess( + Builder().Input(FakeInput({DT_INT32, DT_BOOL})).Input(FakeInput()), + {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto( + op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"] + attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(0)), {}, {}, + R"proto( + op: "TypeListTwice" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {}, + R"proto( + op: "TypeListTwice" attr { key: "T" value { list { } } } )proto"); + + ExpectFailure(Builder() + .Input(FakeInput({DT_INT32, DT_BOOL})) + .Input(FakeInput({DT_INT32, DT_STRING})), + "Inconsistent values for attr 'T' [DT_INT32, DT_BOOL] vs. " + "[DT_INT32, DT_STRING] while"); +} + +TEST_F(NodeDefBuilderTest, OutTypeList) { + Op(OpDefBuilder("OutTypeList").Output("out: T").Attr("T: list(type) >= 0")); + + ExpectSuccess(Builder().Attr("T", {DT_FLOAT}), {}, {DT_FLOAT}, R"proto( + op: "OutTypeList" + attr { key: "T" value { list { type: DT_FLOAT } } } )proto"); + + ExpectSuccess(Builder().Attr("T", {DT_STRING, DT_BOOL}), {}, + {DT_STRING, DT_BOOL}, R"proto( + op: "OutTypeList" + attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto"); + + ExpectSuccess(Builder().Attr("T", DataTypeVector()), {}, {}, R"proto( + op: "OutTypeList" + attr { key: "T" value { list { } } } )proto"); + + ExpectInvalid( + Builder().Attr("T", DT_FLOAT), + "AttrValue had value with type 'type' when 'list(type)' expected"); +} + +TEST_F(NodeDefBuilderTest, TypeListRestrict) { + Op(OpDefBuilder("TypeListRestrict") + .Input("a: T") + .Attr("T: list({string, bool}) >= 0")); + + ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_BOOL})), + {DT_STRING, DT_BOOL}, {}, R"proto( + op: "TypeListRestrict" input: ["a", "a:1"] + attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto"); + + ExpectInvalid(Builder().Input(FakeInput({DT_STRING, DT_INT32})), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, OutTypeListRestrict) { + Op(OpDefBuilder("OutTypeListRestrict") + .Output("out: t") + .Attr("t: list({string, bool}) >= 0")); + + ExpectSuccess(Builder().Attr("t", {DT_BOOL, DT_STRING}), {}, + {DT_BOOL, DT_STRING}, R"proto( + op: "OutTypeListRestrict" + attr { key: "t" value { list { type: [DT_BOOL, DT_STRING] } } } )proto"); + + ExpectInvalid(Builder().Attr("t", {DT_STRING, DT_INT32}), + "Value for attr 't' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, Attr) { + Op(OpDefBuilder("Attr").Attr("a: int")); + + ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto( + op: "Attr" attr { key: "a" value { i: 12 } } )proto"); + + // Attr has wrong type + ExpectInvalid(Builder().Attr("a", "bad"), + "AttrValue had value with type 'string' when 'int' expected"); + + ExpectInvalid( + Builder().Attr("a", {12}), + "AttrValue had value with type 'list(int)' when 'int' expected"); + + // Missing attr + ExpectInvalid(Builder(), "NodeDef missing attr 'a' from Op<"); + + // Wrong attr + ExpectInvalid(Builder().Attr("b", 12), + "NodeDef mentions attr 'b' not in Op<"); + + // Extra attr + ExpectInvalid(Builder().Attr("a", 12).Attr("extra", 12), + "NodeDef mentions attr 'extra' not in Op<"); +} + +TEST_F(NodeDefBuilderTest, AttrFloat) { + Op(OpDefBuilder("AttrFloat").Attr("a: float")); + + ExpectSuccess(Builder().Attr("a", 1.2f /* float */), {}, {}, R"proto( + op: "AttrFloat" attr { key: "a" value { f: 1.2 } } + )proto"); + + ExpectSuccess(Builder().Attr("a", 1.2 /* double */), {}, {}, R"proto( + op: "AttrFloat" attr { key: "a" value { f: 1.2 } } + )proto"); + + // Won't automatically cast int to float + ExpectInvalid(Builder().Attr("a", 12), + "AttrValue had value with type 'int' when 'float' expected"); +} + +TEST_F(NodeDefBuilderTest, AttrBoolList) { + Op(OpDefBuilder("AttrBoolList").Attr("a: list(bool)")); + + ExpectSuccess(Builder().Attr("a", {true, false, true}), {}, {}, R"proto( + op: "AttrBoolList" + attr { key: "a" value { list { b: [true, false, true] } } } + )proto"); + + ExpectSuccess(Builder().Attr("a", std::vector()), {}, {}, R"proto( + op: "AttrBoolList" attr { key: "a" value { list { } } } + )proto"); + + // Won't cast int -> bool. + ExpectInvalid(Builder().Attr("a", {0}), + "AttrValue had value with type 'list(int)' when 'list(bool)' " + "expected"); +} + +TEST_F(NodeDefBuilderTest, AttrMin) { + Op(OpDefBuilder("AttrMin").Attr("a: int >= 5")); + + ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto( + op: "AttrMin" attr { key: "a" value { i: 12 } } )proto"); + + ExpectInvalid(Builder().Attr("a", 2), + "Value for attr 'a' of 2 must be at least minimum 5"); +} + +TEST_F(NodeDefBuilderTest, AttrListMin) { + Op(OpDefBuilder("AttrListMin").Attr("a: list(int) >= 2")); + + ExpectSuccess(Builder().Attr("a", {1, 2}), {}, {}, R"proto( + op: "AttrListMin" + attr { key: "a" value { list { i: [1, 2] } } } )proto"); + + ExpectInvalid(Builder().Attr("a", {17}), + "Length for attr 'a' of 1 must be at least minimum 2"); +} + +TEST_F(NodeDefBuilderTest, AttrEnum) { + Op(OpDefBuilder("AttrEnum").Attr("a: {'apples', 'oranges'}")); + + ExpectSuccess(Builder().Attr("a", "oranges"), {}, {}, R"proto( + op: "AttrEnum" + attr { key: "a" value { s: "oranges" } } )proto"); + + ExpectInvalid( + Builder().Attr("a", "invalid"), + "Value for attr 'a' of \"invalid\" is not in the list of allowed values: " + "\"apples\", \"oranges\""); +} + +TEST_F(NodeDefBuilderTest, AttrEnumList) { + Op(OpDefBuilder("AttrEnumList").Attr("a: list({'apples', 'oranges'})")); + + ExpectSuccess(Builder().Attr("a", {"oranges", "apples"}), {}, {}, R"proto( + op: "AttrEnumList" + attr { key: "a" value { list { s: ["oranges", "apples"] } } } )proto"); + + ExpectInvalid( + Builder().Attr("a", {"apples", "invalid", "oranges"}), + "Value for attr 'a' of \"invalid\" is not in the list of allowed values: " + "\"apples\", \"oranges\""); +} + +TEST_F(NodeDefBuilderTest, AttrShape) { + Op(OpDefBuilder("AttrShape").Attr("a: shape")); + + ExpectSuccess(Builder().Attr("a", TensorShape({5})), {}, {}, R"proto( + op: "AttrShape" + attr { key: "a" value { shape { dim { size: 5 } } } } )proto"); + + ExpectSuccess(Builder().Attr("a", TensorShape({4, 3, 2})), {}, {}, R"proto( + op: "AttrShape" + attr { key: "a" value { shape { + dim { size: 4 } dim { size: 3 } dim { size: 2 } } } } )proto"); + + ExpectSuccess(Builder().Attr("a", TensorShape({3, 2})), {}, {}, + R"proto( + op: "AttrShape" + attr { key: "a" value { shape { + dim { size: 3 } dim { size: 2 } } } } )proto"); + + ExpectSuccess(Builder().Attr("a", TensorShape()), {}, {}, R"proto( + op: "AttrShape" + attr { key: "a" value { shape { } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrDefault) { + Op(OpDefBuilder("AttrDefault").Attr("a: string = 'banana'")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrDefault" + attr { key: "a" value { s: "banana" } } )proto"); + + ExpectSuccess(Builder().Attr("a", "kiwi"), {}, {}, R"proto( + op: "AttrDefault" + attr { key: "a" value { s: "kiwi" } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrManyDefault) { + Op(OpDefBuilder("AttrManyDefault") + .Attr("a: string = 'banana'") + .Attr("b: string = 'kiwi'")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrManyDefault" + attr { key: "a" value { s: "banana" } } + attr { key: "b" value { s: "kiwi" } } )proto"); + + Op(OpDefBuilder("AttrManyDefaultWithMandatory") + .Attr("a: string = 'banana'") + .Attr("b: string = 'kiwi'") + .Attr("c: string")); + + ExpectSuccess(Builder().Attr("c", "strawberry"), {}, {}, R"proto( + op: "AttrManyDefaultWithMandatory" + attr { key: "c" value { s: "strawberry" } } + attr { key: "a" value { s: "banana" } } + attr { key: "b" value { s: "kiwi" } } )proto"); + + Op(OpDefBuilder("AttrManyDefaultAndInferred") + .Input("input: T") + .Attr("T: {float, double}") + .Attr("a: string") + .Attr("b: list(string) >= 1") + .Attr("c: bool = true") + .Attr("d: float = 0.3") + .Attr("e: string") + .Attr("f: float = 0.25")); + + ExpectSuccess(Builder() + .Input(FakeInput(DT_FLOAT)) + .Attr("a", "foo") + .Attr("e", "foo") + .Attr("b", std::vector({"bar", "baz"})) + .Attr("f", 1.0f), + {DT_FLOAT}, {}, R"proto( + op: "AttrManyDefaultAndInferred" + input: "a" + attr { key: "T" value { type: DT_FLOAT } } + attr { key: "a" value { s: "foo" } } + attr { key: "e" value { s: "foo" } } + attr { key: "b" value { list { s: "bar" s: "baz" } } } + attr { key: "f" value { f: 1.0 } } + attr { key: "c" value { b: true } } + attr { key: "d" value { f: 0.3 } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrListDefault) { + Op(OpDefBuilder("AttrListDefault").Attr("a: list(int) = [5, 15]")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrListDefault" + attr { key: "a" value { list { i: [5, 15] } } } )proto"); + + ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto( + op: "AttrListDefault" + attr { key: "a" value { list { i: 3 } } } )proto"); + + ExpectSuccess(Builder().Attr("a", std::vector()), {}, {}, R"proto( + op: "AttrListDefault" + attr { key: "a" value { list { } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrEmptyListDefault) { + Op(OpDefBuilder("AttrEmptyListDefault").Attr("a: list(int) = []")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrEmptyListDefault" + attr { key: "a" value { list { } } } )proto"); + + ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto( + op: "AttrEmptyListDefault" + attr { key: "a" value { list { i: 3 } } } )proto"); + + ExpectSuccess(Builder().Attr("a", std::vector()), {}, {}, R"proto( + op: "AttrEmptyListDefault" + attr { key: "a" value { list { } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, NIntsIn) { + Op(OpDefBuilder("NIntsIn").Input("a: N*int32").Attr("N: int >= 2")); + + ExpectSuccess(Builder().Input(FakeInput(2)), {DT_INT32, DT_INT32}, {}, + R"proto( + op: "NIntsIn" input: ["a", "a:1"] + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(5, DT_INT32)), + {DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( + op: "NIntsIn" + input: ["a", "a:1", "a:2", "a:3", "a:4"] + attr { key: "N" value { i: 5 } } )proto"); + + ExpectFailures(Builder().Input(FakeInput(2, DT_STRING)), + {"2 errors while building NodeDef", + "Input 'a' passed string expected int32"}); + + ExpectInvalid(Builder().Input(FakeInput(1)), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectFailures( + Builder().Input(FakeInput(DT_INT32)), + {"2 errors while building NodeDef", + "Could not infer length of input 'a': No attr named 'N' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); + + ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}), + "Input 'a' passed string expected int32 while"); + + ExpectFailures( + Builder().Input(FakeInput()), + {"2 errors while building NodeDef", + "Could not infer length of input 'a': No attr named 'N' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicIn) { + Op(OpDefBuilder("NPolymorphicIn") + .Input("a: N*T") + .Attr("T: type") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)), {DT_INT32, DT_INT32}, + {}, R"proto( + op: "NPolymorphicIn" input: ["a", "a:1"] + attr { key: "N" value { i: 2 } } + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)), + {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto( + op: "NPolymorphicIn" + input: ["a", "a:1", "a:2"] + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectFailures( + Builder().Input(FakeInput(2)), + {"2 errors while building NodeDef", + "Could not infer type for input 'a': No attr named 'T' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); + + ExpectFailure(Builder().Input(FakeInput({DT_INT32, DT_STRING})), + "Input 'a' passed string expected int32 while"); + + ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}), + "Input 'a' passed string expected int32 while"); + + ExpectInvalid(Builder().Input(FakeInput(1, DT_INT32)), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectFailure(Builder().Input("in", 0, DT_INT32), + "Single tensor passed to 'a', expected list while"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicRestrictIn) { + Op(OpDefBuilder("NPolymorphicRestrictIn") + .Input("a: N*T") + .Attr("T: {string, bool}") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Input(FakeInput(2, DT_BOOL)), {DT_BOOL, DT_BOOL}, {}, + R"proto( + op: "NPolymorphicRestrictIn" input: ["a", "a:1"] + attr { key: "N" value { i: 2 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)), + {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto( + op: "NPolymorphicRestrictIn" + input: ["a", "a:1", "a:2"] + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectInvalid(Builder().Input(FakeInput(2, DT_INT32)), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, NInTwice) { + Op(OpDefBuilder("NInTwice") + .Input("a: N*int32") + .Input("b: N*string") + .Attr("N: int >= 0")); + + ExpectSuccess(Builder().Input(FakeInput(2)).Input(FakeInput(2)), + {DT_INT32, DT_INT32, DT_STRING, DT_STRING}, {}, R"proto( + op: "NInTwice" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {}, + R"proto( + op: "NInTwice" attr { key: "N" value { i: 0 } } )proto"); + + ExpectFailure(Builder().Input(FakeInput(3)).Input(FakeInput(1)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); +} + +TEST_F(NodeDefBuilderTest, NInPolymorphicTwice) { + Op(OpDefBuilder("NInPolymorphicTwice") + .Input("a: N*T") + .Input("b: N*T") + .Attr("T: type") + .Attr("N: int >= 0")); + + ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput()), + {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( + op: "NInPolymorphicTwice" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectFailure( + Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_INT32)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); + + ExpectFailure(Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); + + ExpectFailure( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)), + "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); + + ExpectFailure( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_STRING)), + "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); +} + +TEST_F(NodeDefBuilderTest, NInTwoTypeVariables) { + Op(OpDefBuilder("NInTwoTypeVariables") + .Input("a: N*S") + .Input("b: N*T") + .Attr("S: type") + .Attr("T: type") + .Attr("N: int >= 0")); + + ExpectSuccess( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_BOOL)), + {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto( + op: "NInTwoTypeVariables" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } + attr { key: "S" value { type: DT_INT32 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectSuccess( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_BOOL)), + {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto( + op: "NInTwoTypeVariables" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } + attr { key: "S" value { type: DT_INT32 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectFailure( + Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_STRING)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); +} + +TEST_F(NodeDefBuilderTest, InPolymorphicTwice) { + Op(OpDefBuilder("InPolymorphicTwice") + .Input("a: N*T") + .Input("b: M*T") + .Attr("T: type") + .Attr("N: int >= 0") + .Attr("M: int >= 0")); + + ExpectSuccess( + Builder().Input(FakeInput(1, DT_INT32)).Input(FakeInput(3, DT_INT32)), + {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( + op: "InPolymorphicTwice" + input: ["a", "b", "b:1", "b:2"] + attr { key: "N" value { i: 1 } } + attr { key: "T" value { type: DT_INT32 } } + attr { key: "M" value { i: 3 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(1, DT_BOOL)).Input(FakeInput(0)), + {DT_BOOL}, {}, R"proto( + op: "InPolymorphicTwice" input: "a" + attr { key: "N" value { i: 1 } } + attr { key: "T" value { type: DT_BOOL } } + attr { key: "M" value { i: 0 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(1, DT_BOOL)), + {DT_BOOL}, {}, R"proto( + op: "InPolymorphicTwice" input: "b" + attr { key: "N" value { i: 0 } } + attr { key: "M" value { i: 1 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectFailure( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)), + "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); +} + +TEST_F(NodeDefBuilderTest, NIntsOut) { + Op(OpDefBuilder("NIntsOut").Output("a: N*int32").Attr("N: int >= 2")); + + ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( + op: "NIntsOut" + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3), {}, {DT_INT32, DT_INT32, DT_INT32}, + R"proto( + op: "NIntsOut" + attr { key: "N" value { i: 3 } } )proto"); + + ExpectInvalid(Builder().Attr("N", 1), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectInvalid( + Builder().Attr("N", {3}), + "AttrValue had value with type 'list(int)' when 'int' expected"); + + ExpectInvalid(Builder(), "NodeDef missing attr 'N' from"); +} + +TEST_F(NodeDefBuilderTest, NIntsOutDefault) { + Op(OpDefBuilder("NIntsOutDefault") + .Output("a: N*int32") + .Attr("N: int >= 2 = 3")); + + ExpectSuccess(Builder(), {}, {DT_INT32, DT_INT32, DT_INT32}, R"proto( + op: "NIntsOutDefault" + attr { key: "N" value { i: 3 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( + op: "NIntsOutDefault" + attr { key: "N" value { i: 2 } } )proto"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicOut) { + Op(OpDefBuilder("NPolymorphicOut") + .Output("a: N*T") + .Attr("T: type") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Attr("T", DT_INT32).Attr("N", 2), {}, + {DT_INT32, DT_INT32}, R"proto( + op: "NPolymorphicOut" + attr { key: "T" value { type: DT_INT32 } } + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_STRING), {}, + {DT_STRING, DT_STRING, DT_STRING}, R"proto( + op: "NPolymorphicOut" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectInvalid(Builder().Attr("N", 1).Attr("T", DT_STRING), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectInvalid( + Builder().Attr("N", 3).Attr("T", {DT_STRING}), + "AttrValue had value with type 'list(type)' when 'type' expected"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicOutDefault) { + Op(OpDefBuilder("NPolymorphicOutDefault") + .Output("a: N*T") + .Attr("T: type = DT_BOOL") + .Attr("N: int >= 2 = 2")); + + ExpectSuccess(Builder(), {}, {DT_BOOL, DT_BOOL}, R"proto( + op: "NPolymorphicOutDefault" + attr { key: "T" value { type: DT_BOOL } } + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3), {}, {DT_BOOL, DT_BOOL, DT_BOOL}, + R"proto( + op: "NPolymorphicOutDefault" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32, DT_INT32}, + R"proto( + op: "NPolymorphicOutDefault" + attr { key: "T" value { type: DT_INT32 } } + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_INT32), {}, + {DT_INT32, DT_INT32, DT_INT32}, R"proto( + op: "NPolymorphicOutDefault" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_INT32 } } )proto"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicRestrictOut) { + Op(OpDefBuilder("NPolymorphicRestrictOut") + .Output("a: N*T") + .Attr("T: {string, bool}") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_BOOL), {}, + {DT_BOOL, DT_BOOL, DT_BOOL}, R"proto( + op: "NPolymorphicRestrictOut" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectInvalid(Builder().Attr("N", 3).Attr("T", DT_INT32), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, RefIn) { + Op(OpDefBuilder("RefIn").Input("a: Ref(int32)")); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32_REF}, {}, + R"proto( + op: "RefIn" input: "a" )proto"); + + ExpectFailure(Builder().Input(FakeInput(DT_BOOL_REF)), + "Input 'a' passed bool_ref expected int32_ref while"); + + ExpectFailure(Builder().Input(FakeInput(DT_INT32)), + "Input 'a' passed int32 expected int32_ref while"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicRefIn) { + Op(OpDefBuilder("PolymorphicRefIn").Input("a: Ref(T)").Attr("T: type")); + + ExpectSuccess(Builder().Input(FakeInput(DT_BOOL_REF)), {DT_BOOL_REF}, {}, + R"proto( + op: "PolymorphicRefIn" input: "a" + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectFailure(Builder().Input(FakeInput(DT_BOOL)), + "Input 'a' passed bool expected ref type while"); +} + +TEST_F(NodeDefBuilderTest, RefOut) { + Op(OpDefBuilder("RefOut").Output("a: Ref(string)")); + + ExpectSuccess(Builder(), {}, {DT_STRING_REF}, R"proto( + op: "RefOut" )proto"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicRefOut) { + Op(OpDefBuilder("PolymorphicRefOut").Output("a: Ref(t)").Attr("t: type")); + + ExpectSuccess(Builder().Attr("t", DT_BOOL), {}, {DT_BOOL_REF}, R"proto( + op: "PolymorphicRefOut" + attr { key: "t" value { type: DT_BOOL } } )proto"); +} + +TEST_F(NodeDefBuilderTest, SpecifyDevice) { + Op(OpDefBuilder("SpecifyDevice")); + + ExpectSuccess(Builder().Device("ADevice"), {}, {}, R"proto( + op: "SpecifyDevice" device: "ADevice" )proto"); +} + +} // namespace +} // namespace tensorflow diff --git a/node_def_util.cc b/node_def_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..477184022df4bb7e4d329cc5ed09572f9dbe9585 --- /dev/null +++ b/node_def_util.cc @@ -0,0 +1,661 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def_util.h" + +#include +#include +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/graph.pb_text.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb_text.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/tensor.pb_text.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +const char* const kColocationAttrName = "_class"; +const char* const kColocationGroupPrefix = "loc:@"; + +AttrSlice::AttrSlice() : ndef_(nullptr) { + static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap; + attrs_ = kEmptyAttrValueMap; +} + +AttrSlice::AttrSlice(const NodeDef& node_def) + : ndef_(&node_def), attrs_(&ndef_->attr()) {} + +AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {} + +static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) { + string ret; + + // We sort the attrs so the output is deterministic. + std::vector attr_names; + attr_names.reserve(attrs.size()); + for (const auto& attr : attrs) { + attr_names.push_back(attr.first); + } + std::sort(attr_names.begin(), attr_names.end()); + bool first = true; + for (const string& attr_name : attr_names) { + if (!first) strings::StrAppend(&ret, ", "); + first = false; + strings::StrAppend(&ret, attr_name, "=", + SummarizeAttrValue(*attrs.Find(attr_name))); + } + + // Consider the device to be a final attr with name "_device". + if (!device.empty()) { + if (!first) strings::StrAppend(&ret, ", "); + first = false; + strings::StrAppend(&ret, "_device=\"", device, "\""); + } + return ret; +} + +string AttrSlice::SummarizeNode() const { + return ndef_ ? SummarizeNodeDef(*ndef_) + : strings::StrCat( + "[", SummarizeAttrsHelper(*this, StringPiece()), "]"); +} + +string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); } + +string SummarizeNodeDef(const NodeDef& node_def) { + string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "["); + strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device())); + strings::StrAppend(&ret, "]("); + + // Output inputs, including control inputs, verbatim. + bool first = true; + for (const string& input : node_def.input()) { + if (!first) strings::StrAppend(&ret, ", "); + first = false; + strings::StrAppend(&ret, input); + } + strings::StrAppend(&ret, ")"); + return ret; +} + +const AttrValue* AttrSlice::Find(StringPiece attr_name) const { + // Currently, the collection used for NodeDef::attr() (google::protobuf::Map) + // requires that the keys used for lookups have type 'const string&'. Because + // this method takes a StringPiece, it is necessary to allocate a temporary + // string, copy attr_name to it, and then use that temporary string for the + // lookup. This causes an excessive number of short-lived allocations, and for + // large graphs, this can be a significant cost. + // + // Because most nodes have a small number of attributes, a simple linear scan + // is generally more efficient than a hashed lookup. If google::protobuf::Map + // changes so that it supports efficient lookups using StringPiece instead of + // const string&, then this code could be changed to use attrs_->find() again. + + for (const auto& attr : *attrs_) { + if (attr.first == attr_name) { + return &attr.second; + } + } + return nullptr; +} + +Status AttrSlice::Find(StringPiece attr_name, + const AttrValue** attr_value) const { + *attr_value = Find(attr_name); + if (*attr_value != nullptr) { + return Status::OK(); + } + Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:"); + // Skip AttachDef for internal attrs since it is a little bit + // expensive and it is common for them to correctly not be included + // in a NodeDef. + if (!attr_name.starts_with("_") && ndef_ != nullptr) { + s = AttachDef(s, *ndef_); + } + return s; +} + +bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const { + if (size() != other.size()) return false; + + for (const auto& attr : *other.attrs_) { + auto iter = attrs_->find(attr.first); + if (iter == attrs_->end()) return false; + // TODO(irving): Comparing AttrValues by proto is slightly buggy, since + // TensorProto is a nonunique representation of Tensor. This bug will go + // away once AttrSlice switches over to NodeInfo. + iter->second.SerializeToString(&scratch->a); + attr.second.SerializeToString(&scratch->b); + if (scratch->a != scratch->b) return false; + } + return true; +} + +// The ... is to allow the caller to inject some value validation code. Use +// just ; if no additional validation code is needed. +#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \ + Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \ + TYPE* value) { \ + const AttrValue* attr_value; \ + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \ + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \ + const auto& v = attr_value->FIELD(); \ + __VA_ARGS__; \ + *value = CAST; \ + return Status::OK(); \ + } \ + Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \ + std::vector* value) { \ + const AttrValue* attr_value; \ + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \ + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \ + for (const auto& v : attr_value->list().FIELD()) { \ + __VA_ARGS__; \ + value->APPEND_OP(CAST); \ + } \ + return Status::OK(); \ + } + +#define DEFINE_GET_ATTR_SIMPLE(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \ + bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name, \ + TYPE* value) { \ + const AttrValue* attr_value = attrs.Find(attr_name); \ + if (attr_value == nullptr) { \ + return false; \ + } \ + Status s = AttrValueHasType(*attr_value, ATTR_TYPE); \ + if (!s.ok()) { \ + return false; \ + } \ + const auto& v = attr_value->FIELD(); \ + __VA_ARGS__; \ + *value = CAST; \ + return true; \ + } \ + bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name, \ + std::vector* value) { \ + const AttrValue* attr_value = attrs.Find(attr_name); \ + if (attr_value == nullptr) { \ + return false; \ + } \ + Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")"); \ + if (!s.ok()) { \ + return false; \ + } \ + for (const auto& v : attr_value->list().FIELD()) { \ + __VA_ARGS__; \ + value->APPEND_OP(CAST); \ + } \ + return true; \ + } + +DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;) +DEFINE_GET_ATTR_SIMPLE(string, s, "string", emplace_back, v, ;) +DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;) +DEFINE_GET_ATTR(int32, i, "int", emplace_back, static_cast(v), + if (static_cast(static_cast(v)) != v) { + return errors::InvalidArgument("Attr ", attr_name, + " has value ", v, + " out of range for an int32"); + }) +DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;) +// std::vector specialization does not have emplace_back until +// c++14, so we have to use push_back (see +// http://en.cppreference.com/w/cpp/container/vector/emplace_back) +DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;) +DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast(v), + ;) +DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;) +DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v), + TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));) +DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back, + PartialTensorShape(v), + TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));) +DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t; + if (!t.FromProto(v)) { + return errors::InvalidArgument( + "Attr ", attr_name, " has value ", + ProtoShortDebugString(v), + " that can't be converted to a Tensor"); + }) +DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;); +#undef DEFINE_GET_ATTR + +bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) { + return node_def.attr().find(attr_name.ToString()) != node_def.attr().end(); +} + +static const string& kEmptyString = *new string(); + +const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) { + const AttrValue* attr_value = attrs.Find(attr_name); + if (attr_value == nullptr) { + return kEmptyString; + } + Status s = AttrValueHasType(*attr_value, "string"); + if (!s.ok()) { + return kEmptyString; + } + return attr_value->s(); +} + +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + DataTypeVector* value) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)")); + for (const auto& v : attr_value->list().type()) { + value->push_back(static_cast(v)); + } + return Status::OK(); +} + +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + const TensorProto** value) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor")); + *value = &attr_value->tensor(); + return Status::OK(); +} + +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + const NameAttrList** value) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func")); + *value = &attr_value->func(); + return Status::OK(); +} + +namespace { // Helper for InOutTypesForNode(). + +Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def, + DataTypeVector* sig) { + const int original_size = sig->size(); + if (!arg_def.number_attr().empty()) { + // Same type repeated "repeats" times. + int32 repeats = -1; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.number_attr(), &repeats)); + if (repeats < 0) { + return errors::InvalidArgument("Value for number_attr() ", repeats, + " < 0"); + } + + if (!arg_def.type_attr().empty()) { + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.type_attr(), &dtype)); + for (int i = 0; i < repeats; ++i) { + sig->push_back(dtype); + } + } else if (arg_def.type() != DT_INVALID) { + for (int i = 0; i < repeats; ++i) { + sig->push_back(arg_def.type()); + } + } else { + return errors::InvalidArgument("Missing type or type_attr field in ", + ProtoShortDebugString(arg_def)); + } + } else if (!arg_def.type_attr().empty()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + AttrSlice(node_def).Find(arg_def.type_attr(), &attr_value)); + sig->push_back(attr_value->type()); + } else if (!arg_def.type_list_attr().empty()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value)); + for (int dtype : attr_value->list().type()) { + sig->push_back(static_cast(dtype)); + } + } else if (arg_def.type() != DT_INVALID) { + sig->push_back(arg_def.type()); + } else { + return errors::InvalidArgument("No type fields in ", + ProtoShortDebugString(arg_def)); + } + if (arg_def.is_ref()) { + // For all types that were added by this function call, make them refs. + for (size_t i = original_size; i < sig->size(); ++i) { + (*sig)[i] = MakeRefType((*sig)[i]); + } + } + return Status::OK(); +} + +} // namespace + +Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs, DataTypeVector* outputs) { + for (const auto& arg : op_def.input_arg()) { + TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); + } + for (const auto& arg : op_def.output_arg()) { + TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs)); + } + return Status::OK(); +} + +Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { + if (node_def.op() != op_def.name()) { + return errors::InvalidArgument("NodeDef op '", node_def.op(), + "' does not match ", SummarizeOpDef(op_def), + "; NodeDef: ", SummarizeNodeDef(node_def)); + } + + bool seen_control = false; + size_t num_inputs = 0; + // TODO(josh11b): Unify the input field validation. + for (const string& input : node_def.input()) { + if (StringPiece(input).starts_with("^")) { + seen_control = true; + if (input.find(':') != string::npos) { + return errors::InvalidArgument( + "Control input '", input, + "' must not have ':' in NodeDef: ", SummarizeNodeDef(node_def)); + } + } else if (seen_control) { + return errors::InvalidArgument( + "Non-control input '", input, + "' after control input in NodeDef: ", SummarizeNodeDef(node_def)); + } else { + ++num_inputs; + } + } + + std::unordered_map op_attrs; + for (const auto& attr : op_def.attr()) { + if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) { + return errors::InvalidArgument("OpDef has duplicate attr name '", + attr.name(), + "': ", SummarizeOpDef(op_def)); + } + } + for (const auto& attr : node_def.attr()) { + // Allow internal optional attributes with names starting with "_". + if (StringPiece(attr.first).starts_with("_")) { + continue; + } + auto iter = op_attrs.find(attr.first); + if (iter == op_attrs.end()) { + // A common cause of this error is that TensorFlow has made a + // backwards-compatible change to the NodeDef (e.g., adding a + // new attr with a default value), but the binary consuming the + // NodeDef does not know about the new attribute; the solution + // in these cases is to ensure that the binary consuming the + // NodeDef is built with a version of TensorFlow no earlier than + // the binary producing it. + return errors::InvalidArgument( + "NodeDef mentions attr '", attr.first, "' not in ", + SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def), + ". (Check whether your GraphDef-interpreting binary is up to date " + "with your GraphDef-generating binary.)."); + } + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ValidateAttrValue(attr.second, *iter->second), + "; NodeDef: ", SummarizeNodeDef(node_def), "; ", + SummarizeOpDef(op_def)); + // Keep track of which attr names have (not) been found in the NodeDef. + op_attrs.erase(iter); + } + + // Were all attrs in the OpDef found in the NodeDef? + if (!op_attrs.empty()) { + string attrs; + for (const auto& attr_pair : op_attrs) { + if (!attrs.empty()) strings::StrAppend(&attrs, "', '"); + strings::StrAppend(&attrs, attr_pair.first); + } + return errors::InvalidArgument("NodeDef missing attr", + op_attrs.size() == 1 ? " '" : "s '", attrs, + "' from ", SummarizeOpDef(op_def), + "; NodeDef: ", SummarizeNodeDef(node_def)); + } + + // Validate the number of inputs. + DataTypeVector inputs, outputs; + TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs)); + + if (num_inputs != inputs.size()) { + return errors::InvalidArgument( + "NodeDef expected inputs '", DataTypeVectorString(inputs), + "' do not match ", num_inputs, " inputs specified; ", + SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def)); + } + + return Status::OK(); +} + +namespace { // Helpers for NameRangesForNode() + +Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def, + const OpDef& op_def, int* num) { + if (!arg_def.number_attr().empty()) { + // Same type repeated "num" times. + return GetNodeAttr(node_def, arg_def.number_attr(), num); + } else if (!arg_def.type_list_attr().empty()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value)); + *num = attr_value->list().type_size(); + } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) { + *num = 1; + } else { + return errors::InvalidArgument( + "Argument '", arg_def.name(), + "' incorrectly specified in op definition: ", SummarizeOpDef(op_def)); + } + return Status::OK(); +} + +Status NameRangesHelper(const NodeDef& node_def, + const protobuf::RepeatedPtrField& args, + const OpDef& op_def, NameRangeMap* result) { + int start = 0; + int num; + for (const auto& arg : args) { + TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num)); + (*result)[arg.name()] = std::make_pair(start, start + num); + start += num; + } + return Status::OK(); +} + +} // namespace + +Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs) { + if (inputs != nullptr) { + TF_RETURN_IF_ERROR( + NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs)); + } + if (outputs != nullptr) { + return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs); + } + return Status::OK(); +} + +Status NameRangesForNode(const Node& node, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs) { + return NameRangesForNode(node.def(), op_def, inputs, outputs); +} + +void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) { + for (const auto& attr_def : op_def.attr()) { + AttrSlice attrs(*node_def); + if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) { + AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def); + } + } +} + +namespace { + +using ::tensorflow::strings::Scanner; + +bool IsValidOpName(StringPiece sp) { + return Scanner(sp) + .One(Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .Eos() + .GetResult(); +} + +bool IsValidDataInputName(StringPiece sp) { + // Data inputs are op_name, op_name:0, or op_name:12345. + Scanner scan(sp); + scan.One(Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); + if (scan.Peek() == ':') { + scan.OneLiteral(":"); + if (scan.Peek() == '0') { + scan.OneLiteral("0"); // :0 + } else { + scan.Many(Scanner::DIGIT); // :[1-9][0-9]* + } + } + scan.Eos(); + + return scan.GetResult(); +} + +bool IsValidControlInputName(StringPiece sp) { + return Scanner(sp) + .OneLiteral("^") + .One(Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .Eos() + .GetResult(); +} + +} // namespace + +Status ValidateOpInput(const string& input_name, bool* is_control_input) { + *is_control_input = false; + if (IsValidDataInputName(input_name)) { + return Status::OK(); + } else if (IsValidControlInputName(input_name)) { + *is_control_input = true; + return Status::OK(); + } else { + return errors::InvalidArgument("Illegal op input name '", input_name, "'"); + } +} + +Status ValidateOpName(const string& op_name) { + if (IsValidOpName(op_name)) { + return Status::OK(); + } else { + return errors::InvalidArgument("Illegal op name '", op_name, "'"); + } +} + +Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) { + Status s = ValidateOpName(node_def.name()); + if (!s.ok()) { + return AttachDef(s, node_def); + } + bool in_control_inputs = false; + for (const string& input_name : node_def.input()) { + bool is_control_input; + s = ValidateOpInput(input_name, &is_control_input); + if (!s.ok()) { + return AttachDef(s, node_def); + } + + if (in_control_inputs && !is_control_input) { + return AttachDef(errors::InvalidArgument( + "All control inputs must follow all data inputs"), + node_def); + } + in_control_inputs = is_control_input; + } + return Status::OK(); +} + +Status AttachDef(const Status& status, const NodeDef& node_def) { + Status ret = status; + errors::AppendToMessage( + &ret, strings::StrCat(" [[Node: ", SummarizeNodeDef(node_def), "]]")); + return ret; +} + +Status AttachDef(const Status& status, const Node& node) { + return AttachDef(status, node.def()); +} + +void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) { + node_def->mutable_attr()->insert( + AttrValueMap::value_type(name.ToString(), value)); +} + +#define ADD_NODE_ATTR(T) \ + void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \ + AttrValue attr_value; \ + SetAttrValue(value, &attr_value); \ + AddNodeAttr(name, attr_value, node_def); \ + } +ADD_NODE_ATTR(StringPiece) +ADD_NODE_ATTR(const char*) +ADD_NODE_ATTR(int32) +ADD_NODE_ATTR(int64) +ADD_NODE_ATTR(float) +ADD_NODE_ATTR(double) +ADD_NODE_ATTR(bool) +ADD_NODE_ATTR(DataType) +ADD_NODE_ATTR(const PartialTensorShape&) +ADD_NODE_ATTR(const Tensor&) +ADD_NODE_ATTR(const TensorProto&) +ADD_NODE_ATTR(const NameAttrList&) +ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(const std::vector&) +ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(gtl::ArraySlice) +#undef ADD_NODE_ATTR + +void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) { + map->insert(AttrValueMap::value_type(name.ToString(), value)); +} + +#define ADD_ATTR(T) \ + void AddAttr(StringPiece name, T value, AttrValueMap* map) { \ + AttrValue attr_value; \ + SetAttrValue(value, &attr_value); \ + AddAttr(name, attr_value, map); \ + } +ADD_ATTR(bool) +#undef ADD_ATTR + +} // namespace tensorflow diff --git a/node_def_util.h b/node_def_util.h new file mode 100644 index 0000000000000000000000000000000000000000..f6f28aac4811d30b845191735536b389e41bf259 --- /dev/null +++ b/node_def_util.h @@ -0,0 +1,286 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +class Node; + +// We forward declare protos so that kernels don't need to depend on them +class NodeDef; +class OpDef; + +// Name of the attribute used to encode node colocation constraints. +// +// Nodes can be co-located on the same device. Desire for explicit co-location +// is described by list(string) attribute containing the name of colocation +// groups. +extern const char* const kColocationAttrName; + +// String prefix applied to the operation name for colocation constraints. +extern const char* const kColocationGroupPrefix; + +// Produce a human-readable version of a Node or NodeDef that is more concise +// than a text-format proto. +string SummarizeNode(const Node& node); +string SummarizeNodeDef(const NodeDef& node_def); + +typedef protobuf::Map AttrValueMap; + +// Adds an attr with name and value to *node_def. +// The type of the attr is based on the type of value. +void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, int64 value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, float value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, double value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, const PartialTensorShape& value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, const NameAttrList& value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, const std::vector& value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice value, + NodeDef* node_def); + +// Version to workaround C++'s "perfect" forwarding not being able to +// forward {...} initialization. +template +void AddNodeAttr(StringPiece name, std::initializer_list value, + NodeDef* node_def) { + AddNodeAttr(name, gtl::ArraySlice(value), node_def); +} + +// Adds an attr to an attr value map. +void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map); +void AddAttr(StringPiece name, bool value, AttrValueMap* map); + +class AttrSlice { + public: + AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit) + + AttrSlice(); // Empty + explicit AttrSlice(const AttrValueMap* a); + + int size() const { return attrs_->size(); } + + // Returns the attr with attr_name if found. Otherwise, returns + // nullptr. + const AttrValue* Find(StringPiece attr_name) const; + + // Returns the attr_value for attr_name if found. Otherwise, returns a + // NotFound status. + Status Find(StringPiece attr_name, const AttrValue** attr_value) const; + + // Helper class to avoid allocations in EqualAttrs. + // TODO(irving): Will go away once NodeInfo is used. + struct Scratch { + string a; + string b; + }; + + // Check if all attrs and attr values match. Does not take defaults into + // account. + // + // TODO(irving): There is a bug in this routine inherited from its + // OptimizerCSE::EqualAttrs precedecessor. The same tensor attr can be + // represented in more than one way as an AttrValue, since TensorProto is + // not 1-1. This bug will go away once I replace everything with NodeInfo, + // which stores a Tensor object directly. The Scratch object will also go + // away. + bool EqualAttrs(AttrSlice other, Scratch* scratch) const; + + // If this AttrSlice has an attached NodeDef, summarize it. This is for + // error messages only: we intentionally do not provide direct access to the + // NodeDef, since it is not always there. + string SummarizeNode() const; + + // Iteration over all attrs + AttrValueMap::const_iterator begin() const { return attrs_->begin(); } + AttrValueMap::const_iterator end() const { return attrs_->end(); } + + private: + const NodeDef* ndef_; + const AttrValueMap* attrs_; +}; + +// Return true if the attr with the name attr_name is defined in node_def. +bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name); + +// Look up the attr with name attr_name and set *value to its value. If no +// attr with attr_name is found in node_def, or the attr does not have +// a matching type, a non-ok status will be returned. +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + string* value); // type: "string" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + int64* value); // type: "int" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + int32* value); // type: "int" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + float* value); // type: "float" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + bool* value); // type: "bool" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + DataType* value); // type: "type" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + TensorShapeProto* value); // type: "shape" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + TensorShape* value); // type: "shape" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + PartialTensorShape* value); // type: "shape" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + Tensor* value); // type: "tensor" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(string)" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(int)" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(int)" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(float)" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(bool)" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(type)" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + DataTypeVector* value); // type "list(type)" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(shape)" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(shape)" +Status GetNodeAttr( + const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(shape)" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type: "list(tensor)" + +// This version avoids copying the TensorProto. +// REQUIRES: Must not use *value beyond the lifetime of node_def. +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + const TensorProto** value); // type: "tensor" + +// This version avoids copying the NameAttrList. +// REQUIRES: Must not use *value beyond the lifetime of node_def. +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + const NameAttrList** value); // type: "func" + +// These versions copies the NameAttrList(s). +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + NameAttrList* value); // type: "func" +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type: "list(func)" + +// Look up the attr with name attr_name and set *value to its value. If no +// attr with attr_name is found in node_def, or the attr does not have +// a matching type, false is returned. +bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name, + string* value); // type: "string" +bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type: "string" + +// Look up the attr with name attr_name and return a reference to its value. +// If no attr with attr_name is found in node_def, or the attr does not have +// a matching type, a reference to an empty string is returned. +// REQUIRES: Must not use the returned value beyond the lifetime of node_def. +const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name); + +// Computes the input and output types for a specific node. +// REQUIRES: ValidateOpDef(op_def).ok() +Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs, DataTypeVector* outputs); + +// Validates that the NodeDef: +// * Defines all expected attrs from the OpDef. +// * All attrs satisfies constraints from the OpDef. +// * Has a signature matching SignatureForNode(). +// etc. +Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def); + +// Computes the mapping from input/output argument name to the +// corresponding input/output index range. For example, +// input "foo" corresponds to input indices +// [ (*inputs)["foo"].first, (*inputs)["foo"].second ). +// TODO(irving): Remove the NodeDef version; keep only the Node version. +typedef std::unordered_map> NameRangeMap; +Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs); +Status NameRangesForNode(const Node& node, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs); + +// Adds default values to *node_def for unspecified attrs from op_def. +void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def); + +// Validates the syntax of a NodeDef provided externally. +// +// The following is an EBNF-style syntax for NodeDef objects. Note that +// Node objects are actually specified as tensorflow::NodeDef protocol buffers, +// which contain many other fields that are not (currently) validated. +// +// Node = NodeName, Inputs +// Inputs = ( DataInput * ), ( ControlInput * ) +// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ? +// ControlInput = "^", NodeName +// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] * +Status ValidateExternalNodeDefSyntax(const NodeDef& node_def); + +// Returns "status" with kernel's NodeDef attached as additional text +// in the error message. +Status AttachDef(const Status& status, const NodeDef& node_def); +Status AttachDef(const Status& status, const Node& node); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_ diff --git a/node_def_util_test.cc b/node_def_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bfd598a97202e4bcbf1f869b2687f7cbca36b36b --- /dev/null +++ b/node_def_util_test.cc @@ -0,0 +1,497 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def_util.h" + +#include "tensorflow/core/framework/attr_value.pb.h" // NOLINT +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +OpDef ToOpDef(const OpDefBuilder& builder) { + OpRegistrationData op_reg_data; + TF_EXPECT_OK(builder.Finalize(&op_reg_data)); + return op_reg_data.op_def; +} + +NodeDef ToNodeDef(const string& text) { + NodeDef node_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); + return node_def; +} + +NodeDef ToNodeDef(const NodeDefBuilder& builder) { + NodeDef node_def; + TF_EXPECT_OK(builder.Finalize(&node_def)); + return node_def; +} + +void ExpectSuccess(const NodeDef& good, const OpDef& op_def) { + EXPECT_EQ(Status::OK(), ValidateNodeDef(good, op_def)) + << "NodeDef: " << SummarizeNodeDef(good) + << "; OpDef: " << SummarizeOpDef(op_def); +} + +void ExpectFailure(const NodeDef& bad, const OpDef& op_def, + const string& message) { + Status status = ValidateNodeDef(bad, op_def); + + EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad) + << "; OpDef: " << SummarizeOpDef(op_def); + if (status.ok()) return; + + EXPECT_TRUE(errors::IsInvalidArgument(status)) + << status << "; NodeDef: " << SummarizeNodeDef(bad) + << "; OpDef: " << SummarizeOpDef(op_def); + + LOG(INFO) << "Message: " << status.error_message(); + EXPECT_TRUE(StringPiece(status.ToString()).contains(message)) + << "NodeDef: " << SummarizeNodeDef(bad) + << "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status + << "\nDoes not contain: " << message; +} + +TEST(NodeDefUtilTest, In) { + const OpDef op = ToOpDef(OpDefBuilder("In").Input("i: T").Attr("T: type")); + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'In' input:'a' attr { key:'T' value { type:DT_FLOAT } } + )proto"); + ExpectSuccess(node_def, op); + + EXPECT_EQ("n = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def)); + + // Mismatching Op names. + NodeDef bad = node_def; + bad.set_op("Wrong"); + ExpectFailure(bad, op, "NodeDef op 'Wrong' does not match Op= 2") + .Attr("T: {float,double}")); + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'SameIn' input:'a' input:'b' + attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_DOUBLE } } + )proto"); + ExpectSuccess(node_def, op); + + EXPECT_EQ("n = SameIn[N=2, T=DT_DOUBLE](a, b)", SummarizeNodeDef(node_def)); + + // Illegal type + NodeDef bad = ToNodeDef(R"proto( + name:'n' op:'SameIn' input:'a' input:'b' + attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_STRING } } + )proto"); + ExpectFailure(bad, op, + "Value for attr 'T' of string is not in the list of allowed " + "values: float, double"); + + // Too few inputs + bad = ToNodeDef(R"proto( + name:'n' op:'SameIn' input:'a' input:'b' + attr { key:'N' value { i:1 } } attr { key:'T' value { type:DT_FLOAT } } + )proto"); + ExpectFailure(bad, op, "Value for attr 'N' of 1 must be at least minimum 2"); +} + +TEST(NodeDefUtilTest, AnyIn) { + const OpDef op = + ToOpDef(OpDefBuilder("AnyIn").Input("i: T").Attr("T: list(type) >= 1")); + + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectSuccess(node_def, op); + + EXPECT_EQ("n = AnyIn[T=[DT_INT32, DT_STRING]](a, b)", + SummarizeNodeDef(node_def)); + + const NodeDef bad = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' attr { key:'T' value { list { } } } + )proto"); + ExpectFailure(bad, op, "Length for attr 'T' of 0 must be at least minimum 1"); + + // With proto3 semantics, an empty value {} is indistinguishable from a value + // with an empty list in it. So we simply expect to get a message complaining + // about empty list for value {}. + const NodeDef bad2 = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' attr { key:'T' value { } } + )proto"); + ExpectFailure(bad2, op, + "Length for attr 'T' of 0 must be at least minimum 1"); +} + +TEST(NodeDefUtilTest, Device) { + const OpDef op_def1 = ToOpDef(OpDefBuilder("None")); + const NodeDef node_def1 = + ToNodeDef(NodeDefBuilder("d", &op_def1).Device("/cpu:17")); + ExpectSuccess(node_def1, op_def1); + EXPECT_EQ("d = None[_device=\"/cpu:17\"]()", SummarizeNodeDef(node_def1)); + + const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int")); + const NodeDef node_def2 = + ToNodeDef(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5")); + ExpectSuccess(node_def2, op_def2); + EXPECT_EQ("d = WithAttr[v=7, _device=\"/cpu:5\"]()", + SummarizeNodeDef(node_def2)); +} + +void ExpectValidSyntax(const NodeDef& good) { + EXPECT_EQ(Status::OK(), ValidateExternalNodeDefSyntax(good)) + << "NodeDef: " << SummarizeNodeDef(good); +} + +void ExpectInvalidSyntax(const NodeDef& bad, const string& message) { + Status status = ValidateExternalNodeDefSyntax(bad); + + ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad); + + EXPECT_TRUE(errors::IsInvalidArgument(status)) + << status << "; NodeDef: " << SummarizeNodeDef(bad); + + EXPECT_TRUE(StringPiece(status.ToString()).contains(message)) + << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", " + << message; +} + +TEST(NodeDefUtilTest, ValidSyntax) { + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectValidSyntax(node_def); + + const NodeDef node_def_explicit_inputs = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a:0' input:'b:123' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectValidSyntax(node_def_explicit_inputs); + + EXPECT_EQ("n = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)", + SummarizeNodeDef(node_def_explicit_inputs)); + + const NodeDef node_def_partial_shape = ToNodeDef(R"proto( + name:'n' op:'AnyIn' + attr { key:'shp' value { shape { dim { size: -1 } dim { size: 0 } } } } + )proto"); + ExpectValidSyntax(node_def_partial_shape); + + const NodeDef node_def_control_input = ToNodeDef(R"proto( + name:'n-' op:'AnyIn' input:'a' input:'^b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectValidSyntax(node_def_control_input); + + const NodeDef node_def_invalid_name = ToNodeDef(R"proto( + name:'n:0' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_invalid_name, "Illegal op name 'n:0'"); + + const NodeDef node_def_internal_name = ToNodeDef(R"proto( + name:'_n' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'"); + + const NodeDef node_def_slash_in_name = ToNodeDef(R"proto( + name:'n\\' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_slash_in_name, "Illegal op name 'n\\'"); + + const NodeDef node_def_internal_input_name = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'_a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_internal_input_name, + "Illegal op input name '_a'"); + + const NodeDef node_def_input_name_slash = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a\\' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_input_name_slash, "Illegal op input name 'a\\'"); + + const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' input:'^b:0' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_invalid_control_input_name, + "Illegal op input name '^b:0'"); + + const NodeDef node_def_control_input_name_slash = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' input:'^b\\' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_control_input_name_slash, + "Illegal op input name '^b\\'"); + + const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'^a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_data_input_after_control, + "All control inputs must follow all data inputs"); + + const NodeDef node_def_data_input_invalid_port = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a:b' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_data_input_invalid_port, + "Illegal op input name 'a:b"); + + const NodeDef node_def_data_input_invalid_port2 = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a:00' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_data_input_invalid_port2, + "Illegal op input name 'a:00"); +} + +TEST(NameRangesForNodeTest, Simple) { + const OpDef op_def = ToOpDef(OpDefBuilder("Simple") + .Input("a: float") + .Input("b: int32") + .Output("c: string") + .Output("d: bool")); + NameRangeMap inputs, outputs; + const NodeDef node_def = ToNodeDef( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + TF_EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs); + + EXPECT_EQ("simple = Simple[](a, b)", SummarizeNodeDef(node_def)); + + OpDef bad_op_def = op_def; + bad_op_def.mutable_input_arg(0)->clear_type(); + EXPECT_FALSE(NameRangesForNode(node_def, bad_op_def, &inputs, &outputs).ok()); +} + +TEST(NameRangesForNodeTest, Polymorphic) { + const OpDef op_def = ToOpDef(OpDefBuilder("Polymorphic") + .Input("a: T") + .Input("b: T") + .Output("c: T") + .Attr("T: type")); + NameRangeMap inputs, outputs; + const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("poly", &op_def) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32))); + TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs); + EXPECT_EQ("poly = Polymorphic[T=DT_INT32](a, b)", + SummarizeNodeDef(node_def1)); + + const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("poly", &op_def) + .Input(FakeInput(DT_BOOL)) + .Input(FakeInput(DT_BOOL))); + TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs); + EXPECT_EQ("poly = Polymorphic[T=DT_BOOL](a, b)", SummarizeNodeDef(node_def2)); +} + +TEST(NameRangesForNodeTest, NRepeats) { + const OpDef op_def = ToOpDef(OpDefBuilder("NRepeats") + .Input("a: N * int32") + .Input("b: N * T") + .Output("c: T") + .Output("d: N * string") + .Output("e: M * bool") + .Attr("N: int") + .Attr("M: int") + .Attr("T: type")); + NameRangeMap inputs, outputs; + const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("nr", &op_def) + .Input(FakeInput(4, DT_INT32)) + .Input(FakeInput(4, DT_FLOAT)) + .Attr("M", 3)); + TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}), + outputs); + EXPECT_EQ( + "nr = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, b:2, b:3)", + SummarizeNodeDef(node_def1)); + + const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("nr", &op_def) + .Input(FakeInput(2, DT_INT32)) + .Input(FakeInput(2, DT_DOUBLE)) + .Attr("M", 7)); + TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}), + outputs); + EXPECT_EQ("nr = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)", + SummarizeNodeDef(node_def2)); + + NodeDef bad_node_def = node_def2; + bad_node_def.clear_attr(); + EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok()); +} + +TEST(NameRangesForNodeTest, TypeList) { + const OpDef op_def = ToOpDef(OpDefBuilder("TypeList") + .Input("a: T1") + .Input("b: T2") + .Output("c: T2") + .Output("d: T3") + .Output("e: T1") + .Attr("T1: list(type)") + .Attr("T2: list(type)") + .Attr("T3: list(type)")); + NameRangeMap inputs, outputs; + const NodeDef node_def1 = + ToNodeDef(NodeDefBuilder("tl", &op_def) + .Input(FakeInput({DT_BOOL, DT_FLOAT})) + .Input(FakeInput(4, DT_FLOAT)) + .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING})); + TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}), + outputs); + EXPECT_EQ( + "tl = TypeList[T1=[DT_BOOL, DT_FLOAT]," + " T2=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT]," + " T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)", + SummarizeNodeDef(node_def1)); + + const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("tl", &op_def) + .Input(FakeInput(7, DT_INT32)) + .Input(FakeInput({DT_DOUBLE})) + .Attr("T3", {DT_DOUBLE, DT_STRING})); + TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}), + outputs); + EXPECT_EQ( + "tl = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32," + " DT_INT32, DT_INT32], T2=[DT_DOUBLE], T3=[DT_DOUBLE, DT_STRING]]" + "(a, a:1, a:2, a:3, a:4, a:5, a:6, b)", + SummarizeNodeDef(node_def2)); + + NodeDef bad_node_def = node_def2; + bad_node_def.clear_attr(); + EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok()); +} + +} // namespace +} // namespace tensorflow diff --git a/numeric_op.h b/numeric_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4538ff053cd10b05a8874ff6db6b3c5e60d7622e --- /dev/null +++ b/numeric_op.h @@ -0,0 +1,113 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_ +#define TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// One input and one output, both the same type. +template +class UnaryOp : public OpKernel { + public: + explicit UnaryOp(OpKernelConstruction* context) : OpKernel(context) { + const DataType dt = DataTypeToEnum::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt})); + } +}; + +// Two inputs and one output, all the same type. +template +class BinaryOp : public OpKernel { + public: + explicit BinaryOp(OpKernelConstruction* context) : OpKernel(context) { + const DataType dt = DataTypeToEnum::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt})); + } +}; + +// For operations where the input and output are the same shape. +// +// For usage, see ../framework/elementwise_ops.cc. +template +class UnaryElementWiseOp : public UnaryOp { + public: + using UnaryOp::UnaryOp; + + void Compute(OpKernelContext* context) override { + // Output shape is the same as input shape. + const Tensor& input = context->input(0); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {0}, 0, input.shape(), &output)); + static_cast(this)->Operate(context, input, output); + } +}; + +// For binary elementwise operations. +template +class BinaryElementWiseOp : public BinaryOp { + public: + using BinaryOp::BinaryOp; + + void Compute(OpKernelContext* context) override { + const Tensor& a = context->input(0); + const Tensor& b = context->input(1); + + if (!context->ValidateInputsAreSameShape(this)) { + return; + } + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {0, 1}, 0, a.shape(), &output)); + + // Dispatch to the descendant's Operate() function. + switch (a.dims()) { +#define NDIM_CASE(NDIMS) \ + case NDIMS: { \ + static_cast(this)->template Operate(context, a, b, output); \ + break; \ + } + + NDIM_CASE(0); + NDIM_CASE(1); + NDIM_CASE(2); + NDIM_CASE(3); + NDIM_CASE(4); + NDIM_CASE(5); + NDIM_CASE(6); + NDIM_CASE(7); + NDIM_CASE(8); +#undef NDIM_CASE + + default: + context->SetStatus(errors::InvalidArgument( + "We only handle up to Tensor::dims() up to 8, not ", a.dims())); + break; + } + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_ diff --git a/numeric_types.h b/numeric_types.h new file mode 100644 index 0000000000000000000000000000000000000000..988a18da0e824c633b4a38d3ae10a4c93bb13110 --- /dev/null +++ b/numeric_types.h @@ -0,0 +1,90 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_ + +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +// Disable clang-format to prevent 'FixedPoint' header from being included +// before 'Tensor' header on which it depends. +// clang-format off +#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" +// clang-format on + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Single precision complex. +typedef std::complex complex64; +// Double precision complex. +typedef std::complex complex128; + +// We use Eigen's QInt implementations for our quantized int types. +typedef Eigen::QInt8 qint8; +typedef Eigen::QUInt8 quint8; +typedef Eigen::QInt32 qint32; +typedef Eigen::QInt16 qint16; +typedef Eigen::QUInt16 quint16; + +} // namespace tensorflow + +namespace Eigen { +// TOOD(xpan): We probably need to overwrite more methods to have correct eigen +// behavior. E.g. loest(), is_integer, etc. See NumTraits.h in eigen. +template <> +struct NumTraits + : GenericNumTraits {}; + +using ::tensorflow::operator==; +using ::tensorflow::operator!=; + +namespace numext { + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 log( + const tensorflow::bfloat16& x) { + return static_cast(::logf(static_cast(x))); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 exp( + const tensorflow::bfloat16& x) { + return static_cast(::expf(static_cast(x))); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE tensorflow::bfloat16 abs( + const tensorflow::bfloat16& x) { + return static_cast(::fabsf(static_cast(x))); +} + +} // namespace numext +} // namespace Eigen + +#ifdef COMPILER_MSVC +namespace std { +template <> +struct hash { + std::size_t operator()(const Eigen::half& a) const { + return static_cast(a.x); + } +}; +} // namespace std +#endif // COMPILER_MSVC + +#endif // TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_ diff --git a/op.cc b/op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fadb60d744217daa0c569601c437146a70f9b4d5 --- /dev/null +++ b/op.cc @@ -0,0 +1,260 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" + +#include +#include +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/host_info.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// OpRegistry ----------------------------------------------------------------- + +OpRegistryInterface::~OpRegistryInterface() {} + +Status OpRegistryInterface::LookUpOpDef(const string& op_type_name, + const OpDef** op_def) const { + *op_def = nullptr; + const OpRegistrationData* op_reg_data = nullptr; + TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data)); + *op_def = &op_reg_data->op_def; + return Status::OK(); +} + +OpRegistry::OpRegistry() : initialized_(false) {} + +OpRegistry::~OpRegistry() { + for (const auto& e : registry_) delete e.second; +} + +void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) { + mutex_lock lock(mu_); + if (initialized_) { + TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory)); + } else { + deferred_.push_back(op_data_factory); + } +} + +Status OpRegistry::LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const { + *op_reg_data = nullptr; + const OpRegistrationData* res = nullptr; + + bool first_call = false; + bool first_unregistered = false; + { // Scope for lock. + mutex_lock lock(mu_); + first_call = MustCallDeferred(); + res = gtl::FindWithDefault(registry_, op_type_name, nullptr); + + static bool unregistered_before = false; + first_unregistered = !unregistered_before && (res == nullptr); + if (first_unregistered) { + unregistered_before = true; + } + // Note: Can't hold mu_ while calling Export() below. + } + if (first_call) { + TF_QCHECK_OK(ValidateKernelRegistrations(*this)); + } + if (res == nullptr) { + if (first_unregistered) { + OpList op_list; + Export(true, &op_list); + if (VLOG_IS_ON(3)) { + LOG(INFO) << "All registered Ops:"; + for (const auto& op : op_list.op()) { + LOG(INFO) << SummarizeOpDef(op); + } + } + } + Status status = + errors::NotFound("Op type not registered '", op_type_name, + "' in binary running on ", port::Hostname(), ". ", + "Make sure the Op and Kernel are registered in the " + "binary running in this process."); + VLOG(1) << status.ToString(); + return status; + } + *op_reg_data = res; + return Status::OK(); +} + +void OpRegistry::GetRegisteredOps(std::vector* op_defs) { + mutex_lock lock(mu_); + MustCallDeferred(); + for (const auto& p : registry_) { + op_defs->push_back(p.second->op_def); + } +} + +Status OpRegistry::SetWatcher(const Watcher& watcher) { + mutex_lock lock(mu_); + if (watcher_ && watcher) { + return errors::AlreadyExists( + "Cannot over-write a valid watcher with another."); + } + watcher_ = watcher; + return Status::OK(); +} + +void OpRegistry::Export(bool include_internal, OpList* ops) const { + mutex_lock lock(mu_); + MustCallDeferred(); + + std::vector> sorted( + registry_.begin(), registry_.end()); + std::sort(sorted.begin(), sorted.end()); + + auto out = ops->mutable_op(); + out->Clear(); + out->Reserve(sorted.size()); + + for (const auto& item : sorted) { + if (include_internal || !StringPiece(item.first).starts_with("_")) { + *out->Add() = item.second->op_def; + } + } +} + +void OpRegistry::DeferRegistrations() { + mutex_lock lock(mu_); + initialized_ = false; +} + +void OpRegistry::ClearDeferredRegistrations() { + mutex_lock lock(mu_); + deferred_.clear(); +} + +Status OpRegistry::ProcessRegistrations() const { + mutex_lock lock(mu_); + return CallDeferred(); +} + +string OpRegistry::DebugString(bool include_internal) const { + OpList op_list; + Export(include_internal, &op_list); + string ret; + for (const auto& op : op_list.op()) { + strings::StrAppend(&ret, SummarizeOpDef(op), "\n"); + } + return ret; +} + +bool OpRegistry::MustCallDeferred() const { + if (initialized_) return false; + initialized_ = true; + for (size_t i = 0; i < deferred_.size(); ++i) { + TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i])); + } + deferred_.clear(); + return true; +} + +Status OpRegistry::CallDeferred() const { + if (initialized_) return Status::OK(); + initialized_ = true; + for (size_t i = 0; i < deferred_.size(); ++i) { + Status s = RegisterAlreadyLocked(deferred_[i]); + if (!s.ok()) { + return s; + } + } + deferred_.clear(); + return Status::OK(); +} + +Status OpRegistry::RegisterAlreadyLocked( + const OpRegistrationDataFactory& op_data_factory) const { + std::unique_ptr op_reg_data(new OpRegistrationData); + Status s = op_data_factory(op_reg_data.get()); + if (s.ok()) { + s = ValidateOpDef(op_reg_data->op_def); + if (s.ok() && + !gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def.name(), + op_reg_data.get())) { + s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name()); + } + } + Status watcher_status = s; + if (watcher_) { + watcher_status = watcher_(s, op_reg_data->op_def); + } + if (s.ok()) { + op_reg_data.release(); + } else { + op_reg_data.reset(); + } + return watcher_status; +} + +// static +OpRegistry* OpRegistry::Global() { + static OpRegistry* global_op_registry = new OpRegistry; + return global_op_registry; +} + +// OpListOpRegistry ----------------------------------------------------------- + +OpListOpRegistry::OpListOpRegistry(const OpList* op_list) { + for (const OpDef& op_def : op_list->op()) { + auto* op_reg_data = new OpRegistrationData(); + op_reg_data->op_def = op_def; + index_[op_def.name()] = op_reg_data; + } +} + +OpListOpRegistry::~OpListOpRegistry() { + for (const auto& e : index_) delete e.second; +} + +Status OpListOpRegistry::LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const { + auto iter = index_.find(op_type_name); + if (iter == index_.end()) { + *op_reg_data = nullptr; + return errors::NotFound("Op type not registered '", op_type_name, + "' in binary running on ", port::Hostname(), ". ", + "Make sure the Op and Kernel are registered in the " + "binary running in this process."); + } + *op_reg_data = iter->second; + return Status::OK(); +} + +// Other registration --------------------------------------------------------- + +namespace register_op { +OpDefBuilderReceiver::OpDefBuilderReceiver( + const OpDefBuilderWrapper& wrapper) { + OpRegistry::Global()->Register( + [wrapper](OpRegistrationData* op_reg_data) -> Status { + return wrapper.builder().Finalize(op_reg_data); + }); +} +} // namespace register_op + +} // namespace tensorflow diff --git a/op.h b/op.h new file mode 100644 index 0000000000000000000000000000000000000000..f7f1ed2a886548c39fa38239d65aa2a73564c3c4 --- /dev/null +++ b/op.h @@ -0,0 +1,309 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_OP_H_ +#define TENSORFLOW_FRAMEWORK_OP_H_ + +#include +#include + +#include +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/selective_registration.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Users that want to look up an OpDef by type name should take an +// OpRegistryInterface. Functions accepting a +// (const) OpRegistryInterface* may call LookUp() from multiple threads. +class OpRegistryInterface { + public: + virtual ~OpRegistryInterface(); + + // Returns an error status and sets *op_reg_data to nullptr if no OpDef is + // registered under that name, otherwise returns the registered OpDef. + // Caller must not delete the returned pointer. + virtual Status LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const = 0; + + // Shorthand for calling LookUp to get the OpDef. + Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const; +}; + +// The standard implementation of OpRegistryInterface, along with a +// global singleton used for registering ops via the REGISTER +// macros below. Thread-safe. +// +// Example registration: +// OpRegistry::Global()->Register( +// [](OpRegistrationData* op_reg_data)->Status { +// // Populate *op_reg_data here. +// return Status::OK(); +// }); +class OpRegistry : public OpRegistryInterface { + public: + typedef std::function OpRegistrationDataFactory; + + OpRegistry(); + ~OpRegistry() override; + + void Register(const OpRegistrationDataFactory& op_data_factory); + + Status LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const override; + + // Fills *ops with all registered OpDefs (except those with names + // starting with '_' if include_internal == false) sorted in + // ascending alphabetical order. + void Export(bool include_internal, OpList* ops) const; + + // Returns ASCII-format OpList for all registered OpDefs (except + // those with names starting with '_' if include_internal == false). + string DebugString(bool include_internal) const; + + // A singleton available at startup. + static OpRegistry* Global(); + + // Get all registered ops. + void GetRegisteredOps(std::vector* op_defs); + + // Watcher, a function object. + // The watcher, if set by SetWatcher(), is called every time an op is + // registered via the Register function. The watcher is passed the Status + // obtained from building and adding the OpDef to the registry, and the OpDef + // itself if it was successfully built. A watcher returns a Status which is in + // turn returned as the final registration status. + typedef std::function Watcher; + + // An OpRegistry object has only one watcher. This interface is not thread + // safe, as different clients are free to set the watcher any time. + // Clients are expected to atomically perform the following sequence of + // operations : + // SetWatcher(a_watcher); + // Register some ops; + // op_registry->ProcessRegistrations(); + // SetWatcher(nullptr); + // Returns a non-OK status if a non-null watcher is over-written by another + // non-null watcher. + Status SetWatcher(const Watcher& watcher); + + // Process the current list of deferred registrations. Note that calls to + // Export, LookUp and DebugString would also implicitly process the deferred + // registrations. Returns the status of the first failed op registration or + // Status::OK() otherwise. + Status ProcessRegistrations() const; + + // Defer the registrations until a later call to a function that processes + // deferred registrations are made. Normally, registrations that happen after + // calls to Export, LookUp, ProcessRegistrations and DebugString are processed + // immediately. Call this to defer future registrations. + void DeferRegistrations(); + + // Clear the registrations that have been deferred. + void ClearDeferredRegistrations(); + + private: + // Ensures that all the functions in deferred_ get called, their OpDef's + // registered, and returns with deferred_ empty. Returns true the first + // time it is called. Prints a fatal log if any op registration fails. + bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Calls the functions in deferred_ and registers their OpDef's + // It returns the Status of the first failed op registration or Status::OK() + // otherwise. + Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Add 'def' to the registry with additional data 'data'. On failure, or if + // there is already an OpDef with that name registered, returns a non-okay + // status. + Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) + const EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutable mutex mu_; + // Functions in deferred_ may only be called with mu_ held. + mutable std::vector deferred_ GUARDED_BY(mu_); + // Values are owned. + mutable std::unordered_map registry_ + GUARDED_BY(mu_); + mutable bool initialized_ GUARDED_BY(mu_); + + // Registry watcher. + mutable Watcher watcher_ GUARDED_BY(mu_); +}; + +// An adapter to allow an OpList to be used as an OpRegistryInterface. +// +// Note that shape inference functions are not passed in to OpListOpRegistry, so +// it will return an unusable shape inference function for every op it supports; +// therefore, it should only be used in contexts where this is okay. +class OpListOpRegistry : public OpRegistryInterface { + public: + // Does not take ownership of op_list, *op_list must outlive *this. + OpListOpRegistry(const OpList* op_list); + ~OpListOpRegistry() override; + Status LookUp(const string& op_type_name, + const OpRegistrationData** op_reg_data) const override; + + private: + // Values are owned. + std::unordered_map index_; +}; + +// Support for defining the OpDef (specifying the semantics of the Op and how +// it should be created) and registering it in the OpRegistry::Global() +// registry. Usage: +// +// REGISTER_OP("my_op_name") +// .Attr(":") +// .Attr(":=") +// .Input(":") +// .Input(":Ref()") +// .Output(":") +// .Doc(R"( +// <1-line summary> +// +// : +// : +// )"); +// +// Note: .Doc() should be last. +// For details, see the OpDefBuilder class in op_def_builder.h. + +namespace register_op { + +// OpDefBuilderWrapper is a templated class that is used in the REGISTER_OP +// calls. This allows the result of REGISTER_OP to be used in chaining, as in +// REGISTER_OP(a).Attr("...").Input("...");, while still allowing selective +// registration to turn the entire call-chain into a no-op. +template +class OpDefBuilderWrapper; + +// Template specialization that forwards all calls to the contained builder. +template <> +class OpDefBuilderWrapper { + public: + OpDefBuilderWrapper(const char name[]) : builder_(name) {} + OpDefBuilderWrapper& Attr(StringPiece spec) { + builder_.Attr(spec); + return *this; + } + OpDefBuilderWrapper& Input(StringPiece spec) { + builder_.Input(spec); + return *this; + } + OpDefBuilderWrapper& Output(StringPiece spec) { + builder_.Output(spec); + return *this; + } + OpDefBuilderWrapper& SetIsCommutative() { + builder_.SetIsCommutative(); + return *this; + } + OpDefBuilderWrapper& SetIsAggregate() { + builder_.SetIsAggregate(); + return *this; + } + OpDefBuilderWrapper& SetIsStateful() { + builder_.SetIsStateful(); + return *this; + } + OpDefBuilderWrapper& SetAllowsUninitializedInput() { + builder_.SetAllowsUninitializedInput(); + return *this; + } + OpDefBuilderWrapper& Deprecated(int version, StringPiece explanation) { + builder_.Deprecated(version, explanation); + return *this; + } + OpDefBuilderWrapper& Doc(StringPiece text) { + builder_.Doc(text); + return *this; + } + OpDefBuilderWrapper& SetShapeFn( + Status (*fn)(shape_inference::InferenceContext*)) { + builder_.SetShapeFn(fn); + return *this; + } + const ::tensorflow::OpDefBuilder& builder() const { return builder_; } + + private: + mutable ::tensorflow::OpDefBuilder builder_; +}; + +// Template specialization that turns all calls into no-ops. +template <> +class OpDefBuilderWrapper { + public: + constexpr OpDefBuilderWrapper(const char name[]) {} + OpDefBuilderWrapper& Attr(StringPiece spec) { return *this; } + OpDefBuilderWrapper& Input(StringPiece spec) { return *this; } + OpDefBuilderWrapper& Output(StringPiece spec) { return *this; } + OpDefBuilderWrapper& SetIsCommutative() { return *this; } + OpDefBuilderWrapper& SetIsAggregate() { return *this; } + OpDefBuilderWrapper& SetIsStateful() { return *this; } + OpDefBuilderWrapper& SetAllowsUninitializedInput() { return *this; } + OpDefBuilderWrapper& Deprecated(int, StringPiece) { return *this; } + OpDefBuilderWrapper& Doc(StringPiece text) { return *this; } + OpDefBuilderWrapper& SetShapeFn( + Status (*fn)(shape_inference::InferenceContext*)) { + return *this; + } +}; + +struct OpDefBuilderReceiver { + // To call OpRegistry::Global()->Register(...), used by the + // REGISTER_OP macro below. + // Note: These are implicitly converting constructors. + OpDefBuilderReceiver( + const OpDefBuilderWrapper& wrapper); // NOLINT(runtime/explicit) + constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper&) { + } // NOLINT(runtime/explicit) +}; +} // namespace register_op + +#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name) +#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name) +#define REGISTER_OP_UNIQ(ctr, name) \ + static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ + TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::register_op::OpDefBuilderWrapper(name) + +// The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except +// that the op is registered unconditionally even when selective +// registration is used. +#define REGISTER_SYSTEM_OP(name) \ + REGISTER_SYSTEM_OP_UNIQ_HELPER(__COUNTER__, name) +#define REGISTER_SYSTEM_OP_UNIQ_HELPER(ctr, name) \ + REGISTER_SYSTEM_OP_UNIQ(ctr, name) +#define REGISTER_SYSTEM_OP_UNIQ(ctr, name) \ + static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ + TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::register_op::OpDefBuilderWrapper(name) + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_H_ diff --git a/op_compatibility_test.cc b/op_compatibility_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae2fdae379a21289df2e0eb2dd5cbda0a6d5ed81 --- /dev/null +++ b/op_compatibility_test.cc @@ -0,0 +1,1062 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Test that verifies that various changes to an OpDef are +// backwards-compatible. + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class TestKernel : public OpKernel { + public: + explicit TestKernel(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override { + Tensor* out_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output("ndef", TensorShape({}), + &out_tensor)); + out_tensor->scalar()() = SummarizeNodeDef(def()); + } +}; + +class OpCompatibilityTest : public OpsTestBase { + protected: + const OpDef* RegisteredOpDef() { + const OpDef* op_def; + TF_CHECK_OK(OpRegistry::Global()->LookUpOpDef(node_def()->op(), &op_def)); + return op_def; + } + + void ExpectSuccess(const OpDef& old_op_def) { + // Record the original signature before we change *node_def(). + DataTypeVector old_in_types, old_out_types; + TF_ASSERT_OK(InOutTypesForNode(*node_def(), old_op_def, &old_in_types, + &old_out_types)); + + // This should be all that is needed to get compatibility. + const OpDef* new_op_def = RegisteredOpDef(); + AddDefaultsToNodeDef(*new_op_def, node_def()); + + // Validate that it is indeed compatible. + TF_ASSERT_OK(ValidateNodeDef(*node_def(), *new_op_def)); + DataTypeVector new_in_types, new_out_types; + TF_ASSERT_OK(InOutTypesForNode(*node_def(), *new_op_def, &new_in_types, + &new_out_types)); + if (new_in_types.size() == old_in_types.size()) { + // Ref inputs are allowed to become non-ref inputs. + for (int i = 0; i < new_in_types.size(); ++i) { + if (IsRefType(old_in_types[i]) && !IsRefType(new_in_types[i])) { + old_in_types[i] = RemoveRefType(old_in_types[i]); + } + } + } + ASSERT_EQ(new_in_types, old_in_types); + if (new_out_types.size() == old_out_types.size()) { + // Non-ref outputs are allowed to become ref outputs. + for (int i = 0; i < new_out_types.size(); ++i) { + if (!IsRefType(old_out_types[i]) && IsRefType(new_out_types[i])) { + old_out_types[i] = MakeRefType(old_out_types[i]); + } + } + } + ASSERT_EQ(new_out_types, old_out_types); + TF_ASSERT_OK(OpDefCompatible(old_op_def, *new_op_def)); + + // Verify the Op actually runs. Result() will return the output. + TF_ASSERT_OK(InitOp()); + TF_ASSERT_OK(RunOpKernel()); + } + + string Result() { return GetOutput(0)->scalar()(); } + + void ExpectIncompatible(const OpDef& old_op_def, const OpDef& new_op_def, + const string& error) { + // Test OpDefCompatible gives the same answer without the node_def. + Status status = OpDefCompatible(old_op_def, new_op_def); + if (status.ok()) { + ADD_FAILURE() << SummarizeOpDef(old_op_def) << " vs. " + << SummarizeOpDef(new_op_def); + } else { + EXPECT_TRUE(StringPiece(status.error_message()).contains(error)) + << status << " does not contain " << error; + } + } + + void ExpectInvalid(const OpDef& old_op_def, const string& validation_error, + const string& compatibility_error) { + // Record the original signature before we change *node_def(). + DataTypeVector old_in_types, old_out_types; + TF_ASSERT_OK(InOutTypesForNode(*node_def(), old_op_def, &old_in_types, + &old_out_types)); + + // This should be all that is needed to get compatibility. + const OpDef* new_op_def = RegisteredOpDef(); + AddDefaultsToNodeDef(*new_op_def, node_def()); + + // Validate that it does not pass validation. + Status status = ValidateNodeDef(*node_def(), *new_op_def); + if (status.ok()) { + ADD_FAILURE() << SummarizeNodeDef(*node_def()); + } else { + EXPECT_TRUE( + StringPiece(status.error_message()).contains(validation_error)) + << status << " does not contain " << validation_error; + } + + ExpectIncompatible(old_op_def, *new_op_def, compatibility_error); + } + + void ExpectTypeMismatch(const OpDef& old_op_def, + const string& compatibility_error) { + // Record the original signature before we change *node_def(). + DataTypeVector old_in_types, old_out_types; + TF_ASSERT_OK(InOutTypesForNode(*node_def(), old_op_def, &old_in_types, + &old_out_types)); + + // This should be all that is needed to get compatibility. + const OpDef* new_op_def = RegisteredOpDef(); + AddDefaultsToNodeDef(*new_op_def, node_def()); + + // Validate that it is valid, but with incompatible types. + TF_ASSERT_OK(ValidateNodeDef(*node_def(), *new_op_def)); + + DataTypeVector new_in_types, new_out_types; + TF_ASSERT_OK(InOutTypesForNode(*node_def(), *new_op_def, &new_in_types, + &new_out_types)); + if (new_in_types == old_in_types && new_out_types == old_out_types) { + ADD_FAILURE() << SummarizeNodeDef(*node_def()) << "\n" + << DataTypeVectorString(new_in_types) << " -> " + << DataTypeVectorString(new_out_types); + } + + ExpectIncompatible(old_op_def, *new_op_def, compatibility_error); + } + + void ExpectRenameFailure(const OpDef& old_op_def, + const string& compatibility_error) { + // This should be all that is needed to get compatibility. + const OpDef* new_op_def = RegisteredOpDef(); + AddDefaultsToNodeDef(*new_op_def, node_def()); + + // Validate that the NodeDef is valid. This will ignore + // problems caused by output name changes for functions. + TF_ASSERT_OK(ValidateNodeDef(*node_def(), *new_op_def)); + + ExpectIncompatible(old_op_def, *new_op_def, compatibility_error); + } +}; + +// Should be compatible if the Op hasn't changed (sanity check). +REGISTER_OP("Same") + .Input("a: int32") + .Input("b: T") + .Input("c: N * int32") + .Input("d: N * T") + .Input("e: TList") + .Output("ndef: string") + .Attr("T: type") + .Attr("N: int") + .Attr("TList: list(type)"); +REGISTER_KERNEL_BUILDER(Name("Same").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, Same) { + TF_ASSERT_OK(NodeDefBuilder("same", "Same") + .Input(FakeInput()) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(3)) + .Input(FakeInput(3, DT_FLOAT)) + .Input(FakeInput(2, DT_BOOL)) + .Finalize(node_def())); + ExpectSuccess(*RegisteredOpDef()); + EXPECT_EQ( + "same = Same[N=3, T=DT_FLOAT, TList=[DT_BOOL, DT_BOOL]](a, b, c, c:1, " + "c:2, d, d:1, d:2, e, e:1)", + Result()); +} + +// Should be able to add an attr with a default. +REGISTER_OP("AddAttr").Output("ndef: string").Attr("a: int = 42"); +REGISTER_KERNEL_BUILDER(Name("AddAttr").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, AddAttr) { + OpRegistrationData old_op; + TF_ASSERT_OK( + OpDefBuilder("AddAttr").Output("ndef: string").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("add_attr", &old_op.op_def).Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("add_attr = AddAttr[a=42]()", Result()); +} + +// Should be able to make an attr restriction less strict. +REGISTER_OP("LessStrict").Output("ndef: string").Attr("a: {'A', 'B', 'C'}"); +REGISTER_KERNEL_BUILDER(Name("LessStrict").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, LessStrict) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("LessStrict") + .Output("ndef: string") + .Attr("a: {'A', 'B'}") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("less_strict", &old_op.op_def) + .Attr("a", "B") + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("less_strict = LessStrict[a=\"B\"]()", Result()); +} + +// Should be able to remove an attr restriction. +REGISTER_OP("RemoveRestriction").Output("ndef: string").Attr("a: type"); +REGISTER_KERNEL_BUILDER(Name("RemoveRestriction").Device(DEVICE_CPU), + TestKernel); + +TEST_F(OpCompatibilityTest, RemoveRestriction) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("RemoveRestriction") + .Output("ndef: string") + .Attr("a: {int32, bool}") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("remove_restriction", &old_op.op_def) + .Attr("a", DT_INT32) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("remove_restriction = RemoveRestriction[a=DT_INT32]()", Result()); +} + +// Should be able to change the order of attrs. +REGISTER_OP("AttrOrder").Output("ndef: string").Attr("a: int").Attr("b: bool"); +REGISTER_KERNEL_BUILDER(Name("AttrOrder").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, AttrOrder) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AttrOrder") + .Output("ndef: string") + .Attr("b: bool") + .Attr("a: int") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("attr_order", &old_op.op_def) + .Attr("b", true) + .Attr("a", 7) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("attr_order = AttrOrder[a=7, b=true]()", Result()); +} + +// Should be able to add a default to an attr. +REGISTER_OP("AddDefault").Output("ndef: string").Attr("a: int = 1234"); +REGISTER_KERNEL_BUILDER(Name("AddDefault").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, AddDefault) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddDefault") + .Output("ndef: string") + .Attr("a: int") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op.op_def) + .Attr("a", 765) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("add_default = AddDefault[a=765]()", Result()); +} + +// Should be able to remove a default from an attr, *as long as that +// attr has always existed*. +REGISTER_OP("RemoveDefault").Output("ndef: string").Attr("a: int"); +REGISTER_KERNEL_BUILDER(Name("RemoveDefault").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, RemoveDefault) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("RemoveDefault") + .Output("ndef: string") + .Attr("a: int = 91") + .Finalize(&old_op)); + TF_ASSERT_OK( + NodeDefBuilder("remove_default", &old_op.op_def).Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("remove_default = RemoveDefault[a=91]()", Result()); +} + +// Should be able to make an input/output polymorphic. +// Changing from int32 -> T (where T: type = DT_INT32 by default). +REGISTER_OP("TypePolymorphic") + .Input("a: T") + .Output("ndef: string") + .Attr("T: type = DT_INT32"); +REGISTER_KERNEL_BUILDER(Name("TypePolymorphic").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, TypePolymorphic) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("TypePolymorphic") + .Input("a: int32") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("type_polymorphic", &old_op.op_def) + .Input(FakeInput()) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("type_polymorphic = TypePolymorphic[T=DT_INT32](a)", Result()); +} + +// Should be able to make a single input/output into a list. +// Changing from int32 -> N * int32 (where N: int = 1 by default). +REGISTER_OP("MakeList") + .Input("a: N * int32") + .Output("ndef: string") + .Attr("N: int = 1"); +REGISTER_KERNEL_BUILDER(Name("MakeList").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, MakeList) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("MakeList") + .Input("a: int32") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op.op_def) + .Input(FakeInput()) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("make_list = MakeList[N=1](a)", Result()); +} + +// Should be able to make a single input/output into a polymorphic list. +// Changing from int32 -> N * T (where N: int = 1 by default and +// T: type = DT_INT32 by default). +REGISTER_OP("MakePolyList") + .Input("a: N * T") + .Output("ndef: string") + .Attr("N: int = 1") + .Attr("T: type = DT_INT32"); +REGISTER_KERNEL_BUILDER(Name("MakePolyList").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, MakePolyList) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("MakePolyList") + .Input("a: int32") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("make_poly_list", &old_op.op_def) + .Input(FakeInput()) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("make_poly_list = MakePolyList[N=1, T=DT_INT32](a)", Result()); +} + +// Should be able to make a single input/output into an arbitrary list. +// Changing from int32 -> T (where T: list(type) = [DT_INT32] by default). +REGISTER_OP("MakeAnyList") + .Input("a: T") + .Output("ndef: string") + .Attr("T: list(type) = [DT_INT32]"); +REGISTER_KERNEL_BUILDER(Name("MakeAnyList").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, MakeAnyList) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("MakeAnyList") + .Input("a: int32") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("make_any_list", &old_op.op_def) + .Input(FakeInput()) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("make_any_list = MakeAnyList[T=[DT_INT32]](a)", Result()); +} + +// Should be able to make a single polymorphic input/output into a list of +// the same type. Changing from T -> N * T (where N: int = 1 by default). +REGISTER_OP("PolyIntoList") + .Input("a: N * T") + .Output("ndef: string") + .Attr("N: int = 1") + .Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("PolyIntoList").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, PolyIntoList) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("PolyIntoList") + .Input("a: T") + .Output("ndef: string") + .Attr("T: type") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("poly_into_list", &old_op.op_def) + .Input(FakeInput(DT_INT32)) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("poly_into_list = PolyIntoList[N=1, T=DT_INT32](a)", Result()); +} + +// Should be able to make a multiple inputs/outputs into a list with +// the default types matching the inputs/outputs being replaced. + +// Changing from int32, int32 -> N * int32 (where N: int = 2 by default). +REGISTER_OP("MakeMultipleSameList") + .Input("a: N * int32") + .Output("ndef: string") + .Attr("N: int = 2"); +REGISTER_KERNEL_BUILDER(Name("MakeMultipleSameList").Device(DEVICE_CPU), + TestKernel); + +TEST_F(OpCompatibilityTest, MakeMultipleSameList) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("MakeMultipleSameList") + .Input("a: int32") + .Input("b: int32") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op.op_def) + .Input(FakeInput()) + .Input(FakeInput()) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("make_list = MakeMultipleSameList[N=2](a, b)", Result()); +} + +// Changing from int32, float -> T +// (where T: list(type) = [int32, float] by default). +REGISTER_OP("MakeMultipleAnyList") + .Input("a: T") + .Output("ndef: string") + .Attr("T: list(type) = [DT_INT32, DT_FLOAT]"); +REGISTER_KERNEL_BUILDER(Name("MakeMultipleAnyList").Device(DEVICE_CPU), + TestKernel); + +TEST_F(OpCompatibilityTest, MakeMultipleAnyList) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("MakeMultipleAnyList") + .Input("a: int32") + .Input("b: float") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("make_list", &old_op.op_def) + .Input(FakeInput()) + .Input(FakeInput()) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("make_list = MakeMultipleAnyList[T=[DT_INT32, DT_FLOAT]](a, b)", + Result()); +} + +// Should be able to change the name of an input/output. +REGISTER_OP("ChangeName").Input("y: int32").Output("ndef: string"); +REGISTER_KERNEL_BUILDER(Name("ChangeName").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, ChangeName) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("ChangeName") + .Input("x: int32") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("change_name", &old_op.op_def) + .Input(FakeInput()) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("change_name = ChangeName[](a)", Result()); +} + +// Should be able to add an input/output of type +// N * int32 (where N: int = 0 by default). +REGISTER_OP("AddNInts") + .Input("a: N * int32") + .Output("ndef: string") + .Attr("N: int >= 0 = 0"); +REGISTER_KERNEL_BUILDER(Name("AddNInts").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, AddNInts) { + OpRegistrationData old_op; + TF_ASSERT_OK( + OpDefBuilder("AddNInts").Output("ndef: string").Finalize(&old_op)); + TF_ASSERT_OK( + NodeDefBuilder("add_n_ints", &old_op.op_def).Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("add_n_ints = AddNInts[N=0]()", Result()); +} + +// Should be able to add an input/output of type N * T +// (where N: int = 0 by default, and T: type = any valid default). +REGISTER_OP("AddNSame") + .Input("a: N * T") + .Output("ndef: string") + .Attr("N: int >= 0 = 0") + .Attr("T: type = DT_BOOL"); +REGISTER_KERNEL_BUILDER(Name("AddNSame").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, AddNSame) { + OpRegistrationData old_op; + TF_ASSERT_OK( + OpDefBuilder("AddNSame").Output("ndef: string").Finalize(&old_op)); + TF_ASSERT_OK( + NodeDefBuilder("add_n_same", &old_op.op_def).Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("add_n_same = AddNSame[N=0, T=DT_BOOL]()", Result()); +} + +// Should be able to add an input/output of type N * T +// (where N: int >= 0 = 0 by default, and T an existing type attr). +REGISTER_OP("AddNSameAsExisting") + .Input("a: T") + .Input("b: N * T") + .Output("ndef: string") + .Attr("N: int >= 0 = 0") + .Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("AddNSameAsExisting").Device(DEVICE_CPU), + TestKernel); + +TEST_F(OpCompatibilityTest, AddNSameAsExisting) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddNSameAsExisting") + .Input("a: T") + .Output("ndef: string") + .Attr("T: type") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("add_n_same_as_existing", &old_op.op_def) + .Input(FakeInput(DT_STRING)) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("add_n_same_as_existing = AddNSameAsExisting[N=0, T=DT_STRING](a)", + Result()); +} + +// Should be able to add an input/output of type T +// (where T: list(type) >= 0 = [] by default). +REGISTER_OP("AddAnyList") + .Input("a: T") + .Output("ndef: string") + .Attr("T: list(type) >= 0 = []"); +REGISTER_KERNEL_BUILDER(Name("AddAnyList").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, AddAnyList) { + OpRegistrationData old_op; + TF_ASSERT_OK( + OpDefBuilder("AddAnyList").Output("ndef: string").Finalize(&old_op)); + TF_ASSERT_OK( + NodeDefBuilder("add_any_list", &old_op.op_def).Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("add_any_list = AddAnyList[T=[]]()", Result()); +} + +// Should be able to allow shorter lists. +REGISTER_OP("ShorterAnyList") + .Input("a: T") + .Output("ndef: string") + .Attr("T: list(type) >= 1"); +REGISTER_KERNEL_BUILDER(Name("ShorterAnyList").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, ShorterAnyList) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("ShorterAnyList") + .Input("a: T") + .Output("ndef: string") + .Attr("T: list(type) >= 2") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("shorter_any_list", &old_op.op_def) + .Input(FakeInput(2, DT_BOOL)) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("shorter_any_list = ShorterAnyList[T=[DT_BOOL, DT_BOOL]](a, a:1)", + Result()); +} + +REGISTER_OP("ShorterSameList") + .Input("a: N * int32") + .Output("ndef: string") + .Attr("N: int >= 1"); +REGISTER_KERNEL_BUILDER(Name("ShorterSameList").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, ShorterSameList) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("ShorterSameList") + .Input("a: N * int32") + .Output("ndef: string") + .Attr("N: int >= 2") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("shorter_same_list", &old_op.op_def) + .Input(FakeInput(2)) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("shorter_same_list = ShorterSameList[N=2](a, a:1)", Result()); +} + +// Can remove a restriction to an attr + +REGISTER_OP("AttrRemoveRestriction").Attr("t: type").Output("ndef: string"); +REGISTER_KERNEL_BUILDER(Name("AttrRemoveRestriction").Device(DEVICE_CPU), + TestKernel); + +TEST_F(OpCompatibilityTest, AttrRemoveRestriction) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AttrRemoveRestriction") + .Attr("t: {int32,int64}") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("remove_restriction", &old_op.op_def) + .Attr("t", DT_INT32) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("remove_restriction = AttrRemoveRestriction[t=DT_INT32]()", + Result()); +} + +// Can make the restrictions on an attr less restrictive. + +REGISTER_OP("AttrLessRestrictive") + .Attr("t: {int32, int64, bool}") + .Output("ndef: string"); +REGISTER_KERNEL_BUILDER(Name("AttrLessRestrictive").Device(DEVICE_CPU), + TestKernel); + +TEST_F(OpCompatibilityTest, AttrLessRestrictive) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AttrLessRestrictive") + .Attr("t: {int32, int64}") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("less_restrictive", &old_op.op_def) + .Attr("t", DT_INT32) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("less_restrictive = AttrLessRestrictive[t=DT_INT32]()", Result()); +} + +// Can remove a minimum from an attr. + +REGISTER_OP("AttrRemoveMin").Attr("n: int").Output("ndef: string"); +REGISTER_KERNEL_BUILDER(Name("AttrRemoveMin").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, AttrRemoveMin) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AttrRemoveMin") + .Attr("n: int >= 3") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("remove_min", &old_op.op_def) + .Attr("n", 4) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("remove_min = AttrRemoveMin[n=4]()", Result()); +} + +// Can lower the minimum on an attr. + +REGISTER_OP("AttrLowerMin").Attr("n: int >= 1").Output("ndef: string"); +REGISTER_KERNEL_BUILDER(Name("AttrLowerMin").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, AttrLowerMin) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AttrLowerMin") + .Attr("n: int >= 3") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("lower_min", &old_op.op_def) + .Attr("n", 4) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); + EXPECT_EQ("lower_min = AttrLowerMin[n=4]()", Result()); +} + +// Can make a ref input into a non-ref input. + +REGISTER_OP("InputRemoveRef").Input("i: int32").Output("ndef: string"); +REGISTER_KERNEL_BUILDER(Name("InputRemoveRef").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, InputRemoveRef) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("InputRemoveRef") + .Input("i: Ref(int32)") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("remove_input_ref", &old_op.op_def) + .Input(FakeInput()) + .Finalize(node_def())); + ExpectSuccess(old_op.op_def); +} + +// Can make a non-ref output into a ref output. + +REGISTER_OP("OutputAddRef").Output("o: Ref(int32)").Output("ndef: string"); +REGISTER_KERNEL_BUILDER(Name("OutputAddRef").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, OutputAddRef) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("OutputAddRef") + .Output("o: int32") + .Output("ndef: string") + .Finalize(&old_op)); + TF_ASSERT_OK( + NodeDefBuilder("add_output_ref", &old_op.op_def).Finalize(node_def())); + ExpectSuccess(old_op.op_def); +} + +// Negative tests ------------------------------------------------------------- + +// Can't remove an attr. +REGISTER_OP("RemoveAttr"); + +TEST_F(OpCompatibilityTest, RemoveAttrFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("RemoveAttr").Attr("a: int").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) + .Attr("a", 3) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, "NodeDef mentions attr 'a' not in", + "Attr 'a' removed"); +} + +// Can't add an attr without a default. +REGISTER_OP("AddAttrNoDefault").Attr("a: int"); + +TEST_F(OpCompatibilityTest, AddAttrNoDefaultFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddAttrNoDefault").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def())); + ExpectInvalid(old_op.op_def, "NodeDef missing attr 'a'", + "Attr 'a' added without default"); +} + +// Can't add a non-list input/output. +REGISTER_OP("AddSingleInput").Input("a: int32"); + +TEST_F(OpCompatibilityTest, AddSingleInputFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddSingleInput").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "expected inputs 'int32' do not match 0 inputs specified", + "Input signature mismatch '' vs. 'int32'"); +} + +// Can't add a list input/output without an empty default. + +REGISTER_OP("AddNIntsBigDefault").Input("a: N * int32").Attr("N: int = 1"); +REGISTER_OP("AddNSameBigDefault") + .Input("a: N * T") + .Attr("N: int = 1") + .Attr("T: type = DT_INT32"); +REGISTER_OP("AddListBigDefault") + .Input("a: T") + .Attr("T: list(type) = [DT_INT32]"); + +TEST_F(OpCompatibilityTest, AddNIntsBigDefaultFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddNIntsBigDefault").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "expected inputs 'int32' do not match 0 inputs specified", + "Input signature mismatch '' vs. 'int32'"); +} + +TEST_F(OpCompatibilityTest, AddNSameBigDefaultFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddNSameBigDefault").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "expected inputs 'int32' do not match 0 inputs specified", + "Input signature mismatch '' vs. 'int32'"); +} + +TEST_F(OpCompatibilityTest, AddListBigDefaultFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddListBigDefault").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def).Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "expected inputs 'int32' do not match 0 inputs specified", + "Input signature mismatch '' vs. 'int32'"); +} + +// Can't change the type of an input/output. + +REGISTER_OP("ChangeType").Input("a: float"); + +TEST_F(OpCompatibilityTest, ChangeTypeFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("ChangeType").Input("a: int32").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) + .Input(FakeInput()) + .Finalize(node_def())); + ExpectTypeMismatch(old_op.op_def, + "Input signature mismatch 'int32' vs. 'float'"); +} + +// Can't change the order of inputs/outputs. + +REGISTER_OP("ChangeOrder").Input("a: float").Input("b: int32"); + +TEST_F(OpCompatibilityTest, ChangeOrderFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("ChangeOrder") + .Input("b: int32") + .Input("a: float") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) + .Input(FakeInput()) + .Input(FakeInput()) + .Finalize(node_def())); + ExpectTypeMismatch( + old_op.op_def, + "Input signature mismatch 'int32, float' vs. 'float, int32'"); +} + +// Can't remove inputs/outputs. + +REGISTER_OP("RemoveInput"); + +TEST_F(OpCompatibilityTest, RemoveInputFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("RemoveInput").Input("a: float").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) + .Input(FakeInput()) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "expected inputs '' do not match 1 inputs specified", + "Input signature mismatch 'float' vs. ''"); +} + +// Can't change the type of an attr. + +REGISTER_OP("ChangeAttrType").Attr("a: int"); + +TEST_F(OpCompatibilityTest, ChangeAttrTypeFails) { + OpRegistrationData old_op; + TF_ASSERT_OK( + OpDefBuilder("ChangeAttrType").Attr("a: bool").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) + .Attr("a", true) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, "value with type 'bool' when 'int' expected", + "Attr 'a' changed type 'bool' -> 'int'"); +} + +// Can't change an attr from a list. + +REGISTER_OP("AttrFromList").Attr("a: int"); + +TEST_F(OpCompatibilityTest, AttrFromListFails) { + OpRegistrationData old_op; + TF_ASSERT_OK( + OpDefBuilder("AttrFromList").Attr("a: list(int)").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) + .Attr("a", {5}) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "value with type 'list(int)' when 'int' expected", + "Attr 'a' changed type 'list(int)' -> 'int'"); +} + +// Can't change an attr to a list. + +REGISTER_OP("AttrToList").Attr("a: list(int)"); + +TEST_F(OpCompatibilityTest, AttrToListFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AttrToList").Attr("a: int").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) + .Attr("a", 5) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "value with type 'int' when 'list(int)' expected", + "Attr 'a' changed type 'int' -> 'list(int)'"); +} + +// Can't change an input from polymorphic to a list of any type. + +REGISTER_OP("PolymorphicToAnyList").Input("a: T").Attr("T: list(type)"); + +TEST_F(OpCompatibilityTest, PolymorphicToAnyListFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("PolymorphicToAnyList") + .Input("a: T") + .Attr("T: type") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) + .Input(FakeInput(DT_INT32)) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "value with type 'type' when 'list(type)' expected", + "Attr 'T' changed type 'type' -> 'list(type)'"); +} + +// Can't change an input from a list of the same type to a list of any type. + +REGISTER_OP("SameToAnyList") + .Input("a: T") + .Attr("T: list(type)") + .Attr("N: int = 1"); + +TEST_F(OpCompatibilityTest, SameToAnyListFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("SameToAnyList") + .Input("a: N * T") + .Attr("T: type") + .Attr("N: int") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("fails", &old_op.op_def) + .Input(FakeInput(1, DT_INT32)) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "value with type 'type' when 'list(type)' expected", + "Attr 'T' changed type 'type' -> 'list(type)'"); +} + +// Can't add a restriction to an attr + +REGISTER_OP("AttrAddRestriction").Attr("t: {int32, int64}"); + +TEST_F(OpCompatibilityTest, AttrAddRestrictionFails) { + OpRegistrationData old_op; + TF_ASSERT_OK( + OpDefBuilder("AttrAddRestriction").Attr("t: type").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("add_restriction", &old_op.op_def) + .Attr("t", DT_BOOL) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "Value for attr 't' of bool is not in the list of allowed " + "values: int32, int64", + "Attr 't' has a stricter set of allowed values; from " + "no restriction to [DT_INT32, DT_INT64]"); +} + +// Can't make the restrictions on an attr more restrictive. + +REGISTER_OP("AttrMoreRestrictive").Attr("t: {int32, int64}"); + +TEST_F(OpCompatibilityTest, AttrMoreRestrictiveFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AttrMoreRestrictive") + .Attr("t: {int32, int64, bool}") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("more_restrictive", &old_op.op_def) + .Attr("t", DT_BOOL) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "Value for attr 't' of bool is not in the list of allowed " + "values: int32, int64", + "Attr 't' has a stricter set of allowed values; from " + "[DT_INT32, DT_INT64, DT_BOOL] to [DT_INT32, DT_INT64]"); +} + +// Can't add a minimum to an attr. + +REGISTER_OP("AttrAddMin").Attr("n: int >= 3"); + +TEST_F(OpCompatibilityTest, AttrAddMinFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AttrAddMin").Attr("n: int").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("add_min", &old_op.op_def) + .Attr("n", 2) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "Value for attr 'n' of 2 must be at least minimum 3", + "Attr 'n' has a higher minimum; from no minimum to 3"); +} + +// Can't raise the minimum on an attr. + +REGISTER_OP("AttrRaiseMin").Attr("n: int >= 3"); + +TEST_F(OpCompatibilityTest, AttrRaiseMinFails) { + OpRegistrationData old_op; + TF_ASSERT_OK( + OpDefBuilder("AttrRaiseMin").Attr("n: int >= 1").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("raise_min", &old_op.op_def) + .Attr("n", 2) + .Finalize(node_def())); + ExpectInvalid(old_op.op_def, + "Value for attr 'n' of 2 must be at least minimum 3", + "Attr 'n' has a higher minimum; from 1 to 3"); +} + +// Can't make a non-ref input into a ref input. + +REGISTER_OP("InputAddRef").Input("i: Ref(int32)"); + +TEST_F(OpCompatibilityTest, InputAddRefFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("InputAddRef").Input("i: int32").Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("add_input_ref", &old_op.op_def) + .Input(FakeInput()) + .Finalize(node_def())); + ExpectTypeMismatch(old_op.op_def, "Input 0 changed from non-ref to ref"); +} + +// Can't make a ref output into a non-ref output. + +REGISTER_OP("OutputRemoveRef").Output("o: int32"); + +TEST_F(OpCompatibilityTest, OutputRemoveRefFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("OutputRemoveRef") + .Output("o: Ref(int32)") + .Finalize(&old_op)); + TF_ASSERT_OK( + NodeDefBuilder("remove_output_ref", &old_op.op_def).Finalize(node_def())); + ExpectTypeMismatch(old_op.op_def, "Output 0 changed from ref to non-ref"); +} + +// Can't rename an output, to avoid problems in FunctionDefs. + +REGISTER_OP("RenameOutput").Output("new: int32"); + +TEST_F(OpCompatibilityTest, RenameOutputFails) { + OpRegistrationData old_op; + TF_ASSERT_OK( + OpDefBuilder("RenameOutput").Output("old: int32").Finalize(&old_op)); + TF_ASSERT_OK( + NodeDefBuilder("rename_output", &old_op.op_def).Finalize(node_def())); + ExpectRenameFailure(old_op.op_def, + "Output signature mismatch 'old:int32' vs. 'new:int32'"); +} + +REGISTER_OP("RenameNOutputs").Output("new: N*int32").Attr("N: int"); + +TEST_F(OpCompatibilityTest, RenameNOutputsFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("RenameNOutputs") + .Output("old: N*int32") + .Attr("N: int") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("rename_n_outputs", &old_op.op_def) + .Attr("N", 2) + .Finalize(node_def())); + ExpectRenameFailure( + old_op.op_def, + "Output signature mismatch 'old:N * int32' vs. 'new:N * int32'"); +} + +REGISTER_OP("RenameOutputList").Output("new: T").Attr("T: list(type)"); + +TEST_F(OpCompatibilityTest, RenameOutputListFails) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("RenameOutputList") + .Output("old: T") + .Attr("T: list(type)") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("rename_output_list", &old_op.op_def) + .Attr("T", {DT_INT32, DT_FLOAT}) + .Finalize(node_def())); + ExpectRenameFailure(old_op.op_def, + "Output signature mismatch 'old:T' vs. 'new:T'"); +} + +// Changing an attr's default is not technically illegal, but should +// be forbidden if it the attr ever didn't exist since it likely +// affects semantics. + +} // namespace +} // namespace tensorflow diff --git a/op_def.proto b/op_def.proto new file mode 100644 index 0000000000000000000000000000000000000000..ba545a19949e5574086756dc2092033341be4b30 --- /dev/null +++ b/op_def.proto @@ -0,0 +1,160 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/attr_value.proto"; +import "tensorflow/core/framework/types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +// LINT.IfChange +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + // TODO(josh11b): bool is_optional? + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + // TODO(josh11b): Implement that optimization. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/op_def_builder.cc b/op_def_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..962bc11ccbd2b9abdd4ce26dc3e75c45862cdc74 --- /dev/null +++ b/op_def_builder.cc @@ -0,0 +1,618 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_def_builder.h" + +#include +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +using ::tensorflow::strings::Scanner; + +namespace tensorflow { + +namespace { + +string AttrError(StringPiece orig, const string& op_name) { + return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name); +} + +bool ConsumeAttrName(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .OneLiteral(":") + .AnySpace() + .GetResult(sp, out); +} + +bool ConsumeListPrefix(StringPiece* sp) { + return Scanner(*sp) + .OneLiteral("list") + .AnySpace() + .OneLiteral("(") + .AnySpace() + .GetResult(sp); +} + +bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) { + const string quote_str(1, quote_ch); + return Scanner(*sp) + .OneLiteral(quote_str.c_str()) + .RestartCapture() + .ScanEscapedUntil(quote_ch) + .StopCapture() + .OneLiteral(quote_str.c_str()) + .AnySpace() + .GetResult(sp, out); +} + +bool ConsumeAttrType(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .Many(Scanner::LOWERLETTER_DIGIT) + .StopCapture() + .AnySpace() + .GetResult(sp, out); +} + +bool ConsumeAttrNumber(StringPiece* sp, int64* out) { + Scanner scan(*sp); + StringPiece match; + StringPiece remaining; + + scan.AnySpace().RestartCapture(); + if (scan.Peek() == '-') { + scan.OneLiteral("-"); + } + if (!scan.Many(Scanner::DIGIT) + .StopCapture() + .AnySpace() + .GetResult(&remaining, &match)) { + return false; + } + int64 value = 0; + if (!strings::safe_strto64(match, &value)) { + return false; + } + *out = value; + *sp = remaining; + return true; +} + +#define VERIFY(expr, ...) \ + do { \ + if (!(expr)) { \ + errors->push_back( \ + strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \ + return; \ + } \ + } while (false) + +bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) { + auto capture_begin = sp->begin(); + if (sp->Consume("numbertype") || sp->Consume("numerictype") || + sp->Consume("quantizedtype") || sp->Consume("realnumbertype") || + sp->Consume("realnumberictype")) { + *out = StringPiece(capture_begin, sp->begin() - capture_begin); + return true; + } + return false; +} + +bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) { + if (type_string == "numbertype" || type_string == "numerictype") { + for (DataType dt : NumberTypes()) { + allowed->mutable_list()->add_type(dt); + } + } else if (type_string == "quantizedtype") { + for (DataType dt : QuantizedTypes()) { + allowed->mutable_list()->add_type(dt); + } + } else if (type_string == "realnumbertype" || + type_string == "realnumerictype") { + for (DataType dt : RealNumberTypes()) { + allowed->mutable_list()->add_type(dt); + } + } else { + return false; + } + return true; +} + +void FinalizeAttr(StringPiece spec, OpDef* op_def, + std::vector* errors) { + OpDef::AttrDef* attr = op_def->add_attr(); + StringPiece orig(spec); + + // Parse ":" at the beginning. + StringPiece tmp_name; + VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing ':'"); + attr->set_name(tmp_name.data(), tmp_name.size()); + + // Read "" or "list()". + bool is_list = ConsumeListPrefix(&spec); + string type; + StringPiece type_string; // Used if type == "type" + if (spec.Consume("string")) { + type = "string"; + } else if (spec.Consume("int")) { + type = "int"; + } else if (spec.Consume("float")) { + type = "float"; + } else if (spec.Consume("bool")) { + type = "bool"; + } else if (spec.Consume("type")) { + type = "type"; + } else if (spec.Consume("shape")) { + type = "shape"; + } else if (spec.Consume("tensor")) { + type = "tensor"; + } else if (spec.Consume("func")) { + type = "func"; + } else if (ConsumeCompoundAttrType(&spec, &type_string)) { + type = "type"; + AttrValue* allowed = attr->mutable_allowed_values(); + VERIFY(ProcessCompoundType(type_string, allowed), + "Expected to see a compound type, saw: ", type_string); + } else if (spec.Consume("{")) { + // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }" + AttrValue* allowed = attr->mutable_allowed_values(); + str_util::RemoveLeadingWhitespace(&spec); + if (spec.starts_with("\"") || spec.starts_with("'")) { + type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }" + while (true) { + StringPiece escaped_string; + VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) || + ConsumeQuotedString('\'', &spec, &escaped_string), + "Trouble parsing allowed string at '", spec, "'"); + string unescaped; + string error; + VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error), + "Trouble unescaping \"", escaped_string, + "\", got error: ", error); + allowed->mutable_list()->add_s(unescaped); + if (spec.Consume(",")) { + str_util::RemoveLeadingWhitespace(&spec); + if (spec.Consume("}")) break; // Allow ending with ", }". + } else { + VERIFY(spec.Consume("}"), + "Expected , or } after strings in list, not: '", spec, "'"); + break; + } + } + } else { // "{ bool, numbertype, string }" + type = "type"; + while (true) { + VERIFY(ConsumeAttrType(&spec, &type_string), + "Trouble parsing type string at '", spec, "'"); + if (ProcessCompoundType(type_string, allowed)) { + // Processed a compound type. + } else { + DataType dt; + VERIFY(DataTypeFromString(type_string, &dt), + "Unrecognized type string '", type_string, "'"); + allowed->mutable_list()->add_type(dt); + } + if (spec.Consume(",")) { + str_util::RemoveLeadingWhitespace(&spec); + if (spec.Consume("}")) break; // Allow ending with ", }". + } else { + VERIFY(spec.Consume("}"), + "Expected , or } after types in list, not: '", spec, "'"); + break; + } + } + } + } else { // if spec.Consume("{") + VERIFY(false, "Trouble parsing type string at '", spec, "'"); + } + str_util::RemoveLeadingWhitespace(&spec); + + // Write the type into *attr. + if (is_list) { + VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'"); + str_util::RemoveLeadingWhitespace(&spec); + attr->set_type(strings::StrCat("list(", type, ")")); + } else { + attr->set_type(type); + } + + // Read optional minimum constraint at the end. + if ((is_list || type == "int") && spec.Consume(">=")) { + int64 min_limit = -999; + VERIFY(ConsumeAttrNumber(&spec, &min_limit), + "Could not parse integer lower limit after '>=', found '", spec, + "' instead"); + attr->set_has_minimum(true); + attr->set_minimum(min_limit); + } + + // Parse default value, if present. + if (spec.Consume("=")) { + str_util::RemoveLeadingWhitespace(&spec); + VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()), + "Could not parse default value '", spec, "'"); + } else { + VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end"); + } +} + +#undef VERIFY + +string InOutError(bool is_output, StringPiece orig, const string& op_name) { + return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig, + "\") for Op ", op_name); +} + +bool ConsumeInOutName(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .One(Scanner::LOWERLETTER) + .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .OneLiteral(":") + .AnySpace() + .GetResult(sp, out); +} + +bool ConsumeInOutRefOpen(StringPiece* sp) { + return Scanner(*sp) + .OneLiteral("Ref") + .AnySpace() + .OneLiteral("(") + .AnySpace() + .GetResult(sp); +} + +bool ConsumeInOutRefClose(StringPiece* sp) { + return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp); +} + +bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .GetResult(sp, out); +} + +bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .OneLiteral("*") + .AnySpace() + .RestartCapture() + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .GetResult(sp, out); +} + +#define VERIFY(expr, ...) \ + do { \ + if (!(expr)) { \ + errors->push_back(strings::StrCat( \ + __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \ + return; \ + } \ + } while (false) + +void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def, + std::vector* errors) { + OpDef::ArgDef* arg = + is_output ? op_def->add_output_arg() : op_def->add_input_arg(); + + StringPiece orig(spec); + + // Parse ":" at the beginning. + StringPiece tmp_name; + VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'"); + arg->set_name(tmp_name.data(), tmp_name.size()); + + // Detect "Ref(...)". + if (ConsumeInOutRefOpen(&spec)) { + arg->set_is_ref(true); + } + + { // Parse "" or "*". + StringPiece first, second, type_or_attr; + VERIFY(ConsumeInOutNameOrType(&spec, &first), + "Trouble parsing either a type or an attr name at '", spec, "'"); + if (ConsumeInOutTimesType(&spec, &second)) { + arg->set_number_attr(first.data(), first.size()); + type_or_attr = second; + } else { + type_or_attr = first; + } + DataType dt; + if (DataTypeFromString(type_or_attr, &dt)) { + arg->set_type(dt); + } else { + const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def); + VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'"); + if (attr->type() == "type") { + arg->set_type_attr(type_or_attr.data(), type_or_attr.size()); + } else { + VERIFY(attr->type() == "list(type)", "Reference to attr '", + type_or_attr, "' with type ", attr->type(), + " that isn't type or list(type)"); + arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size()); + } + } + } + + // Closing ) for Ref(. + if (arg->is_ref()) { + VERIFY(ConsumeInOutRefClose(&spec), + "Did not find closing ')' for 'Ref(', instead found: '", spec, "'"); + } + + // Should not have anything else. + VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end"); + + // Int attrs that are the length of an input or output get a default + // minimum of 1. + if (!arg->number_attr().empty()) { + OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def); + if (attr != nullptr && !attr->has_minimum()) { + attr->set_has_minimum(true); + attr->set_minimum(1); + } + } else if (!arg->type_list_attr().empty()) { + // If an input or output has type specified by a list(type) attr, + // it gets a default minimum of 1 as well. + OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def); + if (attr != nullptr && attr->type() == "list(type)" && + !attr->has_minimum()) { + attr->set_has_minimum(true); + attr->set_minimum(1); + } + } + + // If the arg's dtype is resource we should mark the op as stateful as it + // likely touches a resource manager. This deliberately doesn't cover inputs / + // outputs which resolve to resource via Attrs as those mostly operate on + // resource handles as an opaque type (as opposed to ops which explicitly take + // / produce resources). + if (arg->type() == DT_RESOURCE) { + op_def->set_is_stateful(true); + } +} + +#undef VERIFY + +int num_leading_spaces(StringPiece s) { + size_t i = 0; + while (i < s.size() && s[i] == ' ') { + ++i; + } + return i; +} + +bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) { + return Scanner(*sp) + .One(Scanner::LETTER) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .StopCapture() + .AnySpace() + .OneLiteral(":") + .AnySpace() + .GetResult(sp, out); +} + +bool IsDocNameColon(StringPiece s) { + return ConsumeDocNameColon(&s, nullptr /* out */); +} + +void FinalizeDoc(const string& text, OpDef* op_def, + std::vector* errors) { + std::vector lines = str_util::Split(text, '\n'); + + // Remove trailing spaces. + for (string& line : lines) { + str_util::StripTrailingWhitespace(&line); + } + + // First non-blank line -> summary. + int l = 0; + while (static_cast(l) < lines.size() && lines[l].empty()) ++l; + if (static_cast(l) < lines.size()) { + op_def->set_summary(lines[l]); + ++l; + } + while (static_cast(l) < lines.size() && lines[l].empty()) ++l; + + // Lines until we see name: -> description. + int start_l = l; + while (static_cast(l) < lines.size() && !IsDocNameColon(lines[l])) { + ++l; + } + int end_l = l; + // Trim trailing blank lines from the description. + while (start_l < end_l && lines[end_l - 1].empty()) --end_l; + string desc = str_util::Join( + gtl::ArraySlice(lines.data() + start_l, end_l - start_l), "\n"); + if (!desc.empty()) op_def->set_description(desc); + + // name: description + // possibly continued on the next line + // if so, we remove the minimum indent + StringPiece name; + std::vector description; + while (static_cast(l) < lines.size()) { + description.clear(); + description.push_back(lines[l]); + ConsumeDocNameColon(&description.back(), &name); + ++l; + while (static_cast(l) < lines.size() && !IsDocNameColon(lines[l])) { + description.push_back(lines[l]); + ++l; + } + // Remove any trailing blank lines. + while (!description.empty() && description.back().empty()) { + description.pop_back(); + } + // Compute the minimum indent of all lines after the first. + int min_indent = -1; + for (size_t i = 1; i < description.size(); ++i) { + if (!description[i].empty()) { + int indent = num_leading_spaces(description[i]); + if (min_indent < 0 || indent < min_indent) min_indent = indent; + } + } + // Remove min_indent spaces from all lines after the first. + for (size_t i = 1; i < description.size(); ++i) { + if (!description[i].empty()) description[i].remove_prefix(min_indent); + } + // Concatenate lines into a single string. + const string complete(str_util::Join(description, "\n")); + + // Find name. + bool found = false; + for (int i = 0; !found && i < op_def->input_arg_size(); ++i) { + if (op_def->input_arg(i).name() == name) { + op_def->mutable_input_arg(i)->set_description(complete); + found = true; + } + } + for (int i = 0; !found && i < op_def->output_arg_size(); ++i) { + if (op_def->output_arg(i).name() == name) { + op_def->mutable_output_arg(i)->set_description(complete); + found = true; + } + } + for (int i = 0; !found && i < op_def->attr_size(); ++i) { + if (op_def->attr(i).name() == name) { + op_def->mutable_attr(i)->set_description(complete); + found = true; + } + } + if (!found) { + errors->push_back( + strings::StrCat("No matching input/output/attr for name '", name, + "' from Doc() for Op ", op_def->name())); + return; + } + } +} + +} // namespace + +OpDefBuilder::OpDefBuilder(StringPiece op_name) { + op_def()->set_name(op_name.ToString()); // NOLINT +} + +OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) { + attrs_.emplace_back(spec.data(), spec.size()); + return *this; +} + +OpDefBuilder& OpDefBuilder::Input(StringPiece spec) { + inputs_.emplace_back(spec.data(), spec.size()); + return *this; +} + +OpDefBuilder& OpDefBuilder::Output(StringPiece spec) { + outputs_.emplace_back(spec.data(), spec.size()); + return *this; +} + +#ifndef TF_LEAN_BINARY +OpDefBuilder& OpDefBuilder::Doc(StringPiece text) { + if (!doc_.empty()) { + errors_.push_back( + strings::StrCat("Extra call to Doc() for Op ", op_def()->name())); + } else { + doc_.assign(text.data(), text.size()); + } + return *this; +} +#endif + +OpDefBuilder& OpDefBuilder::SetIsCommutative() { + op_def()->set_is_commutative(true); + return *this; +} + +OpDefBuilder& OpDefBuilder::SetIsAggregate() { + op_def()->set_is_aggregate(true); + return *this; +} + +OpDefBuilder& OpDefBuilder::SetIsStateful() { + op_def()->set_is_stateful(true); + return *this; +} + +OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() { + op_def()->set_allows_uninitialized_input(true); + return *this; +} + +OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) { + if (op_def()->has_deprecation()) { + errors_.push_back( + strings::StrCat("Deprecated called twice for Op ", op_def()->name())); + } else { + OpDeprecation* deprecation = op_def()->mutable_deprecation(); + deprecation->set_version(version); + deprecation->set_explanation(explanation.ToString()); + } + return *this; +} + +OpDefBuilder& OpDefBuilder::SetShapeFn( + Status (*fn)(shape_inference::InferenceContext*)) { + if (op_reg_data_.shape_inference_fn != nullptr) { + errors_.push_back( + strings::StrCat("SetShapeFn called twice for Op ", op_def()->name())); + } else { + op_reg_data_.shape_inference_fn = OpShapeInferenceFn(fn); + } + return *this; +} + +Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { + std::vector errors = errors_; + *op_reg_data = op_reg_data_; + + OpDef* op_def = &op_reg_data->op_def; + for (StringPiece attr : attrs_) { + FinalizeAttr(attr, op_def, &errors); + } + for (StringPiece input : inputs_) { + FinalizeInputOrOutput(input, false, op_def, &errors); + } + for (StringPiece output : outputs_) { + FinalizeInputOrOutput(output, true, op_def, &errors); + } + FinalizeDoc(doc_, op_def, &errors); + + if (errors.empty()) return Status::OK(); + return errors::InvalidArgument(str_util::Join(errors, "\n")); +} + +} // namespace tensorflow diff --git a/op_def_builder.h b/op_def_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..fbfb4018aadb7d58a72ffa514b0d5be2384e08ea --- /dev/null +++ b/op_def_builder.h @@ -0,0 +1,165 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Class and associated machinery for specifying an Op's OpDef and shape +// inference function for Op registration. + +#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ +#define TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ + +#include +#include +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +namespace shape_inference { +class InferenceContext; +} +typedef std::function + OpShapeInferenceFn; + +struct OpRegistrationData { + public: + OpRegistrationData() {} + OpRegistrationData(const OpDef& def) : op_def(def) {} + OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn, + bool is_function = false) + : op_def(def), shape_inference_fn(fn), is_function_op(is_function) {} + + OpDef op_def; + OpShapeInferenceFn shape_inference_fn; + bool is_function_op = false; +}; + +// Builder class passed to the REGISTER_OP() macro. +class OpDefBuilder { + public: + // Constructs an OpDef with just the name field set. + explicit OpDefBuilder(StringPiece op_name); + + // Adds an attr to this OpDefBuilder (and returns *this). The spec has + // format ":" or ":=" + // where matches regexp [a-zA-Z][a-zA-Z0-9_]* + // (by convention only using capital letters for attrs that can be inferred) + // can be: + // "string", "int", "float", "bool", "type", "shape", or "tensor" + // "numbertype", "realnumbertype", "quantizedtype" + // (meaning "type" with a restriction on valid values) + // "{int32,int64}" or {realnumbertype,quantizedtype,string}" + // (meaning "type" with a restriction containing unions of value types) + // "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}" + // (meaning "string" with a restriction on valid values) + // "list(string)", ..., "list(tensor)", "list(numbertype)", ... + // (meaning lists of the above types) + // "int >= 2" (meaning "int" with a restriction on valid values) + // "list(string) >= 2", "list(int) >= 2" + // (meaning "list(string)" / "list(int)" with length at least 2) + // , if included, should use the Proto text format + // of . For lists use [a, b, c] format. + // + // Note that any attr specifying the length of an input or output will + // get a default minimum of 1 unless the >= # syntax is used. + // + // TODO(josh11b): Perhaps support restrictions and defaults as optional + // extra arguments to Attr() instead of encoding them in the spec string. + // TODO(josh11b): Would like to have better dtype handling for tensor attrs: + // * Ability to say the type of an input/output matches the type of + // the tensor. + // * Ability to restrict the type of the tensor like the existing + // restrictions for type attrs. + // Perhaps by linking the type of the tensor to a type attr? + OpDefBuilder& Attr(StringPiece spec); + + // Adds an input or output to this OpDefBuilder (and returns *this). + // The spec has form ":" or ":Ref()" + // where matches regexp [a-z][a-z0-9_]* and can be: + // * For a single tensor: + // * For a sequence of tensors with the same type: * + // * For a sequence of tensors with different types: + // Where: + // is either one of "float", "int32", "string", ... + // or the name of an attr (see above) with type "type". + // is the name of an attr with type "int". + // is the name of an attr with type "list(type)". + // TODO(josh11b): Indicate Ref() via an optional argument instead of + // in the spec? + // TODO(josh11b): SparseInput() and SparseOutput() matching the Python + // handling? + OpDefBuilder& Input(StringPiece spec); + OpDefBuilder& Output(StringPiece spec); + + // Turns on the indicated boolean flag in this OpDefBuilder (and + // returns *this). + OpDefBuilder& SetIsCommutative(); + OpDefBuilder& SetIsAggregate(); + OpDefBuilder& SetIsStateful(); + OpDefBuilder& SetAllowsUninitializedInput(); + + // Deprecate the op at a certain GraphDef version. + OpDefBuilder& Deprecated(int version, StringPiece explanation); + + // Adds docs to this OpDefBuilder (and returns *this). + // Docs have the format: + // <1-line summary> + // + // : + // : + // + // Where is the name of an attr, input, or output. Please + // wrap docs at 72 columns so that it may be indented in the + // generated output. For tensor inputs or outputs (not attrs), you + // may start the description with an "=" (like name:= ) + // to suppress the automatically-generated type documentation in + // generated output. +#ifndef TF_LEAN_BINARY + OpDefBuilder& Doc(StringPiece text); +#else + OpDefBuilder& Doc(StringPiece text) { return *this; } +#endif + + // Sets the shape function to be used for shape inference. + // + // Note that currently (October 2016), python code still requires a + // RegisterShape call to invoke this; see call_cpp_shape_fn in + // python/framework/common_shapes.py + OpDefBuilder& SetShapeFn(Status (*fn)(shape_inference::InferenceContext*)); + + // Sets op_reg_data->op_def to the requested OpDef and + // op_reg_data->shape_inference_fn to the requested shape inference function, + // or returns an error. + // Must be called after all of the above methods. + // + // Note that OpDefBuilder only reports parsing errors. You should also + // call ValidateOpDef() to detect other problems. + Status Finalize(OpRegistrationData* op_reg_data) const; + + private: + OpDef* op_def() { return &op_reg_data_.op_def; } + + OpRegistrationData op_reg_data_; + std::vector attrs_; + std::vector inputs_; + std::vector outputs_; + string doc_; + std::vector errors_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ diff --git a/op_def_builder_test.cc b/op_def_builder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b24e3aa00425321eda2e196b1e7b243a552c730 --- /dev/null +++ b/op_def_builder_test.cc @@ -0,0 +1,637 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_def_builder.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +static void CanonicalizeAttrTypeListOrder(OpDef* def) { + for (int i = 0; i < def->attr_size(); i++) { + AttrValue* a = def->mutable_attr(i)->mutable_allowed_values(); + std::sort(a->mutable_list()->mutable_type()->begin(), + a->mutable_list()->mutable_type()->end()); + } +} + +class OpDefBuilderTest : public ::testing::Test { + protected: + OpDefBuilder b() { return OpDefBuilder("Test"); } + + void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto, + OpShapeInferenceFn* shape_fn_out = nullptr) { + OpRegistrationData op_reg_data; + Status status = builder.Finalize(&op_reg_data); + TF_EXPECT_OK(status); + OpDef& op_def = op_reg_data.op_def; + if (status.ok()) { + OpDef expected; + protobuf::TextFormat::ParseFromString( + strings::StrCat("name: 'Test' ", proto), &expected); + // Allow different orderings + CanonicalizeAttrTypeListOrder(&op_def); + CanonicalizeAttrTypeListOrder(&expected); + EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString()); + + if (shape_fn_out) { + *shape_fn_out = op_reg_data.shape_inference_fn; + } + } + } + + void ExpectOrdered(const OpDefBuilder& builder, StringPiece proto) { + OpRegistrationData op_reg_data; + Status status = builder.Finalize(&op_reg_data); + TF_EXPECT_OK(status); + OpDef& op_def = op_reg_data.op_def; + if (status.ok()) { + OpDef expected; + protobuf::TextFormat::ParseFromString( + strings::StrCat("name: 'Test' ", proto), &expected); + EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString()); + } + } + + void ExpectFailure(const OpDefBuilder& builder, const string& error) { + OpRegistrationData op_reg_data; + Status status = builder.Finalize(&op_reg_data); + EXPECT_FALSE(status.ok()); + if (!status.ok()) { + EXPECT_EQ(status.error_message(), error); + } + } +}; + +TEST_F(OpDefBuilderTest, Attr) { + ExpectSuccess(b().Attr("a:string"), "attr: { name: 'a' type: 'string' }"); + ExpectSuccess(b().Attr("A: int"), "attr: { name: 'A' type: 'int' }"); + ExpectSuccess(b().Attr("a1 :float"), "attr: { name: 'a1' type: 'float' }"); + ExpectSuccess(b().Attr("a_a : bool"), "attr: { name: 'a_a' type: 'bool' }"); + ExpectSuccess(b().Attr("aB : type"), "attr: { name: 'aB' type: 'type' }"); + ExpectSuccess(b().Attr("aB_3\t: shape"), + "attr: { name: 'aB_3' type: 'shape' }"); + ExpectSuccess(b().Attr("t: tensor"), "attr: { name: 't' type: 'tensor' }"); + ExpectSuccess(b().Attr("XYZ\t:\tlist(type)"), + "attr: { name: 'XYZ' type: 'list(type)' }"); + ExpectSuccess(b().Attr("f: func"), "attr { name: 'f' type: 'func'}"); +} + +TEST_F(OpDefBuilderTest, AttrFailure) { + ExpectFailure( + b().Attr("_:string"), + "Trouble parsing ':' from Attr(\"_:string\") for Op Test"); + ExpectFailure( + b().Attr("9:string"), + "Trouble parsing ':' from Attr(\"9:string\") for Op Test"); + ExpectFailure(b().Attr(":string"), + "Trouble parsing ':' from Attr(\":string\") for Op Test"); + ExpectFailure(b().Attr("string"), + "Trouble parsing ':' from Attr(\"string\") for Op Test"); + ExpectFailure(b().Attr("a:invalid"), + "Trouble parsing type string at 'invalid' from " + "Attr(\"a:invalid\") for Op Test"); + ExpectFailure( + b().Attr("b:"), + "Trouble parsing type string at '' from Attr(\"b:\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, AttrWithRestrictions) { + // Types with restrictions. + ExpectSuccess( + b().Attr("a:numbertype"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " + "DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, " + "DT_QINT32, DT_UINT32, DT_UINT64, DT_BFLOAT16] } } }"); + ExpectSuccess( + b().Attr("a:{numbertype, variant}"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " + "DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, " + "DT_QINT32, DT_UINT32, DT_UINT64, DT_BFLOAT16, DT_VARIANT] } } }"); + ExpectSuccess(b().Attr("a:realnumbertype"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, " + "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64, " + "DT_BFLOAT16] } } }"); + ExpectSuccess(b().Attr("a:{realnumbertype, variant , string, }"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, " + "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64, " + "DT_BFLOAT16, DT_VARIANT, DT_STRING] } } }"); + ExpectSuccess(b().Attr("a:quantizedtype"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }"); + ExpectSuccess(b().Attr("a:{quantizedtype ,string}"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16, " + "DT_STRING]} } }"); + ExpectSuccess(b().Attr("a:{string,int32}"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_STRING, DT_INT32] } } }"); + ExpectSuccess(b().Attr("a: { float , complex64 } "), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_FLOAT, DT_COMPLEX64] } } }"); + ExpectSuccess(b().Attr("a: {float, complex64,} "), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_FLOAT, DT_COMPLEX64] } }"); + ExpectSuccess(b().Attr(R"(a: { "X", "yz" })"), + "attr: { name: 'a' type: 'string' allowed_values { list { s: " + "['X', 'yz'] } } }"); + ExpectSuccess(b().Attr(R"(a: { "X", "yz", })"), + "attr: { name: 'a' type: 'string' allowed_values { list { s: " + "['X', 'yz'] } } }"); + ExpectSuccess( + b().Attr("i: int >= -5"), + "attr: { name: 'i' type: 'int' has_minimum: true minimum: -5 }"); + ExpectSuccess(b().Attr("i: int >= 9223372036854775807"), + ("attr: { name: 'i' type: 'int' has_minimum: true " + "minimum: 9223372036854775807 }")); + ExpectSuccess(b().Attr("i: int >= -9223372036854775808"), + ("attr: { name: 'i' type: 'int' has_minimum: true " + "minimum: -9223372036854775808 }")); +} + +TEST_F(OpDefBuilderTest, AttrRestrictionFailure) { + ExpectFailure( + b().Attr("a:{}"), + "Trouble parsing type string at '}' from Attr(\"a:{}\") for Op Test"); + ExpectFailure( + b().Attr("a:{,}"), + "Trouble parsing type string at ',}' from Attr(\"a:{,}\") for Op Test"); + ExpectFailure(b().Attr("a:{invalid}"), + "Unrecognized type string 'invalid' from Attr(\"a:{invalid}\") " + "for Op Test"); + ExpectFailure(b().Attr("a:{\"str\", float}"), + "Trouble parsing allowed string at 'float}' from " + "Attr(\"a:{\"str\", float}\") for Op Test"); + ExpectFailure(b().Attr("a:{ float, \"str\" }"), + "Trouble parsing type string at '\"str\" }' from Attr(\"a:{ " + "float, \"str\" }\") for Op Test"); + ExpectFailure(b().Attr("a:{float,,string}"), + "Trouble parsing type string at ',string}' from " + "Attr(\"a:{float,,string}\") for Op Test"); + ExpectFailure(b().Attr("a:{float,,}"), + "Trouble parsing type string at ',}' from " + "Attr(\"a:{float,,}\") for Op Test"); + ExpectFailure(b().Attr("i: int >= a"), + "Could not parse integer lower limit after '>=', " + "found ' a' instead from Attr(\"i: int >= a\") for Op Test"); + ExpectFailure(b().Attr("i: int >= -a"), + "Could not parse integer lower limit after '>=', found ' -a' " + "instead from Attr(\"i: int >= -a\") for Op Test"); + ExpectFailure(b().Attr("i: int >= 9223372036854775808"), + "Could not parse integer lower limit after '>=', found " + "' 9223372036854775808' instead from " + "Attr(\"i: int >= 9223372036854775808\") for Op Test"); + ExpectFailure(b().Attr("i: int >= -9223372036854775809"), + "Could not parse integer lower limit after '>=', found " + "' -9223372036854775809' instead from " + "Attr(\"i: int >= -9223372036854775809\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, AttrListOfRestricted) { + ExpectSuccess( + b().Attr("a:list(realnumbertype)"), + "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " + "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " + "DT_UINT16, DT_INT8, DT_HALF, DT_BFLOAT16, DT_UINT32, DT_UINT64" + "] } } }"); + ExpectSuccess( + b().Attr("a:list({realnumbertype, variant})"), + "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " + "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " + "DT_UINT16, DT_INT8, DT_HALF, DT_BFLOAT16, DT_UINT32, DT_UINT64, " + "DT_VARIANT] } } }"); + ExpectSuccess( + b().Attr("a:list(quantizedtype)"), + "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " + "[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16] } } }"); + ExpectSuccess( + b().Attr("a: list({float, string, bool})"), + "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " + "[DT_FLOAT, DT_STRING, DT_BOOL] } } }"); + ExpectSuccess( + b().Attr(R"(a: list({ "one fish", "two fish" }))"), + "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: " + "['one fish', 'two fish'] } } }"); + ExpectSuccess( + b().Attr(R"(a: list({ 'red fish', 'blue fish' }))"), + "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: " + "['red fish', 'blue fish'] } } }"); + ExpectSuccess( + b().Attr(R"(a: list({ "single' ", 'double"' }))"), + "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: " + "[\"single' \", 'double\"'] } } }"); + ExpectSuccess( + b().Attr(R"(a: list({ 'escape\'\n', "from\\\"NY" }))"), + "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: " + "[\"escape'\\n\", 'from\\\\\"NY'] } } }"); +} + +TEST_F(OpDefBuilderTest, AttrListWithMinLength) { + ExpectSuccess( + b().Attr("i: list(bool) >= 4"), + "attr: { name: 'i' type: 'list(bool)' has_minimum: true minimum: 4 }"); +} + +TEST_F(OpDefBuilderTest, AttrWithDefaults) { + ExpectSuccess(b().Attr(R"(a:string="foo")"), + "attr: { name: 'a' type: 'string' default_value { s:'foo' } }"); + ExpectSuccess(b().Attr(R"(a:string='foo')"), + "attr: { name: 'a' type: 'string' default_value { s:'foo' } }"); + ExpectSuccess(b().Attr("a:float = 1.25"), + "attr: { name: 'a' type: 'float' default_value { f: 1.25 } }"); + ExpectSuccess(b().Attr("a:tensor = { dtype: DT_INT32 int_val: 5 }"), + "attr: { name: 'a' type: 'tensor' default_value { tensor {" + " dtype: DT_INT32 int_val: 5 } } }"); + ExpectSuccess(b().Attr("a:shape = { dim { size: 3 } dim { size: 4 } }"), + "attr: { name: 'a' type: 'shape' default_value { shape {" + " dim { size: 3 } dim { size: 4 } } } }"); + ExpectSuccess(b().Attr("a:shape = { dim { size: -1 } dim { size: 4 } }"), + "attr: { name: 'a' type: 'shape' default_value { shape {" + " dim { size: -1 } dim { size: 4 } } } }"); +} + +TEST_F(OpDefBuilderTest, AttrFailedDefaults) { + ExpectFailure(b().Attr(R"(a:int="foo")"), + "Could not parse default value '\"foo\"' from " + "Attr(\"a:int=\"foo\"\") for Op Test"); + ExpectFailure(b().Attr("a:float = [1.25]"), + "Could not parse default value '[1.25]' from Attr(\"a:float = " + "[1.25]\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, AttrListWithDefaults) { + ExpectSuccess(b().Attr(R"(a:list(string)=["foo", "bar"])"), + "attr: { name: 'a' type: 'list(string)' " + "default_value { list { s: ['foo', 'bar'] } } }"); + ExpectSuccess(b().Attr("a:list(bool)=[true, false, true]"), + "attr: { name: 'a' type: 'list(bool)' " + "default_value { list { b: [true, false, true] } } }"); + ExpectSuccess(b().Attr(R"(a:list(int)=[0, -1, 2, -4, 8])"), + "attr: { name: 'a' type: 'list(int)' " + "default_value { list { i: [0, -1, 2, -4, 8] } } }"); + ExpectSuccess(b().Attr(R"(a:list(int)=[ ])"), + "attr: { name: 'a' type: 'list(int)' " + "default_value { list { i: [] } } }"); +} + +TEST_F(OpDefBuilderTest, AttrFailedListDefaults) { + ExpectFailure(b().Attr(R"(a:list(int)=["foo"])"), + "Could not parse default value '[\"foo\"]' from " + "Attr(\"a:list(int)=[\"foo\"]\") for Op Test"); + ExpectFailure(b().Attr(R"(a:list(int)=[7, "foo"])"), + "Could not parse default value '[7, \"foo\"]' from " + "Attr(\"a:list(int)=[7, \"foo\"]\") for Op Test"); + ExpectFailure(b().Attr("a:list(float) = [[1.25]]"), + "Could not parse default value '[[1.25]]' from " + "Attr(\"a:list(float) = [[1.25]]\") for Op Test"); + ExpectFailure(b().Attr("a:list(float) = 1.25"), + "Could not parse default value '1.25' from " + "Attr(\"a:list(float) = 1.25\") for Op Test"); + ExpectFailure(b().Attr(R"(a:list(string)='foo')"), + "Could not parse default value ''foo'' from " + "Attr(\"a:list(string)='foo'\") for Op Test"); + ExpectFailure(b().Attr("a:list(float) = ["), + "Could not parse default value '[' from " + "Attr(\"a:list(float) = [\") for Op Test"); + ExpectFailure(b().Attr("a:list(float) = "), + "Could not parse default value '' from " + "Attr(\"a:list(float) = \") for Op Test"); +} + +TEST_F(OpDefBuilderTest, InputOutput) { + ExpectSuccess(b().Input("a: int32"), + "input_arg: { name: 'a' type: DT_INT32 }"); + ExpectSuccess(b().Output("b: string"), + "output_arg: { name: 'b' type: DT_STRING }"); + ExpectSuccess(b().Input("c: float "), + "input_arg: { name: 'c' type: DT_FLOAT }"); + ExpectSuccess(b().Output("d: Ref ( bool ) "), + "output_arg: { name: 'd' type: DT_BOOL is_ref: true }"); + ExpectOrdered(b().Input("a: bool") + .Output("c: complex64") + .Input("b: int64") + .Output("d: string"), + "input_arg: { name: 'a' type: DT_BOOL } " + "input_arg: { name: 'b' type: DT_INT64 } " + "output_arg: { name: 'c' type: DT_COMPLEX64 } " + "output_arg: { name: 'd' type: DT_STRING }"); +} + +TEST_F(OpDefBuilderTest, PolymorphicInputOutput) { + ExpectSuccess(b().Input("a: foo").Attr("foo: type"), + "input_arg: { name: 'a' type_attr: 'foo' } " + "attr: { name: 'foo' type: 'type' }"); + ExpectSuccess(b().Output("a: foo").Attr("foo: { bool, int32 }"), + "output_arg: { name: 'a' type_attr: 'foo' } " + "attr: { name: 'foo' type: 'type' " + "allowed_values: { list { type: [DT_BOOL, DT_INT32] } } }"); +} + +TEST_F(OpDefBuilderTest, InputOutputListSameType) { + ExpectSuccess(b().Input("a: n * int32").Attr("n: int"), + "input_arg: { name: 'a' number_attr: 'n' type: DT_INT32 } " + "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 }"); + // Polymorphic case: + ExpectSuccess(b().Output("b: n * foo").Attr("n: int").Attr("foo: type"), + "output_arg: { name: 'b' number_attr: 'n' type_attr: 'foo' } " + "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 } " + "attr: { name: 'foo' type: 'type' }"); +} + +TEST_F(OpDefBuilderTest, InputOutputListAnyType) { + ExpectSuccess( + b().Input("c: foo").Attr("foo: list(type)"), + "input_arg: { name: 'c' type_list_attr: 'foo' } " + "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 }"); + ExpectSuccess( + b().Output("c: foo").Attr("foo: list({string, float})"), + "output_arg: { name: 'c' type_list_attr: 'foo' } " + "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 " + "allowed_values: { list { type: [DT_STRING, DT_FLOAT] } } }"); +} + +TEST_F(OpDefBuilderTest, InputOutputFailure) { + ExpectFailure(b().Input("9: int32"), + "Trouble parsing 'name:' from Input(\"9: int32\") for Op Test"); + ExpectFailure( + b().Output("_: int32"), + "Trouble parsing 'name:' from Output(\"_: int32\") for Op Test"); + ExpectFailure(b().Input(": int32"), + "Trouble parsing 'name:' from Input(\": int32\") for Op Test"); + ExpectFailure(b().Output("int32"), + "Trouble parsing 'name:' from Output(\"int32\") for Op Test"); + ExpectFailure( + b().Input("CAPS: int32"), + "Trouble parsing 'name:' from Input(\"CAPS: int32\") for Op Test"); + ExpectFailure( + b().Input("_underscore: int32"), + "Trouble parsing 'name:' from Input(\"_underscore: int32\") for Op Test"); + ExpectFailure( + b().Input("0digit: int32"), + "Trouble parsing 'name:' from Input(\"0digit: int32\") for Op Test"); + ExpectFailure(b().Input("a: _"), + "Trouble parsing either a type or an attr name at '_' from " + "Input(\"a: _\") for Op Test"); + ExpectFailure(b().Input("a: 9"), + "Trouble parsing either a type or an attr name at '9' from " + "Input(\"a: 9\") for Op Test"); + ExpectFailure(b().Input("a: 9 * int32"), + "Trouble parsing either a type or an attr name at '9 * int32' " + "from Input(\"a: 9 * int32\") for Op Test"); + ExpectFailure( + b().Input("a: x * _").Attr("x: type"), + "Extra '* _' unparsed at the end from Input(\"a: x * _\") for Op Test"); + ExpectFailure(b().Input("a: x * y extra").Attr("x: int").Attr("y: type"), + "Extra 'extra' unparsed at the end from Input(\"a: x * y " + "extra\") for Op Test"); + ExpectFailure(b().Input("a: Ref(int32"), + "Did not find closing ')' for 'Ref(', instead found: '' from " + "Input(\"a: Ref(int32\") for Op Test"); + ExpectFailure( + b().Input("a: Ref"), + "Reference to unknown attr 'Ref' from Input(\"a: Ref\") for Op Test"); + ExpectFailure(b().Input("a: Ref(x y").Attr("x: type"), + "Did not find closing ')' for 'Ref(', instead found: 'y' from " + "Input(\"a: Ref(x y\") for Op Test"); + ExpectFailure( + b().Input("a: x"), + "Reference to unknown attr 'x' from Input(\"a: x\") for Op Test"); + ExpectFailure( + b().Input("a: x * y").Attr("x: int"), + "Reference to unknown attr 'y' from Input(\"a: x * y\") for Op Test"); + ExpectFailure(b().Input("a: x").Attr("x: int"), + "Reference to attr 'x' with type int that isn't type or " + "list(type) from Input(\"a: x\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, Set) { + ExpectSuccess(b().SetIsStateful(), "is_stateful: true"); + ExpectSuccess(b().SetIsCommutative().SetIsAggregate(), + "is_commutative: true is_aggregate: true"); +} + +TEST_F(OpDefBuilderTest, DocUnpackSparseFeatures) { + ExpectOrdered(b().Input("sf: string") + .Output("indices: int32") + .Output("ids: int64") + .Output("weights: float") + .Doc(R"doc( +Converts a vector of strings with dist_belief::SparseFeatures to tensors. + +Note that indices, ids and weights are vectors of the same size and have +one-to-one correspondence between their elements. ids and weights are each +obtained by sequentially concatenating sf[i].id and sf[i].weight, for i in +1...size(sf). Note that if sf[i].weight is not set, the default value for the +weight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were +extracted from sf[i], then index[j] is set to i. + +sf: vector of string, where each element is the string encoding of + SparseFeatures proto. +indices: vector of indices inside sf +ids: vector of id extracted from the SparseFeatures proto. +weights: vector of weight extracted from the SparseFeatures proto. +)doc"), + R"proto( +input_arg { + name: "sf" + description: "vector of string, where each element is the string encoding of\nSparseFeatures proto." + type: DT_STRING +} +output_arg { + name: "indices" + description: "vector of indices inside sf" + type: DT_INT32 +} +output_arg { + name: "ids" + description: "vector of id extracted from the SparseFeatures proto." + type: DT_INT64 +} +output_arg { + name: "weights" + description: "vector of weight extracted from the SparseFeatures proto." + type: DT_FLOAT +} +summary: "Converts a vector of strings with dist_belief::SparseFeatures to tensors." +description: "Note that indices, ids and weights are vectors of the same size and have\none-to-one correspondence between their elements. ids and weights are each\nobtained by sequentially concatenating sf[i].id and sf[i].weight, for i in\n1...size(sf). Note that if sf[i].weight is not set, the default value for the\nweight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were\nextracted from sf[i], then index[j] is set to i." +)proto"); +} + +TEST_F(OpDefBuilderTest, DocConcat) { + ExpectOrdered(b().Input("concat_dim: int32") + .Input("values: num_values * dtype") + .Output("output: dtype") + .Attr("dtype: type") + .Attr("num_values: int >= 2") + .Doc(R"doc( +Concatenate N Tensors along one dimension. + +concat_dim: The (scalar) dimension along which to concatenate. Must be + in the range [0, rank(values...)). +values: The N Tensors to concatenate. Their ranks and types must match, + and their sizes must match in all dimensions except concat_dim. +output: A Tensor with the concatenation of values stacked along the + concat_dim dimension. This Tensor's shape matches the Tensors in + values, except in concat_dim where it has the sum of the sizes. +)doc"), + R"proto( +input_arg { + name: "concat_dim" + description: "The (scalar) dimension along which to concatenate. Must be\nin the range [0, rank(values...))." + type: DT_INT32 +} +input_arg { + name: "values" + description: "The N Tensors to concatenate. Their ranks and types must match,\nand their sizes must match in all dimensions except concat_dim." + type_attr: "dtype" + number_attr: "num_values" +} +output_arg { + name: "output" + description: "A Tensor with the concatenation of values stacked along the\nconcat_dim dimension. This Tensor\'s shape matches the Tensors in\nvalues, except in concat_dim where it has the sum of the sizes." + type_attr: "dtype" +} +summary: "Concatenate N Tensors along one dimension." +attr { + name: "dtype" + type: "type" +} +attr { + name: "num_values" + type: "int" + has_minimum: true + minimum: 2 +} +)proto"); +} + +TEST_F(OpDefBuilderTest, DocAttr) { + ExpectOrdered(b().Attr("i: int").Doc(R"doc( +Summary + +i: How much to operate. +)doc"), + R"proto( +summary: "Summary" +attr { + name: "i" + type: "int" + description: "How much to operate." +} +)proto"); +} + +TEST_F(OpDefBuilderTest, DocCalledTwiceFailure) { + ExpectFailure(b().Doc("What's").Doc("up, doc?"), + "Extra call to Doc() for Op Test"); +} + +TEST_F(OpDefBuilderTest, DocFailureMissingName) { + ExpectFailure( + b().Input("a: int32").Doc(R"doc( +Summary + +a: Something for a. +b: b is not defined. +)doc"), + "No matching input/output/attr for name 'b' from Doc() for Op Test"); + + ExpectFailure( + b().Input("a: int32").Doc(R"doc( +Summary + +b: b is not defined and by itself. +)doc"), + "No matching input/output/attr for name 'b' from Doc() for Op Test"); +} + +TEST_F(OpDefBuilderTest, DefaultMinimum) { + ExpectSuccess(b().Input("values: num_values * dtype") + .Output("output: anything") + .Attr("anything: list(type)") + .Attr("dtype: type") + .Attr("num_values: int"), + R"proto( +input_arg { + name: "values" + type_attr: "dtype" + number_attr: "num_values" +} +output_arg { + name: "output" + type_list_attr: "anything" +} +attr { + name: "anything" + type: "list(type)" + has_minimum: true + minimum: 1 +} +attr { + name: "dtype" + type: "type" +} +attr { + name: "num_values" + type: "int" + has_minimum: true + minimum: 1 +} +)proto"); +} + +TEST_F(OpDefBuilderTest, SetShapeFn) { + auto fn = [](shape_inference::InferenceContext* c) { + return errors::Unknown("ShapeFn was called"); + }; + OpShapeInferenceFn fn_out; + ExpectSuccess( + b().SetShapeFn(fn).Attr("dtype: type"), + "attr { name: \"dtype\" type: \"type\" allowed_values { list { } } }", + &fn_out); + ASSERT_TRUE(fn_out != nullptr); + EXPECT_EQ("ShapeFn was called", fn_out(nullptr).error_message()); +} + +TEST_F(OpDefBuilderTest, SetShapeFnCalledTwiceFailure) { + auto fn = [](shape_inference::InferenceContext* c) { + return errors::Unknown("ShapeFn was called"); + }; + ExpectFailure(b().SetShapeFn(fn).SetShapeFn(fn), + "SetShapeFn called twice for Op Test"); +} + +TEST_F(OpDefBuilderTest, ResourceIsStateful) { + OpRegistrationData op_reg_data; + TF_EXPECT_OK(b().Input("a: resource").Finalize(&op_reg_data)); + const OpDef& op_def = op_reg_data.op_def; + EXPECT_TRUE(op_def.is_stateful()); +} + +} // namespace +} // namespace tensorflow diff --git a/op_def_util.cc b/op_def_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..29feda499fd2646a00c1f5bc9fc7223e9f134af9 --- /dev/null +++ b/op_def_util.cc @@ -0,0 +1,814 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_def_util.h" + +#include +#include +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_def.pb_text.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { // ------ Helper functions ------ + +bool HasAttrStyleType(const OpDef::ArgDef& arg) { + return arg.type() != DT_INVALID || !arg.type_attr().empty() || + !arg.type_list_attr().empty(); +} + +Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { + const AttrValue& allowed_values(attr.allowed_values()); + for (auto allowed : allowed_values.list().type()) { + if (dt == allowed) { + return Status::OK(); + } + } + string allowed_str; + for (int i = 0; i < allowed_values.list().type_size(); ++i) { + if (!allowed_str.empty()) { + strings::StrAppend(&allowed_str, ", "); + } + strings::StrAppend(&allowed_str, + DataTypeString(allowed_values.list().type(i))); + } + return errors::InvalidArgument( + "Value for attr '", attr.name(), "' of ", DataTypeString(dt), + " is not in the list of allowed values: ", allowed_str); +} + +Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { + const AttrValue& allowed_values(attr.allowed_values()); + for (const auto& allowed : allowed_values.list().s()) { + if (str == allowed) { + return Status::OK(); + } + } + string allowed_str; + for (const string& allowed : allowed_values.list().s()) { + if (!allowed_str.empty()) { + strings::StrAppend(&allowed_str, ", "); + } + strings::StrAppend(&allowed_str, "\"", allowed, "\""); + } + return errors::InvalidArgument( + "Value for attr '", attr.name(), "' of \"", str, + "\" is not in the list of allowed values: ", allowed_str); +} + +} // namespace + +// Requires: attr has already been validated. +Status ValidateAttrValue(const AttrValue& attr_value, + const OpDef::AttrDef& attr) { + // Is it a valid value? + TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()), + " for attr '", attr.name(), "'"); + + // Does the value satisfy the minimum constraint in the AttrDef? + if (attr.has_minimum()) { + if (attr.type() == "int") { + if (attr_value.i() < attr.minimum()) { + return errors::InvalidArgument( + "Value for attr '", attr.name(), "' of ", attr_value.i(), + " must be at least minimum ", attr.minimum()); + } + } else { + int length = -1; + if (attr.type() == "list(string)") { + length = attr_value.list().s_size(); + } else if (attr.type() == "list(int)") { + length = attr_value.list().i_size(); + } else if (attr.type() == "list(float)") { + length = attr_value.list().f_size(); + } else if (attr.type() == "list(bool)") { + length = attr_value.list().b_size(); + } else if (attr.type() == "list(type)") { + length = attr_value.list().type_size(); + } else if (attr.type() == "list(shape)") { + length = attr_value.list().shape_size(); + } else if (attr.type() == "list(tensor)") { + length = attr_value.list().tensor_size(); + } + if (length < attr.minimum()) { + return errors::InvalidArgument( + "Length for attr '", attr.name(), "' of ", length, + " must be at least minimum ", attr.minimum()); + } + } + } + + // Does the value satisfy the allowed_value constraint in the AttrDef? + if (attr.has_allowed_values()) { + if (attr.type() == "type") { + TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr)); + } else if (attr.type() == "list(type)") { + for (int dt : attr_value.list().type()) { + TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast(dt), attr)); + } + } else if (attr.type() == "string") { + TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr)); + } else if (attr.type() == "list(string)") { + for (const string& str : attr_value.list().s()) { + TF_RETURN_IF_ERROR(AllowedStringValue(str, attr)); + } + } else { + return errors::Unimplemented( + "Support for allowed_values not implemented for type ", attr.type()); + } + } + return Status::OK(); +} + +const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) { + for (int i = 0; i < op_def.attr_size(); ++i) { + if (op_def.attr(i).name() == name) { + return &op_def.attr(i); + } + } + return nullptr; +} + +OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) { + for (int i = 0; i < op_def->attr_size(); ++i) { + if (op_def->attr(i).name() == name) { + return op_def->mutable_attr(i); + } + } + return nullptr; +} + +const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) { + for (int i = 0; i < op_def.input_arg_size(); ++i) { + if (op_def.input_arg(i).name() == name) { + return &op_def.input_arg(i); + } + } + return nullptr; +} + +#define VALIDATE(EXPR, ...) \ + do { \ + if (!(EXPR)) { \ + return errors::InvalidArgument(__VA_ARGS__, "; in OpDef: ", \ + ProtoShortDebugString(op_def)); \ + } \ + } while (false) + +static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, + bool output, std::set* names) { + const string suffix = strings::StrCat( + output ? " for output '" : " for input '", arg.name(), "'"); + VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), "Duplicate name: ", + arg.name()); + VALIDATE(HasAttrStyleType(arg), "Missing type", suffix); + + if (!arg.number_attr().empty()) { + const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def); + VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'", + suffix); + VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length", + suffix, " has type ", attr->type(), " != int"); + VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length", + suffix, " must have minimum"); + VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length", + suffix, " must have minimum >= 0"); + VALIDATE(arg.type_list_attr().empty(), + "Can't have both number_attr and type_list_attr", suffix); + VALIDATE((arg.type() != DT_INVALID ? 1 : 0) + + (!arg.type_attr().empty() ? 1 : 0) == + 1, + "Exactly one of type, type_attr must be set", suffix); + } else { + const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) + + (!arg.type_attr().empty() ? 1 : 0) + + (!arg.type_list_attr().empty() ? 1 : 0); + VALIDATE(num_type_fields == 1, + "Exactly one of type, type_attr, type_list_attr must be set", + suffix); + } + + if (!arg.type_attr().empty()) { + const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def); + VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'", + suffix); + VALIDATE(attr->type() == "type", "Attr '", attr->name(), + "' used as type_attr", suffix, " has type ", attr->type(), + " != type"); + } else if (!arg.type_list_attr().empty()) { + const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def); + VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'", + suffix); + VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(), + "' used as type_list_attr", suffix, " has type ", attr->type(), + " != list(type)"); + } else { + // All argument types should be non-reference types at this point. + // ArgDef.is_ref is set to true for reference arguments. + VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '", + DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix); + } + + return Status::OK(); +} + +Status ValidateOpDef(const OpDef& op_def) { + using ::tensorflow::strings::Scanner; + + if (!StringPiece(op_def.name()).starts_with("_")) { + VALIDATE(Scanner(op_def.name()) + .One(Scanner::UPPERLETTER) + .Any(Scanner::LETTER_DIGIT) + .Eos() + .GetResult(), + "Invalid name: ", op_def.name(), " (Did you use CamelCase?)"); + } + + std::set names; // for detecting duplicate names + for (const auto& attr : op_def.attr()) { + // Validate name + VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), "Duplicate name: ", + attr.name()); + DataType dt; + VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ", + attr.name(), " that matches a data type"); + + // Validate type + StringPiece type(attr.type()); + bool is_list = type.Consume("list("); + bool found = false; + for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape", + "tensor", "func"}) { + if (type.Consume(valid)) { + found = true; + break; + } + } + VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(), + "'"); + if (is_list) { + VALIDATE(type.Consume(")"), "'list(' is missing ')' in attr ", + attr.name(), "'s type ", attr.type()); + } + VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ", + attr.name(), "'s type ", attr.type()); + + // Validate minimum + if (attr.has_minimum()) { + VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(), + "' has minimum for unsupported type ", attr.type()); + if (is_list) { + VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(), + "' with list type must have a non-negative minimum, not ", + attr.minimum()); + } + } else { + VALIDATE(attr.minimum() == 0, "Attr '", attr.name(), + "' with has_minimum = false but minimum ", attr.minimum(), + " not equal to default of 0"); + } + + // Validate allowed_values + if (attr.has_allowed_values()) { + const string list_type = + is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")"); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + AttrValueHasType(attr.allowed_values(), list_type), " for attr '", + attr.name(), "' in Op '", op_def.name(), "'"); + } + + // Validate default_value (after we have validated the rest of the attr, + // so we can use ValidateAttrValue()). + if (attr.has_default_value()) { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ValidateAttrValue(attr.default_value(), attr), " in Op '", + op_def.name(), "'"); + } + } + + for (const auto& arg : op_def.input_arg()) { + TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names)); + } + + for (const auto& arg : op_def.output_arg()) { + TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names)); + } + + return Status::OK(); +} + +#undef VALIDATE + +Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) { + if (op_def.has_deprecation()) { + const OpDeprecation& dep = op_def.deprecation(); + if (graph_def_version >= dep.version()) { + return errors::Unimplemented( + "Op ", op_def.name(), " is not available in GraphDef version ", + graph_def_version, ". It has been removed in version ", dep.version(), + ". ", dep.explanation(), "."); + } else { + // Warn only once for each op name, and do it in a threadsafe manner. + static mutex mu(LINKER_INITIALIZED); + static std::unordered_set warned; + bool warn; + { + mutex_lock lock(mu); + warn = warned.insert(op_def.name()).second; + } + if (warn) { + LOG(WARNING) << "Op " << op_def.name() << " is deprecated." + << " It will cease to work in GraphDef version " + << dep.version() << ". " << dep.explanation() << "."; + } + } + } + return Status::OK(); +} + +namespace { + +string SummarizeArgs(const protobuf::RepeatedPtrField& args) { + string ret; + for (const OpDef::ArgDef& arg : args) { + if (!ret.empty()) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, arg.name(), ":"); + if (arg.is_ref()) strings::StrAppend(&ret, "Ref("); + if (!arg.number_attr().empty()) { + strings::StrAppend(&ret, arg.number_attr(), "*"); + } + if (arg.type() != DT_INVALID) { + strings::StrAppend(&ret, DataTypeString(arg.type())); + } else { + strings::StrAppend(&ret, arg.type_attr()); + } + if (arg.is_ref()) strings::StrAppend(&ret, ")"); + } + return ret; +} + +} // namespace + +string SummarizeOpDef(const OpDef& op_def) { + string ret = strings::StrCat("Op ", SummarizeArgs(op_def.output_arg())); + for (int i = 0; i < op_def.attr_size(); ++i) { + strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":", + op_def.attr(i).type()); + if (op_def.attr(i).has_default_value()) { + strings::StrAppend(&ret, ",default=", + SummarizeAttrValue(op_def.attr(i).default_value())); + } + if (op_def.attr(i).has_minimum()) { + strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum()); + } + if (op_def.attr(i).has_allowed_values()) { + strings::StrAppend(&ret, ",allowed=", + SummarizeAttrValue(op_def.attr(i).allowed_values())); + } + } + if (op_def.is_commutative()) { + strings::StrAppend(&ret, "; is_commutative=true"); + } + if (op_def.is_aggregate()) { + strings::StrAppend(&ret, "; is_aggregate=true"); + } + if (op_def.is_stateful()) { + strings::StrAppend(&ret, "; is_stateful=true"); + } + if (op_def.allows_uninitialized_input()) { + strings::StrAppend(&ret, "; allows_uninitialized_input=true"); + } + strings::StrAppend(&ret, ">"); + return ret; +} + +namespace { + +// Returns true if every element of `sub` is contained in `super`. +template +bool IsSubsetOf(const T& sub, const T& super) { + for (const auto& o : sub) { + bool found = false; + for (const auto& n : super) { + if (o == n) { + found = true; + break; + } + } + if (!found) return false; + } + return true; +} + +bool MoreRestrictive(const OpDef::AttrDef& old_attr, + const OpDef::AttrDef& new_attr) { + // Anything -> no restriction : not more restrictive. + if (!new_attr.has_allowed_values()) return false; + // No restriction -> restriction : more restrictive. + if (!old_attr.has_allowed_values()) return true; + // If anything that was previously allowed is no longer allowed: + // more restrictive. + if (!IsSubsetOf(old_attr.allowed_values().list().type(), + new_attr.allowed_values().list().type())) { + return true; + } + if (!IsSubsetOf(old_attr.allowed_values().list().s(), + new_attr.allowed_values().list().s())) { + return true; + } + return false; +} + +string AllowedStr(const OpDef::AttrDef& attr) { + if (!attr.has_allowed_values()) return "no restriction"; + return SummarizeAttrValue(attr.allowed_values()); +} + +bool HigherMinimum(const OpDef::AttrDef& old_attr, + const OpDef::AttrDef& new_attr) { + // Anything -> no restriction : not more restrictive. + if (!new_attr.has_minimum()) return false; + // No restriction -> restriction : more restrictive. + if (!old_attr.has_minimum()) return true; + // If anything that was previously allowed is no longer allowed: + // more restrictive. + return new_attr.minimum() > old_attr.minimum(); +} + +string MinStr(const OpDef::AttrDef& attr) { + if (!attr.has_minimum()) return "no minimum"; + return strings::StrCat(attr.minimum()); +} + +typedef std::unordered_map AttrMap; +void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) { + for (const auto& attr : op_def.attr()) { + (*attr_map)[attr.name()] = &attr; + } +} + +// Add a comma to *s every call but the first (*add_comma should be +// initialized to false). +void AddComma(string* s, bool* add_comma) { + if (*add_comma) { + strings::StrAppend(s, ", "); + } else { + *add_comma = true; + } +} + +// Will add the `name` from arg if name is true. +void AddName(string* s, bool name, const OpDef::ArgDef& arg) { + if (name) { + strings::StrAppend(s, arg.name(), ":"); + } +} + +// Compute a signature for either inputs or outputs that will be the +// same for both the old and new OpDef if they are compatible. We +// assume that new_attrs is a superset of old_attrs, and that any attr +// in the difference has a default. Our strategy is to make a list of +// types, where the types are things like: +// * "int32", "float", etc., +// * "T" for some attr "T" in old_attrs, or +// * "N * type" for "N" either some attr in old_attrs. +// +// We get the types by either using the attrs in args if they are in +// old_attrs, or substituting the default value from new_attrs. +string ComputeArgSignature( + const protobuf::RepeatedPtrField& args, + const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector* ref, + bool names) { + string s; + bool add_comma = false; + for (const OpDef::ArgDef& arg : args) { + if (!arg.type_list_attr().empty()) { + const OpDef::AttrDef* old_attr = + gtl::FindPtrOrNull(old_attrs, arg.type_list_attr()); + if (old_attr) { + // Both old and new have the list(type) attr, so can use it directly. + AddComma(&s, &add_comma); + AddName(&s, names, arg); + strings::StrAppend(&s, arg.type_list_attr()); + ref->push_back(arg.is_ref()); + } else { + // Missing the list(type) attr in the old, so use the default + // value for the attr from new instead. + const OpDef::AttrDef* new_attr = + gtl::FindPtrOrNull(new_attrs, arg.type_list_attr()); + const auto& type_list = new_attr->default_value().list().type(); + if (type_list.empty()) continue; + for (int i = 0; i < type_list.size(); ++i) { + AddComma(&s, &add_comma); + AddName(&s, names, arg); + strings::StrAppend( + &s, DataTypeString(static_cast(type_list.Get(i)))); + ref->push_back(arg.is_ref()); + } + } + } else { + int num = 1; // How many input/outputs does this represent? + string type; // What is the type of this arg? + AddName(&type, names, arg); + if (!arg.number_attr().empty()) { + // N * type case. + const OpDef::AttrDef* old_attr = + gtl::FindPtrOrNull(old_attrs, arg.number_attr()); + if (old_attr) { + // Both old and new have the number attr, so can use it directly. + strings::StrAppend(&type, arg.number_attr(), " * "); + } else { + // Missing the number attr in the old, so use the default + // value for the attr from new instead. + const OpDef::AttrDef* new_attr = + gtl::FindPtrOrNull(new_attrs, arg.number_attr()); + num = new_attr->default_value().i(); + } + } + + if (arg.type() != DT_INVALID) { + // int32, float, etc. case + strings::StrAppend(&type, DataTypeString(arg.type())); + } else { + const OpDef::AttrDef* old_attr = + gtl::FindPtrOrNull(old_attrs, arg.type_attr()); + if (old_attr) { + // Both old and new have the type attr, so can use it directly. + strings::StrAppend(&type, arg.type_attr()); + } else { + // Missing the type attr in the old, so use the default + // value for the attr from new instead. + const OpDef::AttrDef* new_attr = + gtl::FindPtrOrNull(new_attrs, arg.type_attr()); + strings::StrAppend(&type, + DataTypeString(new_attr->default_value().type())); + } + } + + // Record `num` * `type` in the signature. + for (int i = 0; i < num; ++i) { + AddComma(&s, &add_comma); + strings::StrAppend(&s, type); + ref->push_back(arg.is_ref()); + } + } + } + + return s; +} + +} // namespace + +Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { +#define VALIDATE(CONDITION, ...) \ + if (!(CONDITION)) { \ + return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \ + "; old: ", SummarizeOpDef(old_op), \ + "; new: ", SummarizeOpDef(new_op)); \ + } + + VALIDATE(old_op.name() == new_op.name(), "Name mismatch"); + + AttrMap new_attrs, old_attrs; + FillAttrMap(old_op, &old_attrs); + FillAttrMap(new_op, &new_attrs); + for (const auto& old_attr : old_op.attr()) { + const OpDef::AttrDef* new_attr = + gtl::FindPtrOrNull(new_attrs, old_attr.name()); + VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed"); + VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(), + "' changed type '", old_attr.type(), "' -> '", new_attr->type(), + "'"); + VALIDATE(!MoreRestrictive(old_attr, *new_attr), "Attr '", old_attr.name(), + "' has a stricter set of allowed values; from ", + AllowedStr(old_attr), " to ", AllowedStr(*new_attr)); + VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '", old_attr.name(), + "' has a higher minimum; from ", MinStr(old_attr), " to ", + MinStr(*new_attr)); + } + + for (const auto& new_attr : new_op.attr()) { + const OpDef::AttrDef* old_attr = + gtl::FindPtrOrNull(old_attrs, new_attr.name()); + VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '", + new_attr.name(), "' added without default"); + } + + std::vector old_in_ref, new_in_ref, old_out_ref, new_out_ref; + const string old_in_sig = ComputeArgSignature( + old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */); + const string new_in_sig = ComputeArgSignature( + new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */); + VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig, + "' vs. '", new_in_sig, "'"); + VALIDATE(old_in_ref.size() == new_in_ref.size(), // Should not happen + "Unexpected change in input ref lists."); + for (int i = 0; i < old_in_ref.size(); ++i) { + // Allowed to remove "ref" from an input (or leave it unchanged). + VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input ", i, + " changed from non-ref to ref"); + } + + const string old_out_sig = + ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs, + &old_out_ref, true /* names */); + const string new_out_sig = + ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs, + &new_out_ref, true /* names */); + VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '", + old_out_sig, "' vs. '", new_out_sig, "'"); + VALIDATE(old_out_ref.size() == new_out_ref.size(), // Should not happen + "Unexpected change in output ref lists"); + for (int i = 0; i < old_out_ref.size(); ++i) { + // Allowed to add "ref" to an output (or leave it unchanged). + VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output ", i, + " changed from ref to non-ref"); + } + + return Status::OK(); +} + +Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, + const OpDef& penultimate_op, + const OpDef& new_op) { + AttrMap new_attrs, old_attrs; + FillAttrMap(old_op, &old_attrs); + FillAttrMap(new_op, &new_attrs); + + for (const auto& penultimate_attr : penultimate_op.attr()) { + const OpDef::AttrDef* old_attr = + gtl::FindPtrOrNull(old_attrs, penultimate_attr.name()); + if (old_attr != nullptr) continue; // attr wasn't added + const OpDef::AttrDef* new_attr = + gtl::FindPtrOrNull(new_attrs, penultimate_attr.name()); + + // These shouldn't happen if the op passed OpDefCompatible(). + if (new_attr == nullptr) { + return errors::InvalidArgument("Missing attr '", penultimate_attr.name(), + "' in op: ", SummarizeOpDef(new_op)); + } + if (!penultimate_attr.has_default_value() || + !new_attr->has_default_value()) { + return errors::InvalidArgument("Missing default for attr '", + penultimate_attr.name(), "' in op: ", + SummarizeOpDef(new_op)); + } + + // Actually test that the attr's default value hasn't changed. + if (!AreAttrValuesEqual(penultimate_attr.default_value(), + new_attr->default_value())) { + return errors::InvalidArgument( + "Can't change default value for attr '", penultimate_attr.name(), + "' from ", SummarizeAttrValue(penultimate_attr.default_value()), + " in op: ", SummarizeOpDef(new_op)); + } + } + + return Status::OK(); +} + +void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) { + for (int i = 0; i < op_def->input_arg_size(); ++i) { + op_def->mutable_input_arg(i)->clear_description(); + } + for (int i = 0; i < op_def->output_arg_size(); ++i) { + op_def->mutable_output_arg(i)->clear_description(); + } + for (int i = 0; i < op_def->attr_size(); ++i) { + op_def->mutable_attr(i)->clear_description(); + } + op_def->clear_summary(); + op_def->clear_description(); +} + +void RemoveDescriptionsFromOpDef(OpDef* op_def) { + RemoveNonDeprecationDescriptionsFromOpDef(op_def); + if (op_def->has_deprecation()) { + op_def->mutable_deprecation()->clear_explanation(); + } +} + +void RemoveDescriptionsFromOpList(OpList* op_list) { + for (int i = 0; i < op_list->op_size(); ++i) { + OpDef* op_def = op_list->mutable_op(i); + RemoveDescriptionsFromOpDef(op_def); + } +} + +bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) { +#ifndef TENSORFLOW_LITE_PROTOS + DCHECK_EQ(7, a1.GetDescriptor()->field_count()) + << "Please modify these equality and hash functions to reflect the " + "changes to the AttrDef protobuf"; +#endif // TENSORFLOW_LITE_PROTOS + + if (a1.name() != a2.name()) return false; + if (a1.type() != a2.type()) return false; + if (a1.description() != a2.description()) return false; + if (a1.has_minimum() != a2.has_minimum()) return false; + if (a1.has_minimum() && a1.minimum() != a2.minimum()) return false; + if (!AreAttrValuesEqual(a1.default_value(), a2.default_value())) return false; + if (!AreAttrValuesEqual(a1.allowed_values(), a2.allowed_values())) + return false; + return true; +} + +uint64 AttrDefHash(const OpDef::AttrDef& a) { + uint64 h = Hash64(a.name()); + h = Hash64(a.type().data(), a.type().size(), h); + h = Hash64Combine(AttrValueHash(a.default_value()), h); + h = Hash64(a.description().data(), a.description().size(), h); + h = Hash64Combine(static_cast(a.has_minimum()), h); + h = Hash64Combine(static_cast(a.minimum()), h); + h = Hash64Combine(AttrValueHash(a.allowed_values()), h); + return h; +} + +bool RepeatedAttrDefEqual( + const protobuf::RepeatedPtrField& a1, + const protobuf::RepeatedPtrField& a2) { + std::unordered_map a1_set; + for (const OpDef::AttrDef& def : a1) { + DCHECK(a1_set.find(def.name()) == a1_set.end()) + << "AttrDef names must be unique, but '" << def.name() + << "' appears more than once"; + a1_set[def.name()] = &def; + } + for (const OpDef::AttrDef& def : a2) { + auto iter = a1_set.find(def.name()); + if (iter == a1_set.end()) return false; + if (!AttrDefEqual(*iter->second, def)) return false; + a1_set.erase(iter); + } + if (!a1_set.empty()) return false; + return true; +} + +uint64 RepeatedAttrDefHash( + const protobuf::RepeatedPtrField& a) { + // Insert AttrDefs into map to deterministically sort by name + std::map a_set; + for (const OpDef::AttrDef& def : a) { + a_set[def.name()] = &def; + } + // Iterate and combines hashes of keys and values + uint64 h = 0xDECAFCAFFE; + for (const auto& pair : a_set) { + h = Hash64(pair.first.data(), pair.first.size(), h); + h = Hash64Combine(AttrDefHash(*pair.second), h); + } + return h; +} + +bool OpDefEqual(const OpDef& o1, const OpDef& o2) { + // attr order doesn't matter. + // Compare it separately here instead of serializing below. + if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false; + + // Clear attr field, serialize, and compare serialized strings + OpDef o1_copy = o1; + OpDef o2_copy = o2; + o1_copy.clear_attr(); + o2_copy.clear_attr(); + string s1, s2; + SerializeToStringDeterministic(o1_copy, &s1); + SerializeToStringDeterministic(o2_copy, &s2); + if (s1 != s2) return false; + return true; +} + +uint64 OpDefHash(const OpDef& o) { + uint64 h = RepeatedAttrDefHash(o.attr()); + OpDef o_copy = o; + o_copy.clear_attr(); + string s; + SerializeToStringDeterministic(o_copy, &s); + return Hash64(s.data(), s.size(), h); +} + +} // namespace tensorflow diff --git a/op_def_util.h b/op_def_util.h new file mode 100644 index 0000000000000000000000000000000000000000..f9661dceddc1a3de694024dddb9afce1cae8680c --- /dev/null +++ b/op_def_util.h @@ -0,0 +1,97 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TODO(josh11b): Probably not needed for OpKernel authors, so doesn't +// need to be as publicly accessible as other files in framework/. + +#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_ + +#include +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// Performs a consistency check across the fields of the op_def. +Status ValidateOpDef(const OpDef& op_def); + +// Check if an op is deprecated at the given GraphDef version. If the op is +// deprecated at a future version, a warning will be logged. +Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version); + +// Validates that attr_value satisfies the type and constraints from attr. +// REQUIRES: attr has already been validated. +Status ValidateAttrValue(const AttrValue& attr_value, + const OpDef::AttrDef& attr); + +// The following search through op_def for an attr with the indicated name. +// Returns nullptr if no such attr is found. +const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def); +OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def); + +// Searches op_def for input argument with the indicated name. +// Returns nullptr if no such attr is found. +const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def); + +// Produce a human-readable version of an op_def that is more concise +// than a text-format proto. Excludes descriptions. +string SummarizeOpDef(const OpDef& op_def); + +// Returns an error if new_op is not backwards-compatible with (more +// accepting than) old_op. +// REQUIRES: old_op and new_op must pass validation. +Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op); + +// Returns an error if any attr in penultimate_op that is not in old_op +// has a different default value in new_op. In general it is not safe +// to change the default for an attr that has been added to an op. +Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, + const OpDef& penultimate_op, + const OpDef& new_op); + +// Remove all docs from *op_def / *op_list. +void RemoveDescriptionsFromOpDef(OpDef* op_def); +void RemoveDescriptionsFromOpList(OpList* op_list); + +// Remove docs from *op_def but leave explanations of deprecations. +void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def); + +// Returns true if `a1` is equal to `a2`. +// Equality includes all the fields. +bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2); + +// Returns hash of `a` that is consistent with AttrDefEqual. +uint64 AttrDefHash(const OpDef::AttrDef& a); + +// Returns true if all AttrDefs in `a1` equal corresponding AttrDefs in +// `a2`. Corrspondence is established by name. +bool RepeatedAttrDefEqual(const protobuf::RepeatedPtrField& a1, + const protobuf::RepeatedPtrField& a2); + +// Returns hash of `a` that is consistent with RepeatedAttrDefEqual +uint64 RepeatedAttrDefHash(const protobuf::RepeatedPtrField& a); + +// Returns true if `o1` is equal to `o2`. +// Equality includes all the fields. OpDef.attr field is treated as a set. +bool OpDefEqual(const OpDef& o1, const OpDef& o2); + +// Returns hash of `o` that is consistent with AttrDefEqual. +uint64 OpDefHash(const OpDef& o); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_ diff --git a/op_def_util_test.cc b/op_def_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..28809c11c58704479c9c45b1de96dffef3d575bd --- /dev/null +++ b/op_def_util_test.cc @@ -0,0 +1,512 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_def_util.h" + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +OpDef FromText(const string& text) { + OpDef op_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &op_def)); + return op_def; +} + +OpDef::AttrDef ADef(const string& text) { + OpDef::AttrDef attr_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &attr_def)); + return attr_def; +} + +class ValidateOpDefTest : public ::testing::Test { + protected: + Status TestProto(const string& text) { return ValidateOpDef(FromText(text)); } + + Status TestBuilder(const OpDefBuilder& builder) { + OpRegistrationData op_reg_data; + Status status = builder.Finalize(&op_reg_data); + TF_EXPECT_OK(status); + if (!status.ok()) { + return status; + } else { + return ValidateOpDef(op_reg_data.op_def); + } + } + + void ExpectFailure(const Status& status, const string& message) { + EXPECT_FALSE(status.ok()) << "Did not see error with: " << message; + if (!status.ok()) { + LOG(INFO) << "message: " << status; + EXPECT_TRUE(StringPiece(status.ToString()).contains(message)) + << "Actual: " << status << "\nExpected to contain: " << message; + } + } +}; + +TEST_F(ValidateOpDefTest, OpDefValid) { + TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int"))); + TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Input("a: int32"))); + TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Output("a: bool"))); + TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("t: type").Input("a: t"))); + TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int = 3"))); + TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5"))); + TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5"))); + TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5 = 3"))); + TF_EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: numbertype"))); + TF_EXPECT_OK(TestBuilder(OpDefBuilder("Uppercase"))); +} + +TEST_F(ValidateOpDefTest, InvalidName) { + ExpectFailure(TestBuilder(OpDefBuilder("lower").Attr("a: int")), + "Invalid name"); + ExpectFailure(TestBuilder(OpDefBuilder("BadSuffix 7%")), "Invalid name"); +} + +TEST_F(ValidateOpDefTest, DuplicateName) { + ExpectFailure( + TestBuilder(OpDefBuilder("DupeName").Input("a: int32").Input("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder( + OpDefBuilder("DupeName").Input("a: int32").Output("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder( + OpDefBuilder("DupeName").Output("a: int32").Output("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder(OpDefBuilder("DupeName").Input("a: int32").Attr("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder(OpDefBuilder("DupeName").Output("a: int32").Attr("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder(OpDefBuilder("DupeName").Attr("a: int").Attr("a: float")), + "Duplicate name: a"); +} + +TEST_F(ValidateOpDefTest, BadAttrName) { + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude").Attr("int32: int")), + "Attr can't have name int32 that matches a data type"); + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude").Attr("float: string")), + "Attr can't have name float that matches a data type"); +} + +TEST_F(ValidateOpDefTest, BadAttrType) { + ExpectFailure( + TestProto("name: 'BadAttrType' attr { name: 'a' type: 'illegal' }"), + "Unrecognized type"); + ExpectFailure( + TestProto("name: 'BadAttrType' attr { name: 'a' type: 'list(illegal)' }"), + "Unrecognized type"); + ExpectFailure( + TestProto("name: 'BadAttrType' attr { name: 'a' type: 'int extra' }"), + "Extra ' extra' at the end"); + ExpectFailure( + TestProto( + "name: 'BadAttrType' attr { name: 'a' type: 'list(int extra)' }"), + "'list(' is missing ')' in attr"); + ExpectFailure( + TestProto( + "name: 'BadAttrType' attr { name: 'a' type: 'list(int) extra' }"), + "Extra ' extra' at the end"); +} + +TEST_F(ValidateOpDefTest, BadAttrDefault) { + ExpectFailure( + TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'int' default_value { s: 'x' } }"), + "AttrValue had value with type 'string' when 'int' expected\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'int' default_value { f: 0.5 } }"), + "AttrValue had value with type 'float' when 'int' expected\n" + "\t for attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure( + TestProto("name: 'BadAttrDef' attr { name: 'a' type: 'int' " + "default_value { i: 5 list { i: [2] } } }"), + "AttrValue had value with type 'list(int)' when 'int' expected\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure( + TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'list(int)' default_value { f: 0.5 } }"), + "AttrValue had value with type 'float' when 'list(int)' expected\n\t " + "for attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure( + TestProto("name: 'BadAttrDef' attr { name: 'a' type: 'list(int)' " + "default_value { list { i: [5] f: [0.5] } } }"), + "AttrValue had value with type 'list(float)' when 'list(int)' " + "expected\n\t for attr 'a'\n\t in Op 'BadAttrDef'"); + + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'type' default_value { } }"), + "AttrValue missing value with expected type 'type'\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'shape' default_value { } }"), + "AttrValue missing value with expected type 'shape'\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'tensor' default_value { } }"), + "AttrValue missing value with expected type 'tensor'\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + + // default_value {} is indistinguishable from default_value{ list{} } (one + // with an empty list) in proto3 semantics. + TF_EXPECT_OK( + TestProto("name: 'GoodAttrDef' attr { name: 'a' " + "type: 'list(int)' default_value { } }")); + + // Empty lists are allowed: + TF_EXPECT_OK( + TestProto("name: 'GoodAttrDef' attr { name: 'a' " + "type: 'list(int)' default_value { list { } } }")); + // Builder should make the same proto: + TF_EXPECT_OK( + TestBuilder(OpDefBuilder("GoodAttrDef").Attr("a: list(int) = []"))); + + // Unless there is a minimum length specified: + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'list(int)' has_minimum: true minimum: 2 " + "default_value { list { } } }"), + "Length for attr 'a' of 0 must be at least minimum 2\n\t in Op " + "'BadAttrDef'"); + ExpectFailure( + TestBuilder(OpDefBuilder("GoodAttrDef").Attr("a: list(bool) >=2 = []")), + "Length for attr 'a' of 0 must be at least minimum 2\n\t in Op " + "'GoodAttrDef'"); + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' type: " + "'list(string)' has_minimum: true minimum: 2 " + "default_value { list { s: ['foo'] } } }"), + "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op " + "'BadAttrDef'"); + ExpectFailure(TestBuilder(OpDefBuilder("GoodAttrDef") + .Attr("a: list(type) >=2 = [DT_STRING]")), + "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op " + "'GoodAttrDef'"); +} + +TEST_F(ValidateOpDefTest, NoRefTypes) { + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrDef").Input("i: float_ref")), + "Illegal use of ref type 'float_ref'. " + "Use 'Ref(type)' instead for input 'i'"); + ExpectFailure( + TestBuilder(OpDefBuilder("BadAttrDef").Attr("T: type = DT_INT32_REF")), + "AttrValue must not have reference type value of int32_ref"); + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrDef") + .Attr("T: list(type) = [DT_STRING_REF]")), + "AttrValue must not have reference type value of string_ref"); +} + +TEST_F(ValidateOpDefTest, BadAttrMin) { + ExpectFailure(TestProto("name: 'BadAttrMin' attr { name: 'a' type: 'string' " + "has_minimum: true minimum: 0 }"), + "minimum for unsupported type string"); + ExpectFailure( + TestProto("name: 'BadAttrMin' attr { name: 'a' type: 'int' default_value " + "{ i: 2 } has_minimum: true minimum: 7 }"), + "Value for attr 'a' of 2 must be at least minimum 7\n\t in Op " + "'BadAttrMin'"); + ExpectFailure( + TestProto("name: 'BadAttrMin' attr { name: 'a' " + "type: 'list(string)' has_minimum: true minimum: -5 }"), + "list type must have a non-negative minimum, not -5"); + TF_EXPECT_OK( + TestProto("name: 'GoodAttrMin' attr { name: 'a' type: 'list(string)' " + "has_minimum: true minimum: 1 }")); + ExpectFailure(TestProto("name: 'NoHasMin' attr { name: 'a' " + "type: 'list(string)' minimum: 3 }"), + "Attr 'a' with has_minimum = false but minimum 3 not equal to " + "default of 0"); +} + +TEST_F(ValidateOpDefTest, BadAttrAllowed) { + // Is in list of allowed types. + TF_EXPECT_OK(TestBuilder( + OpDefBuilder("GoodAttrtude").Attr("x: numbertype = DT_INT32"))); + // Not in list of allowed types. + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: numbertype = DT_STRING")), + "attr 'x' of string is not in the list of allowed values"); + ExpectFailure( + TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: list(realnumbertype) = [DT_COMPLEX64]")), + "attr 'x' of complex64 is not in the list of allowed values"); + ExpectFailure( + TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: list(realnumbertype) = [DT_COMPLEX128]")), + "attr 'x' of complex128 is not in the list of allowed values"); + // Is in list of allowed strings. + TF_EXPECT_OK(TestBuilder( + OpDefBuilder("GoodAttrtude").Attr("x: {'foo', 'bar'} = 'bar'"))); + // Not in list of allowed strings. + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: {'foo', 'bar'} = 'baz'")), + "attr 'x' of \"baz\" is not in the list of allowed values"); + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: list({'foo', 'bar'}) = ['baz']")), + "attr 'x' of \"baz\" is not in the list of allowed values"); + ExpectFailure(TestProto("name: 'BadAttrtude' attr { name: 'a' " + "type: 'string' allowed_values { s: 'not list' } }"), + "with type 'string' when 'list(string)' expected"); + ExpectFailure( + TestProto("name: 'BadAttrtude' attr { name: 'a' " + "type: 'string' allowed_values { list { i: [6] } } }"), + "with type 'list(int)' when 'list(string)' expected"); +} + +TEST_F(ValidateOpDefTest, BadArgType) { + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' " + "type: DT_INT32 } input_arg { name: 'b' }"), + "Missing type for input 'b'"); + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' " + "type: DT_INT32 } output_arg { name: 'b' }"), + "Missing type for output 'b'"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' type: " + "DT_INT32 type_attr: 'x' } attr { name: 'x' type: 'type' }"), + "Exactly one of type, type_attr, type_list_attr must be set for input " + "'a'"); + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_attr: 'x' } attr { name: 'x' type: 'int' }"), + "Attr 'x' used as type_attr for input 'a' has type int"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_attr: 'x' } attr { name: 'x' type: 'list(type)' }"), + "Attr 'x' used as type_attr for input 'a' has type list(type)"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_list_attr: 'x' } attr { name: 'x' type: 'int' }"), + "Attr 'x' used as type_list_attr for input 'a' has type int"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_list_attr: 'x' } attr { name: 'x' type: 'type' }"), + "Attr 'x' used as type_list_attr for input 'a' has type type"); + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_attr: 'x' }"), + "No attr with name 'x' for input 'a'"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' number_attr: 'n' " + "type_attr: 'x' } attr { name: 'x' type: 'list(type)' } " + "attr { name: 'n' type: 'int' has_minimum: true minimum: 1 }"), + "Attr 'x' used as type_attr for input 'a' has type list(type)"); + // But list(type) is fine as the type of an arg without a number_attr: + TF_EXPECT_OK(TestProto( + "name: 'Arg' input_arg { name: 'a' type_list_attr: 'x' } " + "attr { name: 'x' type: 'list(type)' } attr { name: 'n' type: 'int' " + "has_minimum: true minimum: 1 }")); + + // number_attr + TF_EXPECT_OK(TestProto( + "name: 'Arg' input_arg { name: 'a' type: DT_INT32 number_attr: 'n' } " + "attr { name: 'n' type: 'int' has_minimum: true minimum: 0 }")); + + ExpectFailure(TestProto("name: 'Arg' input_arg { name: 'a' type: DT_INT32 " + "number_attr: 'n' }"), + "No attr with name 'n'"); + ExpectFailure( + TestProto( + "name: 'Arg' input_arg { name: 'a' type: " + "DT_INT32 number_attr: 'n' } attr { name: 'n' type: 'string' }"), + "Attr 'n' used as length for input 'a' has type string"); + ExpectFailure( + TestProto("name: 'Arg' input_arg { name: 'a' type: " + "DT_INT32 number_attr: 'n' } attr { name: 'n' type: 'int' }"), + "Attr 'n' used as length for input 'a' must have minimum;"); + ExpectFailure( + TestProto("name: 'Arg' input_arg { name: 'a' type: DT_INT32 number_attr: " + "'n' } attr { name: 'n' type: 'int' has_minimum: true minimum: " + "-5 }"), + "Attr 'n' used as length for input 'a' must have minimum >= 0;"); + ExpectFailure( + TestProto("name: 'Arg' input_arg { name: 'a' number_attr: 'n' } attr { " + "name: 'n' type: 'int' has_minimum: true minimum: 2 }"), + "Missing type for input 'a'; in OpDef:"); + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' number_attr: " + "'n' type_list_attr: 'x' } attr { name: 'n' type: " + "'int' has_minimum: true minimum: 1 } attr { name: " + "'x' type: 'list(type)' }"), + "Can't have both number_attr and type_list_attr for input 'a'"); +} + +void ExpectDifferent(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) { + EXPECT_FALSE(AttrDefEqual(a1, a2)); + EXPECT_FALSE(AttrDefEqual(a2, a1)); + EXPECT_NE(AttrDefHash(a1), AttrDefHash(a2)); +} + +TEST(AttrDefUtilTest, EqualAndHash) { + OpDef::AttrDef a = ADef( + "name: 'foo' type: 'string' description: 'cool' has_minimum: true " + "minimum: 2 default_value { i: 2 } allowed_values { i: 5 }"); + + EXPECT_TRUE(AttrDefEqual(a, a)); + EXPECT_EQ(AttrDefHash(a), AttrDefHash(a)); + + ExpectDifferent( + a, + ADef("name: 'FOO' type: 'string' description: 'cool' has_minimum: true " + "minimum: 2 default_value { i: 2 } allowed_values { i: 5 }")); + ExpectDifferent( + a, + ADef("name: 'foo' type: 'int32' description: 'cool' has_minimum: true " + "minimum: 2 default_value { i: 2 } allowed_values { i: 5 }")); + ExpectDifferent( + a, + ADef("name: 'foo' type: 'string' description: 'COOL' has_minimum: true " + "minimum: 2 default_value { i: 2 } allowed_values { i: 5 }")); + ExpectDifferent( + a, + ADef("name: 'foo' type: 'string' description: 'cool' has_minimum: false " + "minimum: 2 default_value { i: 2 } allowed_values { i: 5 }")); + ExpectDifferent( + a, + ADef("name: 'foo' type: 'string' description: 'cool' has_minimum: true " + "minimum: 3 default_value { i: 2 } allowed_values { i: 5 }")); + ExpectDifferent( + a, + ADef("name: 'foo' type: 'string' description: 'cool' has_minimum: true " + "minimum: 2 default_value { i: 3 } allowed_values { i: 5 }")); + ExpectDifferent( + a, + ADef("name: 'foo' type: 'string' description: 'cool' has_minimum: true " + "minimum: 2 default_value { i: 2 } allowed_values { i: 6 }")); + + // Same cases but where default_value and allowed_values are not set + a = ADef( + "name: 'foo' type: 'string' description: 'cool' has_minimum: true " + "minimum: 2"); + EXPECT_TRUE(AttrDefEqual(a, a)); + EXPECT_EQ(AttrDefHash(a), AttrDefHash(a)); + + ExpectDifferent( + a, + ADef("name: 'FOO' type: 'string' description: 'cool' has_minimum: true " + "minimum: 2")); + ExpectDifferent( + a, + ADef("name: 'foo' type: 'int32' description: 'cool' has_minimum: true " + "minimum: 2")); + ExpectDifferent( + a, + ADef("name: 'foo' type: 'string' description: 'COOL' has_minimum: true " + "minimum: 2")); + ExpectDifferent( + a, + ADef("name: 'foo' type: 'string' description: 'cool' has_minimum: false " + "minimum: 2")); + ExpectDifferent( + a, + ADef("name: 'foo' type: 'string' description: 'cool' has_minimum: true " + "minimum: 3")); +} + +protobuf::RepeatedPtrField Rep( + const std::vector& defs) { + protobuf::RepeatedPtrField rep; + for (const OpDef::AttrDef& def : defs) { + rep.Add()->MergeFrom(def); + } + return rep; +} + +void ExpectEqual(const protobuf::RepeatedPtrField& a1, + const protobuf::RepeatedPtrField& a2) { + EXPECT_TRUE(RepeatedAttrDefEqual(a1, a2)); + EXPECT_TRUE(RepeatedAttrDefEqual(a2, a1)); + EXPECT_EQ(RepeatedAttrDefHash(a1), RepeatedAttrDefHash(a2)); +} + +void ExpectDifferent(const protobuf::RepeatedPtrField& a1, + const protobuf::RepeatedPtrField& a2) { + EXPECT_FALSE(RepeatedAttrDefEqual(a1, a2)); + EXPECT_FALSE(RepeatedAttrDefEqual(a2, a1)); + EXPECT_NE(RepeatedAttrDefHash(a1), RepeatedAttrDefHash(a2)); +} + +TEST(AttrDefUtilTest, EqualAndHash_Repeated) { + OpDef::AttrDef a1 = ADef( + "name: 'foo1' type: 'string' description: 'cool' has_minimum: true " + "minimum: 2 default_value { i: 2 } allowed_values { i: 5 }"); + + // Different from a1 in name only. + // name is special because AttrDefs are matched by name. + OpDef::AttrDef a2 = ADef( + "name: 'foo2' type: 'string' description: 'cool' has_minimum: true " + "minimum: 2 default_value { i: 2 } allowed_values { i: 5 }"); + + // Different from a1 in "body" only. + OpDef::AttrDef a3 = ADef( + "name: 'foo1' type: 'string' description: 'cool' has_minimum: true " + "minimum: 3 default_value { i: 2 } allowed_values { i: 5 }"); + + // Different in name and "body". + OpDef::AttrDef a4 = ADef( + "name: 'foo3' type: 'string' description: 'cool' has_minimum: true " + "minimum: 3 default_value { i: 2 } allowed_values { i: 5 }"); + + ExpectEqual(Rep({}), Rep({})); + ExpectEqual(Rep({a1}), Rep({a1})); + ExpectEqual(Rep({a1, a2}), Rep({a1, a2})); + ExpectEqual(Rep({a1, a2}), Rep({a2, a1})); + ExpectEqual(Rep({a1, a4}), Rep({a4, a1})); + + ExpectDifferent(Rep({a1}), Rep({})); + ExpectDifferent(Rep({a1}), Rep({a2})); + ExpectDifferent(Rep({a1}), Rep({a3})); + ExpectDifferent(Rep({a1}), Rep({a4})); + ExpectDifferent(Rep({a1}), Rep({a1, a2})); + ExpectDifferent(Rep({a1, a2}), Rep({a1, a4})); + ExpectDifferent(Rep({a1, a2}), Rep({a1, a2, a4})); +} + +void ExpectEqual(const OpDef& o1, const OpDef& o2) { + EXPECT_TRUE(OpDefEqual(o1, o2)); + EXPECT_TRUE(OpDefEqual(o2, o1)); + EXPECT_EQ(OpDefHash(o1), OpDefHash(o2)); +} + +void ExpectDifferent(const OpDef& o1, const OpDef& o2) { + EXPECT_FALSE(OpDefEqual(o1, o2)); + EXPECT_FALSE(OpDefEqual(o2, o1)); + EXPECT_NE(OpDefHash(o1), OpDefHash(o2)); +} + +TEST(OpDefEqualityTest, EqualAndHash) { + string a1 = "attr { name: 'a' type: 'string' } "; + string a2 = "attr { name: 'b' type: 'string' } "; + string a3 = "attr { name: 'c' type: 'int32' } "; + OpDef o1 = FromText(strings::StrCat("name: 'MatMul' ", a1)); + OpDef o2 = FromText(strings::StrCat("name: 'MatMul' ", a2)); + OpDef o3 = FromText(strings::StrCat("name: 'MatMul' ", a1, a2)); + OpDef o4 = FromText(strings::StrCat("name: 'MatMul' ", a2, a1)); + + ExpectEqual(o1, o1); + ExpectEqual(o3, o4); + + ExpectDifferent(o1, o2); + ExpectDifferent(o1, o3); +} + +} // namespace +} // namespace tensorflow diff --git a/op_gen_lib.cc b/op_gen_lib.cc new file mode 100644 index 0000000000000000000000000000000000000000..acff74070da92cc7f298560b7bb81a812924cb0f --- /dev/null +++ b/op_gen_lib.cc @@ -0,0 +1,671 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_gen_lib.h" + +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_gen_overrides.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +string WordWrap(StringPiece prefix, StringPiece str, int width) { + const string indent_next_line = "\n" + Spaces(prefix.size()); + width -= prefix.size(); + string result; + strings::StrAppend(&result, prefix); + + while (!str.empty()) { + if (static_cast(str.size()) <= width) { + // Remaining text fits on one line. + strings::StrAppend(&result, str); + break; + } + auto space = str.rfind(' ', width); + if (space == StringPiece::npos) { + // Rather make a too-long line and break at a space. + space = str.find(' '); + if (space == StringPiece::npos) { + strings::StrAppend(&result, str); + break; + } + } + // Breaking at character at position . + StringPiece to_append = str.substr(0, space); + str.remove_prefix(space + 1); + // Remove spaces at break. + while (to_append.ends_with(" ")) { + to_append.remove_suffix(1); + } + while (str.Consume(" ")) { + } + + // Go on to the next line. + strings::StrAppend(&result, to_append); + if (!str.empty()) strings::StrAppend(&result, indent_next_line); + } + + return result; +} + +bool ConsumeEquals(StringPiece* description) { + if (description->Consume("=")) { + while (description->Consume(" ")) { // Also remove spaces after "=". + } + return true; + } + return false; +} + +// Split `*orig` into two pieces at the first occurrence of `split_ch`. +// Returns whether `split_ch` was found. Afterwards, `*before_split` +// contains the maximum prefix of the input `*orig` that doesn't +// contain `split_ch`, and `*orig` contains everything after the +// first `split_ch`. +static bool SplitAt(char split_ch, StringPiece* orig, + StringPiece* before_split) { + auto pos = orig->find(split_ch); + if (pos == StringPiece::npos) { + *before_split = *orig; + *orig = StringPiece(); + return false; + } else { + *before_split = orig->substr(0, pos); + orig->remove_prefix(pos + 1); + return true; + } +} + +// Does this line start with ":" where "" is +// in multi_line_fields? Sets *colon_pos to the position of the colon. +static bool StartsWithFieldName(StringPiece line, + const std::vector& multi_line_fields) { + StringPiece up_to_colon; + if (!SplitAt(':', &line, &up_to_colon)) return false; + while (up_to_colon.Consume(" ")) + ; // Remove leading spaces. + for (const auto& field : multi_line_fields) { + if (up_to_colon == field) { + return true; + } + } + return false; +} + +static bool ConvertLine(StringPiece line, + const std::vector& multi_line_fields, + string* ml) { + // Is this a field we should convert? + if (!StartsWithFieldName(line, multi_line_fields)) { + return false; + } + // Has a matching field name, so look for "..." after the colon. + StringPiece up_to_colon; + StringPiece after_colon = line; + SplitAt(':', &after_colon, &up_to_colon); + while (after_colon.Consume(" ")) + ; // Remove leading spaces. + if (!after_colon.Consume("\"")) { + // We only convert string fields, so don't convert this line. + return false; + } + auto last_quote = after_colon.rfind('\"'); + if (last_quote == StringPiece::npos) { + // Error: we don't see the expected matching quote, abort the conversion. + return false; + } + StringPiece escaped = after_colon.substr(0, last_quote); + StringPiece suffix = after_colon.substr(last_quote + 1); + // We've now parsed line into ': ""' + + string unescaped; + if (!str_util::CUnescape(escaped, &unescaped, nullptr)) { + // Error unescaping, abort the conversion. + return false; + } + // No more errors possible at this point. + + // Find a string to mark the end that isn't in unescaped. + string end = "END"; + for (int s = 0; unescaped.find(end) != string::npos; ++s) { + end = strings::StrCat("END", s); + } + + // Actually start writing the converted output. + strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end); + if (!suffix.empty()) { + // Output suffix, in case there was a trailing comment in the source. + strings::StrAppend(ml, suffix); + } + strings::StrAppend(ml, "\n"); + return true; +} + +string PBTxtToMultiline(StringPiece pbtxt, + const std::vector& multi_line_fields) { + string ml; + // Probably big enough, since the input and output are about the + // same size, but just a guess. + ml.reserve(pbtxt.size() * (17. / 16)); + StringPiece line; + while (!pbtxt.empty()) { + // Split pbtxt into its first line and everything after. + SplitAt('\n', &pbtxt, &line); + // Convert line or output it unchanged + if (!ConvertLine(line, multi_line_fields, &ml)) { + strings::StrAppend(&ml, line, "\n"); + } + } + return ml; +} + +// Given a single line of text `line` with first : at `colon`, determine if +// there is an "< v = str_util::Split(filenames, ","); + for (const string& f : v) { + TF_RETURN_IF_ERROR(LoadFile(env, f)); + } + return Status::OK(); +} + +Status OpGenOverrideMap::LoadFile(Env* env, const string& filename) { + if (filename.empty()) return Status::OK(); + string contents; + TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents)); + OpGenOverrides all; + protobuf::TextFormat::ParseFromString(contents, &all); + for (const auto& one : all.op()) { + map_[one.name()].reset(new OpGenOverride(one)); + } + return Status::OK(); +} + +static void StringReplace(const string& from, const string& to, string* s) { + // Split *s into pieces delimited by `from`. + std::vector split; + string::size_type pos = 0; + while (pos < s->size()) { + auto found = s->find(from, pos); + if (found == string::npos) { + split.push_back(s->substr(pos)); + break; + } else { + split.push_back(s->substr(pos, found - pos)); + pos = found + from.size(); + if (pos == s->size()) { // handle case where `from` is at the very end. + split.push_back(""); + } + } + } + // Join the pieces back together with a new delimiter. + *s = str_util::Join(split, to.c_str()); +} + +static void RenameInDocs(const string& from, const string& to, OpDef* op_def) { + const string from_quoted = strings::StrCat("`", from, "`"); + const string to_quoted = strings::StrCat("`", to, "`"); + for (int i = 0; i < op_def->input_arg_size(); ++i) { + if (!op_def->input_arg(i).description().empty()) { + StringReplace(from_quoted, to_quoted, + op_def->mutable_input_arg(i)->mutable_description()); + } + } + for (int i = 0; i < op_def->output_arg_size(); ++i) { + if (!op_def->output_arg(i).description().empty()) { + StringReplace(from_quoted, to_quoted, + op_def->mutable_output_arg(i)->mutable_description()); + } + } + for (int i = 0; i < op_def->attr_size(); ++i) { + if (!op_def->attr(i).description().empty()) { + StringReplace(from_quoted, to_quoted, + op_def->mutable_attr(i)->mutable_description()); + } + } + if (!op_def->summary().empty()) { + StringReplace(from_quoted, to_quoted, op_def->mutable_summary()); + } + if (!op_def->description().empty()) { + StringReplace(from_quoted, to_quoted, op_def->mutable_description()); + } +} + +static void RenameInDocs(const string& from, const string& to, + ApiDef* api_def) { + const string from_quoted = strings::StrCat("`", from, "`"); + const string to_quoted = strings::StrCat("`", to, "`"); + for (int i = 0; i < api_def->in_arg_size(); ++i) { + if (!api_def->in_arg(i).description().empty()) { + StringReplace(from_quoted, to_quoted, + api_def->mutable_in_arg(i)->mutable_description()); + } + } + for (int i = 0; i < api_def->out_arg_size(); ++i) { + if (!api_def->out_arg(i).description().empty()) { + StringReplace(from_quoted, to_quoted, + api_def->mutable_out_arg(i)->mutable_description()); + } + } + for (int i = 0; i < api_def->attr_size(); ++i) { + if (!api_def->attr(i).description().empty()) { + StringReplace(from_quoted, to_quoted, + api_def->mutable_attr(i)->mutable_description()); + } + } + if (!api_def->summary().empty()) { + StringReplace(from_quoted, to_quoted, api_def->mutable_summary()); + } + if (!api_def->description().empty()) { + StringReplace(from_quoted, to_quoted, api_def->mutable_description()); + } +} + +const OpGenOverride* OpGenOverrideMap::ApplyOverride(OpDef* op_def) const { + // Look up + const auto iter = map_.find(op_def->name()); + if (iter == map_.end()) return nullptr; + const OpGenOverride& proto = *iter->second; + + // Apply overrides from `proto`. + if (!proto.rename_to().empty()) { + op_def->set_name(proto.rename_to()); + RenameInDocs(proto.name(), proto.rename_to(), op_def); + } + for (const auto& attr_default : proto.attr_default()) { + bool found = false; + for (int i = 0; i < op_def->attr_size(); ++i) { + if (op_def->attr(i).name() == attr_default.name()) { + *op_def->mutable_attr(i)->mutable_default_value() = + attr_default.value(); + found = true; + break; + } + } + if (!found) { + LOG(WARNING) << proto.name() << " can't find attr " << attr_default.name() + << " to override default"; + } + } + for (const auto& attr_rename : proto.attr_rename()) { + bool found = false; + for (int i = 0; i < op_def->attr_size(); ++i) { + if (op_def->attr(i).name() == attr_rename.from()) { + *op_def->mutable_attr(i)->mutable_name() = attr_rename.to(); + found = true; + break; + } + } + if (found) { + RenameInDocs(attr_rename.from(), attr_rename.to(), op_def); + } else { + LOG(WARNING) << proto.name() << " can't find attr " << attr_rename.from() + << " to rename"; + } + } + for (const auto& input_rename : proto.input_rename()) { + bool found = false; + for (int i = 0; i < op_def->input_arg_size(); ++i) { + if (op_def->input_arg(i).name() == input_rename.from()) { + *op_def->mutable_input_arg(i)->mutable_name() = input_rename.to(); + found = true; + break; + } + } + if (found) { + RenameInDocs(input_rename.from(), input_rename.to(), op_def); + } else { + LOG(WARNING) << proto.name() << " can't find input " + << input_rename.from() << " to rename"; + } + } + for (const auto& output_rename : proto.output_rename()) { + bool found = false; + for (int i = 0; i < op_def->output_arg_size(); ++i) { + if (op_def->output_arg(i).name() == output_rename.from()) { + *op_def->mutable_output_arg(i)->mutable_name() = output_rename.to(); + found = true; + break; + } + } + if (found) { + RenameInDocs(output_rename.from(), output_rename.to(), op_def); + } else { + LOG(WARNING) << proto.name() << " can't find output " + << output_rename.from() << " to rename"; + } + } + + return &proto; +} + +namespace { + +// Initializes given ApiDef with data in OpDef. +void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) { + api_def->set_graph_op_name(op_def.name()); + api_def->set_visibility(ApiDef::VISIBLE); + + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(op_def.name()); + if (op_def.has_deprecation()) { + endpoint->set_deprecation_version(op_def.deprecation().version()); + } + + for (const auto& op_in_arg : op_def.input_arg()) { + auto* api_in_arg = api_def->add_in_arg(); + api_in_arg->set_name(op_in_arg.name()); + api_in_arg->set_rename_to(op_in_arg.name()); + api_in_arg->set_description(op_in_arg.description()); + + *api_def->add_arg_order() = op_in_arg.name(); + } + for (const auto& op_out_arg : op_def.output_arg()) { + auto* api_out_arg = api_def->add_out_arg(); + api_out_arg->set_name(op_out_arg.name()); + api_out_arg->set_rename_to(op_out_arg.name()); + api_out_arg->set_description(op_out_arg.description()); + } + for (const auto& op_attr : op_def.attr()) { + auto* api_attr = api_def->add_attr(); + api_attr->set_name(op_attr.name()); + api_attr->set_rename_to(op_attr.name()); + if (op_attr.has_default_value()) { + *api_attr->mutable_default_value() = op_attr.default_value(); + } + api_attr->set_description(op_attr.description()); + } + api_def->set_summary(op_def.summary()); + api_def->set_description(op_def.description()); +} + +// Updates base_arg based on overrides in new_arg. +void MergeArg(ApiDef::Arg* base_arg, const ApiDef::Arg& new_arg) { + if (!new_arg.rename_to().empty()) { + base_arg->set_rename_to(new_arg.rename_to()); + } + if (!new_arg.description().empty()) { + base_arg->set_description(new_arg.description()); + } +} + +// Updates base_attr based on overrides in new_attr. +void MergeAttr(ApiDef::Attr* base_attr, const ApiDef::Attr& new_attr) { + if (!new_attr.rename_to().empty()) { + base_attr->set_rename_to(new_attr.rename_to()); + } + if (new_attr.has_default_value()) { + *base_attr->mutable_default_value() = new_attr.default_value(); + } + if (!new_attr.description().empty()) { + base_attr->set_description(new_attr.description()); + } +} + +// Updates base_api_def based on overrides in new_api_def. +Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) { + // Merge visibility + if (new_api_def.visibility() != ApiDef::DEFAULT_VISIBILITY) { + base_api_def->set_visibility(new_api_def.visibility()); + } + // Merge endpoints + if (new_api_def.endpoint_size() > 0) { + base_api_def->clear_endpoint(); + std::copy( + new_api_def.endpoint().begin(), new_api_def.endpoint().end(), + protobuf::RepeatedFieldBackInserter(base_api_def->mutable_endpoint())); + } + // Merge args + for (const auto& new_arg : new_api_def.in_arg()) { + bool found_base_arg = false; + for (int i = 0; i < base_api_def->in_arg_size(); ++i) { + auto* base_arg = base_api_def->mutable_in_arg(i); + if (base_arg->name() == new_arg.name()) { + MergeArg(base_arg, new_arg); + found_base_arg = true; + break; + } + } + if (!found_base_arg) { + return errors::FailedPrecondition("Argument ", new_arg.name(), + " not defined in base api for ", + base_api_def->graph_op_name()); + } + } + for (const auto& new_arg : new_api_def.out_arg()) { + bool found_base_arg = false; + for (int i = 0; i < base_api_def->out_arg_size(); ++i) { + auto* base_arg = base_api_def->mutable_out_arg(i); + if (base_arg->name() == new_arg.name()) { + MergeArg(base_arg, new_arg); + found_base_arg = true; + break; + } + } + if (!found_base_arg) { + return errors::FailedPrecondition("Argument ", new_arg.name(), + " not defined in base api for ", + base_api_def->graph_op_name()); + } + } + // Merge arg order + if (new_api_def.arg_order_size() > 0) { + // Validate that new arg_order is correct. + if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) { + return errors::FailedPrecondition( + "Invalid number of arguments ", new_api_def.arg_order_size(), " for ", + base_api_def->graph_op_name(), + ". Expected: ", base_api_def->arg_order_size()); + } + if (!std::is_permutation(new_api_def.arg_order().begin(), + new_api_def.arg_order().end(), + base_api_def->arg_order().begin())) { + return errors::FailedPrecondition( + "Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "), + " for ", base_api_def->graph_op_name(), + ". All elements in arg_order override must match base arg_order: ", + str_util::Join(base_api_def->arg_order(), ", ")); + } + + base_api_def->clear_arg_order(); + std::copy( + new_api_def.arg_order().begin(), new_api_def.arg_order().end(), + protobuf::RepeatedFieldBackInserter(base_api_def->mutable_arg_order())); + } + // Merge attributes + for (const auto& new_attr : new_api_def.attr()) { + bool found_base_attr = false; + for (int i = 0; i < base_api_def->attr_size(); ++i) { + auto* base_attr = base_api_def->mutable_attr(i); + if (base_attr->name() == new_attr.name()) { + MergeAttr(base_attr, new_attr); + found_base_attr = true; + break; + } + } + if (!found_base_attr) { + return errors::FailedPrecondition("Attribute ", new_attr.name(), + " not defined in base api for ", + base_api_def->graph_op_name()); + } + } + // Merge summary + if (!new_api_def.summary().empty()) { + base_api_def->set_summary(new_api_def.summary()); + } + // Merge description + auto description = new_api_def.description().empty() + ? base_api_def->description() + : new_api_def.description(); + + if (!new_api_def.description_prefix().empty()) { + description = + strings::StrCat(new_api_def.description_prefix(), "\n", description); + } + if (!new_api_def.description_suffix().empty()) { + description = + strings::StrCat(description, "\n", new_api_def.description_suffix()); + } + base_api_def->set_description(description); + return Status::OK(); +} +} // namespace + +ApiDefMap::ApiDefMap(const OpList& op_list) { + for (const auto& op : op_list.op()) { + ApiDef api_def; + InitApiDefFromOpDef(op, &api_def); + map_[op.name()] = api_def; + } +} + +ApiDefMap::~ApiDefMap() {} + +Status ApiDefMap::LoadFileList(Env* env, const std::vector& filenames) { + for (const auto& filename : filenames) { + TF_RETURN_IF_ERROR(LoadFile(env, filename)); + } + return Status::OK(); +} + +Status ApiDefMap::LoadFile(Env* env, const string& filename) { + if (filename.empty()) return Status::OK(); + string contents; + TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents)); + TF_RETURN_IF_ERROR(LoadApiDef(contents)); + return Status::OK(); +} + +Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) { + const string contents = PBTxtFromMultiline(api_def_file_contents); + ApiDefs api_defs; + protobuf::TextFormat::ParseFromString(contents, &api_defs); + for (const auto& api_def : api_defs.op()) { + // Check if the op definition is loaded. If op definition is not + // loaded, then we just skip this ApiDef. + if (map_.find(api_def.graph_op_name()) != map_.end()) { + // Overwrite current api def with data in api_def. + TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def)); + } + } + return Status::OK(); +} + +void ApiDefMap::UpdateDocs() { + for (auto& name_and_api_def : map_) { + auto& api_def = name_and_api_def.second; + CHECK_GT(api_def.endpoint_size(), 0); + const string canonical_name = api_def.endpoint(0).name(); + if (api_def.graph_op_name() != canonical_name) { + RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def); + } + for (const auto& in_arg : api_def.in_arg()) { + if (in_arg.name() != in_arg.rename_to()) { + RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def); + } + } + for (const auto& out_arg : api_def.out_arg()) { + if (out_arg.name() != out_arg.rename_to()) { + RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def); + } + } + for (const auto& attr : api_def.attr()) { + if (attr.name() != attr.rename_to()) { + RenameInDocs(attr.name(), attr.rename_to(), &api_def); + } + } + } +} + +const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const { + return gtl::FindOrNull(map_, name); +} +} // namespace tensorflow diff --git a/op_gen_lib.h b/op_gen_lib.h new file mode 100644 index 0000000000000000000000000000000000000000..1ede3af8d7cf8f591ba3927f7fc99d646629109d --- /dev/null +++ b/op_gen_lib.h @@ -0,0 +1,129 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_ +#define TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_ + +#include +#include +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +// Forward declare protos so their symbols can be removed from .so exports +class OpDef; +class OpGenOverride; + +inline string Spaces(int n) { return string(n, ' '); } + +// Wrap prefix + str to be at most width characters, indenting every line +// after the first by prefix.size() spaces. Intended use case is something +// like prefix = " Foo(" and str is a list of arguments (terminated by a ")"). +// TODO(josh11b): Option to wrap on ", " instead of " " when possible. +string WordWrap(StringPiece prefix, StringPiece str, int width); + +// Looks for an "=" at the beginning of *description. If found, strips it off +// (and any following spaces) from *description and return true. Otherwise +// returns false. +bool ConsumeEquals(StringPiece* description); + +// Convert text-serialized protobufs to/from multiline format. +string PBTxtToMultiline(StringPiece pbtxt, + const std::vector& multi_line_fields); +string PBTxtFromMultiline(StringPiece multiline_pbtxt); + +// Takes a list of files with OpGenOverrides text protos, and allows you to +// look up the specific override for any given op. +class OpGenOverrideMap { + public: + OpGenOverrideMap(); + ~OpGenOverrideMap(); + + // `filenames` is a comma-separated list of file names. If an op + // is mentioned in more than one file, the last one takes priority. + Status LoadFileList(Env* env, const string& filenames); + + // Load a single file. If more than one file is loaded, later ones + // take priority for any ops in common. + Status LoadFile(Env* env, const string& filename); + + // Look up the override for `*op_def` from the loaded files, and + // mutate `*op_def` to reflect the requested changes. Does not apply + // 'skip', 'hide', or 'alias' overrides. Caller has to deal with + // those since they can't be simulated by mutating `*op_def`. + // Returns nullptr if op is not in any loaded file. Otherwise, the + // pointer must not be referenced beyond the lifetime of *this or + // the next file load. + const OpGenOverride* ApplyOverride(OpDef* op_def) const; + + private: + std::unordered_map> map_; +}; + +// Takes a list of files with ApiDefs text protos, and allows you to +// look up the specific ApiDef for any given op. +class ApiDefMap { + public: + // OpList must be a superset of ops of any subsequently loaded + // ApiDef. + explicit ApiDefMap(const OpList& op_list); + ~ApiDefMap(); + + // You can call this method multiple times to load multiple + // sets of files. Api definitions are merged if the same + // op definition is loaded multiple times. Later-loaded + // definitions take precedense. + // ApiDefs loaded from files must contain a subset of ops defined + // in the OpList passed to the constructor. + Status LoadFileList(Env* env, const std::vector& filenames); + + // Load a single file. Api definitions are merged if the same + // op definition is loaded multiple times. Later-loaded + // definitions take precedense. + // ApiDefs loaded from file must contain a subset of ops defined + // in the OpList passed to the constructor. + Status LoadFile(Env* env, const string& filename); + + // Load ApiDefs from string containing ApiDefs text proto. + // api_def_file_contents is expected to be in "multiline format". + // ApiDefs must contain a subset of ops defined in OpsList + // passed to the constructor. + Status LoadApiDef(const string& api_def_file_contents); + + // Updates ApiDef docs. For example, if ApiDef renames an argument + // or attribute, applies these renames to descriptions as well. + // UpdateDocs should only be called once after all ApiDefs are loaded + // since it replaces original op names. + void UpdateDocs(); + + // Look up ApiDef proto based on the given graph op name. + // If graph op name is not in this ApiDefMap, returns nullptr. + // + // Note: Returned ApiDef pointer should stay valid even after calling + // Load* functions defined above. Subsequent calls to Load* might modify + // returned ApiDef contents, but should never remove the ApiDef itself. + const ApiDef* GetApiDef(const string& name) const; + + private: + std::unordered_map map_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_ diff --git a/op_gen_lib_test.cc b/op_gen_lib_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..857b1c8dbcac66899f98bb4f2ef87f65f7442f6b --- /dev/null +++ b/op_gen_lib_test.cc @@ -0,0 +1,516 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_gen_lib.h" + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +constexpr char kTestOpList[] = R"(op { + name: "testop" + input_arg { + name: "arg_a" + } + input_arg { + name: "arg_b" + } + output_arg { + name: "arg_c" + } + attr { + name: "attr_a" + } + deprecation { + version: 123 + explanation: "foo" + } +)"; + +constexpr char kTestApiDef[] = R"(op { + graph_op_name: "testop" + visibility: VISIBLE + endpoint { + name: "testop1" + } + in_arg { + name: "arg_a" + } + in_arg { + name: "arg_b" + } + out_arg { + name: "arg_c" + } + attr { + name: "attr_a" + } + summary: "Mock op for testing." + description: <DebugString()); +} + +TEST(OpGenLibTest, ApiDefLoadSingleApiDef) { + const string expected_api_def = R"(op { + graph_op_name: "testop" + visibility: VISIBLE + endpoint { + name: "testop1" + } + in_arg { + name: "arg_a" + rename_to: "arg_a" + } + in_arg { + name: "arg_b" + rename_to: "arg_b" + } + out_arg { + name: "arg_c" + rename_to: "arg_c" + } + attr { + name: "attr_a" + rename_to: "attr_a" + } + summary: "Mock op for testing." + description: "Description for the\ntestop." + arg_order: "arg_a" + arg_order: "arg_b" +} +)"; + OpList op_list; + protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT + + ApiDefMap api_map(op_list); + TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef)); + const auto* api_def = api_map.GetApiDef("testop"); + EXPECT_EQ(1, api_def->endpoint_size()); + EXPECT_EQ("testop1", api_def->endpoint(0).name()); + + ApiDefs api_defs; + *api_defs.add_op() = *api_def; + EXPECT_EQ(expected_api_def, api_defs.DebugString()); +} + +TEST(OpGenLibTest, ApiDefOverrideVisibility) { + const string api_def1 = R"( +op { + graph_op_name: "testop" + endpoint { + name: "testop2" + } +} +)"; + const string api_def2 = R"( +op { + graph_op_name: "testop" + visibility: HIDDEN + endpoint { + name: "testop2" + } +} +)"; + OpList op_list; + protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT + + ApiDefMap api_map(op_list); + TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef)); + auto* api_def = api_map.GetApiDef("testop"); + EXPECT_EQ(ApiDef::VISIBLE, api_def->visibility()); + + // Loading ApiDef with default visibility should + // keep current visibility. + TF_CHECK_OK(api_map.LoadApiDef(api_def1)); + EXPECT_EQ(ApiDef::VISIBLE, api_def->visibility()); + + // Loading ApiDef with non-default visibility, + // should update visibility. + TF_CHECK_OK(api_map.LoadApiDef(api_def2)); + EXPECT_EQ(ApiDef::HIDDEN, api_def->visibility()); +} + +TEST(OpGenLibTest, ApiDefOverrideEndpoints) { + const string api_def1 = R"( +op { + graph_op_name: "testop" + endpoint { + name: "testop2" + } +} +)"; + OpList op_list; + protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT + + ApiDefMap api_map(op_list); + TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef)); + auto* api_def = api_map.GetApiDef("testop"); + ASSERT_EQ(1, api_def->endpoint_size()); + EXPECT_EQ("testop1", api_def->endpoint(0).name()); + + TF_CHECK_OK(api_map.LoadApiDef(api_def1)); + ASSERT_EQ(1, api_def->endpoint_size()); + EXPECT_EQ("testop2", api_def->endpoint(0).name()); +} + +TEST(OpGenLibTest, ApiDefOverrideArgs) { + const string api_def1 = R"( +op { + graph_op_name: "testop" + in_arg { + name: "arg_a" + rename_to: "arg_aa" + } + out_arg { + name: "arg_c" + rename_to: "arg_cc" + } + arg_order: "arg_b" + arg_order: "arg_a" +} +)"; + OpList op_list; + protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT + + ApiDefMap api_map(op_list); + TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef)); + TF_CHECK_OK(api_map.LoadApiDef(api_def1)); + const auto* api_def = api_map.GetApiDef("testop"); + ASSERT_EQ(2, api_def->in_arg_size()); + EXPECT_EQ("arg_aa", api_def->in_arg(0).rename_to()); + // 2nd in_arg is not renamed + EXPECT_EQ("arg_b", api_def->in_arg(1).rename_to()); + + ASSERT_EQ(1, api_def->out_arg_size()); + EXPECT_EQ("arg_cc", api_def->out_arg(0).rename_to()); + + ASSERT_EQ(2, api_def->arg_order_size()); + EXPECT_EQ("arg_b", api_def->arg_order(0)); + EXPECT_EQ("arg_a", api_def->arg_order(1)); +} + +TEST(OpGenLibTest, ApiDefOverrideDescriptions) { + const string api_def1 = R"( +op { + graph_op_name: "testop" + summary: "New summary" + description: <summary()); + EXPECT_EQ("A\nNew description\nZ", api_def->description()); + EXPECT_EQ("", api_def->description_prefix()); + EXPECT_EQ("", api_def->description_suffix()); + + TF_CHECK_OK(api_map.LoadApiDef(api_def2)); + EXPECT_EQ("B\nA\nNew description\nZ\nY", api_def->description()); + EXPECT_EQ("", api_def->description_prefix()); + EXPECT_EQ("", api_def->description_suffix()); +} + +TEST(OpGenLibTest, ApiDefInvalidOpInOverride) { + const string api_def1 = R"( +op { + graph_op_name: "different_testop" + endpoint { + name: "testop2" + } +} +)"; + OpList op_list; + protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT + + ApiDefMap api_map(op_list); + TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef)); + TF_CHECK_OK(api_map.LoadApiDef(api_def1)); + ASSERT_EQ(nullptr, api_map.GetApiDef("different_testop")); +} + +TEST(OpGenLibTest, ApiDefInvalidArgOrder) { + const string api_def1 = R"( +op { + graph_op_name: "testop" + arg_order: "arg_a" + arg_order: "unexpected_arg" +} +)"; + + const string api_def2 = R"( +op { + graph_op_name: "testop" + arg_order: "arg_a" +} +)"; + + const string api_def3 = R"( +op { + graph_op_name: "testop" + arg_order: "arg_a" + arg_order: "arg_a" +} +)"; + + OpList op_list; + protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT + ApiDefMap api_map(op_list); + TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef)); + + // Loading with incorrect arg name in arg_order should fail. + auto status = api_map.LoadApiDef(api_def1); + ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); + + // Loading with incorrect number of args in arg_order should fail. + status = api_map.LoadApiDef(api_def2); + ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); + + // Loading with the same argument twice in arg_order should fail. + status = api_map.LoadApiDef(api_def3); + ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); +} + +TEST(OpGenLibTest, ApiDefUpdateDocs) { + const string op_list1 = R"(op { + name: "testop" + input_arg { + name: "arg_a" + description: "`arg_a`, `arg_c`, `attr_a`, `testop`" + } + output_arg { + name: "arg_c" + description: "`arg_a`, `arg_c`, `attr_a`, `testop`" + } + attr { + name: "attr_a" + description: "`arg_a`, `arg_c`, `attr_a`, `testop`" + } + description: "`arg_a`, `arg_c`, `attr_a`, `testop`" +} +)"; + + const string api_def1 = R"( +op { + graph_op_name: "testop" + endpoint { + name: "testop2" + } + in_arg { + name: "arg_a" + rename_to: "arg_aa" + } + out_arg { + name: "arg_c" + rename_to: "arg_cc" + description: "New description: `arg_a`, `arg_c`, `attr_a`, `testop`" + } + attr { + name: "attr_a" + rename_to: "attr_aa" + } +} +)"; + OpList op_list; + protobuf::TextFormat::ParseFromString(op_list1, &op_list); // NOLINT + ApiDefMap api_map(op_list); + TF_CHECK_OK(api_map.LoadApiDef(api_def1)); + api_map.UpdateDocs(); + + const string expected_description = + "`arg_aa`, `arg_cc`, `attr_aa`, `testop2`"; + EXPECT_EQ(expected_description, api_map.GetApiDef("testop")->description()); + EXPECT_EQ(expected_description, + api_map.GetApiDef("testop")->in_arg(0).description()); + EXPECT_EQ("New description: " + expected_description, + api_map.GetApiDef("testop")->out_arg(0).description()); + EXPECT_EQ(expected_description, + api_map.GetApiDef("testop")->attr(0).description()); +} +} // namespace +} // namespace tensorflow diff --git a/op_gen_overrides.proto b/op_gen_overrides.proto new file mode 100644 index 0000000000000000000000000000000000000000..8e66d39a7c7f4a9ff05c91f46a11446e18bc1aed --- /dev/null +++ b/op_gen_overrides.proto @@ -0,0 +1,67 @@ +// Defines the text format for adding per-op overrides for client +// language op code generators. + +syntax = "proto3"; + +package tensorflow; +import "tensorflow/core/framework/attr_value.proto"; + +// Used to override the default API & behavior in the generated code +// for client languages, from what you would get from the OpDef alone. +// This is so we can evolve the API while remaining backwards +// compatible when interpretting old graphs. Overrides go in an +// "op_gen_overrides.pbtxt" file with a text-format OpGenOverrides +// message. Right now these only apply to the C++ API. +// TODO(josh11b): In the future there will be a common set of overrides +// and per-client-language overrides. +// +// WARNING: Be *very* careful using these features -- these overrides +// can change the semantics of existing code. These changes may need +// to wait until a major release of TensorFlow to avoid breaking our +// compatibility promises. +message OpGenOverride { + // Name of the op to apply overrides to. + string name = 1; + + // Do not include this op in the generated API. + // If `skip` is true, all other overrides are ignored for this op. + bool skip = 2; + + // Hide this op by putting it into an internal namespace (or whatever + // is appropriate in the target language). + bool hide = 3; + + // Use a different name in the API than the op's name. Note that + // the op's name in `backticks` will also be replaced in the docs. + string rename_to = 4; + + // Create *additional* API endpoints with different names (contrast + // with rename_to, which affects the original name). + repeated string alias = 5; + + // Map the name of an attr to a new default value to use. This + // default will be used when creating new graphs, as opposed to the + // default in the OpDef, which will be used when interpreting old + // GraphDefs. If this attr is also renamed (using attr_rename + // below), use the original name of the attr. + message AttrDefault { + string name = 1; + AttrValue value = 2; + } + repeated AttrDefault attr_default = 6; + + // Change the name used to access attrs/inputs/outputs in the API + // from what is used in the GraphDef. Note that these names in + // `backticks` will also be replaced in the docs. + message Rename { + string from = 1; + string to = 2; + } + repeated Rename attr_rename = 7; + repeated Rename input_rename = 8; + repeated Rename output_rename = 9; +} + +message OpGenOverrides { + repeated OpGenOverride op = 1; +} diff --git a/op_kernel.cc b/op_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..4d410809e77bd6ba7cd24f78c0ef2f97fa54e588 --- /dev/null +++ b/op_kernel.cc @@ -0,0 +1,1202 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/graph.pb_text.h" +#include "tensorflow/core/framework/kernel_def.pb_text.h" +#include "tensorflow/core/framework/log_memory.h" +#include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { + +Status MatchSignatureHelper(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs, + const DataTypeSlice inputs, + const DataTypeSlice outputs) { + bool signature_mismatch = false; + + if (inputs.size() != expected_inputs.size()) signature_mismatch = true; + for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) { + if (!TypesCompatible(expected_inputs[i], inputs[i])) { + signature_mismatch = true; + } + } + + if (outputs.size() != expected_outputs.size()) signature_mismatch = true; + for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) { + if (!TypesCompatible(expected_outputs[i], outputs[i])) { + signature_mismatch = true; + } + } + + if (signature_mismatch) { + return errors::InvalidArgument( + "Signature mismatch, have: ", DataTypeSliceString(inputs), "->", + DataTypeSliceString(outputs), + " expected: ", DataTypeSliceString(expected_inputs), "->", + DataTypeSliceString(expected_outputs)); + } + return Status::OK(); +} + +} // namespace + +// OpKernel ------------------------------------------------------------------ + +OpKernel::OpKernel(OpKernelConstruction* context) + : def_(new NodeDef(context->def())), + input_types_(context->input_types().begin(), + context->input_types().end()), + input_memory_types_(context->input_memory_types().begin(), + context->input_memory_types().end()), + output_types_(context->output_types().begin(), + context->output_types().end()), + output_memory_types_(context->output_memory_types().begin(), + context->output_memory_types().end()), + graph_def_version_(context->graph_def_version()), + is_internal_(StringPiece(type_string()).starts_with("_")), + input_name_map_(context->num_inputs()), + output_name_map_(context->num_outputs()) { + OP_REQUIRES_OK(context, + NameRangesForNode(*def_, *context->op_def_, &input_name_map_, + &output_name_map_)); + OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_, + context->graph_def_version())); + + // Kernels executing on GPU/SYCL tie very few resources on the CPU where the + // scheduler runs: we consider them as inexpensive. + expensive_ = context->device_type() != DeviceType(DEVICE_GPU) && context->device_type() != DeviceType(DEVICE_SYCL); +} + +OpKernel::~OpKernel() {} + +const string& OpKernel::name() const { return def_->name(); } +const string& OpKernel::type_string() const { return def_->op(); } +const string& OpKernel::requested_device() const { return def_->device(); } +const string& OpKernel::requested_input(int i) const { return def_->input(i); } + +Status OpKernel::InputRange(StringPiece input_name, int* start, + int* stop) const { + const auto result = input_name_map_.find(input_name.ToString()); + if (result == input_name_map_.end()) { + return errors::InvalidArgument("Unknown input name: ", input_name); + } else { + *start = result->second.first; + *stop = result->second.second; + return Status::OK(); + } +} + +Status OpKernel::OutputRange(StringPiece output_name, int* start, + int* stop) const { + const auto result = output_name_map_.find(output_name.ToString()); + if (result == output_name_map_.end()) { + return errors::InvalidArgument("Unknown output name: ", output_name); + } else { + *start = result->second.first; + *stop = result->second.second; + return Status::OK(); + } +} + +Status OpKernel::MakeShape(const Tensor& shape, TensorShape* out) const { + if (!IsLegacyVector(shape.shape())) { + return errors::InvalidArgument( + "shape must be a vector of {int32,int64}, got shape ", + shape.shape().DebugString()); + } + if (shape.dtype() == DataType::DT_INT32) { + auto vec = shape.flat(); + return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out); + } else if (shape.dtype() == DataType::DT_INT64) { + auto vec = shape.flat(); + return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out); + } else { + return errors::InvalidArgument("shape must be a vector of {int32,int64}."); + } +} + +void AsyncOpKernel::Compute(OpKernelContext* context) { + Notification n; + ComputeAsync(context, [&n]() { n.Notify(); }); + n.WaitForNotification(); +} + +// PersistentTensor ---------------------------------------------------------- + +Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) { + // the caller has to have a valid context + CHECK(context); + return &tensor_; +} + +Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) { + context->NotifyUseOfPersistentTensor(tensor_); + return &tensor_; +} + +// OpKernelConstruction ------------------------------------------------------ + +OpKernelConstruction::OpKernelConstruction( + DeviceType device_type, DeviceBase* device, Allocator* allocator, + const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib, + const DataTypeSlice& input_types, const MemoryTypeSlice& input_memory_types, + const DataTypeSlice& output_types, + const MemoryTypeSlice& output_memory_types, int graph_def_version, + Status* status) + : device_type_(std::move(device_type)), + device_(device), + allocator_(allocator), + def_(node_def), + op_def_(op_def), + flib_(flib), + input_types_(input_types), + input_memory_types_(input_memory_types), + output_types_(output_types), + output_memory_types_(output_memory_types), + graph_def_version_(graph_def_version), + status_(status) {} + +bool OpKernelConstruction::HasAttr(StringPiece attr_name) const { + return HasNodeAttr(def(), attr_name); +} + +void OpKernelConstruction::SetStatus(const Status& status) { + status_->Update(status); +} + +Status OpKernelConstruction::MatchSignature( + const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) { + return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_, + output_types_); +} + +Status OpKernelConstruction::allocate_temp(DataType type, + const TensorShape& shape, + Tensor* out_temp) { + AllocationAttributes attr; + attr.allocation_will_be_logged = true; + Tensor new_temp(allocator_, type, shape, attr); + + if (!new_temp.IsInitialized()) { + return errors::ResourceExhausted( + "OOM when allocating temporary tensor with shape", shape.DebugString()); + } + if (LogMemory::IsEnabled()) { + LogMemory::RecordTensorAllocation( + def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); + } + *out_temp = new_temp; + return Status::OK(); +} + +Status OpKernelConstruction::allocate_persistent( + DataType type, const TensorShape& shape, PersistentTensor* out_persistent, + Tensor** out_tensor) { + // for now just do the same thing as allocate_temp + // TODO(misard) add specific memory tracking for persistent tensors + Tensor persistent; + Status s = allocate_temp(type, shape, &persistent); + if (!s.ok()) { + return s; + } + *out_persistent = PersistentTensor(persistent); + Tensor* allocated = out_persistent->AccessTensor(this); + if (out_tensor) { + *out_tensor = allocated; + } + return s; +} + +// OpKernelContext ----------------------------------------------------------- + +OpKernelContext::OpKernelContext(Params* params) + : OpKernelContext( + params, static_cast(params->op_kernel->output_types().size())) {} + +OpKernelContext::OpKernelContext(Params* params, int num_outputs) + : params_(params), + outputs_(num_outputs), + host_temp_memory_size_(0), + device_temp_memory_size_(0), + host_persistent_memory_allocated_(0), + device_persistent_memory_allocated_(0) { + Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); + params_->ensure_eigen_gpu_device(); + params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device, + params_->op_device_context, + eigen_gpu_allocator); + if (params_->record_tensor_accesses) { + referenced_tensors_.Init(); + } +} + +OpKernelContext::~OpKernelContext() { + for (TensorValue& value : outputs_) { + if (!value.is_ref()) { + delete value.tensor; + } + } + if (params_->record_tensor_accesses) referenced_tensors_.Destroy(); +} + +Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) { + Allocator* allocator = + params_->device->GetStepAllocator(attr, resource_manager()); + if (track_allocations()) { + mutex_lock lock(mu_); + for (const auto& wrapped : wrapped_allocators_) { + if (wrapped.first == allocator) { + return wrapped.second; + } + } + TrackingAllocator* wrapped_allocator = + new TrackingAllocator(allocator, params_->track_allocations); + wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator)); + return wrapped_allocator; + } else { + return allocator; + } +} + +void OpKernelContext::SetStatus(const Status& status) { + status_.Update(status); +} + +void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) { + mutex_lock l(mu_); + // Keep a reference to the underlying memory around. + referenced_tensors_->Add(tensor); +} + +Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was " + "expected"); + } + if (input_is_ref(start)) { + return errors::InvalidArgument("OpKernel used ref input name '", name, + "' when non-ref input was expected"); + } + *tensor = (*params_->inputs)[start].tensor; + record_tensor_reference(**tensor); + return Status::OK(); +} + +Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was " + "expected"); + } + const TensorValue& value((*params_->inputs)[start]); + if (value.is_ref()) { + *dtype = MakeRefType(value->dtype()); + } else { + *dtype = value->dtype(); + } + return Status::OK(); +} + +Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was expected"); + } + *out_mutex = input_ref_mutex(start); + return Status::OK(); +} + +const Tensor& OpKernelContext::input(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + DCHECK(!input_is_ref(index)); + const Tensor& tensor = *((*params_->inputs)[index].tensor); + record_tensor_reference(tensor); + return tensor; +} + +Tensor OpKernelContext::mutable_input(int index, bool lock_held) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + DCHECK(input_is_ref(index)); + // return a copy of the Ref acquired while holding the mutex + if (lock_held) { + Tensor& tensor = *((*params_->inputs)[index].tensor); + record_tensor_reference(tensor); + return tensor; + } else { + mutex_lock l(*input_ref_mutex(index)); + Tensor& tensor = *((*params_->inputs)[index].tensor); + record_tensor_reference(tensor); + return tensor; + } +} + +void OpKernelContext::replace_ref_input(int index, const Tensor& tensor, + bool lock_held) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + DCHECK(input_is_ref(index)); + // should only modify the tensor while holding the mutex + if (lock_held) { + *(*params_->inputs)[index].tensor = tensor; + } else { + mutex_lock l(*input_ref_mutex(index)); + *(*params_->inputs)[index].tensor = tensor; + } + record_tensor_reference(tensor); +} + +void OpKernelContext::forward_ref_input_to_ref_output(int input_index, + int output_index) { + DCHECK_GE(input_index, 0); + DCHECK_LT(input_index, num_inputs()); + DCHECK(input_is_ref(input_index)); + set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref, + (*params_->inputs)[input_index].tensor); +} + +bool OpKernelContext::forward_input_to_output_with_shape( + int input_index, int output_index, const TensorShape& output_shape, + Tensor** output) { + const auto output_attr = params_->output_attr_array == nullptr + ? AllocatorAttributes() + : output_alloc_attr(output_index); + std::unique_ptr new_tensor = forward_input( + input_index, expected_output_dtype(output_index), output_shape, + output_memory_type(output_index), output_attr); + if (new_tensor != nullptr) { + // Transfer ownership to the output slot in OpKernelContext. + outputs_[output_index] = TensorValue(new_tensor.release()); + *output = outputs_[output_index].tensor; + return true; + } else { + return false; + } +} + +Status OpKernelContext::forward_input_to_output_with_shape( + StringPiece input_name, StringPiece output_name, + const TensorShape& output_shape, Tensor** output) { + int input_index, output_index, stop; + TF_RETURN_IF_ERROR( + params_->op_kernel->InputRange(input_name, &input_index, &stop)); + if (stop != input_index + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + input_name, + "' when single-valued input was " + "expected"); + } + TF_RETURN_IF_ERROR( + params_->op_kernel->OutputRange(output_name, &output_index, &stop)); + if (stop != output_index + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + output_name, + "' when single-valued output was " + "expected"); + } + if (!forward_input_to_output_with_shape(input_index, output_index, + output_shape, output)) { + return errors::FailedPrecondition("OpKernel could not forward input '", + input_name, "' to output '", output_name); + } + return Status::OK(); +} + +std::unique_ptr OpKernelContext::forward_input( + int input_index, DataType output_dtype, const TensorShape& output_shape, + MemoryType output_memory_type, const AllocatorAttributes& output_attr) { + DCHECK_GE(input_index, 0); + DCHECK_LT(input_index, num_inputs()); + const TensorValue& input = (*params_->inputs)[input_index]; + // Check that input tensor exists, is not a ref, and has no other consumers. + if (input.tensor == nullptr || input.is_ref() || !input->RefCountIsOne()) { + return nullptr; + } + // Check that input type matches. + if (input_dtype(input_index) != output_dtype) { + return nullptr; + } + // Check that the input and output sizes are compatible. + if (input.tensor->shape().num_elements() != output_shape.num_elements()) { + return nullptr; + } + // Check that input and output memory types match, i.e. + // that they either both live in host or both live in device memmory. + if (input_memory_type(input_index) != output_memory_type) { + return nullptr; + } + // Check that output allocator attributes are not more restrictive than + // input allocator attributes. + const auto input_attr = params_->input_alloc_attrs == nullptr + ? AllocatorAttributes() + : input_alloc_attr(input_index); + if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) { + return nullptr; + } + // TODO(rmlarsen): Use MakeUnique here. There is already a copy in + // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of + // general cleanup of ownership in this code. + std::unique_ptr output_tensor(new Tensor()); + CHECK(output_tensor->CopyFrom(*input.tensor, output_shape)); + return output_tensor; +} + +Status OpKernelContext::forward_input_or_allocate_temp( + gtl::ArraySlice candidate_input_indices, DataType type, + const TensorShape& shape, const AllocatorAttributes& allocator_attr, + Tensor* out_temp) { + for (int input_index : candidate_input_indices) { + std::unique_ptr new_tensor = + forward_input(input_index, type, shape, DEVICE_MEMORY, allocator_attr); + if (new_tensor != nullptr) { + *out_temp = std::move(*new_tensor); + return Status::OK(); + } + } + return allocate_temp(type, shape, out_temp, allocator_attr); +} + +void OpKernelContext::delete_ref_input(int index, bool lock_held) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + DCHECK(input_is_ref(index)); + // should only modify the tensor while holding the mutex + if (lock_held) { + delete (*params_->inputs)[index].tensor; + } else { + mutex_lock l(*input_ref_mutex(index)); + delete (*params_->inputs)[index].tensor; + } +} + +Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor, + bool lock_held) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was expected"); + } + if (!input_is_ref(start)) { + return errors::InvalidArgument("OpKernel used non-ref input name '", name, + "' when ref input was expected"); + } + // return a copy of the Ref acquired while holding the mutex + if (lock_held) { + *tensor = *(*params_->inputs)[start].tensor; + } else { + mutex_lock l(*input_ref_mutex(start)); + *tensor = *(*params_->inputs)[start].tensor; + } + record_tensor_reference(*tensor); + return Status::OK(); +} + +Status OpKernelContext::replace_ref_input(StringPiece name, + const Tensor& tensor, + bool lock_held) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was expected"); + } + if (!input_is_ref(start)) { + return errors::InvalidArgument("OpKernel used immutable input name '", name, + "' when ref input was expected"); + } + replace_ref_input(start, tensor, lock_held); + return Status::OK(); +} + +Status OpKernelContext::input_list(StringPiece name, OpInputList* list) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); + *list = OpInputList(this, start, stop); + return Status::OK(); +} + +Status OpKernelContext::mutable_input_list(StringPiece name, + OpMutableInputList* list) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); + *list = OpMutableInputList(this, start, stop); + return Status::OK(); +} + +Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); + *list = OpOutputList(this, start, stop); + return Status::OK(); +} + +Status OpKernelContext::allocate_output(int index, const TensorShape& shape, + Tensor** output) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + AllocatorAttributes attr = output_alloc_attr(index); + return allocate_output(index, shape, output, attr); +} + +Status OpKernelContext::allocate_output(StringPiece name, + const TensorShape& shape, + Tensor** tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + return allocate_output(start, shape, tensor); +} + +Status OpKernelContext::allocate_output(StringPiece name, + const TensorShape& shape, + Tensor** tensor, + AllocatorAttributes attr) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + return allocate_output(start, shape, tensor, attr); +} + +Status OpKernelContext::allocate_tensor( + DataType type, const TensorShape& shape, Tensor* out_tensor, + AllocatorAttributes attr, const AllocationAttributes& allocation_attr) { + Allocator* a = get_allocator(attr); + AllocationAttributes logged_attr(allocation_attr); + logged_attr.allocation_will_be_logged = true; + Tensor new_tensor(a, type, shape, logged_attr); + + if (!new_tensor.IsInitialized()) { + return errors::ResourceExhausted( + "OOM when allocating tensor with shape", shape.DebugString(), + " and type ", DataTypeString(type), " on ", params_->device->name(), + " by allocator ", a->Name()); + } + if (params_->log_memory) { + LogMemory::RecordTensorAllocation(params_->op_kernel->name(), + params_->step_id, new_tensor); + } + record_tensor_reference(new_tensor); + *out_tensor = std::move(new_tensor); + return Status::OK(); +} + +Status OpKernelContext::allocate_output(int index, const TensorShape& shape, + Tensor** output, + AllocatorAttributes attr) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + const DataType type = params_->op_kernel->output_type(index); + DCHECK(!IsRefType(type)); + DCHECK(mutable_output(index) == nullptr); + Tensor* output_tensor = new Tensor(); + Status s = allocate_tensor(type, shape, output_tensor, attr); + if (s.ok()) { + outputs_[index] = TensorValue(output_tensor); + *output = outputs_[index].tensor; + } + return s; +} + +Status OpKernelContext::allocate_temp( + DataType type, const TensorShape& shape, Tensor* out_temp, + AllocatorAttributes allocator_attr, + const AllocationAttributes& allocation_attr) { + Status s = + allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr); + if (track_allocations() && out_temp->TotalBytes() > 0) { + Allocator* a = get_allocator(allocator_attr); + if (a->TracksAllocationSizes()) { + int64 alloc_size = + a->AllocatedSize(const_cast(out_temp->tensor_data().data())); + if (allocate_on_host(allocator_attr)) { + record_host_temp_memory_size(alloc_size); + } else { + record_device_temp_memory_size(alloc_size); + } + } + } + return s; +} + +Status OpKernelContext::allocate_persistent(DataType type, + const TensorShape& shape, + PersistentTensor* out_persistent, + Tensor** out_tensor, + AllocatorAttributes attr) { + Tensor persistent; + Status s = allocate_tensor(type, shape, &persistent, attr); + if (s.ok()) { + *out_persistent = PersistentTensor(persistent); + if (out_tensor) { + *out_tensor = out_persistent->AccessTensor(this); + } + } + return s; +} + +Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + set_output(start, tensor); + return Status::OK(); +} + +void OpKernelContext::set_output(int index, const Tensor& tensor) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + DCHECK(!IsRefType(params_->op_kernel->output_type(index))); + DCHECK_EQ(mutable_output(index), nullptr); + record_tensor_reference(tensor); + outputs_[index] = TensorValue(new Tensor(tensor)); +} + +void OpKernelContext::set_output_ref(int index, mutex* mu, + Tensor* tensor_for_ref) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + DCHECK(IsRefType(params_->op_kernel->output_type(index))); + record_tensor_reference(*tensor_for_ref); + outputs_[index] = TensorValue(mu, tensor_for_ref); +} + +Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu, + Tensor* tensor_for_ref) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + set_output_ref(start, mu, tensor_for_ref); + return Status::OK(); +} + +Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + *tensor = mutable_output(start); + return Status::OK(); +} + +Status OpKernelContext::release_output(StringPiece name, TensorValue* value) { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + *value = release_output(start); + return Status::OK(); +} + +bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { + const auto& inputs = *params_->inputs; + for (size_t i = 1; i < inputs.size(); ++i) { + if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) { + SetStatus(errors::InvalidArgument( + "Inputs to operation ", op->name(), " of type ", op->type_string(), + " must have the same size and shape. Input 0: ", + inputs[0]->shape().DebugString(), " != input ", i, ": ", + inputs[i]->shape().DebugString())); + return false; + } + } + return true; +} + +Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs) { + DataTypeVector inputs; + for (const TensorValue& t : *params_->inputs) { + inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype()); + } + DataTypeVector outputs = params_->op_kernel->output_types(); + return MatchSignatureHelper(expected_inputs, expected_outputs, inputs, + outputs); +} + +bool OpKernelContext::allocate_on_host(AllocatorAttributes alloc_attr) const { + return alloc_attr.on_host() || device()->attributes().device_type() == "CPU"; +} + +void OpKernelContext::record_host_persistent_memory_allocation(int64 size, + int64 alloc_id) { + host_persistent_memory_allocated_ += size; + host_persistent_alloc_ids_.push_back(alloc_id); +} + +void OpKernelContext::record_device_persistent_memory_allocation( + int64 size, int64 alloc_id) { + device_persistent_memory_allocated_ += size; + device_persistent_alloc_ids_.push_back(alloc_id); +} + +std::vector OpKernelContext::host_persistent_alloc_ids() const { + return std::vector(host_persistent_alloc_ids_.begin(), + host_persistent_alloc_ids_.end()); +} + +std::vector OpKernelContext::device_persistent_alloc_ids() const { + return std::vector(device_persistent_alloc_ids_.begin(), + device_persistent_alloc_ids_.end()); +} + +// OpKernel registration ------------------------------------------------------ + +struct KernelRegistration { + KernelRegistration(const KernelDef& d, StringPiece c, + kernel_factory::OpKernelRegistrar::Factory f) + : def(d), kernel_class_name(c.ToString()), factory(f) {} + const KernelDef def; + const string kernel_class_name; + const kernel_factory::OpKernelRegistrar::Factory factory; +}; + +// This maps from 'op_type' + DeviceType to the set of KernelDefs and +// factory functions for instantiating the OpKernel that matches the +// KernelDef. +typedef std::unordered_multimap KernelRegistry; + +void* GlobalKernelRegistry() { + static KernelRegistry* global_kernel_registry = new KernelRegistry; + return global_kernel_registry; +} + +static KernelRegistry* GlobalKernelRegistryTyped() { + return reinterpret_cast(GlobalKernelRegistry()); +} + +static string Key(StringPiece op_type, const DeviceType& device_type, + StringPiece label) { + return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":", + label); +} + +namespace kernel_factory { + +void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, + StringPiece kernel_class_name, + Factory factory) { + // See comments in register_kernel::Name in header for info on _no_register. + if (kernel_def->op() != "_no_register") { + const string key = + Key(kernel_def->op(), DeviceType(kernel_def->device_type()), + kernel_def->label()); + GlobalKernelRegistryTyped()->insert(std::make_pair( + key, KernelRegistration(*kernel_def, kernel_class_name, factory))); + } + delete kernel_def; +} + +} // namespace kernel_factory + +namespace { + +// Helper for AttrsMatch(). +bool InTypeList(DataType dt, const AttrValue& type_list) { + for (int in_list : type_list.list().type()) { + if (dt == in_list) return true; + } + return false; +} + +// Returns whether the attrs satisfy the constraints in the kernel_def. Returns +// an error if attrs in kernel_def are not found, or have a mismatching type. +Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) { + *match = false; + for (const auto& constraint : kernel_def.constraint()) { + if (constraint.allowed_values().list().type_size() == 0) { + return errors::Unimplemented( + "KernelDef '", ProtoShortDebugString(kernel_def), + " has constraint on attr '", constraint.name(), + "' with unsupported type: ", + SummarizeAttrValue(constraint.allowed_values())); + } + + const AttrValue* found = attrs.Find(constraint.name()); + if (found) { + if (found->type() != DT_INVALID) { + if (!InTypeList(found->type(), constraint.allowed_values())) { + return Status::OK(); + } + } else { + if (!AttrValueHasType(*found, "list(type)").ok()) { + return errors::InvalidArgument( + "KernelDef '", ProtoShortDebugString(kernel_def), + "' has constraint on attr '", constraint.name(), + "' that has value '", SummarizeAttrValue(*found), + "' that does not have type 'type' or 'list(type)' in NodeDef " + "'", + attrs.SummarizeNode(), "'"); + } + + for (int t : found->list().type()) { + if (!InTypeList(static_cast(t), + constraint.allowed_values())) { + return Status::OK(); + } + } + } + } else { + return errors::InvalidArgument( + "OpKernel '", kernel_def.op(), "' has constraint on attr '", + constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(), + "', KernelDef: '", ProtoShortDebugString(kernel_def), "'"); + } + } + *match = true; + return Status::OK(); +} + +static const StringPiece kKernelAttr("_kernel"); + +// TODO(irving): Replace with const Node& version below. +Status FindKernelRegistration(const DeviceType& device_type, + const NodeDef& node_def, + const KernelRegistration** reg, + bool* was_attr_mismatch) { + *reg = nullptr; + *was_attr_mismatch = false; + // Label defaults to empty if not found in NodeDef. + const string& label = GetNodeAttrString(node_def, kKernelAttr); + + const string key = Key(node_def.op(), device_type, label); + auto regs = GlobalKernelRegistryTyped()->equal_range(key); + for (auto iter = regs.first; iter != regs.second; ++iter) { + // If there is a kernel registered for the op and device_type, + // check that the attrs match. + bool match; + TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match)); + if (match) { + if (*reg != nullptr) { + return errors::InvalidArgument( + "Multiple OpKernel registrations match NodeDef '", + SummarizeNodeDef(node_def), "': '", + ProtoShortDebugString((*reg)->def), "' and '", + ProtoShortDebugString(iter->second.def), "'"); + } + *reg = &iter->second; + } else { + *was_attr_mismatch = true; + } + } + return Status::OK(); +} + +Status FindKernelRegistration(const DeviceType& device_type, const Node& node, + const KernelRegistration** reg, + bool* was_attr_mismatch) { + return FindKernelRegistration(device_type, node.def(), reg, + was_attr_mismatch); +} + +} // namespace + +// TODO(irving): Change const NodeDef& to const Node& +Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, + const KernelDef** def, string* kernel_class_name) { + const KernelRegistration* reg = nullptr; + bool was_attr_mismatch; + TF_RETURN_IF_ERROR( + FindKernelRegistration(device_type, node_def, ®, &was_attr_mismatch)); + if (reg == nullptr) { + Status s = errors::NotFound( + "No registered '", node_def.op(), "' OpKernel for ", + DeviceTypeString(device_type), " devices compatible with node ", + SummarizeNodeDef(node_def)); + if (was_attr_mismatch) { + errors::AppendToMessage( + &s, " (OpKernel was found, but attributes didn't match)"); + } + errors::AppendToMessage( + &s, ". Registered:", KernelsRegisteredForOp(node_def.op())); + return s; + } + if (def != nullptr) *def = ®->def; + if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name; + return Status::OK(); +} + +Status SupportedDeviceTypesForNode( + const std::vector& prioritized_types, const NodeDef& def, + DeviceTypeVector* device_types) { + // TODO(zhifengc): Changes the callers (SimplePlacer and + // DynamicPlacer) to consider the possibility that 'def' is call to + // a user-defined function and only calls this + // SupportedDeviceTypesForNode for primitive ops. + const OpRegistrationData* op_reg_data; + const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data); + if (s.ok()) { + for (const DeviceType& device_type : prioritized_types) { + const KernelRegistration* reg = nullptr; + bool was_attr_mismatch; + TF_RETURN_IF_ERROR( + FindKernelRegistration(device_type, def, ®, &was_attr_mismatch)); + if (reg != nullptr) device_types->push_back(device_type); + } + } else { + // Assumes that all device types support this node. + for (const DeviceType& device_type : prioritized_types) { + device_types->push_back(device_type); + } + } + return Status::OK(); +} + +void LogAllRegisteredKernels() { + for (const auto& key_registration : *GlobalKernelRegistryTyped()) { + const KernelDef& kernel_def(key_registration.second.def); + LOG(INFO) << "OpKernel ('" << ProtoShortDebugString(kernel_def) << "')"; + } +} + +string KernelsRegisteredForOp(StringPiece op_name) { + string ret; + for (const auto& key_registration : *GlobalKernelRegistryTyped()) { + const KernelDef& kernel_def(key_registration.second.def); + if (kernel_def.op() == op_name) { + strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'"); + if (!kernel_def.label().empty()) { + strings::StrAppend(&ret, "; label='", kernel_def.label(), "'"); + } + for (int i = 0; i < kernel_def.constraint_size(); ++i) { + strings::StrAppend( + &ret, "; ", kernel_def.constraint(i).name(), " in ", + SummarizeAttrValue(kernel_def.constraint(i).allowed_values())); + } + strings::StrAppend(&ret, "\n"); + } + } + if (ret.empty()) return " \n"; + return ret; +} + +std::unique_ptr CreateOpKernel( + DeviceType device_type, DeviceBase* device, Allocator* allocator, + const NodeDef& node_def, int graph_def_version, Status* status) { + OpKernel* kernel = nullptr; + *status = CreateOpKernel(std::move(device_type), device, allocator, nullptr, + node_def, graph_def_version, &kernel); + return std::unique_ptr(kernel); +} + +Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + const NodeDef& node_def, int graph_def_version, + OpKernel** kernel) { + VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); + + // Look up the Op registered for this op name. + const OpDef* op_def = nullptr; + Status s = OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def); + if (!s.ok()) return s; + + // Validate node_def against OpDef. + s = ValidateNodeDef(node_def, *op_def); + if (!s.ok()) return s; + + // Look up kernel registration. + const KernelRegistration* registration; + bool was_attr_mismatch; + s = FindKernelRegistration(device_type, node_def, ®istration, + &was_attr_mismatch); + if (!s.ok()) { + errors::AppendToMessage(&s, " when instantiating ", node_def.op()); + return s; + } + if (registration == nullptr) { + s.Update(errors::NotFound("No registered '", node_def.op(), + "' OpKernel for ", DeviceTypeString(device_type), + " devices compatible with node ", + SummarizeNodeDef(node_def))); + if (was_attr_mismatch) { + errors::AppendToMessage( + &s, " (OpKernel was found, but attributes didn't match)"); + } + errors::AppendToMessage( + &s, ". Registered:", KernelsRegisteredForOp(node_def.op())); + return s; + } + + // Get signature from the OpDef & NodeDef + DataTypeVector inputs; + DataTypeVector outputs; + s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs)); + if (!s.ok()) { + errors::AppendToMessage(&s, " for node: ", SummarizeNodeDef(node_def)); + return s; + } + + // We are creating a kernel for an op registered in + // OpRegistry::Global(), we consult the kernel registry to decide + // the kernel's input and output memory types. + MemoryTypeVector input_memory_types; + MemoryTypeVector output_memory_types; + TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type, + node_def, &input_memory_types, + &output_memory_types)); + + // Everything needed for OpKernel construction. + OpKernelConstruction context( + device_type, device, allocator, &node_def, op_def, flib, inputs, + input_memory_types, outputs, output_memory_types, graph_def_version, &s); + *kernel = (*registration->factory)(&context); + if (!s.ok()) { + delete *kernel; + *kernel = nullptr; + } + return s; +} + +namespace { + +bool FindArgInOp(StringPiece arg_name, + const protobuf::RepeatedPtrField& args) { + for (const auto& arg : args) { + if (arg_name == arg.name()) { + return true; + } + } + return false; +} + +} // namespace + +Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) { + for (const auto& key_registration : *GlobalKernelRegistryTyped()) { + const KernelDef& kernel_def(key_registration.second.def); + const OpRegistrationData* op_reg_data; + const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data); + if (!status.ok()) { + // TODO(josh11b): Make this a hard error. + LOG(ERROR) << "OpKernel ('" << ProtoShortDebugString(kernel_def) + << "') for unknown op: " << kernel_def.op(); + continue; + } + const OpDef& op_def = op_reg_data->op_def; + for (const auto& host_memory_arg : kernel_def.host_memory_arg()) { + if (!FindArgInOp(host_memory_arg, op_def.input_arg()) && + !FindArgInOp(host_memory_arg, op_def.output_arg())) { + return errors::InvalidArgument( + "HostMemory arg '", host_memory_arg, + "' not found in OpDef: ", SummarizeOpDef(op_def)); + } + } + } + return Status::OK(); +} + +template <> +const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const { + return eigen_cpu_device(); +} + +template <> +const Eigen::GpuDevice& OpKernelContext::eigen_device() const { + return eigen_gpu_device(); +} + +#ifdef TENSORFLOW_USE_SYCL +template <> +const Eigen::SyclDevice& OpKernelContext::eigen_device() const { + return eigen_sycl_device(); +} +#endif + +void OpKernelConstruction::CtxFailure(Status s) { + VLOG(1) << s; + SetStatus(s); +} + +void OpKernelConstruction::CtxFailureWithWarning(Status s) { + LOG(WARNING) << s; + SetStatus(s); +} + +void OpKernelContext::CtxFailure(Status s) { + VLOG(1) << s; + SetStatus(s); +} + +void OpKernelContext::CtxFailureWithWarning(Status s) { + LOG(WARNING) << s; + SetStatus(s); +} + +} // namespace tensorflow diff --git a/op_kernel.h b/op_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3a9a6121c05b02e0f7724dc77adbddca22f0ff19 --- /dev/null +++ b/op_kernel.h @@ -0,0 +1,1532 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ +#define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ + +#include + +#include +#include +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/selective_registration.h" +#include "tensorflow/core/framework/session_state.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove +#include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/unique_tensor_references.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/manual_constructor.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace Eigen { +struct ThreadPoolDevice; +struct GpuDevice; +struct SyclDevice; +} // end namespace Eigen + +namespace tensorflow { + +namespace checkpoint { +class TensorSliceReaderCacheWrapper; +} // namespace checkpoint + +class AsyncOpKernel; +class CallFrameInterface; +class FunctionLibraryRuntime; +class OpKernelConstruction; // declared below +class OpKernelContext; // declared below +class OpRegistryInterface; +class ResourceMgr; +class ScopedStepContainer; +class StepStatsCollector; + +class OpKernel { + public: + // OpKernel won't be instantiated by the scheduler, so you may perform + // expensive initialization in the descendant's constructor. + explicit OpKernel(OpKernelConstruction* context); + virtual ~OpKernel(); + + // An OpKernel's computation can be either synchronous or + // asynchronous. All OpKernel Compute() methods must be thread-safe as they + // may be called concurrently (e.g. by multiple executions of the same graph + // concurrently). + // + // Most OpKernels should compute synchronously. They should + // subclass OpKernel and override the Compute() method and have it + // return after completing the supplied work. + // + // A few special kernels might need to be asynchronous to bound the + // number of threads (e.g., network receive operations). These + // kernels must subclass AsyncOpKernel and override + // AsyncOpKernel::ComputeAsync(). + // + // In both cases, implementations of Compute() and ComputeAsync() + // get inputs and write outputs through the given OpKernelContext + // and returns a status via context->SetStatus(). They must be + // thread-safe. + + // Synchronous compute. + // + // "context" is guaranteed to be alive until Compute() returns. + virtual void Compute(OpKernelContext* context) = 0; + + // Returns nullptr iff this op kernel is synchronous. + virtual AsyncOpKernel* AsAsync() { return nullptr; } + + // Returns true iff this op kernel is considered "expensive". The + // runtime may use this flag to optimize graph execution for example + // to "inline" inexpensive kernels. + virtual bool IsExpensive() { return expensive_; } + + // Accessors. + const NodeDef& def() const { return *def_; } + const string& name() const; // Same as def().name() + const string& type_string() const; // Same as def().op() + const string& requested_device() const; // Same as def().device() + bool is_internal() const { return is_internal_; } + + int num_inputs() const { return input_types_.size(); } + DataType input_type(int i) const { return input_types_[i]; } + const DataTypeVector& input_types() const { return input_types_; } + const MemoryTypeVector& input_memory_types() const { + return input_memory_types_; + } + const string& requested_input(int i) const; // Same as def().input(i) + + int num_outputs() const { return output_types_.size(); } + DataType output_type(int o) const { return output_types_[o]; } + const DataTypeVector& output_types() const { return output_types_; } + const MemoryTypeVector& output_memory_types() const { + return output_memory_types_; + } + + Status InputRange(StringPiece input_name, int* start, int* stop) const; + Status OutputRange(StringPiece output_name, int* start, int* stop) const; + + // We allow legacy scalars within Google up until GraphDef version 6. + // TODO(irving): Remove when we can drop support for GraphDef version 5. + bool allow_legacy_scalars() const { +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) + return graph_def_version_ < 6; +#else + return false; +#endif + } + + // Allow either scalars or (if allowing legacy scalars) shape (1,). + bool IsLegacyScalar(const TensorShape& shape) const { + return shape.dims() == 0 || (allow_legacy_scalars() && shape.dims() == 1 && + shape.dim_size(0) == 1); + } + + // Allow rank 1 or (if allowing legacy scalars) rank 0. + bool IsLegacyVector(const TensorShape& shape) const { + return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0); + } + + // Turn a shape Tensor into a TensorShape + // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars + Status MakeShape(const Tensor& shape, TensorShape* out) const; + + private: + const std::unique_ptr def_; + const DataTypeVector input_types_; + const MemoryTypeVector input_memory_types_; + const DataTypeVector output_types_; + const MemoryTypeVector output_memory_types_; + const int graph_def_version_; + const bool is_internal_; // True if this is an internal operation + NameRangeMap input_name_map_; + NameRangeMap output_name_map_; + bool expensive_; + + TF_DISALLOW_COPY_AND_ASSIGN(OpKernel); +}; + +class AsyncOpKernel : public OpKernel { + public: + using OpKernel::OpKernel; // Lift OpKernel constructors. + + // Asynchronous compute. + // + // Implementations of ComputeAsync() must run "done" to signal the + // completion of the computation. "context" is guaranteed to be + // alive until the "done" callback starts. + typedef std::function DoneCallback; + virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0; + + AsyncOpKernel* AsAsync() final { return this; } + + void Compute(OpKernelContext* context) final; + + bool IsExpensive() override { return true; } +}; + +// Wraps a tensor that is held by an Op across calls to Compute(). For +// memory safety when using asynchronous devices like GPUs, the system +// must be notified when a Tensor is used inside an Op execution. The +// wrapper ensures that all uses of the Tensor are tracked, because in +// order to retrieve the Tensor the caller must use AccessTensor which +// notifies the context. +class PersistentTensor { + public: + PersistentTensor() {} + explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {} + + // Caller does not own the returned Tensor*. + Tensor* AccessTensor(OpKernelConstruction* context); + // Caller does not own the returned Tensor*. + Tensor* AccessTensor(OpKernelContext* context); + + // The check for initialization does not need to access the + // underlying tensor buffer. + bool IsInitialized() const { return tensor_.IsInitialized(); } + + int64 NumElements() const { return tensor_.NumElements(); } + + int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); } + + private: + Tensor tensor_; +}; + +class OpKernelConstruction { + public: + OpKernelConstruction(DeviceType device_type, DeviceBase* device, + Allocator* allocator, const NodeDef* node_def, + const OpDef* op_def, FunctionLibraryRuntime* flib, + const DataTypeSlice& input_types, + const MemoryTypeSlice& input_memory_types, + const DataTypeSlice& output_types, + const MemoryTypeSlice& output_memory_types, + int graph_def_version, Status* status); + + Env* env() const { return device_->env(); } + + // Allocation of tensors during kernel construction: + // + // It is legal to temporarily allocate scratch tensor storage during + // Op kernel construction. Scratch tensors should be allocated using + // allocate_temp below. Some kernels need to keep tensors in between + // invocations. If such a Tensor is allocated during kernel + // construction this must be done using allocate_persistent, and the + // Op may only store the returned PersistentTensor object. When the + // Tensor is needed in a subsequent invocation, it can be retrieved + // from the PersistentTensor using the AccessTensor method. This + // ensures that the system is made aware of any use of the tensor's + // allocated memory, which is needed for correctness on asynchronous + // devices such as GPUs. + + // Allocates a temporary Tensor of the specified type and shape. The + // Tensor must not be used after kernel construction is + // complete. See comment above. + Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp); + + // Allocates a Tensor of the specified type and shape which the Op + // plans to maintain as persistent state. out_persistent holds the + // PersistentTensor which is the object the caller should store. For + // convenience, if out_tensor is non-null then it will be filled in + // with a Tensor* pointing to the newly-allocated tensor which the + // caller can use instead of calling + // out_persistent->AccessTensor. The caller does not own out_tensor + // and should not keep a copy of it. See comment above. + Status allocate_persistent(DataType type, const TensorShape& shape, + PersistentTensor* out_persistent, + Tensor** out_tensor); + + // User-supplied configuration of this operation. + const NodeDef& def() const { return *def_; } + + // For inspecting the inputs to this operation. + int num_inputs() const { return input_types_.size(); } + DataType input_type(int i) const { return input_types_[i]; } + const DataTypeSlice& input_types() const { return input_types_; } + const MemoryTypeSlice& input_memory_types() const { + return input_memory_types_; + } + + // For inspecting the outputs expected from this operation. + int num_outputs() const { return output_types_.size(); } + DataType output_type(int i) const { return output_types_[i]; } + const DataTypeSlice& output_types() const { return output_types_; } + const MemoryTypeSlice& output_memory_types() const { + return output_memory_types_; + } + + // If expected_inputs == inputs() and expected_outputs == output_types(), + // returns OK, else returns INVALID_ARGUMENT with an error message. + // Recommended for Ops with dynamic signatures. + Status MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs); + + // For recording configuration errors during construction. + void SetStatus(const Status& status); + const Status& status() const { return *status_; } + + // Look up the attr with name attr_name and set *value to its value. If no + // attr with attr_name is found in def(), or the attr does not have + // a matching type, a non-ok status will be returned. + template + Status GetAttr(StringPiece attr_name, T* value) const; + + // Return true if the attr_name is defined in def(). + bool HasAttr(StringPiece attr_name) const; + + // Return the device type. + const DeviceType& device_type() const { return device_type_; } + + // If not nullptr, the kernel can instantiate functions defined in + // the library. E.g., + // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...). + FunctionLibraryRuntime* function_library() const { return flib_; } + + // The GraphDef version whose behavior we should follow. + int graph_def_version() const { return graph_def_version_; } + + // Helper routines for the OP_REQUIRES macros + void CtxFailure(Status s); + void CtxFailureWithWarning(Status s); + + // Unrecommended functions: these are functions that have some + // current uses but are not recommended for use, and may go away at + // some future major version release. + + // May be used, e.g., to get GPU handles, etc. + // + // Currently only used to call MakeTensorFromProto() for + // implementing ConstantOp for every device. See comments + // on Device::MakeTensorFromProto for longer-term replacement + // ideas. + DeviceBase* device() const { return device_; } + + private: + const DeviceType device_type_; + DeviceBase* const device_; + Allocator* allocator_; + const NodeDef* def_; + const OpDef* op_def_; + FunctionLibraryRuntime* flib_; + DataTypeSlice input_types_; + MemoryTypeSlice input_memory_types_; + DataTypeSlice output_types_; + MemoryTypeSlice output_memory_types_; + const int graph_def_version_; + Status* status_; + + // Allow op_def_ across from OpKernel, but not from subclasses. + // TODO(irving): Remove protos from this header entirely. + friend class OpKernel; + + TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction); +}; + +// TODO(mrry): Consider converting to a random_access_iterator, and upgrading +// tensorflow::gtl::iterator_range to make the below container classes +// unnecessary. +template +class OpArgIterator { + public: + typedef OpArgIterator ME; + OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {} + bool operator==(const ME& rhs) { + DCHECK(list_ == rhs.list_); + return i_ == rhs.i_; + } + bool operator!=(const ME& rhs) { + DCHECK(list_ == rhs.list_); + return i_ != rhs.i_; + } + void operator++() { ++i_; } + ElementType& operator*() { return (*list_)[i_]; } + + private: + const ListType* const list_; + int i_; +}; + +// Utility class for representing a list of immutable input tensors +// that are passed to the op as a single named argument. +class OpInputList { + public: + typedef OpArgIterator Iterator; + OpInputList() : ctx_(nullptr), start_(0), stop_(0) {} + OpInputList(OpKernelContext* ctx, int start, int stop) + : ctx_(ctx), start_(start), stop_(stop) {} + OpInputList& operator=(const OpInputList& other) = default; + const Tensor& operator[](int i) const; + int size() const { return stop_ - start_; } + Iterator begin() const { return Iterator(this, 0); } + Iterator end() const { return Iterator(this, size()); } + + private: + OpKernelContext* ctx_; // not owned + int start_; + int stop_; +}; + +// Utility class for representing a list of mutable ("ref") input tensors +// that are passed to the op as a single named argument. +class OpMutableInputList { + public: + typedef OpArgIterator Iterator; + OpMutableInputList(OpKernelContext* ctx, int start, int stop) + : ctx_(ctx), start_(start), stop_(stop) {} + OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {} + OpMutableInputList& operator=(const OpMutableInputList& other) = default; + Tensor at(int i, bool lock_held); + mutex* ref_mutex(int i); + int size() const { return stop_ - start_; } + Iterator begin() const { return Iterator(this, 0); } + Iterator end() const { return Iterator(this, size()); } + + private: + OpKernelContext* ctx_; // not owned + int start_; + int stop_; +}; + +// Utility class for representing a list of output tensors that are +// grouped as a single named output. +class OpOutputList { + public: + typedef OpArgIterator Iterator; + OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {} + OpOutputList(OpKernelContext* ctx, int start, int stop) + : ctx_(ctx), start_(start), stop_(stop) {} + OpOutputList& operator=(const OpOutputList& other) = default; + Tensor* operator[](int i); + bool required(int i) const; + DataType expected_output_dtype(int i) const; + Status allocate(int i, const TensorShape& shape, Tensor** output); + void set(int i, const Tensor& tensor); + void set_ref(int i, mutex* mu, Tensor* tensor_for_ref); + int size() const { return stop_ - start_; } + Iterator begin() const { return Iterator(this, 0); } + Iterator end() const { return Iterator(this, size()); } + + private: + OpKernelContext* ctx_; // not owned + int start_; + int stop_; +}; + +// Holds a tensor or tensor reference. For tensor references, we need +// a mutex to prevent concurrent access to the tensor. +struct TensorValue { + TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {} + TensorValue(Tensor* t) // NOLINT(runtime/explicit) + : mutex_if_ref(nullptr), tensor(t) {} + TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {} + Tensor* operator->() const { return tensor; } + bool is_ref() const { return mutex_if_ref != nullptr; } + + mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref + Tensor* tensor; +}; + +class OpKernelContext { + public: + // The first element of a WrappedAllocator is a "base" Allocator and + // the second element is that Allocator wrapped by a + // TrackingAllocator + typedef std::pair WrappedAllocator; + + // TODO(zhifengc): Do some cleanup of Params. + // The Params struct is passed in to initialize an OpKernelContext, + // and must outlive the OpKernelContext. + struct Params { + ~Params() { delete eigen_gpu_device; } + + // The step being executed. + int64 step_id = 0; + + // The op kernel being computed. + OpKernel* op_kernel = nullptr; + + // The device on which the kernel is running. + DeviceBase* device = nullptr; + + // The Eigen GPU device wrapper, which may include a per-op + // wrapped allocator. The concrete type of this object depends on + // the type of this->device, so eigen_gpu_device can't be an + // inline member and must be heap allocated. However, we don't + // want to allocate a new eigen_gpu_device for every Op that is + // executed. Instead this member is allocated on first use using + // ensure_eigen_gpu_device, and then if the Params structure is + // re-used for subsequent Ops, the eigen_gpu_device is + // ReInitialized in the OpKernelContext constructor. Unlike the + // other pointers in Params, this one is owned by Params. + PerOpGpuDevice* eigen_gpu_device = nullptr; + + inline void ensure_eigen_gpu_device() { + DCHECK(device); + if (nullptr == eigen_gpu_device) { + // Surprisingly, MakeGpuDevice will return nullptr if the + // device is not a GPU device. This is ok, since those devices + // will never use eigen_gpu_device. It seems better to have + // ensure_eigen_gpu_device fall through and regenerate the + // nullptr every time an OpKernelContext is instantiated, than + // to do an unnecessary allocation of a dummy eigen GPU + // device for CPU device Ops. + eigen_gpu_device = device->MakeGpuDevice(); + } + } + + bool track_allocations = false; + bool log_memory = false; + bool record_tensor_accesses = false; + + // Array indexed by output number for this node + const AllocatorAttributes* output_attr_array = nullptr; + + // Shared resources accessible by this op kernel invocation. + ResourceMgr* resource_manager = nullptr; + + // Per-step resources accessible by this op kernel invocation should be + // stored in this container.. + ScopedStepContainer* step_container = nullptr; + + // Mechanism used by this op kernel invocation to communicate with + // computations running on other devices. + Rendezvous* rendezvous = nullptr; + + // The session state for this op. + SessionState* session_state = nullptr; + + // The tensor store for this op. + TensorStore* tensor_store = nullptr; + + // Mechanism used by this op kernel invocation to register a callback + // for its cancellation. + CancellationManager* cancellation_manager = nullptr; + + // Inputs to this op kernel. + const gtl::InlinedVector* inputs = nullptr; + bool is_input_dead = false; + + const gtl::InlinedVector* input_alloc_attrs = + nullptr; + + // Device contexts. + const gtl::InlinedVector* input_device_contexts = + nullptr; + DeviceContext* op_device_context = nullptr; + + // Control-flow op supports. + FrameAndIter frame_iter; + + // Function call supports. + CallFrameInterface* call_frame = nullptr; + FunctionLibraryRuntime* function_library = nullptr; + std::function)>* runner = nullptr; + StepStatsCollector* stats_collector = nullptr; + + // TensorSliceReaderCache support. + checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; + }; + + // params must outlive the OpKernelContext. + explicit OpKernelContext(Params* params); + OpKernelContext(Params* params, int noutputs); + ~OpKernelContext(); + + Env* env() const { return params_->device->env(); } + + int64 step_id() const { return params_->step_id; } + + const OpKernel& op_kernel() const { return *params_->op_kernel; } + + // Input/output signature. + + int num_inputs() const { return params_->inputs->size(); } + DataType input_dtype(int index) const; + Status input_dtype(StringPiece name, DataType* dtype) const; + MemoryType input_memory_type(int index) const; + + int num_outputs() const { return outputs_.size(); } + DataType expected_output_dtype(int index) const; + MemoryType output_memory_type(int index) const; + + // Input + + // Returns an immutable input tensor. May only be used for non-Ref + // inputs. For Ref inputs use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + // TODO(mrry): Convert this to return Status. + const Tensor& input(int index); + + // Returns the named immutable input tensor in "tensor", as defined + // in the OpDef. May only be used for non-Ref inputs. For Ref inputs + // use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + // REQUIRES: the named input must not be a list. + Status input(StringPiece name, const Tensor** tensor); + + // Returns the named list-valued immutable input in "list", as + // defined in the OpDef. If the named output is not list-valued, + // returns a one-element list. May only be used for non-Ref + // inputs. For Ref inputs use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + Status input_list(StringPiece name, OpInputList* list); + + // For mutable inputs, use the following together to make sure there + // is no concurrent access to mutable_input(), e.g.: + // { + // Tensor& t = context->mutable_input(index); + // mutex_lock lock(*context->input_ref_mutex(index)); + // // modify the values in t + // } + // REQUIRES: IsRefType(input_dtype(index)) + Status input_ref_mutex(StringPiece name, mutex** out_mutex); + + // Returns a mutable input tensor. Must be used to access Ref + // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may + // modify the values stored in the Tensor buffer, and modifications + // will be visible to other Ops reading the same ref tensor. If + // !lock_held the input mutex will be acquired before returning the + // Tensor. + // TODO(mrry): Convert this to return Status. + Tensor mutable_input(int index, bool lock_held); + + // Returns the named mutable input tensor in "tensor", as defined in + // the OpDef. Must be used to access Ref inputs. The values stored + // in the Tensor buffer may be modified, and modifications will be + // visible to other Ops reading the same ref tensor. If !lock_held + // the input mutex will be acquired before returning the Tensor. + // REQUIRES: the named input must not be a list. + // REQUIRES: the named input must be a ref tensor. + Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held); + + // Returns the named list-valued mutable input in "list", as defined + // in the OpDef. If the named input is not list-valued, returns a + // one-element list. Must be used to access Ref inputs. The values + // stored in the Tensor buffer may be modified, and modifications + // will be visible to other Ops reading the same ref tensor. + // REQUIRES: the named input must be a ref tensor. + Status mutable_input_list(StringPiece name, OpMutableInputList* list); + + // Replace the corresponding Ref Input to use the storage buffer + // used by tensor. If !lock_held the input mutex will be acquired + // before returning the Tensor. + // REQUIRES: IsRefType(input_dtype(index)). + void replace_ref_input(int index, const Tensor& tensor, bool lock_held); + + // Replace the corresponding named Ref Input to use the storage + // buffer used by tensor. If !lock_held the input mutex will be + // acquired before returning the Tensor. + // REQUIRES: IsRefType(input_dtype(index)). + Status replace_ref_input(StringPiece name, const Tensor& tensor, + bool lock_held); + + // Deletes the Tensor object used as the Ref Input at + // input_index. This is not usually necessary and should be used + // with caution. If !lock_held the input mutex will be acquired + // before returning the Tensor. + // REQUIRES: IsRefType(input_dtype(input_index)). + void delete_ref_input(int input_index, bool lock_held); + + // Return true if there is input at the given index. An operator has no + // input at index if its tensor is null. This is primarily used by the + // merge operator. + // TODO(mrry): Convert this to return Status. + bool has_input(int index) const; + + // Returns true if all inputs are the same shape, otherwise sets the + // status to a non-OK value and returns false. + // Usage: if (!context->ValidateInputsAreSameShape(this)) return; + bool ValidateInputsAreSameShape(OpKernel* op); + + // Input to output forwarding. + + // Set the output Ref Tensor at output_index to be an alias of the + // input Ref Tensor at input_index. + // REQUIRES: IsRefType(input_dtype(input_index)). + // REQUIRES: IsRefType(output_dtype(output_index)). + void forward_ref_input_to_ref_output(int input_index, int output_index); + + // Returns true when an alias to input[input_index], reshaped to output_shape, + // which is safe to use for in-place computation was written to *output. + // Returns false if input[input_index] has a refcount greater than one, or if + // its type does not match the expected output type of output[output_index], + // or the number of elements in input[input_index] does not equal the number + // of elements in output_shape. + bool forward_input_to_output_with_shape(int input_index, int output_index, + const TensorShape& output_shape, + Tensor** output) TF_MUST_USE_RESULT; + Status forward_input_to_output_with_shape(StringPiece input_name, + StringPiece output_name, + const TensorShape& output_shape, + Tensor** output) TF_MUST_USE_RESULT; + + // Returns a pointer to a Tensor aliasing the underlying buffer backing + // input[input_index] iff + // * input[input_index] is not a ref, + // * the data type, shape, memory type, and allocator attributes of + // input[input_index] are compatible with those given in dtype, shape, + // memory_type, and attr, + // * refcount on the underlying buffer is one. + // Otherwise returns nullptr. + // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic, + // forwarding is only safe if there are no reads via __ldg() after writes + // to the same address. + std::unique_ptr forward_input( + int input_index, DataType dtype, const TensorShape& shape, + MemoryType memory_type, + const AllocatorAttributes& attr) TF_MUST_USE_RESULT; + + // Tries to forward one of the inputs given in input_indices to + // output[output_index]. If none of the given inputs can be forwarded, calls + // allocate_output() to allocate a new output buffer. + Status forward_input_or_allocate_output( + gtl::ArraySlice candidate_input_indices, int output_index, + const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT; + Status forward_input_or_allocate_output( + gtl::ArraySlice candidate_input_names, + StringPiece output_name, const TensorShape& output_shape, + Tensor** output) TF_MUST_USE_RESULT; + + // Tries to reuse one of the inputs given in input_indices as a temporary. + // If none of the given inputs can be forwarded, calls + // allocate_temp() to allocate a new temporary buffer. + Status forward_input_or_allocate_temp( + gtl::ArraySlice candidate_input_indices, DataType type, + const TensorShape& shape, const AllocatorAttributes& allocator_attr, + Tensor* out_temp) TF_MUST_USE_RESULT; + + Status forward_input_or_allocate_temp( + gtl::ArraySlice candidate_input_indices, DataType type, + const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT { + return forward_input_or_allocate_temp(candidate_input_indices, type, shape, + AllocatorAttributes(), out_temp); + } + + // Output + + // Returns the named list-valued output in "list", as defined in the OpDef. + // If the named output is not list-valued, returns a one-element list. + Status output_list(StringPiece name, OpOutputList* list); + + // If output_required(index) returns true, the OpKernel's Compute() method + // should call allocate_output(index, ...), set_output(index, ...), + // set_output_ref(index, ...), or set the status to a non-ok value. + // If it returns false, it may output, but is not required to do so. + // TODO(mrry): Convert this to return Status, and implement a string + // name version. + bool output_required(int index) const { + return true; // TODO(josh11b): implement + } + + // Allocation of tensors during kernel execution inside the Compute + // method: + // + // There are three methods to allocate Tensors when an Op kernel + // executes. + // + // 1) allocate_persistent. This is only needed for Tensors that will + // be stored by the Op between invocations, and it *must* be used + // for those Tensors. The call returns a PersistentTensor, and that + // is the only object the Op is allowed to hold on to between + // invocations. When the Tensor is needed in a subsequent + // invocation, it can be retrieved from the PersistentTensor using + // the AccessTensor method. This ensures that the system is made + // aware of any use of the tensor's allocated memory, which is + // needed for correctness on asynchronous devices such as GPUs. + // + // 2) allocate_output. This should be used to allocate any tensor + // that is going to be used as an output from the Op at the end of + // the current execution. The caller indicates which output the + // Tensor will be assigned to, and the call returns the + // newly-allocated Tensor. The Tensor can subsequently be assigned + // to during kernel execution, and will be used as the designated + // output when the kernel execution completes. + // + // 3) allocate_temp. This should be used to allocate any scratch + // storage that is needed while the kernel is executing, and will + // not be retained by the Op. + // + // In some cases a Tensor needs to be used as an output even though + // it was previously allocated elsewhere. The Tensor may have been + // passed as an input, or stored in a PersistentTensor during a + // previous kernel execution, or allocated earlier in the kernel + // execution at a time when it was not known which output it would + // be assigned to. In this case the kernel can use set_output or + // set_output_ref to indicate that the tensor should be used as the + // designated output. It is legal to use any previously-allocated + // Tensor as an argument to set_output or set_output_ref, including + // Tensors allocated via allocate_temp. There may be a performance + // penalty to using a Tensor that was not allocated using + // allocate_output. This is because allocate_output uses the + // AllocatorAttributes stored in output_attr_array for the + // designated output. In some cases, using the wrong attributes may + // cause an extra copy of the Tensor's buffer. + + // Allocates output for the specified output index with shape. + // OpKernelContext retains ownership of the returned pointer. See + // comment above. + // + // If memory allocation fails, returns an error status. + // + // REQUIRES: !IsRefType(expected_output_dtype(index)) + Status allocate_output(int index, const TensorShape& shape, + Tensor** tensor) TF_MUST_USE_RESULT; + Status allocate_output(StringPiece name, const TensorShape& shape, + Tensor** tensor) TF_MUST_USE_RESULT; + // The following methods use the supplied attributes instead of + // those in output_attr_array. The caller is responsible for + // ensuring that the attributes are "compatible" with the + // output_attr_array, e.g. the tensor is allocated on the correct + // device. See comment above. + Status allocate_output(int index, const TensorShape& shape, Tensor** tensor, + AllocatorAttributes attr) TF_MUST_USE_RESULT; + Status allocate_output(StringPiece name, const TensorShape& shape, + Tensor** tensor, + AllocatorAttributes attr) TF_MUST_USE_RESULT; + + // Allocates a temporary Tensor of the specified type and + // shape. Devices such as GPUs that enqueue Ops for lazy execution + // may retain references to the temporary tensors after the Op's + // Compute method has run. See comment above. + Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp, AllocatorAttributes allocator_attr, + const AllocationAttributes& allocation_attr); + Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp, AllocatorAttributes allocator_attr) { + return allocate_temp(type, shape, out_temp, allocator_attr, + AllocationAttributes()); + } + Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp) { + return allocate_temp(type, shape, out_temp, AllocatorAttributes()); + } + + // Allocates a Tensor of the specified type and shape which the Op + // plans to maintain as persistent state. out_persistent holds the + // PersistentTensor which is the object the caller should store. For + // convenience, if out_tensor is non-null then it will be filled in + // with a Tensor* pointing to the newly-allocated tensor which the + // caller can use instead of calling + // out_persistent->AccessTensor. The caller does not own out_tensor + // and should not keep a copy of it. See comment above. + Status allocate_persistent(DataType type, const TensorShape& shape, + PersistentTensor* out_persistent, + Tensor** out_tensor, AllocatorAttributes attr); + Status allocate_persistent(DataType type, const TensorShape& shape, + PersistentTensor* out_persistent, + Tensor** out_tensor) { + return allocate_persistent(type, shape, out_persistent, out_tensor, + AllocatorAttributes()); + } + + // Copies a tensor (allocated by the caller) to the specified output + // index. REQUIRES: !IsRefType(expected_output_dtype(index)) + // REQUIRES: 'tensor' must have the same MemoryType as + // output_memory_types[index]. See comment above. + Status set_output(StringPiece name, const Tensor& tensor); + + // To output a reference. Caller retains ownership of mu and tensor_for_ref, + // and they must outlive all uses within the step. See comment above. + // REQUIRES: IsRefType(expected_output_dtype(index)) + Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref); + + // Returns nullptr if allocate_output() or set_output() have not been called. + Status mutable_output(StringPiece name, Tensor** tensor); + + // Transfers ownership of an output tensor to the caller. + // NOTE: For non-reference outputs, the caller takes responsibility + // for deletion. For reference outputs, the caller does NOT take + // responsibility for deletion. + Status release_output(StringPiece name, TensorValue* value); + + // Records device specific state about how the input tensors were + // computed. + // + // If using the templated function, the type must be a subclass + // of DeviceContext. + // + // Get the DeviceContext used for the index input. Returns nullptr + // if no DeviceContext was provided. + template + T* input_device_context(int index); + DeviceContext* input_device_context(int index); + + // Return the DeviceContext that should be used for this Op. + // + // If using the templated function, the type must be a subclass + // of DeviceContext. + // + // Returns nullptr if the device did not provide one. + template + T* op_device_context(); + DeviceContext* op_device_context() { + DeviceContext* ret = params_->op_device_context; + if (ret == nullptr) { + auto* dev_info = device()->tensorflow_gpu_device_info(); + if (dev_info) ret = dev_info->default_context; + } + return ret; + } + + AllocatorAttributes input_alloc_attr(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_->input_alloc_attrs->size()); + return (*params_->input_alloc_attrs)[index]; + } + + AllocatorAttributes output_alloc_attr(int index) const { + return params_->output_attr_array[index]; + } + + gtl::InlinedVector wrapped_allocators() const { + mutex_lock lock(mu_); + gtl::InlinedVector retrieved = wrapped_allocators_; + return retrieved; + } + + // Communication. + // + // An op kernel communicates with outside environment through + // Rendezvous Send() and Recv(). + Rendezvous* rendezvous() const { return params_->rendezvous; } + + // An op kernel can access the session state it belongs to. + SessionState* session_state() const { return params_->session_state; } + + // An op kernel can access the tensor store of the run it belongs to. + TensorStore* tensor_store() const { return params_->tensor_store; } + + // Function call support. + // + // If this kernel invocation is within a function execution, + // call_frame() returns the call frame for the function call. + CallFrameInterface* call_frame() const { return params_->call_frame; } + + // If not nullptr, the kernel invoke functions defined in the + // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...). + FunctionLibraryRuntime* function_library() const { + return params_->function_library; + } + + std::function)>* runner() const { + return params_->runner; + } + StepStatsCollector* stats_collector() const { + return params_->stats_collector; + } + + // Shared resources accessible to this kernel. + ResourceMgr* resource_manager() const { return params_->resource_manager; } + + checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const { + return params_->slice_reader_cache; + } + + // Execution. + // + // OpKernels can use these eigen devices to carry out their + // numerical computation. + const Eigen::ThreadPoolDevice& eigen_cpu_device() const { + return *device()->eigen_cpu_device(); + } + const Eigen::GpuDevice& eigen_gpu_device() const { + return params_->eigen_gpu_device->device(); + } +#ifdef TENSORFLOW_USE_SYCL + const Eigen::SyclDevice& eigen_sycl_device() const { + return *device()->eigen_sycl_device(); + } +#endif + template + const EigenDeviceType& eigen_device() const; + + // Error handling. + + // If expected_inputs == inputs() and expected_outputs == output_types(), + // returns OK, else returns INVALID_ARGUMENT with an error message. + // Recommended for Ops with dynamic signatures, where validation can only + // be performed at runtime. + Status MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs); + + // An OpKernel should call SetStatus() if Compute() encounters an + // error. + void SetStatus(const Status& status); + const Status& status() const { return status_; } + + // Cancellation. + // + // EXPERIMENTAL. See the implementation in tensorflow::TensorQueue for an + // example of how to use this API. + CancellationManager* cancellation_manager() const { + return params_->cancellation_manager; + } + + // Other accessors. + + // For control flow. + FrameAndIter frame_iter() const { return params_->frame_iter; } + bool is_input_dead() const { return params_->is_input_dead; } + bool* is_output_dead() { return &is_output_dead_; } + + // May be used, e.g., to get GPU handles, etc. + // TODO(tucker): Add example usage. + DeviceBase* device() const { return params_->device; } + + // Retrieve list of referenced tensors in out_vector. Once this is + // called, it is not legal to reference any more tensors. Should + // not be called from Op kernels. + void retrieve_accessed_tensors(TensorReferenceVector* out_vector); + + // Per-step container for use by white-listed internal ops. + ScopedStepContainer* step_container() const { + return params_->step_container; + } + + // Helper routines for the OP_REQUIRES macros + void CtxFailure(Status s); + void CtxFailureWithWarning(Status s); + + // Unrecommended functions: these are functions that have some + // current uses but are not recommended for use, and may go away at + // some future major version release. + // + // The following functions all have versions that return Status + // to capture error conditions, and are strongly preferred. + Tensor* mutable_output(int index); + void set_output(int index, const Tensor& tensor); + mutex* input_ref_mutex(int index); + void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref); + TensorValue release_output(int index); + + bool track_allocations() const { return params_->track_allocations; } + bool allocate_on_host(AllocatorAttributes alloc_attr) const; + + // Records temporary memory sizes. + void record_host_temp_memory_size(int64 size) { + host_temp_memory_size_ += size; + } + void record_device_temp_memory_size(int64 size) { + device_temp_memory_size_ += size; + } + + // Returns recorded size of temporary memory; + int64 host_temp_memory_size() const { return host_temp_memory_size_; } + int64 device_temp_memory_size() const { return device_temp_memory_size_; } + + // Records persistent memory allocation, size can be negative indicating + // deallocation. + void record_host_persistent_memory_allocation(int64 size, + int64 alloc_id = -1); + void record_device_persistent_memory_allocation(int64 size, + int64 alloc_id = -1); + + // Returns recorded size and ids of persistent memory. + int64 host_persistent_memory_allocated() const { + return host_persistent_memory_allocated_; + } + int64 device_persistent_memory_allocated() const { + return device_persistent_memory_allocated_; + } + std::vector host_persistent_alloc_ids() const; + std::vector device_persistent_alloc_ids() const; + + bool input_is_ref(int index) const; + + private: + Allocator* get_allocator(AllocatorAttributes attr); + + // Internal method to add a tensor's buffer to the list of buffers + // referenced during the execution of the Op, so that GPUs may + // accurately track the memory that may not be reused until the Op + // execution completes. + void record_tensor_reference(const Tensor& tensor); + void really_record_tensor_reference(const Tensor& tensor); + + // Internal common method used when allocating tensor memory + Status allocate_tensor(DataType type, const TensorShape& shape, + Tensor* out_tensor, + AllocatorAttributes allocator_attr) { + return allocate_tensor(type, shape, out_tensor, allocator_attr, + AllocationAttributes()); + } + + Status allocate_tensor(DataType type, const TensorShape& shape, + Tensor* out_tensor, AllocatorAttributes allocator_attr, + const AllocationAttributes& allocation_attr); + + // This is called by PersistentTensor::AccessTensor whenever the + // wrapped tensor is retrieved, to ensure the runtime knows that the + // Tensor is being accessed within an Op. This is necessary for + // memory safety of devices like GPUs that queue Ops for + // asynchronous execution after the Compute() method completes. + friend class PersistentTensor; + void NotifyUseOfPersistentTensor(const Tensor& tensor); + + Status status_; + Params* params_; // not owned + mutable mutex mu_; // mutable so const accessors can acquire the lock + gtl::InlinedVector wrapped_allocators_ GUARDED_BY(mu_); + gtl::InlinedVector outputs_; + + // Constructed only if record_tensor_accesses>. + ManualConstructor referenced_tensors_ GUARDED_BY(mu_); + + bool is_output_dead_ = false; + + int64 host_temp_memory_size_; + int64 device_temp_memory_size_; + gtl::InlinedVector host_persistent_alloc_ids_; + gtl::InlinedVector device_persistent_alloc_ids_; + int64 host_persistent_memory_allocated_; + int64 device_persistent_memory_allocated_; + + TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext); +}; + +// Register your OpKernel by specifying the Op's name, the device the +// kernel runs on, any type attr constraints for this kernel, any +// host-memory args, and the class to instantiate. Examples: +// +// // A kernel that supports all types. +// REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp); +// +// // The following are equivalent ways of specifying that the kernel only +// // works if the "T" type attr is set to DT_FLOAT. +// REGISTER_KERNEL_BUILDER( +// Name("Sub").Device(DEVICE_CPU).TypeConstraint("T"), +// SubOp); +// // (You would then repeat this for every type supported by "Sub".) +// +// // This form allows you to specify a list of types as the constraint. +// REGISTER_KERNEL_BUILDER(Name("Sub") +// .Device(DEVICE_CPU) +// .TypeConstraint("T", {DT_FLOAT}), +// SubOp); +// +// // A kernel that expects one of the input tensors in host memory. +// REGISTER_KERNEL_BUILDER( +// Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp); +// +// See kernel_def_builder for details. + +// Instantiate an OpKernel that has been registered. Returns nullptr +// if no operation for that type of device / input signature combination +// (and a NOT_FOUND *status), or there is an error in construction (and +// an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership +// of the returned pointer. +// EXPECTED USAGE: unique_ptr op = CreateOpKernel(...); +// REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +std::unique_ptr CreateOpKernel(DeviceType device_type, + DeviceBase* device, + Allocator* allocator, + const NodeDef& def, + int graph_def_version, Status* status); +Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + const NodeDef& def, int graph_def_version, + OpKernel** kernel); + +// Returns into 'device_types' the subset of prioritized_types that this +// binary has registered for the given NodeDef. +// +// REQUIRES: * 'device_types' is not nullptr. +// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +Status SupportedDeviceTypesForNode( + const std::vector& prioritized_types, const NodeDef& def, + DeviceTypeVector* device_types); + +// Returns a message with a description of the kernels registered for op +// `op_name`. +string KernelsRegisteredForOp(StringPiece op_name); + +// Call once after Op registration has completed. +Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry); + +// ----------------------------------------------------------------------------- +// OpKernel registration implementation follows, please ignore. + +// Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax. +namespace register_kernel { + +class Name : public KernelDefBuilder { + public: + // With selective registration, kernels whose implementation class is not used + // by any kernel are disabled with the SHOULD_REGISTER_OP_KERNEL call in + // REGISTER_KERNEL_BUILDER_UNIQ. However, an unused kernel that shares an + // implementation class with a used kernel would get through that mechanism. + // + // This mechanism stops that registration by changing the name of the kernel + // for the unused op to one that is ignored by + // OpKernelRegistrar::InitInternal. Note that this method alone is + // not sufficient - the compiler can't evaluate the entire KernelDefBuilder at + // compilation time, so this method doesn't actually reduce code size. + explicit Name(const char* op) + : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {} +}; + +namespace system { + +class Name : public KernelDefBuilder { + public: + // For system kernels, we ignore selective registration and + // unconditionally register the kernel. + explicit Name(const char* op) : KernelDefBuilder(op) {} +}; + +} // namespace system + +} // namespace register_kernel + +#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ + REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__) + +#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \ + REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__) + +#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ + constexpr bool should_register_##ctr##__flag = \ + SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \ + static ::tensorflow::kernel_factory::OpKernelRegistrar \ + registrar__body__##ctr##__object( \ + should_register_##ctr##__flag \ + ? ::tensorflow::register_kernel::kernel_builder.Build() \ + : nullptr, \ + #__VA_ARGS__, \ + [](::tensorflow::OpKernelConstruction* context) \ + -> ::tensorflow::OpKernel* { \ + return new __VA_ARGS__(context); \ + }); + +// The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as +// `REGISTER_KERNEL_BUILDER()` except that the kernel is registered +// unconditionally even when selective registration is used. +#define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \ + REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, \ + __VA_ARGS__) + +#define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \ + REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__) + +#define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ + static ::tensorflow::kernel_factory::OpKernelRegistrar \ + registrar__body__##ctr##__object( \ + ::tensorflow::register_kernel::system::kernel_builder.Build(), \ + #__VA_ARGS__, \ + [](::tensorflow::OpKernelConstruction* context) \ + -> ::tensorflow::OpKernel* { \ + return new __VA_ARGS__(context); \ + }); + +void* GlobalKernelRegistry(); + +// If node_def has a corresponding kernel registered on device_type, +// returns OK and fill in the kernel def and kernel_class_name. and +// may be null. +Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, + const KernelDef** def, string* kernel_class_name); + +// Writes a list of all registered kernels to LOG(INFO), to help users debug +// missing kernel errors. +void LogAllRegisteredKernels(); + +namespace kernel_factory { + +class OpKernelRegistrar { + public: + typedef OpKernel* (*Factory)(OpKernelConstruction*); + + OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, + Factory factory) { + // Perform the check in the header to allow compile-time optimization + // to a no-op, allowing the linker to remove the kernel symbols. + if (kernel_def != nullptr) { + InitInternal(kernel_def, kernel_class_name, factory); + } + } + + private: + void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, + Factory factory); +}; + +} // namespace kernel_factory + +// ----------------------------------------------------------------------------- +// Template and inline method implementations, please ignore + +template +Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const { + return GetNodeAttr(def(), attr_name, value); +} + +inline DataType OpKernelContext::input_dtype(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + const TensorValue& value((*params_->inputs)[index]); + if (value.is_ref()) { + return MakeRefType(value->dtype()); + } else { + return value->dtype(); + } +} + +inline MemoryType OpKernelContext::input_memory_type(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + return op_kernel().input_memory_types()[index]; +} + +inline DataType OpKernelContext::expected_output_dtype(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + return params_->op_kernel->output_type(index); +} + +inline MemoryType OpKernelContext::output_memory_type(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + return op_kernel().output_memory_types()[index]; +} + +inline bool OpKernelContext::input_is_ref(int index) const { + const TensorValue& value((*params_->inputs)[index]); + return value.is_ref(); +} + +inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) { + DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(), + params_->record_tensor_accesses); + if (params_->record_tensor_accesses) { + really_record_tensor_reference(tensor); + } +} + +inline void OpKernelContext::retrieve_accessed_tensors( + TensorReferenceVector* out_vector) { + if (params_->record_tensor_accesses) { + mutex_lock l(mu_); + referenced_tensors_->FreezeAndReturnReferences(out_vector); + } +} + +// no input if tensor == nullptr. +inline bool OpKernelContext::has_input(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + return (*params_->inputs)[index].tensor != nullptr; +} + +inline mutex* OpKernelContext::input_ref_mutex(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + DCHECK(input_is_ref(index)); + return (*params_->inputs)[index].mutex_if_ref; +} + +inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) { + if (t.IsInitialized()) { + record_tensor_reference(t); + } +} + +inline Tensor* OpKernelContext::mutable_output(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + // No need to record_tensor_reference since the output must already + // have been set by a call that did so. + return outputs_[index].tensor; +} + +inline TensorValue OpKernelContext::release_output(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + TensorValue value = outputs_[index]; + outputs_[index] = TensorValue(); + return value; +} + +inline Status OpKernelContext::forward_input_or_allocate_output( + gtl::ArraySlice candidate_input_indices, int output_index, + const TensorShape& output_shape, Tensor** output) { + for (int input_index : candidate_input_indices) { + if (forward_input_to_output_with_shape(input_index, output_index, + output_shape, output)) { + return Status::OK(); + } + } + return allocate_output(output_index, output_shape, output); +} + +inline Status OpKernelContext::forward_input_or_allocate_output( + gtl::ArraySlice candidate_input_names, StringPiece output_name, + const TensorShape& output_shape, Tensor** output) { + for (const StringPiece& input_name : candidate_input_names) { + if (forward_input_to_output_with_shape(input_name, output_name, + output_shape, output) + .ok()) { + return Status::OK(); + } + } + return allocate_output(output_name, output_shape, output); +} + +template +T* OpKernelContext::op_device_context() { + static_assert(std::is_base_of::value, + "T is not a subclass of DeviceContext"); + return static_cast(op_device_context()); +} + +template +T* OpKernelContext::input_device_context(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_->input_device_contexts->size()); + static_assert(std::is_base_of::value, + "T is not a subclass of DeviceContext"); + return static_cast((*params_->input_device_contexts)[index]); +} + +inline DeviceContext* OpKernelContext::input_device_context(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_->input_device_contexts->size()); + return (*params_->input_device_contexts)[index]; +} + +inline const Tensor& OpInputList::operator[](int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->input(start_ + i); +} + +inline mutex* OpMutableInputList::ref_mutex(int i) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->input_ref_mutex(start_ + i); +} + +inline Tensor OpMutableInputList::at(int i, bool lock_held) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->mutable_input(start_ + i, lock_held); +} + +inline Tensor* OpOutputList::operator[](int i) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->mutable_output(start_ + i); +} + +inline bool OpOutputList::required(int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->output_required(start_ + i); +} + +inline DataType OpOutputList::expected_output_dtype(int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->expected_output_dtype(start_ + i); +} + +inline Status OpOutputList::allocate(int i, const TensorShape& shape, + Tensor** output) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->allocate_output(start_ + i, shape, output); +} + +inline void OpOutputList::set(int i, const Tensor& tensor) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + ctx_->set_output(start_ + i, tensor); +} + +inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + ctx_->set_output_ref(i, mu, tensor_for_ref); +} + +// Convenience macros for asserting and handling exceptional conditions. +// Analogous to the CHECK* macros provided by logging.h. +// +// Example use: +// void Compute(OperationContext* context) { +// OP_REQUIRES(context, context->num_inputs() == 2, +// errors::InvalidArgument("FooOp requires 2 arguments")); +// ... +// Status status = SomeUncertainMethod(); +// OP_REQUIRES_OK(context, status); +// ... +// } + +#define OP_REQUIRES(CTX, EXP, STATUS) \ + do { \ + if (!TF_PREDICT_TRUE(EXP)) { \ + (CTX)->CtxFailure((STATUS)); \ + return; \ + } \ + } while (0) + +#define OP_REQUIRES_OK(CTX, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(_s); \ + return; \ + } \ + } while (0) + +#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \ + do { \ + if (!TF_PREDICT_TRUE(EXP)) { \ + (CTX)->CtxFailure((STATUS)); \ + (CALLBACK)(); \ + return; \ + } \ + } while (0) + +#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \ + do { \ + ::tensorflow::Status _s(STATUS); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(_s); \ + (CALLBACK)(); \ + return; \ + } \ + } while (0) + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ diff --git a/op_kernel_test.cc b/op_kernel_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..47523358bed40898cf82c531dc1a89fea0de88a3 --- /dev/null +++ b/op_kernel_test.cc @@ -0,0 +1,902 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include +#include +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" + +class DummyKernel : public tensorflow::OpKernel { + public: + explicit DummyKernel(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(tensorflow::OpKernelContext* context) override {} +}; + +// Test that registration works outside a namespace. +REGISTER_OP("Test1").Input("a: float").Input("b: int32").Output("o: uint8"); +REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU), + DummyKernel); + +namespace foo { +bool match_signature_ = false; + +// Test that registration works inside a different namespace. +class TestOp2 : public ::tensorflow::OpKernel { + public: + explicit TestOp2(::tensorflow::OpKernelConstruction* context) + : OpKernel(context) { + ::tensorflow::Status status = context->MatchSignature( + {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32}); + match_signature_ = status.ok(); + context->SetStatus(status); + } + void Compute(::tensorflow::OpKernelContext* context) override {} +}; + +REGISTER_OP("Test2").Input("i: T").Output("o: T").Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("Test2") + .Device(::tensorflow::DEVICE_GPU) + .HostMemory("i") + .HostMemory("o"), + TestOp2); +} // namespace foo + +namespace tensorflow { + +// Two operations with the same name but different devices. +REGISTER_OP("Test3").Input("a: T").Input("b: T").Attr("T: type"); + +class TestOp3Cpu : public tensorflow::OpKernel { + public: + explicit TestOp3Cpu(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +REGISTER_KERNEL_BUILDER( + Name("Test3").Device(DEVICE_CPU).TypeConstraint("T"), TestOp3Cpu); + +namespace { + +class TestOp3Gpu : public tensorflow::OpKernel { + public: + explicit TestOp3Gpu(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +REGISTER_KERNEL_BUILDER( + Name("Test3").Device(DEVICE_GPU).TypeConstraint("T"), TestOp3Cpu); + +// An Op registered for both +REGISTER_OP("Test4").Input("i: float").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel); + +static std::vector DeviceTypes() { + return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}; +} + +class OpKernelTest : public ::testing::Test { + public: + OpKernelTest() : device_(Env::Default()) {} + + protected: + NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs) { + NodeDefBuilder builder(op_type + "-op", op_type); + for (DataType dt : inputs) { + builder.Input(FakeInput(dt)); + } + NodeDef node_def; + TF_CHECK_OK(builder.Finalize(&node_def)); + return node_def; + } + + void ExpectEqual(const string& what, const DataTypeVector& expected, + const DataTypeVector& observed) { + EXPECT_EQ(expected.size(), observed.size()) << what; + const size_t size = std::min(expected.size(), observed.size()); + for (size_t i = 0; i < size; ++i) { + bool match = TypesCompatible(expected[i], observed[i]); + EXPECT_TRUE(match) << what << " i:" << i << ", expected: " << expected[i] + << ", observed: " << observed[i]; + } + } + + void ExpectSuccess(const string& op_type, DeviceType device_type, + const DataTypeVector& inputs, + const DataTypeVector& outputs) { + Status status; + std::unique_ptr op(CreateOpKernel( + std::move(device_type), &device_, cpu_allocator(), + CreateNodeDef(op_type, inputs), TF_GRAPH_DEF_VERSION, &status)); + EXPECT_TRUE(status.ok()) << status; + EXPECT_TRUE(op != nullptr); + if (op != nullptr) { + ExpectEqual("inputs", op->input_types(), inputs); + ExpectEqual("outputs", op->output_types(), outputs); + } + } + + void ExpectFailure(const string& ascii_node_def, DeviceType device_type, + error::Code code) { + NodeDef node_def; + protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def); + Status status; + std::unique_ptr op( + CreateOpKernel(std::move(device_type), &device_, cpu_allocator(), + node_def, TF_GRAPH_DEF_VERSION, &status)); + EXPECT_TRUE(op == nullptr); + EXPECT_FALSE(status.ok()); + if (!status.ok()) { + LOG(INFO) << "Status message: " << status.error_message(); + EXPECT_EQ(code, status.code()); + } + } + + private: + DeviceBase device_; +}; + +TEST_F(OpKernelTest, SuccessCpu) { + ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT, DT_INT32}, {DT_UINT8}); + ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT_REF, DT_INT32}, {DT_UINT8}); +} + +TEST_F(OpKernelTest, SuccessGpu) { + foo::match_signature_ = false; + ExpectSuccess("Test2", DEVICE_GPU, {DT_INT32}, {DT_INT32}); + EXPECT_TRUE(foo::match_signature_); +} + +TEST_F(OpKernelTest, SuccessBothCpuAndGpu) { + ExpectSuccess("Test3", DEVICE_CPU, {DT_INT8, DT_INT8}, {}); + ExpectSuccess("Test3", DEVICE_GPU, {DT_FLOAT, DT_FLOAT}, {}); +} + +TEST_F(OpKernelTest, CpuTypeRegistered) { + NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); + DeviceTypeVector devs; + TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(1, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); +} + +TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) { + { + // Try a node def of an op that is registered for a specific type + // only on CPU. + NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8}); + DeviceTypeVector devs; + TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(1, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); + } + { + // Try a node def of an op that is registered for a specific type + // only on GPU. + NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}); + DeviceTypeVector devs; + TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(1, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); + } + { + // Try a node def of an op that is only registered for other types. + NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING}); + DeviceTypeVector devs; + TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(0, devs.size()); + } + + { + // Try a node def of an op that is registered for both. + NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT}); + DeviceTypeVector devs; + TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(2, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1]); + } +} + +TEST_F(OpKernelTest, NotFound) { + const auto not_found = error::NOT_FOUND; + // Something with that op type name exists, but only with a + // different DeviceType. + ExpectFailure(CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}).DebugString(), + DEVICE_GPU, not_found); + ExpectFailure(CreateNodeDef("Test3", {DT_INT8, DT_INT8}).DebugString(), + DEVICE_GPU, not_found); + ExpectFailure(CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}).DebugString(), + DEVICE_CPU, not_found); + + // No kernel with that signature registered. + ExpectFailure(CreateNodeDef("Test3", {DT_INT32, DT_INT32}).DebugString(), + DEVICE_GPU, not_found); + + // Nothing with that op type name exists. + ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_CPU, not_found); + ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_GPU, not_found); +} + +TEST_F(OpKernelTest, TooFewInputs) { + const auto invalid = error::INVALID_ARGUMENT; + NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); + node_def.clear_input(); + ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid); + node_def.add_input("a"); + ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid); +} + +TEST_F(OpKernelTest, TooManyInputs) { + const auto invalid = error::INVALID_ARGUMENT; + NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); + node_def.add_input("c"); + ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid); +} + +TEST_F(OpKernelTest, MatchSignatureFailes) { + const auto invalid = error::INVALID_ARGUMENT; + foo::match_signature_ = true; + ExpectFailure(CreateNodeDef("Test2", {DT_FLOAT}).DebugString(), DEVICE_GPU, + invalid); + EXPECT_FALSE(foo::match_signature_); +} + +class DummyDevice : public DeviceBase { + public: + DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} + bool RequiresRecordingAccessedTensors() const override { return save_; } + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } + + private: + bool save_; +}; + +TEST_F(OpKernelTest, SaveTempFalse) { + Env* env = Env::Default(); + OpKernelContext::Params params; + params.record_tensor_accesses = false; + params.device = new DummyDevice(env, params.record_tensor_accesses); + Status status; + std::unique_ptr op( + CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), + CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), + TF_GRAPH_DEF_VERSION, &status)); + EXPECT_TRUE(status.ok()); + params.op_kernel = op.get(); + OpKernelContext* ctx = new OpKernelContext(¶ms); + + Tensor t; + TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t)); + + TensorReferenceVector referenced_tensors; + ctx->retrieve_accessed_tensors(&referenced_tensors); + EXPECT_EQ(0, referenced_tensors.size()); + + delete ctx; + delete params.device; +} + +TEST_F(OpKernelTest, SaveTempTrue) { + Env* env = Env::Default(); + OpKernelContext::Params params; + params.record_tensor_accesses = true; + params.device = new DummyDevice(env, params.record_tensor_accesses); + Status status; + std::unique_ptr op( + CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), + CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), + TF_GRAPH_DEF_VERSION, &status)); + EXPECT_TRUE(status.ok()); + params.op_kernel = op.get(); + OpKernelContext* ctx = new OpKernelContext(¶ms); + + Tensor t; + TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t)); + + TensorReferenceVector referenced_tensors; + ctx->retrieve_accessed_tensors(&referenced_tensors); + EXPECT_EQ(1, referenced_tensors.size()); + for (auto& ref : referenced_tensors) { + ref.Unref(); + } + + delete ctx; + delete params.device; +} + +TEST_F(OpKernelTest, InputDtype) { + Env* env = Env::Default(); + OpKernelContext::Params params; + params.record_tensor_accesses = false; + params.device = new DummyDevice(env, params.record_tensor_accesses); + Status status; + std::unique_ptr op( + CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), + CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), + TF_GRAPH_DEF_VERSION, &status)); + EXPECT_TRUE(status.ok()); + params.op_kernel = op.get(); + Tensor a(DT_FLOAT, TensorShape({})); + Tensor b(DT_INT32, TensorShape({})); + Tensor c(DT_UINT8, TensorShape({})); + gtl::InlinedVector inputs{TensorValue(&a), TensorValue(&b), + TensorValue(&c)}; + params.inputs = &inputs; + OpKernelContext* ctx = new OpKernelContext(¶ms); + + DataType dtype; + EXPECT_FALSE(ctx->input_dtype("non_existent_input", &dtype).ok()); + ASSERT_TRUE(ctx->input_dtype("a", &dtype).ok()); + EXPECT_EQ(dtype, DT_FLOAT); + ASSERT_TRUE(ctx->input_dtype("b", &dtype).ok()); + EXPECT_EQ(dtype, DT_INT32); + delete ctx; + delete params.device; +} + +class OpKernelBuilderTest : public ::testing::Test { + protected: + // Each attr is described by a "name|type|value". + NodeDef CreateNodeDef(const string& op_type, + const std::vector& attrs) { + NodeDef node_def; + node_def.set_name(op_type + "-op"); + node_def.set_op(op_type); + for (const string& attr_desc : attrs) { + std::vector parts = str_util::Split(attr_desc, '|'); + CHECK_EQ(parts.size(), 3); + AttrValue attr_value; + CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc; + node_def.mutable_attr()->insert( + AttrValueMap::value_type(parts[0], attr_value)); + } + return node_def; + } + + std::unique_ptr ExpectSuccess(const string& op_type, + const DeviceType& device_type, + const std::vector& attrs, + DataTypeSlice input_types = {}) { + Status status; + NodeDef def = CreateNodeDef(op_type, attrs); + for (size_t i = 0; i < input_types.size(); ++i) { + def.add_input("a:0"); + } + + Env* env = Env::Default(); + DeviceBase device(env); + + // Test CreateOpKernel() + std::unique_ptr op(CreateOpKernel(device_type, &device, + cpu_allocator(), def, + TF_GRAPH_DEF_VERSION, &status)); + EXPECT_TRUE(status.ok()) << status; + EXPECT_TRUE(op != nullptr); + if (op != nullptr) { + EXPECT_EQ(input_types.size(), op->num_inputs()); + EXPECT_EQ(0, op->num_outputs()); + } + + // Test SupportedDeviceTypesForNode() + DeviceTypeVector devices; + TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); + bool found = false; + for (const DeviceType& dt : devices) { + if (dt == device_type) { + found = true; + } + } + EXPECT_TRUE(found) << "Missing " << device_type << " from " + << devices.size() << " devices."; + + // In case the caller wants to use the OpKernel + return op; + } + + void ExpectFailure(const string& op_type, const DeviceType& device_type, + const std::vector& attrs, error::Code code) { + Status status; + const NodeDef def = CreateNodeDef(op_type, attrs); + Env* env = Env::Default(); + DeviceBase device(env); + + // Test CreateOpKernel(). + std::unique_ptr op(CreateOpKernel(device_type, &device, + cpu_allocator(), def, + TF_GRAPH_DEF_VERSION, &status)); + EXPECT_TRUE(op == nullptr); + EXPECT_FALSE(status.ok()); + if (!status.ok()) { + LOG(INFO) << "Status message: " << status.error_message(); + EXPECT_EQ(code, status.code()); + + // Test SupportedDeviceTypesForNode(). + DeviceTypeVector devices; + if (errors::IsNotFound(status)) { + TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); + for (const DeviceType& dt : devices) { + EXPECT_NE(dt, device_type); + } + } else { + Status status2 = + SupportedDeviceTypesForNode(DeviceTypes(), def, &devices); + EXPECT_EQ(status.code(), status2.code()); + } + } + } + + string GetKernelClassName(const string& op_type, + const DeviceType& device_type, + const std::vector& attrs, + DataTypeSlice input_types = {}) { + NodeDef def = CreateNodeDef(op_type, attrs); + for (size_t i = 0; i < input_types.size(); ++i) { + def.add_input("a:0"); + } + + const KernelDef* kernel_def = nullptr; + string kernel_class_name; + const Status status = + FindKernelDef(device_type, def, &kernel_def, &kernel_class_name); + if (status.ok()) { + return kernel_class_name; + } else if (errors::IsNotFound(status)) { + return "not found"; + } else { + return status.ToString(); + } + } +}; + +REGISTER_OP("BuildCPU"); +REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderCPU) { + ExpectSuccess("BuildCPU", DEVICE_CPU, {}); + EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {})); + ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND); + EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {})); +} + +REGISTER_OP("BuildGPU"); +REGISTER_KERNEL_BUILDER(Name("BuildGPU").Device(DEVICE_GPU), DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderGPU) { + ExpectFailure("BuildGPU", DEVICE_CPU, {}, error::NOT_FOUND); + ExpectSuccess("BuildGPU", DEVICE_GPU, {}); +} + +REGISTER_OP("BuildBoth"); +REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_GPU), DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderBoth) { + ExpectSuccess("BuildBoth", DEVICE_CPU, {}); + ExpectSuccess("BuildBoth", DEVICE_GPU, {}); +} + +REGISTER_OP("BuildTypeAttr").Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("BuildTypeAttr") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderTypeAttr) { + ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"}); + ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_BOOL"}, + error::NOT_FOUND); + ExpectFailure("BuildTypeAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT); + ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|int|7"}, + error::INVALID_ARGUMENT); +} + +REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)"); +REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) { + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"}); + EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, + {"T|list(type)|[]"})); + + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"}); + EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, + {"T|list(type)|[]"})); + + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, + {"T|list(type)|[DT_BOOL, DT_BOOL]"}); + + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"}, + error::NOT_FOUND); + EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, + {"T|list(type)|[DT_FLOAT]"})); + + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT); + EXPECT_TRUE( + StringPiece(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {})) + .contains("Invalid argument: ")); + + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"}, + error::INVALID_ARGUMENT); +} + +REGISTER_OP("DuplicateKernel"); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU), + DummyKernel); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU), + DummyKernel); + +TEST_F(OpKernelBuilderTest, DuplicateKernel) { + const NodeDef ndef = CreateNodeDef("DuplicateKernel", {}); + DeviceTypeVector devs; + Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Multiple OpKernel registrations match NodeDef")); + + ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT); +} + +REGISTER_OP("DuplicateKernelForT").Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + DummyKernel); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, DuplicateKernelForT) { + const NodeDef ndef = + CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"}); + DeviceTypeVector devs; + Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Multiple OpKernel registrations match NodeDef")); + + ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"}, + error::INVALID_ARGUMENT); + ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_BOOL"}, + error::NOT_FOUND); +} + +REGISTER_OP("BadConstraint").Attr("dtype: type"); +REGISTER_KERNEL_BUILDER(Name("BadConstraint") + .Device(DEVICE_CPU) + // Mistake: "T" should be "dtype". + .TypeConstraint("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, BadConstraint) { + const NodeDef ndef = CreateNodeDef("BadConstraint", {}); + DeviceTypeVector devs; + Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("OpKernel 'BadConstraint' has constraint on attr " + "'T' not in NodeDef")); + + ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"}, + error::INVALID_ARGUMENT); +} + +REGISTER_OP("ListOut").Output("a: int32").Output("b: T").Attr("T: list(type)"); +REGISTER_KERNEL_BUILDER(Name("ListOut").Device(tensorflow::DEVICE_CPU), + DummyKernel); + +TEST_F(OpKernelBuilderTest, OpOutputList) { + Env* env = Env::Default(); + OpKernelContext::Params params; + params.record_tensor_accesses = false; + std::unique_ptr device( + new DummyDevice(env, params.record_tensor_accesses)); + params.device = device.get(); + Status status; + std::unique_ptr op(CreateOpKernel( + DEVICE_CPU, params.device, cpu_allocator(), + CreateNodeDef("ListOut", {"T|list(type)|[DT_FLOAT, DT_INT32]"}), + TF_GRAPH_DEF_VERSION, &status)); + EXPECT_TRUE(status.ok()) << status.ToString(); + params.op_kernel = op.get(); + gtl::InlinedVector inputs{}; + params.inputs = &inputs; + std::unique_ptr ctx(new OpKernelContext(¶ms)); + + EXPECT_EQ(DT_INT32, ctx->expected_output_dtype(0)); + OpOutputList out_list; + EXPECT_FALSE(ctx->output_list("non_existent_output", &out_list).ok()); + ASSERT_TRUE(ctx->output_list("b", &out_list).ok()); + EXPECT_EQ(DT_FLOAT, out_list.expected_output_dtype(0)); + EXPECT_EQ(DT_INT32, out_list.expected_output_dtype(1)); +} + +class GetAttrKernel : public ::tensorflow::OpKernel { + public: + explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) { + string attr_name; + OP_REQUIRES_OK(context, context->GetAttr("attr_name", &attr_name)); + + status.emplace_back("s", context->GetAttr(attr_name, &s)); + status.emplace_back("s_list", context->GetAttr(attr_name, &s_list)); + status.emplace_back("i", context->GetAttr(attr_name, &i)); + status.emplace_back("i_list", context->GetAttr(attr_name, &i_list)); + status.emplace_back("i32", context->GetAttr(attr_name, &i32)); + status.emplace_back("i32_list", context->GetAttr(attr_name, &i32_list)); + status.emplace_back("f", context->GetAttr(attr_name, &f)); + status.emplace_back("f_list", context->GetAttr(attr_name, &f_list)); + status.emplace_back("b", context->GetAttr(attr_name, &b)); + status.emplace_back("b_list", context->GetAttr(attr_name, &b_list)); + status.emplace_back("type", context->GetAttr(attr_name, &type)); + status.emplace_back("type_list", context->GetAttr(attr_name, &type_list)); + status.emplace_back("type_vector", + context->GetAttr(attr_name, &type_vector)); + status.emplace_back("shape_proto", + context->GetAttr(attr_name, &shape_proto)); + status.emplace_back("shape_proto_list", + context->GetAttr(attr_name, &shape_proto_list)); + status.emplace_back("shape", context->GetAttr(attr_name, &shape)); + status.emplace_back("shape_list", context->GetAttr(attr_name, &shape_list)); + } + void Compute(::tensorflow::OpKernelContext* context) override {} + + void ExpectOk(std::initializer_list keys) { + for (const auto& key_status : status) { + // Only the status for keys in "keys" should be ok(). + bool in_keys = false; + for (const string& key : keys) { + if (key_status.first == key) { + in_keys = true; + } + } + EXPECT_EQ(in_keys, key_status.second.ok()) + << "key_status: " << key_status.first << ", " << key_status.second; + } + } + + string s; + std::vector s_list; + int64 i; + std::vector i_list; + int32 i32; + std::vector i32_list; + float f; + std::vector f_list; + bool b; + std::vector b_list; + DataType type; + std::vector type_list; + DataTypeVector type_vector; + TensorShapeProto shape_proto; + std::vector shape_proto_list; + TensorShape shape; + std::vector shape_list; + std::vector> status; +}; + +class GetAttrTest : public OpKernelBuilderTest {}; + +REGISTER_OP("GetAttrStringList") + .Attr("attr_name: string") + .Attr("a: list(string)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrStringList").Device(DEVICE_CPU), + GetAttrKernel); + +TEST_F(GetAttrTest, StringList) { + std::unique_ptr op_kernel = + ExpectSuccess("GetAttrStringList", DEVICE_CPU, + {"attr_name|string|'a'", "a|list(string)|['foo', 'bar']"}); + auto* get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"s_list"}); + EXPECT_EQ(std::vector({"foo", "bar"}), get_attr_kernel->s_list); + + op_kernel = ExpectSuccess("GetAttrStringList", DEVICE_CPU, + {"attr_name|string|'b'", "a|list(string)|['baz']"}); + get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({}); + EXPECT_TRUE(get_attr_kernel->s_list.empty()); +} + +REGISTER_OP("GetAttrInt") + .Attr("attr_name: string") + .Attr("a: int") + .Attr("b: list(int)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrInt").Device(DEVICE_CPU), GetAttrKernel); + +TEST_F(GetAttrTest, Int) { + std::unique_ptr op_kernel = ExpectSuccess( + "GetAttrInt", DEVICE_CPU, + {"attr_name|string|'a'", "a|int|35", "b|list(int)|[-1, 2, -4]"}); + auto* get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"i", "i32"}); + EXPECT_EQ(35, get_attr_kernel->i); + EXPECT_EQ(35, get_attr_kernel->i32); + + op_kernel = ExpectSuccess( + "GetAttrInt", DEVICE_CPU, + {"attr_name|string|'b'", "a|int|35", "b|list(int)|[-1, 2, -4]"}); + get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"i_list", "i32_list"}); + EXPECT_EQ(std::vector({-1, 2, -4}), get_attr_kernel->i_list); + EXPECT_EQ(std::vector({-1, 2, -4}), get_attr_kernel->i32_list); + + // 8589934592 == 2^33, too big to fit in an int32 + op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU, + {"attr_name|string|'a'", "a|int|8589934592", + "b|list(int)|[-8589934592]"}); + get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"i"}); // no i32 + EXPECT_EQ(8589934592ll, get_attr_kernel->i); + for (const auto& key_status : get_attr_kernel->status) { + if (key_status.first == "i32") { + EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code()); + EXPECT_EQ("Attr a has value 8589934592 out of range for an int32", + key_status.second.error_message()); + } + } + + op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU, + {"attr_name|string|'b'", "a|int|8589934592", + "b|list(int)|[-8589934592]"}); + get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"i_list"}); // no i32_list + EXPECT_EQ(std::vector({-8589934592ll}), get_attr_kernel->i_list); + for (const auto& key_status : get_attr_kernel->status) { + if (key_status.first == "i32_list") { + EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code()); + EXPECT_EQ("Attr b has value -8589934592 out of range for an int32", + key_status.second.error_message()); + } + } +} + +REGISTER_OP("GetAttrShape") + .Attr("attr_name: string") + .Attr("a: shape") + .Attr("b: list(shape)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrShape").Device(DEVICE_CPU), GetAttrKernel); + +TEST_F(GetAttrTest, Shape) { + std::unique_ptr op_kernel = ExpectSuccess( + "GetAttrShape", DEVICE_CPU, + {"attr_name|string|'a'", "a|shape|{ dim { size: 3 } }", + "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"}); + auto* get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"shape", "shape_proto"}); + EXPECT_EQ(get_attr_kernel->shape_proto.ShortDebugString(), "dim { size: 3 }"); + EXPECT_EQ("[3]", get_attr_kernel->shape.DebugString()); + + op_kernel = ExpectSuccess( + "GetAttrShape", DEVICE_CPU, + {"attr_name|string|'b'", "a|shape|{ dim { size: 3 } }", + "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"}); + get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"shape_list", "shape_proto_list"}); + ASSERT_EQ(2, get_attr_kernel->shape_proto_list.size()); + EXPECT_EQ(get_attr_kernel->shape_proto_list[0].ShortDebugString(), + "dim { size: 2 }"); + EXPECT_EQ(get_attr_kernel->shape_proto_list[1].ShortDebugString(), + "dim { size: 4 }"); + ASSERT_EQ(2, get_attr_kernel->shape_list.size()); + EXPECT_EQ("[2]", get_attr_kernel->shape_list[0].DebugString()); + EXPECT_EQ("[4]", get_attr_kernel->shape_list[1].DebugString()); +} + +REGISTER_OP("GetAttrType").Attr("attr_name: string").Attr("a: type"); +REGISTER_KERNEL_BUILDER(Name("GetAttrType").Device(DEVICE_CPU), GetAttrKernel); + +TEST_F(GetAttrTest, Type) { + std::unique_ptr op_kernel = ExpectSuccess( + "GetAttrType", DEVICE_CPU, {"attr_name|string|'a'", "a|type|DT_FLOAT"}); + auto* get_attr_kernel = static_cast(op_kernel.get()); + get_attr_kernel->ExpectOk({"type"}); + EXPECT_EQ(DT_FLOAT, get_attr_kernel->type); +} + +REGISTER_OP("GetAttrTypeList").Attr("attr_name: string").Attr("a: list(type)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrTypeList").Device(DEVICE_CPU), + GetAttrKernel); + +TEST_F(GetAttrTest, TypeList) { + std::unique_ptr op_kernel = ExpectSuccess( + "GetAttrTypeList", DEVICE_CPU, + {"attr_name|string|'a'", "a|list(type)|[DT_INT32, DT_BOOL]"}); + auto* get_attr_kernel = static_cast(op_kernel.get()); + + get_attr_kernel->ExpectOk({"type_list", "type_vector"}); + ASSERT_EQ(2, get_attr_kernel->type_list.size()); + EXPECT_EQ(DT_INT32, get_attr_kernel->type_list[0]); + EXPECT_EQ(DT_BOOL, get_attr_kernel->type_list[1]); + ASSERT_EQ(2, get_attr_kernel->type_vector.size()); + EXPECT_EQ(DT_INT32, get_attr_kernel->type_vector[0]); + EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]); +} + +class BaseKernel : public ::tensorflow::OpKernel { + public: + explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(::tensorflow::OpKernelContext* context) override {} + virtual int Which() const = 0; +}; + +template +class LabeledKernel : public BaseKernel { + public: + using BaseKernel::BaseKernel; + int Which() const override { return WHICH; } +}; + +class LabelTest : public OpKernelBuilderTest {}; + +REGISTER_OP("LabeledKernel"); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU), + LabeledKernel<0>); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("one"), + LabeledKernel<1>); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"), + LabeledKernel<2>); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"), + LabeledKernel<3>); + +TEST_F(LabelTest, Default) { + std::unique_ptr op_kernel = + ExpectSuccess("LabeledKernel", DEVICE_CPU, {}); + auto* get_labeled_kernel = static_cast(op_kernel.get()); + EXPECT_EQ(0, get_labeled_kernel->Which()); + + EXPECT_EQ("LabeledKernel<0>", + GetKernelClassName("LabeledKernel", DEVICE_CPU, {})); +} + +TEST_F(LabelTest, Specified) { + std::unique_ptr op_kernel = + ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"}); + auto* get_labeled_kernel = static_cast(op_kernel.get()); + EXPECT_EQ(1, get_labeled_kernel->Which()); + EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU, + {"_kernel|string|'one'"})); +} + +TEST_F(LabelTest, Duplicate) { + ExpectFailure("LabeledKernel", DEVICE_CPU, {"_kernel|string|'dupe'"}, + error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace tensorflow diff --git a/op_registration_test.cc b/op_registration_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..665b1bf33c78996b38f76f799ba63b32943efc19 --- /dev/null +++ b/op_registration_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/op.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +void Register(const string& op_name, OpRegistry* registry) { + registry->Register([op_name](OpRegistrationData* op_reg_data) -> Status { + op_reg_data->op_def.set_name(op_name); + return Status::OK(); + }); +} + +} // namespace + +TEST(OpRegistrationTest, TestBasic) { + std::unique_ptr registry(new OpRegistry); + Register("Foo", registry.get()); + OpList op_list; + registry->Export(true, &op_list); + EXPECT_EQ(op_list.op().size(), 1); + EXPECT_EQ(op_list.op(0).name(), "Foo"); +} + +TEST(OpRegistrationTest, TestDuplicate) { + std::unique_ptr registry(new OpRegistry); + Register("Foo", registry.get()); + Status s = registry->ProcessRegistrations(); + EXPECT_TRUE(s.ok()); + + TF_EXPECT_OK( + registry->SetWatcher([](const Status& s, const OpDef& op_def) -> Status { + EXPECT_TRUE(errors::IsAlreadyExists(s)); + return Status::OK(); + })); + Register("Foo", registry.get()); + s = registry->ProcessRegistrations(); + EXPECT_TRUE(s.ok()); +} + +} // namespace tensorflow diff --git a/op_segment.cc b/op_segment.cc new file mode 100644 index 0000000000000000000000000000000000000000..dfc5aa7747d7fcaef3d4dcf4c8caab3ea40c7a67 --- /dev/null +++ b/op_segment.cc @@ -0,0 +1,102 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_segment.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +OpSegment::Item::~Item() { + for (auto kv : name_kernel) delete kv.second; +} + +OpSegment::OpSegment() {} + +OpSegment::~OpSegment() { + for (auto kv : sessions_) delete kv.second; +} + +Status OpSegment::FindOrCreate(const string& session_handle, + const string& node_name, OpKernel** kernel, + CreateKernelFn create_fn) { + { + mutex_lock l(mu_); + auto item = gtl::FindPtrOrNull(sessions_, session_handle); + if (item == nullptr) { + return errors::NotFound("Session ", session_handle, " is not found."); + } + *kernel = gtl::FindPtrOrNull(item->name_kernel, node_name); + if (*kernel != nullptr) { + return Status::OK(); + } + } + Status s = create_fn(kernel); + if (!s.ok()) { + LOG(ERROR) << "Create kernel failed: " << s; + return s; + } + { + mutex_lock l(mu_); + auto item = gtl::FindPtrOrNull(sessions_, session_handle); + if (item == nullptr) { + return errors::NotFound("Session ", session_handle, " is not found."); + } + OpKernel** p_kernel = &(item->name_kernel[node_name]); + if (*p_kernel == nullptr) { + *p_kernel = *kernel; // Inserts 'kernel' in the map. + } else { + delete *kernel; + *kernel = *p_kernel; + } + } + return Status::OK(); +} + +void OpSegment::AddHold(const string& session_handle) { + mutex_lock l(mu_); + Item** item = &sessions_[session_handle]; + if (*item == nullptr) { + *item = new Item; // num_holds == 1 + } else { + ++((*item)->num_holds); + } +} + +void OpSegment::RemoveHold(const string& session_handle) { + Item* item = nullptr; + { + mutex_lock l(mu_); + auto siter = sessions_.find(session_handle); + if (siter == sessions_.end()) { + VLOG(1) << "Session " << session_handle << " is not found."; + return; + } + item = siter->second; + if (--(item->num_holds) > 0) { + return; + } else { + sessions_.erase(siter); + } + } + delete item; +} + +} // end namespace tensorflow diff --git a/op_segment.h b/op_segment.h new file mode 100644 index 0000000000000000000000000000000000000000..4433a2554f21a193c3cec75393049f5d1407062a --- /dev/null +++ b/op_segment.h @@ -0,0 +1,84 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_ +#define TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// OpSegment keeps track of OpKernels registered for sessions running +// on a device. +// +// The implementation maintains a two-level map. The 1st level maps +// session handle to the map of registered OpKernels. The 2nd level +// map maps node names to instantiated OpKernel objects. +// +// Each 2-nd level map is reference-counted and the caller can call +// AddHold to obtain a reference on all kernels of a session and +// ensure these kernels are alive until a corresponding RemoveHold is +// called on the same session. +class OpSegment { + public: + OpSegment(); + ~OpSegment(); + + // A hold can be placed on a session, preventing all its kernels + // from being deleted. + void AddHold(const string& session_handle); + void RemoveHold(const string& session_handle); + + // If the kernel for "node_name" has been created in the + // "session_handle", returns the existing op kernel in "*kernel". + // Otherwise, creates the kernel by calling create_fn(), cache it, + // and returns it in "*kernel". If create_fn() fails, returns the + // error. + // + // OpSegment keeps the ownership of the returned "*kernel". + typedef std::function CreateKernelFn; + Status FindOrCreate(const string& session_handle, const string& node_name, + OpKernel** kernel, CreateKernelFn create_fn); + + private: + // op name -> OpKernel + typedef std::unordered_map KernelMap; + struct Item { + int num_holds = 1; // Num of holds put on the session. + KernelMap name_kernel; // op name -> kernel. + ~Item(); + }; + + // session handle -> item. + // Session handles are produced by strings::FpToString() + typedef std::unordered_map SessionMap; + + mutable mutex mu_; + SessionMap sessions_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(OpSegment); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_ diff --git a/op_segment_test.cc b/op_segment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..af16e9f7ef96c468cdec684d23bb932ddf75d99d --- /dev/null +++ b/op_segment_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_segment.h" + +#include +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +class OpSegmentTest : public ::testing::Test { + protected: + DeviceBase device_; + std::vector int32_nodedefs_; + std::vector float_nodedefs_; + + OpSegmentTest() : device_(Env::Default()) { + for (int i = 0; i < 10; ++i) { + NodeDef def; + TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul") + .Input("x", 0, DT_INT32) + .Input("y", 0, DT_INT32) + .Finalize(&def)); + int32_nodedefs_.push_back(def); + TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul") + .Input("x", 0, DT_FLOAT) + .Input("y", 0, DT_FLOAT) + .Finalize(&def)); + float_nodedefs_.push_back(def); + } + } + + void ValidateOpAndTypes(OpKernel* op, const NodeDef& expected, DataType dt) { + ASSERT_NE(op, nullptr); + EXPECT_EQ(expected.DebugString(), op->def().DebugString()); + EXPECT_EQ(2, op->num_inputs()); + EXPECT_EQ(dt, op->input_type(0)); + EXPECT_EQ(dt, op->input_type(1)); + EXPECT_EQ(1, op->num_outputs()); + EXPECT_EQ(dt, op->output_type(0)); + } + + OpSegment::CreateKernelFn GetFn(const NodeDef* ndef) { + return [this, ndef](OpKernel** kernel) { + Status s; + auto created = CreateOpKernel(DEVICE_CPU, &device_, cpu_allocator(), + *ndef, TF_GRAPH_DEF_VERSION, &s); + if (s.ok()) { + *kernel = created.release(); + } + return s; + }; + } +}; + +TEST_F(OpSegmentTest, Basic) { + OpSegment opseg; + OpKernel* op; + + opseg.AddHold("A"); + opseg.AddHold("B"); + for (int i = 0; i < 10; ++i) { + // Register in session A. + auto* ndef = &float_nodedefs_[i]; + TF_EXPECT_OK(opseg.FindOrCreate("A", ndef->name(), &op, GetFn(ndef))); + ValidateOpAndTypes(op, *ndef, DT_FLOAT); + + // Register in session B. + ndef = &int32_nodedefs_[i]; + TF_EXPECT_OK(opseg.FindOrCreate("B", ndef->name(), &op, GetFn(ndef))); + ValidateOpAndTypes(op, *ndef, DT_INT32); + } + + auto reterr = [](OpKernel** kernel) { + return errors::Internal("Should not be called"); + }; + for (int i = 0; i < 10; ++i) { + // Lookup op in session A. + TF_EXPECT_OK( + opseg.FindOrCreate("A", strings::StrCat("op", i), &op, reterr)); + ValidateOpAndTypes(op, float_nodedefs_[i], DT_FLOAT); + + // Lookup op in session B. + TF_EXPECT_OK( + opseg.FindOrCreate("B", strings::StrCat("op", i), &op, reterr)); + ValidateOpAndTypes(op, int32_nodedefs_[i], DT_INT32); + } + + opseg.RemoveHold("A"); + opseg.RemoveHold("B"); +} + +TEST_F(OpSegmentTest, SessionNotFound) { + OpSegment opseg; + OpKernel* op; + NodeDef def = float_nodedefs_[0]; + Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); + EXPECT_TRUE(errors::IsNotFound(s)) << s; +} + +TEST_F(OpSegmentTest, CreateFailure) { + OpSegment opseg; + OpKernel* op; + NodeDef def = float_nodedefs_[0]; + def.set_op("nonexistop"); + opseg.AddHold("A"); + Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); + EXPECT_TRUE(errors::IsNotFound(s)) << s; + opseg.RemoveHold("A"); +} + +TEST_F(OpSegmentTest, AddRemoveHolds) { + OpSegment opseg; + OpKernel* op; + const auto& ndef = int32_nodedefs_[0]; + + // No op. + opseg.RemoveHold("null"); + + // Thread1 register the op and wants to ensure it alive. + opseg.AddHold("foo"); + TF_EXPECT_OK(opseg.FindOrCreate("foo", ndef.name(), &op, GetFn(&ndef))); + + // Thread2 starts some execution needs "op" to be alive. + opseg.AddHold("foo"); + + // Thread1 clears session "foo". E.g., a master sends CleanupGraph + // before an execution finishes. + opseg.RemoveHold("foo"); + + // Thread2 should still be able to access "op". + ValidateOpAndTypes(op, ndef, DT_INT32); + + // Thread2 then remove its hold on "foo". + opseg.RemoveHold("foo"); +} + +} // namespace tensorflow diff --git a/partial_tensor_shape.h b/partial_tensor_shape.h new file mode 100644 index 0000000000000000000000000000000000000000..fa1ce07dd43a462fbe856aca1032a88259248b21 --- /dev/null +++ b/partial_tensor_shape.h @@ -0,0 +1,22 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_PARTIAL_TENSOR_SHAPE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_PARTIAL_TENSOR_SHAPE_H_ + +// TODO(irving): Remove this forwarding header +#include "tensorflow/core/framework/tensor_shape.h" + +#endif // TENSORFLOW_CORE_FRAMEWORK_PARTIAL_TENSOR_SHAPE_H_ diff --git a/partial_tensor_shape_test.cc b/partial_tensor_shape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..54ae019f9b48128aab86b4d6a6154ec5dd60366a --- /dev/null +++ b/partial_tensor_shape_test.cc @@ -0,0 +1,276 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/partial_tensor_shape.h" + +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(PartialTensorShapeTest, Default) { + // The default PartialTensorShape constructor constructs a shape + // with unknown rank. + const PartialTensorShape s; + EXPECT_EQ(s.dims(), -1); + EXPECT_TRUE(s.unknown_rank()); +} + +TEST(PartialTensorShapeTest, Concatenate) { + const PartialTensorShape s({10, 5}); + ASSERT_EQ(2, s.dims()); + EXPECT_EQ(10, s.dim_size(0)); + EXPECT_EQ(5, s.dim_size(1)); + EXPECT_EQ(50, s.num_elements()); + + const auto s1 = s.Concatenate(s); + ASSERT_EQ(4, s1.dims()); + EXPECT_EQ(10, s1.dim_size(0)); + EXPECT_EQ(5, s1.dim_size(1)); + EXPECT_EQ(10, s1.dim_size(2)); + EXPECT_EQ(5, s1.dim_size(3)); + EXPECT_EQ(50 * 50, s1.num_elements()); + + const auto s2 = s.Concatenate(-1); + const auto s3 = s2.Concatenate(0); + ASSERT_EQ(3, s2.dims()); + ASSERT_EQ(4, s3.dims()); + EXPECT_EQ(10, s2.dim_size(0)); + EXPECT_EQ(10, s3.dim_size(0)); + EXPECT_EQ(5, s2.dim_size(1)); + EXPECT_EQ(5, s3.dim_size(1)); + EXPECT_EQ(-1, s2.dim_size(2)); + EXPECT_EQ(-1, s3.dim_size(2)); + EXPECT_EQ(0, s3.dim_size(3)); + EXPECT_EQ(-1, s2.num_elements()); + EXPECT_EQ(-1, s3.num_elements()); + + const auto s4 = s.Concatenate(PartialTensorShape()); + EXPECT_EQ(-1, s4.dims()); + EXPECT_EQ(-1, s4.num_elements()); +} + +TEST(PartialTensorShapeTest, InvalidShapeProto) { + TensorShapeProto proto; + EXPECT_TRUE(PartialTensorShape::IsValid(proto)); + + proto.add_dim()->set_size(357); + proto.add_dim()->set_size(982); + EXPECT_TRUE(PartialTensorShape::IsValid(proto)); + + proto.Clear(); + proto.add_dim()->set_size(0); + proto.add_dim()->set_size(-1); + EXPECT_TRUE(PartialTensorShape::IsValid(proto)); + + proto.Clear(); + proto.set_unknown_rank(true); + EXPECT_TRUE(PartialTensorShape::IsValid(proto)); + + proto.add_dim()->set_size(1); + EXPECT_FALSE(PartialTensorShape::IsValid(proto)); + + proto.Clear(); + proto.add_dim()->set_size(-2); + EXPECT_FALSE(PartialTensorShape::IsValid(proto)); +} + +TEST(PartialTensorShapeTest, PartialShapeFullyDefined) { + const PartialTensorShape a({-1, 0, 1}); + const PartialTensorShape b({1, 0, 1}); + const PartialTensorShape c({-1, -1, 1}); + const PartialTensorShape d({1, 0}); + const PartialTensorShape e({}); + const PartialTensorShape f; + EXPECT_FALSE(a.IsFullyDefined()); + EXPECT_FALSE(c.IsFullyDefined()); + EXPECT_TRUE(b.IsFullyDefined()); + EXPECT_TRUE(d.IsFullyDefined()); + EXPECT_TRUE(e.IsFullyDefined()); + EXPECT_FALSE(f.IsFullyDefined()); +} + +TEST(PartialTensorShapeTest, ToTensorShape) { + const PartialTensorShape a({}); + const PartialTensorShape b({1, 0}); + const PartialTensorShape c({-1, 0}); + const PartialTensorShape d; + TensorShape full; + EXPECT_TRUE(a.AsTensorShape(&full)); + EXPECT_EQ(full.dims(), 0); + EXPECT_TRUE(b.AsTensorShape(&full)); + EXPECT_EQ(full.dims(), 2); + EXPECT_EQ(full.dim_size(0), 1); + EXPECT_EQ(full.dim_size(1), 0); + EXPECT_FALSE(c.AsTensorShape(&full)); + EXPECT_FALSE(d.AsTensorShape(&full)); +} + +TEST(PartialTensorShapeTest, PartialShapeIdenticalTo) { + const PartialTensorShape a({-1, 0, 1}); + const PartialTensorShape b({1, 0, 1}); + const PartialTensorShape c({-1, -1, 1}); + const PartialTensorShape d({1, 0}); + const PartialTensorShape e({-1, 0, 2}); + const PartialTensorShape f({}); + const PartialTensorShape g; + std::vector shapes = {a, b, c, d, e, f, g}; + for (int i = 0; i < shapes.size(); ++i) { + for (int j = 0; j < i; ++j) { + if (i == j) { + EXPECT_TRUE(shapes[i].IsIdenticalTo(shapes[j])); + } else { + EXPECT_FALSE(shapes[i].IsIdenticalTo(shapes[j])); + } + } + } +} + +TEST(PartialTensorShapeTest, PartialShapeCompatibleWith) { + const PartialTensorShape a({-1, 0, 1}); + const PartialTensorShape b({1, 0, 1}); + const PartialTensorShape c({-1, -1, 1}); + const PartialTensorShape d({1, 0}); + const PartialTensorShape e({-1, 0, 2}); + const PartialTensorShape f({}); + const PartialTensorShape g; + + EXPECT_TRUE(f.IsCompatibleWith(f)); + EXPECT_TRUE(a.IsCompatibleWith(b)); + EXPECT_TRUE(a.IsCompatibleWith(a)); + EXPECT_TRUE(b.IsCompatibleWith(b)); + EXPECT_TRUE(a.IsCompatibleWith(c)); + EXPECT_TRUE(b.IsCompatibleWith(c)); + EXPECT_FALSE(a.IsCompatibleWith(d)); + EXPECT_FALSE(b.IsCompatibleWith(d)); + EXPECT_FALSE(c.IsCompatibleWith(d)); + EXPECT_FALSE(a.IsCompatibleWith(e)); + EXPECT_FALSE(b.IsCompatibleWith(e)); + EXPECT_FALSE(c.IsCompatibleWith(e)); + EXPECT_FALSE(a.IsCompatibleWith(f)); + EXPECT_FALSE(b.IsCompatibleWith(f)); + EXPECT_FALSE(c.IsCompatibleWith(f)); + EXPECT_TRUE(a.IsCompatibleWith(g)); + EXPECT_TRUE(g.IsCompatibleWith(a)); + EXPECT_TRUE(g.IsCompatibleWith(g)); +} + +TEST(PartialTensorShapeTest, ShapeCompatibleWith) { + const PartialTensorShape a({-1, 0, 1}); + const PartialTensorShape unknown; + TensorShape b({0, 1}); + TensorShape c({0, 0, 1}); + TensorShape d({1, 0, 1}); + TensorShape e({1, 1, 1}); + + EXPECT_FALSE(a.IsCompatibleWith(b)); + EXPECT_TRUE(a.IsCompatibleWith(c)); + EXPECT_TRUE(a.IsCompatibleWith(d)); + EXPECT_FALSE(a.IsCompatibleWith(e)); + + EXPECT_TRUE(unknown.IsCompatibleWith(b)); + EXPECT_TRUE(unknown.IsCompatibleWith(c)); + EXPECT_TRUE(unknown.IsCompatibleWith(d)); + EXPECT_TRUE(unknown.IsCompatibleWith(e)); +} + +TEST(PartialTensorShapeTest, PartialShapeMergeWith) { + const PartialTensorShape a({-1, 0, 1}); + const PartialTensorShape b({1, 0, 1}); + const PartialTensorShape c({-1, -1, 1}); + const PartialTensorShape d({1, 0}); + const PartialTensorShape e({-1, 0, 2}); + const PartialTensorShape f({}); + const PartialTensorShape g; + + PartialTensorShape test; + EXPECT_EQ(Status::OK(), a.MergeWith(a, &test)); + EXPECT_EQ(test.dims(), 3); + EXPECT_EQ(test.dim_size(0), -1); + EXPECT_EQ(test.dim_size(1), 0); + EXPECT_EQ(test.dim_size(2), 1); + + test = PartialTensorShape(); + EXPECT_EQ(Status::OK(), a.MergeWith(b, &test)); + EXPECT_EQ(test.dims(), 3); + EXPECT_EQ(test.dim_size(0), 1); + EXPECT_EQ(test.dim_size(1), 0); + EXPECT_EQ(test.dim_size(2), 1); + + test = PartialTensorShape(); + EXPECT_TRUE(errors::IsInvalidArgument(a.MergeWith(d, &test))); + + test = PartialTensorShape(); + EXPECT_EQ(Status::OK(), a.MergeWith(c, &test)); + EXPECT_EQ(test.dims(), 3); + EXPECT_EQ(test.dim_size(0), -1); + EXPECT_EQ(test.dim_size(1), 0); + EXPECT_EQ(test.dim_size(2), 1); + + test = PartialTensorShape(); + EXPECT_EQ(Status::OK(), c.MergeWith(a, &test)); + EXPECT_EQ(test.dims(), 3); + EXPECT_EQ(test.dim_size(0), -1); + EXPECT_EQ(test.dim_size(1), 0); + EXPECT_EQ(test.dim_size(2), 1); + + test = PartialTensorShape(); + EXPECT_EQ(Status::OK(), a.MergeWith(g, &test)); + EXPECT_EQ(test.dims(), 3); + EXPECT_EQ(test.dim_size(0), -1); + EXPECT_EQ(test.dim_size(1), 0); + EXPECT_EQ(test.dim_size(2), 1); + + test = PartialTensorShape(); + EXPECT_EQ(Status::OK(), g.MergeWith(a, &test)); + EXPECT_EQ(test.dims(), 3); + EXPECT_EQ(test.dim_size(0), -1); + EXPECT_EQ(test.dim_size(1), 0); + EXPECT_EQ(test.dim_size(2), 1); +} + +TEST(PartialTensorShapeTest, MakePartialShapeEmpty) { + // Empty made partial shapes should still be fully defined + const int64 dims[1] = {}; + PartialTensorShape shape; + EXPECT_FALSE(shape.IsFullyDefined()); + TF_ASSERT_OK(PartialTensorShape::MakePartialShape(dims, 0, &shape)); + EXPECT_TRUE(shape.IsFullyDefined()); +} + +TEST(PartialTensorShapeTest, MakePartialShapeFull) { + // Check that arrays are copied through correctly + const int64 dims[3] = {7, -1, 2}; + PartialTensorShape shape; + TF_ASSERT_OK(PartialTensorShape::MakePartialShape(dims, 3, &shape)); + ASSERT_EQ(shape.dims(), 3); + for (int i = 0; i < 3; i++) { + EXPECT_EQ(shape.dim_size(i), dims[i]); + } +} + +TEST(PartialTensorShapeTest, MakePartialShapeInvalid) { + // Check that arrays are copied through correctly + const int64 dims[3] = {7, -2, 2}; + PartialTensorShape shape; + EXPECT_EQ(error::INVALID_ARGUMENT, + PartialTensorShape::MakePartialShape(dims, 3, &shape).code()); +} + +} // namespace +} // namespace tensorflow diff --git a/queue_interface.h b/queue_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..4aeaab3d9b00a46752279a296f13e67370776357 --- /dev/null +++ b/queue_interface.h @@ -0,0 +1,102 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_ +#define TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// All implementations must be thread-safe. +class QueueInterface : public ResourceBase { + public: + typedef std::vector Tuple; + typedef AsyncOpKernel::DoneCallback DoneCallback; + typedef std::function CallbackWithTuple; + + virtual Status ValidateTuple(const Tuple& tuple) = 0; + virtual Status ValidateManyTuple(const Tuple& tuple) = 0; + + // Stashes a function object for future execution, that will eventually + // enqueue the tuple of tensors into the queue, and returns immediately. The + // function object is guaranteed to call 'callback'. + virtual void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) = 0; + + // Same as above, but the component tensors are sliced along the 0th dimension + // to make multiple queue-element components. + virtual void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) = 0; + + // Stashes a function object for future execution, that will eventually + // dequeue an element from the queue and call 'callback' with that tuple + // element as argument. + virtual void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) = 0; + + // Same as above, but the stashed function object will attempt to dequeue + // num_elements items. If allow_small_batch is true, and the Queue is + // closed but at least 1 element is available, there is no blocking + // and between 1 and num_elements items are immediately returned. + // If the queue does not support the allow_small_batch flag will + // return an Unimplemented error. + virtual void TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, + CallbackWithTuple callback) = 0; + + // Signals that no more elements will be enqueued, and optionally + // cancels pending Enqueue(Many) operations. + // + // After calling this function, subsequent calls to Enqueue(Many) + // will fail. If `cancel_pending_enqueues` is true, all pending + // calls to Enqueue(Many) will fail as well. + // + // After calling this function, all current and subsequent calls to + // Dequeue(Many) will fail instead of blocking (though they may + // succeed if they can be satisfied by the elements in the queue at + // the time it was closed). + virtual void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, + DoneCallback callback) = 0; + + // Returns true if a given queue is closed and false if it is open. + virtual bool is_closed() const = 0; + + // Assuming *this represents a shared queue, verify that it matches + // another instantiation indicated by node_def. + virtual Status MatchesNodeDef(const NodeDef& node_def) = 0; + + // Returns the number of elements in the queue. + virtual int32 size() = 0; + + virtual const DataTypeVector& component_dtypes() const = 0; + + string DebugString() override { + return strings::StrCat("A Queue of size: ", size()); + } + + protected: + virtual ~QueueInterface() {} +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_ diff --git a/reader_base.cc b/reader_base.cc new file mode 100644 index 0000000000000000000000000000000000000000..b8c771a0a1955b29f78478f60972b22d804351b2 --- /dev/null +++ b/reader_base.cc @@ -0,0 +1,266 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/reader_base.h" + +#include "tensorflow/core/framework/reader_base.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +// ReaderBase ------------------------------------------------------ + +ReaderBase::ReaderBase(const string& name) : name_(name) {} + +int64 ReaderBase::NumRecordsProduced() { + mutex_lock lock(mu_); + return num_records_produced_; +} + +int64 ReaderBase::NumWorkUnitsCompleted() { + mutex_lock lock(mu_); + return work_finished_; +} + +Status ReaderBase::Reset() { + mutex_lock lock(mu_); + return ResetLocked(); +} + +Status ReaderBase::ResetLocked() { + work_started_ = 0; + work_finished_ = 0; + num_records_produced_ = 0; + work_.clear(); + return Status::OK(); +} + +Status ReaderBase::SerializeState(string* state) { + mutex_lock lock(mu_); + return SerializeStateLocked(state); +} + +Status ReaderBase::SerializeStateLocked(string* state) { + return errors::Unimplemented("Reader SerializeState"); +} + +Status ReaderBase::RestoreState(const string& state) { + mutex_lock lock(mu_); + Status status = RestoreStateLocked(state); + if (!status.ok()) { + ResetLocked().IgnoreError(); + } + return status; +} + +Status ReaderBase::RestoreStateLocked(const string& state) { + return errors::Unimplemented("Reader RestoreState"); +} + +int64 ReaderBase::ReadUpTo(const int64 num_records, QueueInterface* queue, + std::vector* keys, + std::vector* values, + OpKernelContext* context) { + mutex_lock lock(mu_); + int64 records_produced_this_call = 0; + while (true) { + // Records produced by this iteration of the ReadUpToLocked call. + int64 num_records_produced = 0; + int64 remaining = num_records - records_produced_this_call; + if (remaining == 0) { + return records_produced_this_call; + } + if (!work_in_progress()) { + work_ = GetNextWorkLocked(queue, context); + if (!context->status().ok()) { + return records_produced_this_call; + } + Status status = OnWorkStartedLocked(); + if (status.ok()) { + work_started_++; + } else { + context->SetStatus(status); + return records_produced_this_call; + } + } + bool at_end = false; + + Status status = + ReadUpToLocked(remaining, keys, values, &num_records_produced, &at_end); + // This call so far. + records_produced_this_call += num_records_produced; + + // In total, over the lifetime of the ReaderBase. + num_records_produced_ += num_records_produced; + + if (!at_end && status.ok() && num_records_produced == 0) { + status = errors::Internal( + "ReadManyLocked() for ", name(), + " must set *at_end=true, *num_produced > 0 or return an error."); + context->SetStatus(status); + return records_produced_this_call; + } + if (status.ok() && at_end) { + status = OnWorkFinishedLocked(); + work_finished_ = work_started_; + if (records_produced_this_call > 0) { + return records_produced_this_call; + } + } + if (!status.ok()) { + context->SetStatus(status); + return records_produced_this_call; + } + } +} + +// Default implementation just reads one record at a time. +Status ReaderBase::ReadUpToLocked(int64 num_records, std::vector* keys, + std::vector* values, int64* num_read, + bool* at_end) { + bool produced = false; + string key; + string value; + Status status = ReadLocked(&key, &value, &produced, at_end); + if (produced) { + keys->emplace_back(key); + values->emplace_back(value); + *num_read = 1; + } else { + *num_read = 0; + } + return status; +} + +void ReaderBase::Read(QueueInterface* queue, string* key, string* value, + OpKernelContext* context) { + mutex_lock lock(mu_); + while (true) { + if (!work_in_progress()) { + work_ = GetNextWorkLocked(queue, context); + if (!context->status().ok()) { + return; + } + Status status = OnWorkStartedLocked(); + if (status.ok()) { + work_started_++; + } else { + context->SetStatus(status); + return; + } + } + + bool produced = false; + bool at_end = false; + Status status = ReadLocked(key, value, &produced, &at_end); + + if (!at_end && status.ok() && !produced) { + status = errors::Internal( + "ReadLocked() for ", name(), + " must set *at_end=true, *produced=true, or return an error."); + } + if (!status.ok() && produced) { + status = errors::Internal("ReadLocked() for ", name(), + " set *produced=true *and* returned an error: ", + status.ToString()); + } + if (status.ok() && at_end) { + status = OnWorkFinishedLocked(); + work_finished_ = work_started_; + } + if (!status.ok()) { + context->SetStatus(status); + return; + } + if (produced) { + ++num_records_produced_; + return; + } + } +} + +string ReaderBase::GetNextWorkLocked(QueueInterface* queue, + OpKernelContext* context) const { + string work; + Notification n; + queue->TryDequeue( + context, [this, context, &n, &work](const QueueInterface::Tuple& tuple) { + if (context->status().ok()) { + if (tuple.size() != 1) { + context->SetStatus( + errors::InvalidArgument("Expected single component queue")); + } else if (tuple[0].dtype() != DT_STRING) { + context->SetStatus(errors::InvalidArgument( + "Expected queue with single string component")); + } else if (tuple[0].NumElements() != 1) { + context->SetStatus(errors::InvalidArgument( + "Expected to dequeue a one-element string tensor")); + } else { + work = tuple[0].flat()(0); + } + } + n.Notify(); + }); + n.WaitForNotification(); + return work; +} + +void ReaderBase::SaveBaseState(ReaderBaseState* state) const { + state->Clear(); + state->set_work_started(work_started_); + state->set_work_finished(work_finished_); + state->set_num_records_produced(num_records_produced_); + state->set_current_work(work_); +} + +string ReaderBase::KeyName(const string& key) const { + return strings::StrCat(current_work(), ":", key); +} + +Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) { + work_started_ = state.work_started(); + work_finished_ = state.work_finished(); + num_records_produced_ = state.num_records_produced(); + work_ = state.current_work(); + if (work_started_ < 0 || work_finished_ < 0 || num_records_produced_ < 0) { +#ifdef __ANDROID__ + const string debug_string = ""; +#else + const string debug_string = state.DebugString(); +#endif + return errors::InvalidArgument( + "Unexpected negative value when restoring in ", name(), ": ", + debug_string); + } + if (work_started_ > work_finished_) { +#ifdef __ANDROID__ + const string debug_string = ""; +#else + const string debug_string = state.DebugString(); +#endif + return errors::InvalidArgument( + "Inconsistent work started vs. finished when restoring in ", name(), + ": ", debug_string); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/reader_base.h b/reader_base.h new file mode 100644 index 0000000000000000000000000000000000000000..cb44be4dee8d0b39e0c0073221cb7bb70388a508 --- /dev/null +++ b/reader_base.h @@ -0,0 +1,138 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_READER_BASE_H_ +#define TENSORFLOW_FRAMEWORK_READER_BASE_H_ + +#include +#include +#include "tensorflow/core/framework/queue_interface.h" +#include "tensorflow/core/framework/reader_interface.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +class ReaderBaseState; + +// Default implementation of ReaderInterface. +class ReaderBase : public ReaderInterface { + public: + // name: For use in error messages, should mention both the name of + // the op and the node. + explicit ReaderBase(const string& name); + + // Note that methods with names ending in "Locked" are called while + // the ReaderBase's mutex is held. + + // Implement this function in descendants ----------------------------------- + + // Produce the next key/value pair from the current work item. + // This is called "Locked" since it is executed under a mutex + // that serializes all Reader calls. + // Usage: + // a) If a record was successfully produced, set *produced = true, + // and fill in *key and *value. + // b) If no more records will be produced for this work item, set + // *at_end = true. + // c) If a record was produced, but no more will be produced, you + // may either do both (a) and (b), or do (a) in this call and do (b) in + // the next call to ReadLocked(). + // d) If there was an error producing (e.g. an error reading the file, + // data corruption), return a non-OK() status. ReadLocked may be + // called again if the user reruns this part of the graph. + virtual Status ReadLocked(string* key, string* value, bool* produced, + bool* at_end) = 0; + + // Descendants may optionally implement these ------------------------------- + + // Produce up to num_records next key/value pairs from the current + // work item, in the same manner of ReadLocked. + virtual Status ReadUpToLocked(int64 num_records, std::vector* keys, + std::vector* values, int64* num_read, + bool* at_end); + + // Called when work starts / finishes. + virtual Status OnWorkStartedLocked() { return Status::OK(); } + virtual Status OnWorkFinishedLocked() { return Status::OK(); } + + // Called to reset the Reader to a newly constructed state. + virtual Status ResetLocked(); + + // Default implementation generates an Unimplemented error. + // See the protected helper methods below. + virtual Status SerializeStateLocked(string* state); + virtual Status RestoreStateLocked(const string& state); + + // Accessors ---------------------------------------------------------------- + + // Always true during a call to ReadLocked(). + bool work_in_progress() const { return work_finished_ < work_started_; } + + // Returns the name of the current work item (valid if + // work_in_progress() returns true). May change between calls to + // ReadLocked(). + const string& current_work() const { return work_; } + + // What was passed to the constructor. + const string& name() const { return name_; } + + // Produce the key name (from current_work and the actual key). + string KeyName(const string& key) const; + + protected: + // For descendants wishing to implement serialize & restore state. + + // Writes ReaderBase state to *state. + void SaveBaseState(ReaderBaseState* state) const; + + // Restores ReaderBase state from state. Assumes state was filled + // using SaveBaseState() above. + Status RestoreBaseState(const ReaderBaseState& state); + + private: + // For descendants that wish to obtain the next work item in a different way. + // For implementing Read(). Dequeues the next work item from + // *queue, and if successful returns "work" (a string). May block. + virtual string GetNextWorkLocked(QueueInterface* queue, + OpKernelContext* context) const; + + // Implementations of ReaderInterface methods. These ensure thread-safety + // and call the methods above to do the work. + void Read(QueueInterface* queue, string* key, string* value, + OpKernelContext* context) override; + + // Produces up to num_records. + // In this implementation all the records come from the same work unit. + int64 ReadUpTo(const int64 num_records, QueueInterface* queue, + std::vector* keys, std::vector* value, + OpKernelContext* context) override; + + Status Reset() override; + int64 NumRecordsProduced() override; + int64 NumWorkUnitsCompleted() override; + Status SerializeState(string* state) override; + Status RestoreState(const string& state) override; + + mutable mutex mu_; + const string name_; + int64 work_started_ = 0; + int64 work_finished_ = 0; + int64 num_records_produced_ = 0; + string work_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_READER_BASE_H_ diff --git a/reader_base.proto b/reader_base.proto new file mode 100644 index 0000000000000000000000000000000000000000..1b8b965ee105fbdc3c399e3875e80048d27682db --- /dev/null +++ b/reader_base.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ReaderBaseProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// For serializing and restoring the state of ReaderBase, see +// reader_base.h for details. +message ReaderBaseState { + int64 work_started = 1; + int64 work_finished = 2; + int64 num_records_produced = 3; + bytes current_work = 4; +}; diff --git a/reader_interface.h b/reader_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..dac6056b5abf3d03cf56088db8debccce99adc14 --- /dev/null +++ b/reader_interface.h @@ -0,0 +1,87 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_ +#define TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_ + +#include +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class QueueInterface; +class ReaderInterface; + +// Readers are the mechanism for reading records from files in +// TensorFlow graphs. Each supported file format has a corresponding +// ReaderInterface descendant and a corresponding Op & OpKernel +// (implemented using ReaderOpKernel from reader_op_kernel.h). +// +// To use a Reader, you first encode "work" (some string, typically a +// filename) in the Reader's "work queue". It then processes the +// "work" (reading records from the file), to produce key/value +// strings. The methods of this class are called by ReaderFoo ops, +// so see ../ops/io_ops.cc for detailed descriptions. +// +// All descendants of this class must be thread-safe. +class ReaderInterface : public ResourceBase { + public: + // Read a single record into *key / *value. May get more work from + // *queue if the current work is complete. Sets the status on + // *context with an OutOfRange Status if the current work is + // complete and the queue is done (closed and empty). + // This method may block. + virtual void Read(QueueInterface* queue, string* key, string* value, + OpKernelContext* context) = 0; + + // Read up to num_records records into keys / values. May get more work from + // *queue if the current work is complete. Sets the status on + // *context with an OutOfRange Status if the current work is + // complete and the queue is done (closed and empty). + // This method may block. + // The std::vector keys/value pointers are assumed to point to empty + // structures (that have most likely been reserve(num_records)). + // Returns how many records were actually read. + virtual int64 ReadUpTo(const int64 num_records, QueueInterface* queue, + std::vector* keys, std::vector* value, + OpKernelContext* context) = 0; + + // Restore this reader to its newly-constructed state. + virtual Status Reset() = 0; + + // Accessors + virtual int64 NumRecordsProduced() = 0; + virtual int64 NumWorkUnitsCompleted() = 0; + + // -- Serialization/Restoration support -- + // Not all readers will support saving and restoring state. + virtual Status SerializeState(string* state) = 0; + // Note: Must Reset on error. + virtual Status RestoreState(const string& state) = 0; + + string DebugString() override { return "a reader"; } + + protected: + virtual ~ReaderInterface() {} +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_ diff --git a/reader_op_kernel.h b/reader_op_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ffd6a1a18486cc0b015c75775b40c3a1118109c0 --- /dev/null +++ b/reader_op_kernel.h @@ -0,0 +1,88 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_ +#define TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/reader_interface.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// NOTE: This is now a very thin layer over ResourceOpKernel. +// TODO(sjhwang): Remove dependencies to this class, then delete this. + +// Implementation for ops providing a Reader. +class ReaderOpKernel : public ResourceOpKernel { + public: + using ResourceOpKernel::ResourceOpKernel; + + // Must be called by descendants before the first call to Compute() + // (typically called during construction). factory must return a + // ReaderInterface descendant allocated with new that ReaderOpKernel + // will take ownership of. + void SetReaderFactory(std::function factory) + LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + DCHECK(resource_ == nullptr); + factory_ = factory; + } + + void Compute(OpKernelContext* context) override { + if (!IsCancellable()) { + ResourceOpKernel::Compute(context); + } else { + // Install cancellation + CancellationManager* cm = context->cancellation_manager(); + CancellationToken token = cm->get_cancellation_token(); + bool already_cancelled = + !cm->RegisterCallback(token, [this]() { this->Cancel(); }); + + if (!already_cancelled) { + ResourceOpKernel::Compute(context); + } else { + context->SetStatus(errors::Cancelled("read operation was cancelled")); + } + } + } + + private: + virtual bool IsCancellable() const { return false; } + virtual void Cancel() {} + + Status CreateResource(ReaderInterface** reader) + EXCLUSIVE_LOCKS_REQUIRED(mu_) override { + *reader = factory_(); + if (*reader == nullptr) { + return errors::ResourceExhausted("Failed to allocate reader"); + } + std::function temp = nullptr; + factory_.swap(temp); + return Status::OK(); + } + + std::function factory_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_ diff --git a/register_types.h b/register_types.h new file mode 100644 index 0000000000000000000000000000000000000000..320531f03ad1dc19d5bcc388a393026f4ba6ca11 --- /dev/null +++ b/register_types.h @@ -0,0 +1,223 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ +// This file is used by cuda code and must remain compilable by nvcc. + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/platform/types.h" + +// Two sets of macros: +// - TF_CALL_float, TF_CALL_double, etc. which call the given macro with +// the type name as the only parameter - except on platforms for which +// the type should not be included. +// - Macros to apply another macro to lists of supported types. These also call +// into TF_CALL_float, TF_CALL_double, etc. so they filter by target platform +// as well. +// If you change the lists of types, please also update the list in types.cc. +// +// See example uses of these macros in core/ops. +// +// +// Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple +// times by passing each invocation a data type supported by TensorFlow. +// +// The different variations pass different subsets of the types. +// TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow. +// The set of types depends on the compilation platform. +//. +// This can be used to register a different template instantiation of +// an OpKernel for different signatures, e.g.: +/* + #define REGISTER_PARTITION(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Partition").Device(DEVICE_CPU).TypeConstraint("T"), \ + PartitionOp); + TF_CALL_ALL_TYPES(REGISTER_PARTITION) + #undef REGISTER_PARTITION +*/ + +#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) || defined(NVIDIA_TEGRA) + +// All types are supported, so all macros are invoked. +// +// Note: macros are defined in same order as types in types.proto, for +// readability. +#define TF_CALL_float(m) m(float) +#define TF_CALL_double(m) m(double) +#define TF_CALL_int32(m) m(::tensorflow::int32) +#define TF_CALL_uint32(m) m(::tensorflow::uint32) +#define TF_CALL_uint8(m) m(::tensorflow::uint8) +#define TF_CALL_int16(m) m(::tensorflow::int16) + +#define TF_CALL_int8(m) m(::tensorflow::int8) +#define TF_CALL_string(m) m(string) +#define TF_CALL_resource(m) m(::tensorflow::ResourceHandle) +#define TF_CALL_variant(m) m(::tensorflow::Variant) +#define TF_CALL_complex64(m) m(::tensorflow::complex64) +#define TF_CALL_int64(m) m(::tensorflow::int64) +#define TF_CALL_uint64(m) m(::tensorflow::uint64) +#define TF_CALL_bool(m) m(bool) + +#define TF_CALL_qint8(m) m(::tensorflow::qint8) +#define TF_CALL_quint8(m) m(::tensorflow::quint8) +#define TF_CALL_qint32(m) m(::tensorflow::qint32) +#define TF_CALL_bfloat16(m) m(::tensorflow::bfloat16) +#define TF_CALL_qint16(m) m(::tensorflow::qint16) + +#define TF_CALL_quint16(m) m(::tensorflow::quint16) +#define TF_CALL_uint16(m) m(::tensorflow::uint16) +#define TF_CALL_complex128(m) m(::tensorflow::complex128) +#define TF_CALL_half(m) m(Eigen::half) + +#elif defined(__ANDROID_TYPES_FULL__) + +// Only string, half, float, int32, int64, bool, and quantized types +// supported. +#define TF_CALL_float(m) m(float) +#define TF_CALL_double(m) +#define TF_CALL_int32(m) m(::tensorflow::int32) +#define TF_CALL_uint32(m) +#define TF_CALL_uint8(m) +#define TF_CALL_int16(m) + +#define TF_CALL_int8(m) +#define TF_CALL_string(m) m(string) +#define TF_CALL_resource(m) +#define TF_CALL_variant(m) +#define TF_CALL_complex64(m) +#define TF_CALL_int64(m) m(::tensorflow::int64) +#define TF_CALL_uint64(m) +#define TF_CALL_bool(m) m(bool) + +#define TF_CALL_qint8(m) m(::tensorflow::qint8) +#define TF_CALL_quint8(m) m(::tensorflow::quint8) +#define TF_CALL_qint32(m) m(::tensorflow::qint32) +#define TF_CALL_bfloat16(m) +#define TF_CALL_qint16(m) m(::tensorflow::qint16) + +#define TF_CALL_quint16(m) m(::tensorflow::quint16) +#define TF_CALL_uint16(m) +#define TF_CALL_complex128(m) +#define TF_CALL_half(m) m(Eigen::half) + +#else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) + +// Only float, int32, and bool are supported. +#define TF_CALL_float(m) m(float) +#define TF_CALL_double(m) +#define TF_CALL_int32(m) m(::tensorflow::int32) +#define TF_CALL_uint32(m) +#define TF_CALL_uint8(m) +#define TF_CALL_int16(m) + +#define TF_CALL_int8(m) +#define TF_CALL_string(m) +#define TF_CALL_resource(m) +#define TF_CALL_variant(m) +#define TF_CALL_complex64(m) +#define TF_CALL_int64(m) +#define TF_CALL_uint64(m) +#define TF_CALL_bool(m) m(bool) + +#define TF_CALL_qint8(m) +#define TF_CALL_quint8(m) +#define TF_CALL_qint32(m) +#define TF_CALL_bfloat16(m) +#define TF_CALL_qint16(m) + +#define TF_CALL_quint16(m) +#define TF_CALL_uint16(m) +#define TF_CALL_complex128(m) +#define TF_CALL_half(m) + +#endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines + +// Defines for sets of types. + +#define TF_CALL_INTEGRAL_TYPES(m) \ + TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \ + TF_CALL_uint8(m) TF_CALL_int8(m) + +#define TF_CALL_REAL_NUMBER_TYPES(m) \ + TF_CALL_INTEGRAL_TYPES(m) \ + TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) + +#define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \ + TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) + +#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \ + TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \ + TF_CALL_int8(m) + +// Call "m" for all number types, including complex64 and complex128. +#define TF_CALL_NUMBER_TYPES(m) \ + TF_CALL_REAL_NUMBER_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m) + +#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_complex64(m) TF_CALL_complex128(m) + +#define TF_CALL_POD_TYPES(m) TF_CALL_NUMBER_TYPES(m) TF_CALL_bool(m) + +// Call "m" on all types. +#define TF_CALL_ALL_TYPES(m) \ + TF_CALL_POD_TYPES(m) TF_CALL_string(m) TF_CALL_resource(m) + +// Call "m" on POD and string types. +#define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m) + +// Call "m" on all number types supported on GPU. +#define TF_CALL_GPU_NUMBER_TYPES(m) \ + TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) + +// Call "m" on all types supported on GPU. +#define TF_CALL_GPU_ALL_TYPES(m) \ + TF_CALL_GPU_NUMBER_TYPES(m) \ + TF_CALL_bool(m) TF_CALL_complex64(m) TF_CALL_complex128(m) + +#define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m) + +// Call "m" on all quantized types. +// TODO(cwhipkey): include TF_CALL_qint16(m) TF_CALL_quint16(m) +#define TF_CALL_QUANTIZED_TYPES(m) \ + TF_CALL_qint8(m) TF_CALL_quint8(m) TF_CALL_qint32(m) + +// Types used for save and restore ops. +#define TF_CALL_SAVE_RESTORE_TYPES(m) \ + TF_CALL_INTEGRAL_TYPES(m) \ + TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_complex64(m) \ + TF_CALL_complex128(m) TF_CALL_bool(m) TF_CALL_string(m) \ + TF_CALL_QUANTIZED_TYPES(m) + +#ifdef TENSORFLOW_SYCL_NO_DOUBLE +#define TF_CALL_SYCL_double(m) +#else // TENSORFLOW_SYCL_NO_DOUBLE +#define TF_CALL_SYCL_double(m) TF_CALL_double(m) +#endif // TENSORFLOW_SYCL_NO_DOUBLE + +#ifdef __ANDROID_TYPES_SLIM__ +#define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) +#else // __ANDROID_TYPES_SLIM__ +#define TF_CALL_SYCL_NUMBER_TYPES(m) \ + TF_CALL_float(m) \ + TF_CALL_SYCL_double(m) +#endif // __ANDROID_TYPES_SLIM__ + +#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ diff --git a/register_types_traits.h b/register_types_traits.h new file mode 100644 index 0000000000000000000000000000000000000000..c1fe5517c6986838a07f67c0f2fa5474f89ffa33 --- /dev/null +++ b/register_types_traits.h @@ -0,0 +1,105 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_ +#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_ +// This file is used by cuda code and must remain compilable by nvcc. + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Remap POD types by size to equivalent proxy types. This works +// since all we are doing is copying data around. +struct UnusableProxyType; +template +struct proxy_type_pod { + typedef UnusableProxyType type; +}; +template <> +struct proxy_type_pod { + typedef ::tensorflow::complex128 type; +}; +template <> +struct proxy_type_pod { + typedef ::tensorflow::int64 type; +}; +template <> +struct proxy_type_pod { + typedef ::tensorflow::int32 type; +}; +template <> +struct proxy_type_pod { + typedef ::tensorflow::int16 type; +}; +template <> +struct proxy_type_pod { + typedef ::tensorflow::int8 type; +}; +template <> +struct proxy_type_pod { + typedef double type; +}; +template <> +struct proxy_type_pod { + typedef float type; +}; +template <> +struct proxy_type_pod { + typedef Eigen::half type; +}; + +#ifdef TENSORFLOW_USE_SYCL +template <> +struct proxy_type_pod { + typedef double type; +}; +template <> +struct proxy_type_pod { + typedef float type; +}; +#endif // TENSORFLOW_USE_SYCL + +/// If POD we use proxy_type_pod, otherwise this maps to identiy. +template +struct proxy_type { + typedef typename std::conditional< + std::is_arithmetic::value, + typename proxy_type_pod::type, T>::type type; + static_assert(sizeof(type) == sizeof(T), "proxy_type_pod is not valid"); +}; + +/// The active proxy types +#define TF_CALL_CPU_PROXY_TYPES(m) \ + TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \ + TF_CALL_int8(m) TF_CALL_complex128(m) +#define TF_CALL_GPU_PROXY_TYPES(m) \ + TF_CALL_double(m) TF_CALL_float(m) TF_CALL_half(m) TF_CALL_int32(m) +#ifdef TENSORFLOW_USE_SYCL +#define TF_CALL_SYCL_PROXY_TYPES(m) \ + TF_CALL_double(m) TF_CALL_float(m) TF_CALL_int32(m) +#endif // TENSORFLOW_USE_SYCL +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_ diff --git a/remote_fused_graph_execute_info.proto b/remote_fused_graph_execute_info.proto new file mode 100644 index 0000000000000000000000000000000000000000..389a08ac2f3f23080bfb80555da2777eaba24404 --- /dev/null +++ b/remote_fused_graph_execute_info.proto @@ -0,0 +1,55 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "RemoteFusedGraphExecuteInfoProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message RemoteFusedGraphExecuteInfo { + enum NodeType { + UNUSED = 0; + GRAPH_INPUT = 1; + GRAPH_OUTPUT = 2; + FUSED_NODE = 3; + BORDER_INPUT = 4; + BORDER_OUTPUT = 5; + } + + message TensorShapeTypeProto { + DataType dtype = 1; + TensorShapeProto shape = 2; + } + + // Definition of remote graph + GraphDef remote_graph = 1; + + // Remote fused graph input node name + repeated string graph_input_node_name = 2; + + // Remote fused graph output node name + repeated string graph_output_node_name = 3; + + // Executor's name + string executor_name = 4; + + // Optional: Parameters given to the executor + bytes serialized_executor_parameters = 5; + + // Optional: Default graph input tensor shape used to allocate memory + // before executing op + repeated TensorShapeTypeProto default_graph_input_tensor_shape = 6; + + // Optional: Default graph input tensor shape used to allocate memory + // before executing op + // TODO(satok): Remote output tensor shape once shape information is stored + // in NodeDef + repeated TensorShapeTypeProto default_graph_output_tensor_shape = 7; +}; diff --git a/rendezvous.cc b/rendezvous.cc new file mode 100644 index 0000000000000000000000000000000000000000..90756a4f2fceb366f2ec0eb991adc31dcf884d99 --- /dev/null +++ b/rendezvous.cc @@ -0,0 +1,307 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/rendezvous.h" + +#include +#include +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +Rendezvous::ParsedKey& Rendezvous::ParsedKey::operator=(const ParsedKey& b) { + const char* b_base = b.buf_.data(); + buf_ = b.buf_; + src_device = StringPiece(buf_.data() + (b.src_device.data() - b_base), + b.src_device.size()); + src = b.src; + src_incarnation = b.src_incarnation; + dst_device = StringPiece(buf_.data() + (b.dst_device.data() - b_base), + b.dst_device.size()); + dst = b.dst; + edge_name = StringPiece(buf_.data() + (b.edge_name.data() - b_base), + b.edge_name.size()); + return *this; +} + +/* static */ +string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation, + const string& dst_device, const string& name, + const FrameAndIter& frame_iter) { + // NOTE: ';' is not used in the device name's job name. + // + // We include both sender and receiver in the key to facilitate + // debugging. For correctness, we only need to encode the receiver. + // + // "src_incarnation" is used to distinguish a worker when it + // restarts. + char buf[strings::kFastToBufferSize]; + return strings::StrCat( + src_device, ";", strings::Uint64ToHexString(src_incarnation, buf), ";", + dst_device, ";", name, ";", frame_iter.frame_id, ":", frame_iter.iter_id); +} + +// Return the prefix of "*s" up to the next occurrence of "delim", or +// the whole remaining string if "delim" is not found. "*s" is advanced +// past the string returned plus the delimiter (if found). +static StringPiece ConsumeNextPart(StringPiece* s, char delim) { + for (size_t offset = 0; offset < s->size(); offset++) { + if ((*s)[offset] == delim) { + StringPiece result(s->data(), offset); + s->remove_prefix(offset + 1); // +1: remove delim, as well + return result; + } + } + // No delimiter found: return rest of string + StringPiece result(s->data(), s->size()); + s->remove_prefix(s->size()); + return result; +} + +/* static */ +Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { + if (key.data() == out->buf_.data()) { + // Caller used our buf_ string directly, so we don't need to copy. (The + // SendOp and RecvOp implementations do this, for example). + DCHECK_EQ(key.size(), out->buf_.size()); + } else { + // Make a copy that our StringPieces can point at a copy that will persist + // for the lifetime of the ParsedKey object. + out->buf_.assign(key.data(), key.size()); + } + StringPiece s(out->buf_); + StringPiece parts[5]; + for (int i = 0; i < 5; i++) { + parts[i] = ConsumeNextPart(&s, ';'); + } + if (s.empty() && // Consumed the whole string + !parts[4].empty() && // Exactly five parts + DeviceNameUtils::ParseFullName(parts[0], &out->src) && + strings::HexStringToUint64(parts[1], &out->src_incarnation) && + DeviceNameUtils::ParseFullName(parts[2], &out->dst) && + !parts[3].empty()) { + out->src_device = StringPiece(parts[0].data(), parts[0].size()); + out->dst_device = StringPiece(parts[2].data(), parts[2].size()); + out->edge_name = StringPiece(parts[3].data(), parts[3].size()); + return Status::OK(); + } + return errors::InvalidArgument("Invalid rendezvous key: ", key); +} + +Rendezvous::~Rendezvous() {} + +Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args, + Tensor* val, bool* is_dead, int64 timeout_ms) { + Status ret; + Notification n; + RecvAsync(key, recv_args, + [&ret, &n, val, is_dead](const Status& s, const Args& send_args, + const Args& recv_args, const Tensor& v, + const bool dead) { + ret = s; + *val = v; + *is_dead = dead; + n.Notify(); + }); + if (timeout_ms > 0) { + int64 timeout_us = timeout_ms * 1000; + bool notified = WaitForNotificationWithTimeout(&n, timeout_us); + if (!notified) { + return Status(error::DEADLINE_EXCEEDED, + "Timed out waiting for notification"); + } + } else { + n.WaitForNotification(); + } + return ret; +} + +Status Rendezvous::Recv(const ParsedKey& key, const Args& args, Tensor* val, + bool* is_dead) { + const int64 no_timeout = 0; + return Recv(key, args, val, is_dead, no_timeout); +} + +class LocalRendezvousImpl : public Rendezvous { + public: + explicit LocalRendezvousImpl() {} + + Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val, + const bool is_dead) override { + uint64 key_hash = KeyHash(key.FullKey()); + VLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey(); + + mu_.lock(); + if (!status_.ok()) { + // Rendezvous has been aborted. + Status s = status_; + mu_.unlock(); + return s; + } + + ItemQueue* queue = &table_[key_hash]; + if (queue->empty() || queue->front()->IsSendValue()) { + // There is no waiter for this message. Append the message + // into the queue. The waiter will pick it up when arrives. + // Only send-related fields need to be filled. + Item* item = new Item; + item->value = val; + item->is_dead = is_dead; + item->send_args = send_args; + if (item->send_args.device_context) { + item->send_args.device_context->Ref(); + } + queue->push_back(item); + mu_.unlock(); + return Status::OK(); + } + + // There is an earliest waiter to consume this message. + Item* item = queue->front(); + queue->pop_front(); + mu_.unlock(); + + // Notify the waiter by invoking its done closure, outside the + // lock. + DCHECK(!item->IsSendValue()); + item->waiter(Status::OK(), send_args, item->recv_args, val, is_dead); + delete item; + return Status::OK(); + } + + void RecvAsync(const ParsedKey& key, const Args& recv_args, + DoneCallback done) override { + uint64 key_hash = KeyHash(key.FullKey()); + VLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey(); + + mu_.lock(); + if (!status_.ok()) { + // Rendezvous has been aborted. + Status s = status_; + mu_.unlock(); + done(s, Args(), recv_args, Tensor(), false); + return; + } + + ItemQueue* queue = &table_[key_hash]; + if (queue->empty() || !queue->front()->IsSendValue()) { + // There is no message to pick up. + // Only recv-related fields need to be filled. + Item* item = new Item; + item->waiter = std::move(done); + item->recv_args = recv_args; + if (item->recv_args.device_context) { + item->recv_args.device_context->Ref(); + } + queue->push_back(item); + mu_.unlock(); + return; + } + + // A message has already arrived and is queued in the table under + // this key. Consumes the message and invokes the done closure. + Item* item = queue->front(); + queue->pop_front(); + mu_.unlock(); + + // Invokes the done() by invoking its done closure, outside scope + // of the table lock. + DCHECK(item->IsSendValue()); + done(Status::OK(), item->send_args, recv_args, item->value, item->is_dead); + delete item; + } + + void StartAbort(const Status& status) override { + CHECK(!status.ok()); + Table table; + { + mutex_lock l(mu_); + status_.Update(status); + table_.swap(table); + } + for (auto& p : table) { + for (Item* item : p.second) { + if (!item->IsSendValue()) { + item->waiter(status, Args(), Args(), Tensor(), false); + } + delete item; + } + } + } + + private: + typedef LocalRendezvousImpl ME; + + struct Item { + DoneCallback waiter = nullptr; + Tensor value; + bool is_dead = false; + Args send_args; + Args recv_args; + + ~Item() { + if (send_args.device_context) { + send_args.device_context->Unref(); + } + if (recv_args.device_context) { + recv_args.device_context->Unref(); + } + } + + // Returns true iff this item represents a value being sent. + bool IsSendValue() const { return this->waiter == nullptr; } + }; + + // We key the hash table by KeyHash of the Rendezvous::CreateKey string + static uint64 KeyHash(const StringPiece& k) { + return Hash64(k.data(), k.size()); + } + + // By invariant, the item queue under each key is of the form + // [item.IsSendValue()]* meaning each item is a sent message. + // or + // [!item.IsSendValue()]* meaning each item is a waiter. + // + // TODO(zhifengc): consider a better queue impl than std::deque. + typedef std::deque ItemQueue; + typedef gtl::FlatMap Table; + + // TODO(zhifengc): shard table_. + mutex mu_; + Table table_ GUARDED_BY(mu_); + Status status_ GUARDED_BY(mu_); + + ~LocalRendezvousImpl() override { + StartAbort(errors::Cancelled("LocalRendezvousImpl deleted")); + } + + TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl); +}; + +Rendezvous* NewLocalRendezvous() { return new LocalRendezvousImpl(); } + +} // end namespace tensorflow diff --git a/rendezvous.h b/rendezvous.h new file mode 100644 index 0000000000000000000000000000000000000000..01e43e44e3f71503015557e3f6b30457d89a54b8 --- /dev/null +++ b/rendezvous.h @@ -0,0 +1,134 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ +#define TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ + +#include + +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// A Rendezvous is an abstraction for passing tensors from producers +// to consumers. A rendezvous is a table of channels. Each channel is +// keyed by a rendezvous key. The key encodes a pair of , where the producer and the consumer are tensorflow +// devices. +// +// The producer calls the Send() method to send one tensor over one +// named channel. The consumer calls the Recv() method to receive one +// tensor from a named channel. A sequence of tensors can be passed +// from the producer to the consumer. The consumer receives them in +// the order as the producer sends them. +// +// A consumer may safely request the tensor before or after it has +// been produced. A consumer has the choice of making a blocking call +// or providing a callback: in either case, the consumer receives the +// Tensor as soon as it is available. A producer never blocks. +class Rendezvous : public core::RefCounted { + public: + struct Args { + DeviceContext* device_context = nullptr; + AllocatorAttributes alloc_attrs; + }; + + // Constructs a rendezvous key for the tensor of "name" sent from + // "src_device" to "dst_device". The tensor is generated in the frame + // and iteration specified by "frame_iter". + static string CreateKey(const string& src_device, uint64 src_incarnation, + const string& dst_device, const string& name, + const FrameAndIter& frame_iter); + + // Parses the key constructed by CreateKey and parse src/dst device + // names into structures respectively. + struct ParsedKey { + StringPiece src_device; + DeviceNameUtils::ParsedName src; + uint64 src_incarnation = 0; + StringPiece dst_device; + DeviceNameUtils::ParsedName dst; + StringPiece edge_name; + + ParsedKey() {} + ParsedKey(const ParsedKey& b) { *this = b; } + + ParsedKey& operator=(const ParsedKey& b); + StringPiece FullKey() const { return buf_; } + + private: + friend class Rendezvous; + friend class SendOp; + friend class RecvOp; + string buf_; + }; + static Status ParseKey(StringPiece key, ParsedKey* out); + + // The caller is a tensor producer and it sends a message (a tensor + // "val" and a bool "is_dead") under the given "key". + // + // {val, is_dead} is bundled as a message sent and received. + // Typically, is_dead is set by some control flow nodes + // (e.g., a not-taken branch). args is passed by Send to the + // Recv function to communicate any information that the Recv + // function might need. This is typically only necessary for + // Send/Recv on the same worker. + // + // Send() never blocks. + virtual Status Send(const ParsedKey& key, const Args& args, const Tensor& val, + const bool is_dead) = 0; + + // Callback provided by a tensor consumer waiting on the rendezvous. + // It will be invoked when the tensor is available, or when a non-OK + // status arises in the production of that tensor. It also gets + // two Rendezvous::Args, one provided by the sender, the other by the + // receiver, which may be needed when a non-CPU device is in use + // by either side. + typedef std::function + DoneCallback; + + virtual void RecvAsync(const ParsedKey& key, const Args& args, + DoneCallback done) = 0; + + // Synchronous wrapper for RecvAsync. + Status Recv(const ParsedKey& key, const Args& args, Tensor* val, + bool* is_dead, int64 timeout_ms); + Status Recv(const ParsedKey& key, const Args& args, Tensor* val, + bool* is_dead); + + // Aborts all pending and future Send/Recv with the given "status". + // + // StartAbort() does not wait for ongoing calls to finish. + // REQUIRES: !status.ok() + virtual void StartAbort(const Status& status) = 0; + + protected: + ~Rendezvous() override; +}; + +// Returns a Rendezvous instance that is limited to use only by +// producers and consumers in the local process. The caller assumes +// ownership of one Ref() on the returned object. +Rendezvous* NewLocalRendezvous(); + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ diff --git a/rendezvous_test.cc b/rendezvous_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..32b8ad784d5228a40a073d166f33972def380280 --- /dev/null +++ b/rendezvous_test.cc @@ -0,0 +1,364 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/rendezvous.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +TEST(RendezvousTest, Key) { + const string key = Rendezvous::CreateKey( + "/job:mnist/replica:1/task:2/CPU:0", 7890, + "/job:mnist/replica:1/task:2/device:GPU:0", "var0", FrameAndIter(0, 0)); + EXPECT_EQ(key, + "/job:mnist/replica:1/task:2/CPU:0;" + "0000000000001ed2;" // 7890 = 0x1ed2 + "/job:mnist/replica:1/task:2/device:GPU:0;" + "var0;" + "0:0"); + Rendezvous::ParsedKey parsed; + TF_EXPECT_OK(Rendezvous::ParseKey(key, &parsed)); + EXPECT_EQ(parsed.src_device, "/job:mnist/replica:1/task:2/CPU:0"); + EXPECT_EQ(parsed.src_incarnation, 7890); + EXPECT_EQ(parsed.src.type, "CPU"); + EXPECT_EQ(parsed.dst_device, "/job:mnist/replica:1/task:2/device:GPU:0"); + EXPECT_EQ(parsed.dst.type, "GPU"); + + EXPECT_FALSE(Rendezvous::ParseKey("foo;bar;baz", &parsed).ok()); + EXPECT_FALSE(Rendezvous::ParseKey("/job:mnist/replica:1/task:2/CPU:0;" + "/job:mnist/replica:1/task:2/device:GPU:0;", + &parsed) + .ok()); + EXPECT_FALSE( + Rendezvous::ParseKey(strings::StrCat(key, ";", key), &parsed).ok()); +} + +class LocalRendezvousTest : public ::testing::Test { + public: + LocalRendezvousTest() : threads_(Env::Default(), "test", 16) { + rendez_ = NewLocalRendezvous(); + } + + ~LocalRendezvousTest() override { + rendez_->Unref(); + } + + void SchedClosure(std::function fn) { + threads_.Schedule(std::move(fn)); + } + + Rendezvous* rendez_; + + private: + thread::ThreadPool threads_; +}; + +// string -> Tensor +Tensor V(const string& content) { + Tensor tensor(DT_STRING, TensorShape({})); + tensor.scalar()() = content; + return tensor; +} + +// Tensor -> string +string V(const Tensor& tensor) { + CHECK_EQ(tensor.dtype(), DT_STRING); + CHECK(TensorShapeUtils::IsScalar(tensor.shape())); + return tensor.scalar()(); +} + +Rendezvous::ParsedKey MakeKey(const string& name) { + string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890, + "/job:mnist/replica:1/task:2/device:GPU:0", name, + FrameAndIter(0, 0)); + Rendezvous::ParsedKey k; + TF_EXPECT_OK(Rendezvous::ParseKey(s, &k)); + return k; +} + +const Rendezvous::ParsedKey& KeyFoo() { + static auto key = MakeKey("foo"); + return key; +} + +const Rendezvous::ParsedKey& KeyBar() { + static auto key = MakeKey("bar"); + return key; +} + +TEST_F(LocalRendezvousTest, SendRecv) { + Rendezvous::Args args; + TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false)); + Tensor val(DT_STRING); + bool is_dead = false; + TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead)); + EXPECT_EQ("hello", V(val)); +} + +TEST_F(LocalRendezvousTest, RecvSend) { + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(10000); + Rendezvous::Args args; + TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false)); + }); + Tensor val(DT_STRING); + bool is_dead = false; + Rendezvous::Args args; + TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead)); + EXPECT_EQ("hello", V(val)); +} + +TEST_F(LocalRendezvousTest, PingPong) { + SchedClosure([this]() { + Tensor t(DT_STRING); + bool is_dead = false; + Rendezvous::Args args; + TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead)); + TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead)); + }); + Env::Default()->SleepForMicroseconds(1000000); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead)); + TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead)); + EXPECT_EQ("secret msg", V(val)); +} + +// A simple structure that behaves a bit like a blocking counter. The +// user that decrements counter to 0 does done.Notify(), and the main +// thread waits for done to be notified. +struct BlockingState { + mutex lock; + int counter = 0; + Notification done; +}; + +TEST_F(LocalRendezvousTest, RandomSendRecv) { + // We are scheduling 2*N closures in the this->threads_, which is + // configured with only 16 threads. Furthermore, because the + // threadpool may execute the closures in an arbitrary order, we + // must use RecvAsync below. Otherwise, blocking Recv() may run + // before all all the Send() and deadlock. + static const int N = 100; + random::PhiloxRandom philox(testing::RandomSeed(), 17); + random::SimplePhilox rnd(&philox); + BlockingState state; + state.counter = N; + for (int i = 0; i < N; ++i) { + int micros = 100 + rnd.Uniform(1000); + SchedClosure([this, i, micros]() { + Env::Default()->SleepForMicroseconds(micros); + Rendezvous::Args args; + TF_ASSERT_OK(rendez_->Send(MakeKey(strings::StrCat(i)), args, + V(strings::StrCat(i)), false)); + }); + auto recv_done = [this, &state, i](const Status& status, + const Rendezvous::Args& sender_args, + const Rendezvous::Args& recver_args, + const Tensor& val, const bool val_dead) { + EXPECT_EQ(strings::StrCat(i), V(val)); + bool done = false; + { + mutex_lock l(state.lock); + state.counter--; + if (state.counter == 0) { + done = true; + } + } + if (done) { + state.done.Notify(); + } + }; + micros = 100 + rnd.Uniform(1000); + SchedClosure([this, i, micros, recv_done]() { + Env::Default()->SleepForMicroseconds(micros); + rendez_->RecvAsync(MakeKey(strings::StrCat(i)), Rendezvous::Args(), + recv_done); + }); + } + + state.done.WaitForNotification(); +} + +void RandomSleep() { + if (std::rand() % 10 == 0) { + Env::Default()->SleepForMicroseconds(1000); + } +} + +TEST_F(LocalRendezvousTest, MultiSends) { + static const int N = 100; + const auto& key_foo = KeyFoo(); + Rendezvous::Args args; + SchedClosure([=]() { + for (int i = 0; i < N; ++i) { + TF_ASSERT_OK(rendez_->Send(key_foo, args, V(strings::StrCat(i)), false)); + RandomSleep(); + } + }); + Tensor val; + bool val_dead; + for (int i = 0; i < N; ++i) { + TF_ASSERT_OK(rendez_->Recv(key_foo, args, &val, &val_dead)); + RandomSleep(); + } +} + +TEST_F(LocalRendezvousTest, RecvAbort) { + rendez_->Ref(); + SchedClosure([this]() { + rendez_->StartAbort(errors::Aborted("")); // abort + rendez_->Unref(); + }); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead); + EXPECT_TRUE(errors::IsAborted(status)); +} + +// Similar to RecvAbort. But this test case ensures the main thread +// Recv() call happens after StartAbort(). +TEST_F(LocalRendezvousTest, RecvSleepAbort) { + rendez_->Ref(); + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(1000000); + rendez_->StartAbort(errors::Aborted("")); // abort + rendez_->Unref(); + }); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead); + EXPECT_TRUE(errors::IsAborted(status)); +} + +TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) { + rendez_->StartAbort(errors::Aborted("")); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + EXPECT_TRUE(errors::IsAborted(rendez_->Send(KeyFoo(), args, val, val_dead))); + EXPECT_TRUE( + errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead))); +} + +class DummyDeviceContext : public DeviceContext { + public: + explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {} + ~DummyDeviceContext() override {} + int stream_id() const { return stream_id_; } + + private: + const int stream_id_; +}; + +TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) { + Rendezvous::Args args; + args.device_context = new DummyDeviceContext(123); + + TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false)); + + Notification n; + Rendezvous::Args args1; + args1.device_context = new DummyDeviceContext(1); + rendez_->RecvAsync( + KeyFoo(), args1, + [&n](const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) { + CHECK_EQ(123, dynamic_cast( + send_args.device_context) + ->stream_id()); + n.Notify(); + }); + + n.WaitForNotification(); + args.device_context->Unref(); + args1.device_context->Unref(); +} + +void BM_SendRecv(int iters) { + Rendezvous* rendez = NewLocalRendezvous(); + Tensor orig = V("val"); + Tensor val(DT_STRING, TensorShape({})); + bool is_dead = false; + Rendezvous::Args args; + Status s; + if (iters > 0) { + while (iters--) { + TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead)); + TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &val, &is_dead)); + } + CHECK_EQ(V(val), V(orig)); + } + rendez->Unref(); +} +BENCHMARK(BM_SendRecv); + +void BM_PingPong(int iters) { + CHECK_GT(iters, 0); + thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1); + + // The main thread sends "foo" for iters times and receives "bar" + // for iters times. The other thread sends "bar" for iters times + // and receives "foo" for iters times. + Rendezvous* rendez = NewLocalRendezvous(); + pool->Schedule([rendez, iters]() { + Tensor bar = V("bar"); + Tensor foo(DT_STRING, TensorShape({})); + bool is_dead = false; + Rendezvous::Args args; + Status s; + for (int i = 0; i < iters; ++i) { + TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &foo, &is_dead)); + TF_CHECK_OK(rendez->Send(KeyBar(), args, bar, is_dead)); + } + CHECK_EQ("foo", V(foo)); + }); + Tensor foo = V("foo"); + Tensor bar(DT_STRING, TensorShape({})); + bool is_dead = false; + Rendezvous::Args args; + Status s; + for (int i = 0; i < iters; ++i) { + TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead)); + TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead)); + } + CHECK_EQ("bar", V(bar)); + delete pool; +} +BENCHMARK(BM_PingPong); + +} // namespace +} // namespace tensorflow diff --git a/resource_handle.cc b/resource_handle.cc new file mode 100644 index 0000000000000000000000000000000000000000..39ef82765f5deda70d68c4ccab77ac7258a3a4e6 --- /dev/null +++ b/resource_handle.cc @@ -0,0 +1,69 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/resource_handle.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +ResourceHandle::ResourceHandle() {} + +ResourceHandle::ResourceHandle(const ResourceHandleProto& proto) { + FromProto(proto); +} + +ResourceHandle::~ResourceHandle() {} + +void ResourceHandle::AsProto(ResourceHandleProto* proto) const { + proto->set_device(device()); + proto->set_container(container()); + proto->set_name(name()); + proto->set_hash_code(hash_code()); + proto->set_maybe_type_name(maybe_type_name()); +} + +void ResourceHandle::FromProto(const ResourceHandleProto& proto) { + set_device(proto.device()); + set_container(proto.container()); + set_name(proto.name()); + set_hash_code(proto.hash_code()); + set_maybe_type_name(proto.maybe_type_name()); +} + +string ResourceHandle::SerializeAsString() const { + ResourceHandleProto proto; + AsProto(&proto); + return proto.SerializeAsString(); +} + +bool ResourceHandle::ParseFromString(const string& s) { + ResourceHandleProto proto; + const bool status = proto.ParseFromString(s); + if (status) FromProto(proto); + return status; +} + +string ResourceHandle::DebugString() const { + return strings::StrCat("device: ", device(), " container: ", container(), + " name: ", name(), " hash_code: ", hash_code(), + " maybe_type_name: ", maybe_type_name()); +} + +string ProtoDebugString(const ResourceHandle& handle) { + return handle.DebugString(); +} + +} // namespace tensorflow diff --git a/resource_handle.h b/resource_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..06df1b9046da2c99ef30b1a69763f4e18e404f94 --- /dev/null +++ b/resource_handle.h @@ -0,0 +1,82 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_HANDLE_H_ +#define TENSORFLOW_FRAMEWORK_RESOURCE_HANDLE_H_ + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class ResourceHandleProto; + +// Class representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +// +// This is the native C++ class equivalent of ResourceHandleProto. They are +// separate so that kernels do not need to depend on protos. +class ResourceHandle { + public: + ResourceHandle(); + ResourceHandle(const ResourceHandleProto& proto); + ~ResourceHandle(); + + // Unique name for the device containing the resource. + const string& device() const { return device_; } + void set_device(const string& device) { device_ = device; } + + // Container in which this resource is placed. + const string& container() const { return container_; } + void set_container(const string& container) { container_ = container; } + + // Unique name of this resource. + const string& name() const { return name_; } + void set_name(const string& name) { name_ = name; } + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code() const { return hash_code_; } + void set_hash_code(uint64 hash_code) { hash_code_ = hash_code; } + + // For debug-only, the name of the type pointed to by this handle, if + // available. + const string& maybe_type_name() const { return maybe_type_name_; } + void set_maybe_type_name(const string& value) { maybe_type_name_ = value; } + + // Conversion to and from ResourceHandleProto + void AsProto(ResourceHandleProto* proto) const; + void FromProto(const ResourceHandleProto& proto); + + // Serialization via ResourceHandleProto + string SerializeAsString() const; + bool ParseFromString(const string& s); + + string DebugString() const; + + public: + string device_; + string container_; + string name_; + uint64 hash_code_ = 0; + string maybe_type_name_; +}; + +// For backwards compatibility for when this was a proto +string ProtoDebugString(const ResourceHandle& handle); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_RESOURCE_HANDLE_H_ diff --git a/resource_handle.proto b/resource_handle.proto new file mode 100644 index 0000000000000000000000000000000000000000..b1921337f5fd0b434e256ae85c6baffe95df286a --- /dev/null +++ b/resource_handle.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; +}; diff --git a/resource_mgr.cc b/resource_mgr.cc new file mode 100644 index 0000000000000000000000000000000000000000..78574bc0b133fd63944eac17707fbf937b8af59a --- /dev/null +++ b/resource_mgr.cc @@ -0,0 +1,295 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/resource_mgr.h" + +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/demangle.h" + +namespace tensorflow { +ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container, + const string& name, + const TypeIndex& type_index) { + ResourceHandle result; + result.set_device(ctx->device()->attributes().name()); + string actual_container; + if (!container.empty()) { + actual_container = container; + } else { + actual_container = ctx->resource_manager()->default_container(); + } + result.set_container(actual_container); + result.set_name(name); + result.set_hash_code(type_index.hash_code()); + result.set_maybe_type_name(type_index.name()); + return result; +} + +Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, + const string& container, const string& name, + const TypeIndex& type_index) { + Tensor* handle; + TF_RETURN_IF_ERROR( + context->allocate_output(output_index, TensorShape({}), &handle)); + handle->scalar()() = + MakeResourceHandle(context, container, name, type_index); + return Status::OK(); +} + +namespace internal { + +Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) { + if (ctx->device()->attributes().name() != p.device()) { + return errors::InvalidArgument( + "Trying to access resource located in device ", p.device(), + " from device ", ctx->device()->attributes().name()); + } + return Status::OK(); +} + +} // end namespace internal + +Status ResourceMgr::InsertDebugTypeName(uint64 hash_code, + const string& type_name) { + auto iter = debug_type_names_.emplace(hash_code, type_name); + if (iter.first->second != type_name) { + return errors::AlreadyExists("Duplicate hash code found for type ", + type_name); + } + return Status::OK(); +} + +const char* ResourceMgr::DebugTypeName(uint64 hash_code) const { + auto type_name_iter = debug_type_names_.find(hash_code); + if (type_name_iter == debug_type_names_.end()) { + return ""; + } else { + return type_name_iter->second.c_str(); + } +} + +ResourceMgr::ResourceMgr() : default_container_("localhost") {} + +ResourceMgr::ResourceMgr(const string& default_container) + : default_container_(default_container) {} + +ResourceMgr::~ResourceMgr() { Clear(); } + +void ResourceMgr::Clear() { + mutex_lock l(mu_); + for (const auto& p : containers_) { + for (const auto& q : *p.second) { + q.second->Unref(); + } + delete p.second; + } + containers_.clear(); +} + +string ResourceMgr::DebugString() const { + mutex_lock l(mu_); + struct Line { + const string* container; + const string type; + const string* resource; + const string detail; + }; + std::vector lines; + for (const auto& p : containers_) { + const string& container = p.first; + for (const auto& q : *p.second) { + const Key& key = q.first; + const char* type = DebugTypeName(key.first); + const string& resource = key.second; + Line l{&container, port::Demangle(type), &resource, + q.second->DebugString()}; + lines.push_back(l); + } + } + std::vector text; + text.reserve(lines.size()); + for (const Line& line : lines) { + text.push_back(strings::Printf( + "%-20s | %-40s | %-40s | %-s", line.container->c_str(), + line.type.c_str(), line.resource->c_str(), line.detail.c_str())); + } + std::sort(text.begin(), text.end()); + return str_util::Join(text, "\n"); +} + +Status ResourceMgr::DoCreate(const string& container, TypeIndex type, + const string& name, ResourceBase* resource) { + { + mutex_lock l(mu_); + Container** b = &containers_[container]; + if (*b == nullptr) { + *b = new Container; + } + if ((*b)->insert({{type.hash_code(), name}, resource}).second) { + TF_RETURN_IF_ERROR(InsertDebugTypeName(type.hash_code(), type.name())); + return Status::OK(); + } + } + resource->Unref(); + return errors::AlreadyExists("Resource ", container, "/", name, "/", + type.name()); +} + +Status ResourceMgr::DoLookup(const string& container, TypeIndex type, + const string& name, + ResourceBase** resource) const { + tf_shared_lock l(mu_); + const Container* b = gtl::FindPtrOrNull(containers_, container); + if (b == nullptr) { + return errors::NotFound("Container ", container, + " does not exist. (Could not find resource: ", + container, "/", name, ")"); + } + auto r = gtl::FindPtrOrNull(*b, {type.hash_code(), name}); + if (r == nullptr) { + return errors::NotFound("Resource ", container, "/", name, "/", type.name(), + " does not exist."); + } + *resource = const_cast(r); + (*resource)->Ref(); + return Status::OK(); +} + +Status ResourceMgr::DoDelete(const string& container, uint64 type_hash_code, + const string& resource_name, + const string& type_name) { + ResourceBase* base = nullptr; + { + mutex_lock l(mu_); + Container* b = gtl::FindPtrOrNull(containers_, container); + if (b == nullptr) { + return errors::NotFound("Container ", container, " does not exist."); + } + auto iter = b->find({type_hash_code, resource_name}); + if (iter == b->end()) { + return errors::NotFound("Resource ", container, "/", resource_name, "/", + type_name, " does not exist."); + } + base = iter->second; + b->erase(iter); + } + CHECK(base != nullptr); + base->Unref(); + return Status::OK(); +} + +Status ResourceMgr::DoDelete(const string& container, TypeIndex type, + const string& resource_name) { + return DoDelete(container, type.hash_code(), resource_name, type.name()); +} + +Status ResourceMgr::Delete(const ResourceHandle& handle) { + return DoDelete(handle.container(), handle.hash_code(), handle.name(), + ""); +} + +Status ResourceMgr::Cleanup(const string& container) { + Container* b = nullptr; + { + mutex_lock l(mu_); + auto iter = containers_.find(container); + if (iter == containers_.end()) { + // Nothing to cleanup, it's OK. + return Status::OK(); + } + b = iter->second; + containers_.erase(iter); + } + CHECK(b != nullptr); + for (const auto& p : *b) { + p.second->Unref(); + } + delete b; + return Status::OK(); +} + +static bool IsValidContainerName(StringPiece s) { + using ::tensorflow::strings::Scanner; + return Scanner(s) + .One(Scanner::LETTER_DIGIT_DOT) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH) + .Eos() + .GetResult(); +} + +Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef, + bool use_node_name_as_default) { + CHECK(rmgr); + rmgr_ = rmgr; + string attr_container; + TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container", &attr_container)); + if (!attr_container.empty() && !IsValidContainerName(attr_container)) { + return errors::InvalidArgument("container contains invalid characters: ", + attr_container); + } + string attr_shared_name; + TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &attr_shared_name)); + if (!attr_shared_name.empty() && (attr_shared_name[0] == '_')) { + return errors::InvalidArgument("shared_name cannot start with '_':", + attr_shared_name); + } + if (!attr_container.empty()) { + container_ = attr_container; + } else { + container_ = rmgr_->default_container(); + } + if (!attr_shared_name.empty()) { + name_ = attr_shared_name; + } else if (use_node_name_as_default) { + name_ = ndef.name(); + } else { + resource_is_private_to_kernel_ = true; + static std::atomic counter(0); + name_ = strings::StrCat("_", counter.fetch_add(1), "_", ndef.name()); + } + return Status::OK(); +} + +string ContainerInfo::DebugString() const { + return strings::StrCat("[", container(), ",", name(), ",", + resource_is_private_to_kernel() ? "private" : "public", + "]"); +} + +ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) { + return ctx->input(input).flat()(0); +} + +Status HandleFromInput(OpKernelContext* ctx, StringPiece input, + ResourceHandle* handle) { + const Tensor* tensor; + TF_RETURN_IF_ERROR(ctx->input(input, &tensor)); + *handle = tensor->flat()(0); + return Status::OK(); +} + +Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { + TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); + return ctx->resource_manager()->Delete(p); +} + +} // end namespace tensorflow diff --git a/resource_mgr.h b/resource_mgr.h new file mode 100644 index 0000000000000000000000000000000000000000..9a458431e7c1b038c3177b2aa58e21dfa3e4e837 --- /dev/null +++ b/resource_mgr.h @@ -0,0 +1,523 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_ +#define TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// A ResourceMgr instance keeps track of named and typed resources +// grouped into containers. +// +// Each resource must be represented as a sub-class of ResourceBase, +// which is reference counted explicitly. Each named resource is +// registered with ResourceMgr under a named "container" name. At any +// time, there is at most one instance of a resource given the container +// name, the resource type and the resource name. +// +// All resources for a given container can be dropped by one call of +// Cleanup(). +// +// E.g., +// struct MyVar : public ResourceBase { +// mutex mu; +// Tensor val; +// } +// +// ResourceMgr rm; +// +// // Create a var. +// MyVar* my_var = new MyVar; +// my_var.val = Tensor(DT_FLOAT, my_shape); +// my_var.val.flat().setZeros(); // 0 initialized. +// ctx->SetStatus(rm.Create("my_container", "my_name", my_var)); +// +// // += a variable. +// MyVar* my_var = nullptr; +// Status s = rm.Lookup("my_container", "my_name", &my_var); +// if (s.ok()) { +// my_var->val.flat() += grad; +// } +// my_var->Unref(); // Or use ScopedUnref(). +// ctx->SetStatus(s); +class ResourceBase : public core::RefCounted { + public: + // Returns a debug string for *this. + virtual string DebugString() = 0; + + // Returns memory used by this resource. + virtual int64 MemoryUsed() const { return 0; }; +}; + +// Container used for per-step resources. +class ScopedStepContainer { + public: + // step_id: the unique ID of this step. Doesn't have to be sequential, just + // has to be unique. + // cleanup: callback to delete a container of this name. + ScopedStepContainer(const int64 step_id, + std::function cleanup) + : name_(strings::StrCat("__per_step_", step_id)), cleanup_(cleanup) {} + ~ScopedStepContainer() { cleanup_(name_); } + + const string& name() const { return name_; } + + private: + const string name_; + const std::function cleanup_; +}; + +class ResourceMgr { + public: + ResourceMgr(); + explicit ResourceMgr(const string& default_container); + ~ResourceMgr(); + + // Returns the default container name for *this. + const string& default_container() const { return default_container_; } + + // Creates a resource "name" in the "container". The caller transfers + // the ownership of one ref on "resource" to *this + // + // REQUIRES: std::is_base_of + // REQUIRES: resource != nullptr. + template + Status Create(const string& container, const string& name, + T* resource) TF_MUST_USE_RESULT; + + // If "container" has a resource "name", returns it in "*resource" and + // the caller takes the ownership of one ref on "*resource". + // + // REQUIRES: std::is_base_of + // REQUIRES: resource != nullptr + template + Status Lookup(const string& container, const string& name, + T** resource) const TF_MUST_USE_RESULT; + + // If "container" has a resource "name", returns it in + // "*resource". Otherwise, invokes creator() to create the resource. + // The caller takes the ownership of one ref on "*resource". + // + // REQUIRES: std::is_base_of + // REQUIRES: resource != nullptr + template + Status LookupOrCreate(const string& container, const string& name, + T** resource, + std::function creator) TF_MUST_USE_RESULT; + + // Deletes the resource "name" from the "container". + // + // REQUIRES: std::is_base_of + template + Status Delete(const string& container, const string& name) TF_MUST_USE_RESULT; + + // Deletes the resource pointed by "handle". + Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT; + + // Deletes all resources from the "container" and removes the container. + Status Cleanup(const string& container) TF_MUST_USE_RESULT; + + // Deletes all resources in all containers. + void Clear(); + + // Returns a text description for all resources. + string DebugString() const; + + private: + typedef std::pair Key; + struct KeyHash { + std::size_t operator()(const Key& k) const { + return Hash64(k.second.data(), k.second.size(), k.first); + } + }; + struct KeyEqual { + bool operator()(const Key& x, const Key& y) const { + return (x.second == y.second) && (x.first == y.first); + } + }; + typedef std::unordered_map Container; + + const string default_container_; + mutable mutex mu_; + std::unordered_map containers_ GUARDED_BY(mu_); + + Status DoCreate(const string& container, TypeIndex type, const string& name, + ResourceBase* resource) TF_MUST_USE_RESULT; + Status DoLookup(const string& container, TypeIndex type, const string& name, + ResourceBase** resource) const TF_MUST_USE_RESULT; + Status DoDelete(const string& container, uint64 type_hash_code, + const string& resource_name, + const string& type_name) TF_MUST_USE_RESULT; + Status DoDelete(const string& container, TypeIndex type, + const string& resource_name) TF_MUST_USE_RESULT; + + // Inserts the type name for 'hash_code' into the hash_code to type name map. + Status InsertDebugTypeName(uint64 hash_code, const string& type_name) + EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT; + + // Returns the type name for the 'hash_code'. + // Returns "" if a resource with such a type was never inserted into + // the container. + const char* DebugTypeName(uint64 hash_code) const + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Map from type hash_code to type name. + std::unordered_map debug_type_names_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr); +}; + +// Makes a resource handle with the specified type for a given container / +// name. +ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container, + const string& name, + const TypeIndex& type_index); + +template +ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container, + const string& name) { + return MakeResourceHandle(ctx, container, name, MakeTypeIndex()); +} + +Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, + const string& container, const string& name, + const TypeIndex& type_index); + +template +ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx, + const string& name); + +// Returns a resource handle from a numbered op input. +ResourceHandle HandleFromInput(OpKernelContext* ctx, int input); +Status HandleFromInput(OpKernelContext* ctx, StringPiece input, + ResourceHandle* handle); + +// Create a resource pointed by a given resource handle. +template +Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value); + +// Looks up a resource pointed by a given resource handle. +template +Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value); + +// Looks up or creates a resource. +template +Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, + T** value, std::function creator); + +// Destroys a resource pointed by a given resource handle. +template +Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); + +// Same as above, but uses the hash code of the type directly. +// The type name information will be missing in the debug output when the +// resource is not present in the container. +Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); + +// Policy helper to decide which container/shared_name to use for a +// stateful kernel that accesses shared resource. +class ContainerInfo { + public: + // Analyze the node attribute of 'ndef' and decides the container and + // resource name the kernel should use for accessing the shared + // resource. + // + // 'ndef' is expected to have node attribute "container" and + // "shared_name". Returns non-OK if they are not provided or they are + // invalid. + // + // The policy is as following: + // * If the attribute "container" is non-empty, it is used as is. + // Otherwise, uses the resource manager's default container. + // * If the attribute "shared_name" is non-empty, it is used as is. + // Otherwise, if "use_node_name_as_default" is true, the kernel's + // node name is used as the resource name. Otherwise, a string + // unique to this process is used. + Status Init(ResourceMgr* rmgr, const NodeDef& ndef, + bool use_node_name_as_default); + Status Init(ResourceMgr* rmgr, const NodeDef& ndef) { + return Init(rmgr, ndef, false); + } + + // The policy decides that the kernel should access the resource in + // resource_manager(), the resource is in the container() and its + // name is name(). If resource_is_private_to_kernel() is true, the + // kernel should delete the resource when the kernel is deleted. + ResourceMgr* resource_manager() const { return rmgr_; } + const string& container() const { return container_; } + const string& name() const { return name_; } + bool resource_is_private_to_kernel() const { + return resource_is_private_to_kernel_; + } + + // Returns a readable string for *this. + string DebugString() const; + + private: + ResourceMgr* rmgr_ = nullptr; + string container_; + string name_; + bool resource_is_private_to_kernel_ = false; +}; + +// Helper for kernels to obtain 'resource' from the +// ctx->resource_manager(). +// +// "input_name" specifies the kernel's ref input which gives a string +// tensor with two elements, which specifies the container and +// resource name. +// +// Returns OK if the resource is found and transfers one ref of +// *resource to the caller. Otherwise, returns an error. +template +Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name, + T** resource); + +// Utility op kernel to check if a handle to resource type T is initialized. +template +class IsResourceInitialized : public OpKernel { + public: + explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* ctx) override; +}; + +// Registers an op which produces just a resource handle to a resource of the +// specified type. The type will be a part of the generated op name. +// TODO(apassos): figure out how to get non-cpu-allocated tensors to work +// through constant folding so this doesn't have to be marked as stateful. +#define REGISTER_RESOURCE_HANDLE_OP(Type) \ + REGISTER_OP(#Type "HandleOp") \ + .Attr("container: string = ''") \ + .Attr("shared_name: string = ''") \ + .Output("resource: resource") \ + .SetIsStateful() \ + .SetShapeFn(tensorflow::shape_inference::ScalarShape) \ + .Doc("Creates a handle to a " #Type) + +// Utility op kernel to produce a handle to a resource of type T. +template +class ResourceHandleOp : public OpKernel { + public: + explicit ResourceHandleOp(OpKernelConstruction* context); + + void Compute(OpKernelContext* ctx) override; + + private: + string container_; + string name_; +}; + +// Registers a kernel for an op which produces a handle to a resource of the +// specified type. +#define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \ + REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \ + ResourceHandleOp) + +// Implementation details below. + +template +void CheckDeriveFromResourceBase() { + static_assert(std::is_base_of::value, + "T must derive from ResourceBase"); +} + +template +Status ResourceMgr::Create(const string& container, const string& name, + T* resource) { + CheckDeriveFromResourceBase(); + CHECK(resource != nullptr); + return DoCreate(container, MakeTypeIndex(), name, resource); +} + +template +Status ResourceMgr::Lookup(const string& container, const string& name, + T** resource) const { + CheckDeriveFromResourceBase(); + ResourceBase* found = nullptr; + Status s = DoLookup(container, MakeTypeIndex(), name, &found); + if (s.ok()) { + // It's safe to down cast 'found' to T* since + // typeid(T).hash_code() is part of the map key. + *resource = static_cast(found); + } + return s; +} + +template +Status ResourceMgr::LookupOrCreate(const string& container, const string& name, + T** resource, + std::function creator) { + Status s; + *resource = nullptr; + while (*resource == nullptr) { + s = Lookup(container, name, resource); + if (s.ok()) break; + s = creator(resource); + if (!s.ok()) break; + s = Create(container, name, *resource); + if (s.ok()) { + (*resource)->Ref(); + break; + } + // Rare event. Concurrent racy creation. Redo the lookup. + *resource = nullptr; + } + return s; +} + +template +Status ResourceMgr::Delete(const string& container, const string& name) { + CheckDeriveFromResourceBase(); + return DoDelete(container, MakeTypeIndex(), name); +} + +template +Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name, + T** resource) { + DataType dtype; + TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype)); + if (dtype == DT_RESOURCE) { + const Tensor* handle; + TF_RETURN_IF_ERROR(ctx->input(input_name, &handle)); + return LookupResource(ctx, handle->scalar()(), resource); + } + string container; + string shared_name; + { + mutex* mu; + TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu)); + mutex_lock l(*mu); + Tensor tensor; + TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true)); + if (tensor.NumElements() != 2) { + return errors::InvalidArgument( + "Resource handle must have 2 elements, but had shape: ", + tensor.shape().DebugString()); + } + container = tensor.flat()(0); + shared_name = tensor.flat()(1); + } + return ctx->resource_manager()->Lookup(container, shared_name, resource); +} + +template +ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx, + const string& name) { + return MakeResourceHandle(ctx, ctx->step_container()->name(), name); +} + +namespace internal { + +Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p); + +template +Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) { + TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); + auto type_index = MakeTypeIndex(); + if (type_index.hash_code() != p.hash_code()) { + return errors::InvalidArgument( + "Trying to access resource using the wrong type. Expected ", + p.maybe_type_name(), " got ", type_index.name()); + } + return Status::OK(); +} + +} // namespace internal + +template +Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) { + TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); + return ctx->resource_manager()->Create(p.container(), p.name(), value); +} + +template +Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + T** value) { + TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); + return ctx->resource_manager()->Lookup(p.container(), p.name(), value); +} + +template +Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, + T** value, std::function creator) { + TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); + return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value, + creator); +} + +template +Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { + TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); + return ctx->resource_manager()->Delete(p.container(), p.name()); +} + +Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); + +template +void IsResourceInitialized::Compute(OpKernelContext* ctx) { + Tensor* output; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output)); + T* object; + bool found; + if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) { + found = true; + object->Unref(); + } else { + found = false; + } + + output->flat()(0) = found; +} + +template +ResourceHandleOp::ResourceHandleOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("container", &container_)); + OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_)); +} + +template +void ResourceHandleOp::Compute(OpKernelContext* ctx) { + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); + output->scalar()() = + MakeResourceHandle(ctx, container_, name_); +} + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_ diff --git a/resource_mgr_test.cc b/resource_mgr_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..07272e2374cbf4fb46c5b8da5df73ef4d6858c62 --- /dev/null +++ b/resource_mgr_test.cc @@ -0,0 +1,324 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/resource_mgr.h" + +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class Resource : public ResourceBase { + public: + explicit Resource(const string& label) : label_(label) {} + ~Resource() override {} + + string DebugString() override { return strings::StrCat("R/", label_); } + + private: + string label_; +}; + +class Other : public ResourceBase { + public: + explicit Other(const string& label) : label_(label) {} + ~Other() override {} + + string DebugString() override { return strings::StrCat("O/", label_); } + + private: + string label_; +}; + +template +string Find(const ResourceMgr& rm, const string& container, + const string& name) { + T* r; + TF_CHECK_OK(rm.Lookup(container, name, &r)); + const string ret = r->DebugString(); + r->Unref(); + return ret; +} + +template +string LookupOrCreate(ResourceMgr* rm, const string& container, + const string& name, const string& label) { + T* r; + TF_CHECK_OK(rm->LookupOrCreate(container, name, &r, [&label](T** ret) { + *ret = new T(label); + return Status::OK(); + })); + const string ret = r->DebugString(); + r->Unref(); + return ret; +} + +static void HasError(const Status& s, const string& substr) { + EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) + << s << ", expected substring " << substr; +} + +template +Status FindErr(const ResourceMgr& rm, const string& container, + const string& name) { + T* r; + Status s = rm.Lookup(container, name, &r); + CHECK(!s.ok()); + return s; +} + +TEST(ResourceMgrTest, Basic) { + ResourceMgr rm; + TF_CHECK_OK(rm.Create("foo", "bar", new Resource("cat"))); + TF_CHECK_OK(rm.Create("foo", "baz", new Resource("dog"))); + TF_CHECK_OK(rm.Create("foo", "bar", new Other("tiger"))); + + // Expected to fail. + HasError(rm.Create("foo", "bar", new Resource("kitty")), + "Already exists: Resource foo/bar"); + + // Expected to be found. + EXPECT_EQ("R/cat", Find(rm, "foo", "bar")); + EXPECT_EQ("R/dog", Find(rm, "foo", "baz")); + EXPECT_EQ("O/tiger", Find(rm, "foo", "bar")); + + // Expected to be not found. + HasError(FindErr(rm, "bar", "foo"), "Not found: Container bar"); + HasError(FindErr(rm, "foo", "xxx"), "Not found: Resource foo/xxx"); + HasError(FindErr(rm, "foo", "baz"), "Not found: Resource foo/baz"); + + // Delete foo/bar/Resource. + TF_CHECK_OK(rm.Delete("foo", "bar")); + HasError(FindErr(rm, "foo", "bar"), "Not found: Resource foo/bar"); + + TF_CHECK_OK(rm.Create("foo", "bar", new Resource("kitty"))); + EXPECT_EQ("R/kitty", Find(rm, "foo", "bar")); + + // Drop the whole container foo. + TF_CHECK_OK(rm.Cleanup("foo")); + HasError(FindErr(rm, "foo", "bar"), "Not found: Container foo"); + + // Dropping it a second time is OK. + TF_CHECK_OK(rm.Cleanup("foo")); + HasError(FindErr(rm, "foo", "bar"), "Not found: Container foo"); + + // Dropping a non-existent container is also ok. + TF_CHECK_OK(rm.Cleanup("bar")); +} + +TEST(ResourceMgr, CreateOrLookup) { + ResourceMgr rm; + EXPECT_EQ("R/cat", LookupOrCreate(&rm, "foo", "bar", "cat")); + EXPECT_EQ("R/cat", LookupOrCreate(&rm, "foo", "bar", "dog")); + EXPECT_EQ("R/cat", Find(rm, "foo", "bar")); + + EXPECT_EQ("O/tiger", LookupOrCreate(&rm, "foo", "bar", "tiger")); + EXPECT_EQ("O/tiger", LookupOrCreate(&rm, "foo", "bar", "lion")); + TF_CHECK_OK(rm.Delete("foo", "bar")); + HasError(FindErr(rm, "foo", "bar"), "Not found: Resource foo/bar"); +} + +Status ComputePolicy(const string& attr_container, + const string& attr_shared_name, + bool use_node_name_as_default, string* result) { + ContainerInfo cinfo; + ResourceMgr rmgr; + NodeDef ndef; + ndef.set_name("foo"); + if (attr_container != "none") { + AddNodeAttr("container", attr_container, &ndef); + } + if (attr_shared_name != "none") { + AddNodeAttr("shared_name", attr_shared_name, &ndef); + } + TF_RETURN_IF_ERROR(cinfo.Init(&rmgr, ndef, use_node_name_as_default)); + *result = cinfo.DebugString(); + return Status::OK(); +} + +string Policy(const string& attr_container, const string& attr_shared_name, + bool use_node_name_as_default) { + string ret; + TF_CHECK_OK(ComputePolicy(attr_container, attr_shared_name, + use_node_name_as_default, &ret)); + return ret; +} + +TEST(ContainerInfo, Basic) { + // Correct cases. + EXPECT_EQ(Policy("", "", false), "[localhost,_0_foo,private]"); + EXPECT_EQ(Policy("", "", true), "[localhost,foo,public]"); + EXPECT_EQ(Policy("", "bar", false), "[localhost,bar,public]"); + EXPECT_EQ(Policy("", "bar", true), "[localhost,bar,public]"); + EXPECT_EQ(Policy("cat", "", false), "[cat,_1_foo,private]"); + EXPECT_EQ(Policy("cat", "", true), "[cat,foo,public]"); + EXPECT_EQ(Policy("cat", "bar", false), "[cat,bar,public]"); + EXPECT_EQ(Policy("cat", "bar", true), "[cat,bar,public]"); + EXPECT_EQ(Policy("cat.0-dog", "bar", true), "[cat.0-dog,bar,public]"); + EXPECT_EQ(Policy(".cat", "bar", true), "[.cat,bar,public]"); +} + +Status WrongPolicy(const string& attr_container, const string& attr_shared_name, + bool use_node_name_as_default) { + string dbg; + auto s = ComputePolicy(attr_container, attr_shared_name, + use_node_name_as_default, &dbg); + CHECK(!s.ok()); + return s; +} + +TEST(ContainerInfo, Error) { + // Missing attribute. + HasError(WrongPolicy("none", "", false), "No attr"); + HasError(WrongPolicy("", "none", false), "No attr"); + HasError(WrongPolicy("none", "none", false), "No attr"); + + // Invalid container. + HasError(WrongPolicy("12$%", "", false), "container contains invalid char"); + HasError(WrongPolicy("-cat", "", false), "container contains invalid char"); + + // Invalid shared name. + HasError(WrongPolicy("", "_foo", false), "shared_name cannot start with '_'"); +} + +// Stub DeviceBase subclass which only sets a device name, for testing resource +// handles. +class StubDevice : public DeviceBase { + public: + explicit StubDevice(const string& name) : DeviceBase(nullptr) { + attr_.set_name(name); + } + + Allocator* GetAllocator(AllocatorAttributes) override { + return cpu_allocator(); + } + + const DeviceAttributes& attributes() const override { return attr_; } + + private: + DeviceAttributes attr_; +}; + +// Empty stub resource for testing resource handles. +class StubResource : public ResourceBase { + public: + string DebugString() override { return ""; } + int value_{0}; +}; + +TEST(ResourceHandleTest, CRUD) { + ResourceMgr resource_mgr(""); + OpKernelContext::Params params; + params.resource_manager = &resource_mgr; + StubDevice device("device_name"); + params.device = &device; + OpKernelContext ctx(¶ms, 0); + + ResourceHandle p = + MakeResourceHandle(&ctx, "container", "name"); + + { + auto* r = new StubResource(); + r->value_ = 42; + TF_EXPECT_OK(CreateResource(&ctx, p, r)); + } + { + StubResource* r = nullptr; + TF_ASSERT_OK(LookupResource(&ctx, p, &r)); + ASSERT_TRUE(r != nullptr); + EXPECT_EQ(r->value_, 42); + r->Unref(); + } + { + TF_EXPECT_OK(DeleteResource(&ctx, p)); + StubResource* unused = nullptr; + EXPECT_FALSE(LookupResource(&ctx, p, &unused).ok()); + } +} + +TEST(ResourceHandleTest, DifferentDevice) { + ResourceMgr resource_mgr(""); + OpKernelContext::Params params; + params.resource_manager = &resource_mgr; + StubDevice device("device_name"); + params.device = &device; + OpKernelContext ctx(¶ms, 0); + + ResourceHandle p = + MakeResourceHandle(&ctx, "container", "name"); + + ResourceMgr other_resource_mgr(""); + OpKernelContext::Params other_params; + other_params.resource_manager = &other_resource_mgr; + StubDevice other_device("other_device_name"); + other_params.device = &other_device; + OpKernelContext other_ctx(&other_params, 0); + + auto* r = new StubResource(); + ASSERT_FALSE(CreateResource(&other_ctx, p, r).ok()); + r->Unref(); +} + +// Other stub resource to test type-checking of resource handles. +class OtherStubResource : public ResourceBase { + public: + string DebugString() override { return ""; } +}; + +TEST(ResourceHandleTest, DifferentType) { + ResourceMgr resource_mgr(""); + OpKernelContext::Params params; + params.resource_manager = &resource_mgr; + StubDevice device("device_name"); + params.device = &device; + OpKernelContext ctx(¶ms, 0); + + ResourceHandle p = + MakeResourceHandle(&ctx, "container", "name"); + + auto* r = new OtherStubResource; + ASSERT_FALSE(CreateResource(&ctx, p, r).ok()); + r->Unref(); +} + +TEST(ResourceHandleTest, DeleteUsingResourceHandle) { + ResourceMgr resource_mgr(""); + OpKernelContext::Params params; + params.resource_manager = &resource_mgr; + StubDevice device("device_name"); + params.device = &device; + OpKernelContext ctx(¶ms, 0); + + ResourceHandle p = + MakeResourceHandle(&ctx, "container", "name"); + + StubResource* r = new StubResource; + TF_EXPECT_OK(CreateResource(&ctx, p, r)); + + StubResource* lookup_r = nullptr; + TF_EXPECT_OK(LookupResource(&ctx, p, &lookup_r)); + EXPECT_EQ(lookup_r, r); + + TF_EXPECT_OK(DeleteResource(&ctx, p)); + EXPECT_NE(LookupResource(&ctx, p, &lookup_r).ok(), true); + r->Unref(); +} + +} // end namespace tensorflow diff --git a/resource_op_kernel.h b/resource_op_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..813ec6eed58e975ec1dda0e1a61f01a37414a56f --- /dev/null +++ b/resource_op_kernel.h @@ -0,0 +1,128 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_ +#define TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// ResourceOpKernel is a virtual base class for resource op implementing +// interface type T. The inherited op looks up the resource name (determined by +// ContainerInfo), and creates a new resource if necessary. +// +// Requirements: +// - Op must be marked as stateful. +// - Op must have `container` and `shared_name` attributes. Empty `container` +// means using the default container. Empty `shared_name` means private +// resource. +// - Subclass must override CreateResource(). +// - Subclass is encouraged to override VerifyResource(). +template +class ResourceOpKernel : public OpKernel { + public: + explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_STRING, TensorShape({2}), + &handle_, nullptr)); + } + + // The resource is deleted from the resource manager only when it is private + // to kernel. Ideally the resource should be deleted when it is no longer held + // by anyone, but it would break backward compatibility. + ~ResourceOpKernel() override { + if (resource_ != nullptr) { + resource_->Unref(); + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->template Delete(cinfo_.container(), cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + } + + void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (resource_ == nullptr) { + ResourceMgr* mgr = context->resource_manager(); + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); + + T* resource; + OP_REQUIRES_OK( + context, + mgr->LookupOrCreate(cinfo_.container(), cinfo_.name(), &resource, + [this](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Status s = CreateResource(ret); + if (!s.ok() && *ret != nullptr) { + CHECK((*ret)->Unref()); + } + return s; + })); + + Status s = VerifyResource(resource); + if (TF_PREDICT_FALSE(!s.ok())) { + resource->Unref(); + context->SetStatus(s); + return; + } + + auto h = handle_.AccessTensor(context)->template flat(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + resource_ = resource; + } + if (context->expected_output_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK(context, MakeResourceHandleToOutput( + context, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } else { + context->set_output_ref(0, &mu_, handle_.AccessTensor(context)); + } + } + + protected: + // Variables accessible from subclasses. + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + T* resource_ GUARDED_BY(mu_) = nullptr; + + private: + // Must return a T descendant allocated with new that ResourceOpKernel will + // take ownership of. + virtual Status CreateResource(T** resource) EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; + + // During the first Compute(), resource is either created or looked up using + // shared_name. In the latter case, the resource found should be verified if + // it is compatible with this op's configuration. The verification may fail in + // cases such as two graphs asking queues of the same shared name to have + // inconsistent capacities. + virtual Status VerifyResource(T* resource) { return Status::OK(); } + + PersistentTensor handle_ GUARDED_BY(mu_); +}; +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_ diff --git a/resource_op_kernel_test.cc b/resource_op_kernel_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c1e503dc57643d2023d89f317a6c5ff643a3c60b --- /dev/null +++ b/resource_op_kernel_test.cc @@ -0,0 +1,202 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/resource_op_kernel.h" + +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { +namespace { + +// Stub DeviceBase subclass which only returns allocators. +class StubDevice : public DeviceBase { + public: + StubDevice() : DeviceBase(nullptr) {} + + Allocator* GetAllocator(AllocatorAttributes) override { + return cpu_allocator(); + } +}; + +// Stub resource for testing resource op kernel. +class StubResource : public ResourceBase { + public: + string DebugString() override { return ""; } + int code; +}; + +class StubResourceOpKernel : public ResourceOpKernel { + public: + using ResourceOpKernel::ResourceOpKernel; + + StubResource* resource() LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + return resource_; + } + + private: + Status CreateResource(StubResource** resource) override { + *resource = CHECK_NOTNULL(new StubResource); + return GetNodeAttr(def(), "code", &(*resource)->code); + } + + Status VerifyResource(StubResource* resource) override { + int code; + TF_RETURN_IF_ERROR(GetNodeAttr(def(), "code", &code)); + if (code != resource->code) { + return errors::InvalidArgument("stub has code ", resource->code, + " but requested code ", code); + } + return Status::OK(); + } +}; + +REGISTER_OP("StubResourceOp") + .Attr("code: int") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("output: Ref(string)"); + +REGISTER_KERNEL_BUILDER(Name("StubResourceOp").Device(DEVICE_CPU), + StubResourceOpKernel); + +class ResourceOpKernelTest : public ::testing::Test { + protected: + std::unique_ptr CreateOp(int code, + const string& shared_name) { + NodeDef node_def; + TF_CHECK_OK( + NodeDefBuilder(strings::StrCat("test-node", count_++), "StubResourceOp") + .Attr("code", code) + .Attr("shared_name", shared_name) + .Finalize(&node_def)); + Status status; + std::unique_ptr op(CreateOpKernel( + DEVICE_CPU, &device_, device_.GetAllocator(AllocatorAttributes()), + node_def, TF_GRAPH_DEF_VERSION, &status)); + TF_EXPECT_OK(status) << status; + EXPECT_TRUE(op != nullptr); + + // Downcast to StubResourceOpKernel to call resource() later. + std::unique_ptr resource_op( + dynamic_cast(op.get())); + EXPECT_TRUE(resource_op != nullptr); + if (resource_op != nullptr) { + op.release(); + } + return resource_op; + } + + Status RunOpKernel(OpKernel* op) { + OpKernelContext::Params params; + + params.device = &device_; + params.resource_manager = &mgr_; + params.op_kernel = op; + + OpKernelContext context(¶ms); + op->Compute(&context); + return context.status(); + } + + StubDevice device_; + ResourceMgr mgr_; + int count_ = 0; +}; + +TEST_F(ResourceOpKernelTest, PrivateResource) { + // Empty shared_name means private resource. + const int code = -100; + auto op = CreateOp(code, ""); + ASSERT_TRUE(op != nullptr); + TF_EXPECT_OK(RunOpKernel(op.get())); + + // Default non-shared name provided from ContainerInfo. + const string key = "_0_" + op->name(); + + StubResource* resource; + TF_ASSERT_OK( + mgr_.Lookup(mgr_.default_container(), key, &resource)); + EXPECT_EQ(op->resource(), resource); // Check resource identity. + EXPECT_EQ(code, resource->code); // Check resource stored information. + resource->Unref(); + + // Destroy the op kernel. Expect the resource to be released. + op = nullptr; + Status s = + mgr_.Lookup(mgr_.default_container(), key, &resource); + + EXPECT_FALSE(s.ok()); +} + +TEST_F(ResourceOpKernelTest, SharedResource) { + const string shared_name = "shared_stub"; + const int code = -201; + auto op = CreateOp(code, shared_name); + ASSERT_TRUE(op != nullptr); + TF_EXPECT_OK(RunOpKernel(op.get())); + + StubResource* resource; + TF_ASSERT_OK(mgr_.Lookup(mgr_.default_container(), shared_name, + &resource)); + EXPECT_EQ(op->resource(), resource); // Check resource identity. + EXPECT_EQ(code, resource->code); // Check resource stored information. + resource->Unref(); + + // Destroy the op kernel. Expect the resource not to be released. + op = nullptr; + TF_ASSERT_OK(mgr_.Lookup(mgr_.default_container(), shared_name, + &resource)); + resource->Unref(); +} + +TEST_F(ResourceOpKernelTest, LookupShared) { + auto op1 = CreateOp(-333, "shared_stub"); + auto op2 = CreateOp(-333, "shared_stub"); + ASSERT_TRUE(op1 != nullptr); + ASSERT_TRUE(op2 != nullptr); + + TF_EXPECT_OK(RunOpKernel(op1.get())); + TF_EXPECT_OK(RunOpKernel(op2.get())); + EXPECT_EQ(op1->resource(), op2->resource()); +} + +TEST_F(ResourceOpKernelTest, VerifyResource) { + auto op1 = CreateOp(-444, "shared_stub"); + auto op2 = CreateOp(0, "shared_stub"); // Different resource code. + ASSERT_TRUE(op1 != nullptr); + ASSERT_TRUE(op2 != nullptr); + + TF_EXPECT_OK(RunOpKernel(op1.get())); + EXPECT_FALSE(RunOpKernel(op2.get()).ok()); + EXPECT_TRUE(op1->resource() != nullptr); + EXPECT_TRUE(op2->resource() == nullptr); +} + +} // namespace +} // namespace tensorflow diff --git a/selective_registration.h b/selective_registration.h new file mode 100644 index 0000000000000000000000000000000000000000..503947969d3fd330fcbfcedd605abf193922fb54 --- /dev/null +++ b/selective_registration.h @@ -0,0 +1,58 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_ +#define TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_ + +#include + +#ifdef SELECTIVE_REGISTRATION + +// Experimental selective registration support to reduce binary size. +// +// To use selective registration, when building: +// 1. define SELECTIVE_REGISTRATION, e.g. in gcc by passing +// -DSELECTIVE_REGISTRATION to compilation. +// 2. Provide ops_to_register.h. This file is not included in the repo and must +// be placed by the user or a tool where the compiler can find it. It must +// define the constants and functions used in the macros below. The +// functions should be defined as valid constexpr functions, so that they are +// evaluated at compile time: this is needed to make symbols referenced by +// un-registered objects unused, and therefore allow the linker to strip them +// out. See python/tools/print_selective_registration_header.py for a tool +// that can be used to generate ops_to_register.h. +// +// ops_to_register.h should define macros for: +// // Ops for which this is false will not be registered. +// SHOULD_REGISTER_OP(op) +// // If this is false, then no gradient ops are registered. +// SHOULD_REGISTER_OP_GRADIENT +// // Op kernel classes where this is false won't be registered. +// SHOULD_REGISTER_OP_KERNEL(clz) +// The macros should be defined using constexprs. + +#include "ops_to_register.h" + +#if (!defined(SHOULD_REGISTER_OP) || !defined(SHOULD_REGISTER_OP_GRADIENT) || \ + !defined(SHOULD_REGISTER_OP_KERNEL)) +static_assert(false, "ops_to_register.h must define SHOULD_REGISTER macros"); +#endif +#else +#define SHOULD_REGISTER_OP(op) true +#define SHOULD_REGISTER_OP_GRADIENT true +#define SHOULD_REGISTER_OP_KERNEL(clz) true +#endif + +#endif // TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_ diff --git a/session_state.h b/session_state.h new file mode 100644 index 0000000000000000000000000000000000000000..8fbe940f6aefe70cad3016ecc99528bae1b761dd --- /dev/null +++ b/session_state.h @@ -0,0 +1,87 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_SESSION_STATE_H_ +#define TENSORFLOW_FRAMEWORK_SESSION_STATE_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// The session state remembers the tensors we choose to keep across +// multiple run calls. +class SessionState { + public: + // Get a tensor from the session state. + Status GetTensor(const string& handle, Tensor* tensor); + + // Store a tensor in the session state. + Status AddTensor(const string& handle, const Tensor& tensor); + + // Delete a tensdor from the session state. + Status DeleteTensor(const string& handle); + + int64 GetNewId(); + + static const char* kTensorHandleResourceTypeName; + + private: + mutex state_lock_; + + // For generating unique ids for tensors stored in the session. + int64 tensor_id_ = 0; + + // The live tensors in the session. A map from tensor handle to tensor. + std::unordered_map tensors_; +}; + +// The tensor store remembers the tensors we choose to keep for the +// current run call. It is available to every op kernel. +class TensorStore { + public: + struct TensorAndKey { + Tensor tensor; + int64 id; + string device_name; + + string GetHandle(const string& tensor_name) { + return strings::StrCat(tensor_name, ";", id, ";", device_name); + } + }; + + // Add the named tensor to the tensor store for this run. + Status AddTensor(const string& name, const TensorAndKey& tk); + + // Save the tensors in the tensor store of this run to the session. + Status SaveTensors(const std::vector& output_names, + SessionState* session_state); + + private: + mutex lock_; + + // The tensors that will be saved to session state when this run completes. + // A map from tensor string name to tensor. + std::unordered_map tensors_ GUARDED_BY(lock_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_SESSION_STATE_H_ diff --git a/shape_inference.cc b/shape_inference.cc new file mode 100644 index 0000000000000000000000000000000000000000..c13f13a126f148fa6d23dcb80c2fae8e8ecbcf3c --- /dev/null +++ b/shape_inference.cc @@ -0,0 +1,1203 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/shape_inference.h" + +#include "tensorflow/core/framework/node_def.pb_text.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { +namespace shape_inference { + +constexpr int32 InferenceContext::kUnknownRank; +constexpr int64 InferenceContext::kUnknownDim; + +InferenceContext::InferenceContext( + int graph_def_version, const NodeDef* node_def, const OpDef& op_def, + const std::vector& input_shapes, + const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes, + const std::vector< + std::unique_ptr>>>& + input_handle_shapes_and_types) + : graph_def_version_(graph_def_version), + node_def_(CHECK_NOTNULL(node_def)) { + std::vector input_tensors_as_shape_handles; + for (const TensorShapeProto& p : input_tensors_as_shapes) { + ShapeHandle shape; + construction_status_.Update(MakeShapeFromShapeProto(p, &shape)); + if (!construction_status_.ok()) { + return; + } + input_tensors_as_shape_handles.push_back(shape); + } + PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles); + if (!construction_status_.ok()) return; + for (const TensorShapeProto& p : input_shapes) { + ShapeHandle shape; + construction_status_.Update(MakeShapeFromShapeProto(p, &shape)); + if (!construction_status_.ok()) { + return; + } + inputs_.push_back(shape); + } + + std::vector>> handle_data( + input_shapes.size()); + for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) { + const auto& v = input_handle_shapes_and_types[i]; + if (v == nullptr) { + continue; + } + handle_data[i].reset(new std::vector(v->size())); + auto& new_v = *handle_data[i]; + for (int j = 0; j < v->size(); ++j) { + const auto& p = (*v)[j]; + construction_status_.Update( + MakeShapeFromShapeProto(p.first, &new_v[j].shape)); + if (!construction_status_.ok()) { + return; + } + new_v[j].dtype = p.second; + } + } + PostInputInit(std::move(handle_data)); +} + +// Same as above, but with PartialTensorShape instead of TensorShapeProto +InferenceContext::InferenceContext( + int graph_def_version, const NodeDef* node_def, const OpDef& op_def, + const std::vector& input_shapes, + const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes, + const std::vector< + std::unique_ptr>>>& + input_handle_shapes_and_types) + : graph_def_version_(graph_def_version), + node_def_(CHECK_NOTNULL(node_def)) { + std::vector input_tensors_as_shape_handles; + for (const PartialTensorShape& p : input_tensors_as_shapes) { + ShapeHandle shape; + construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape)); + if (!construction_status_.ok()) { + return; + } + input_tensors_as_shape_handles.push_back(shape); + } + PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles); + if (!construction_status_.ok()) return; + for (const PartialTensorShape& p : input_shapes) { + ShapeHandle shape; + construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape)); + if (!construction_status_.ok()) { + return; + } + inputs_.push_back(shape); + } + std::vector>> handle_data( + input_shapes.size()); + for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) { + const auto& v = input_handle_shapes_and_types[i]; + if (v == nullptr) { + continue; + } + handle_data[i].reset(new std::vector(v->size())); + auto& new_v = *handle_data[i]; + for (int j = 0; j < v->size(); ++j) { + const auto& p = (*v)[j]; + construction_status_.Update( + MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape)); + if (!construction_status_.ok()) { + return; + } + new_v[j].dtype = p.second; + } + } + PostInputInit(std::move(handle_data)); +} + +InferenceContext::InferenceContext( + int graph_def_version, const NodeDef* node_def, const OpDef& op_def, + const std::vector& input_shapes, + const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes, + std::vector>> + input_handle_shapes_and_types) + : graph_def_version_(graph_def_version), + node_def_(CHECK_NOTNULL(node_def)) { + PreInputInit(op_def, input_tensors, input_tensors_as_shapes); + if (!construction_status_.ok()) return; + inputs_ = input_shapes; + + PostInputInit(std::move(input_handle_shapes_and_types)); +} + +InferenceContext::~InferenceContext() {} + +Status InferenceContext::Run( + const std::function& fn) { + Status s = fn(this); + if (!s.ok()) { + return AttachContext(s); + } +#ifndef NDEBUG + for (int i = 0; i < num_outputs(); ++i) { + DCHECK(output(i).IsSet()) + << i << " for " << node_def_->name() << " of type " << node_def_->op(); + } +#endif // NDEBUG + return s; +} + +Status InferenceContext::set_output(StringPiece output_name, + const std::vector& shapes) { + const auto result = output_name_map_.find(output_name.ToString()); + if (result == output_name_map_.end()) { + return errors::InvalidArgument("Unknown output name: ", output_name); + } else { + const int start = result->second.first; + const int size = result->second.second - start; + if (size != shapes.size()) { + return errors::InvalidArgument("Must have exactly ", shapes.size(), + " shapes."); + } + for (int i = 0; i < size; ++i) { + outputs_[i + start] = shapes[i]; + } + } + return Status::OK(); +} + +Status InferenceContext::input(StringPiece input_name, + std::vector* output) const { + const auto result = input_name_map_.find(input_name.ToString()); + if (result == input_name_map_.end()) { + return errors::InvalidArgument("Unknown input name: ", input_name); + } else { + output->clear(); + for (int i = result->second.first; i < result->second.second; ++i) { + output->push_back(inputs_[i]); + } + } + return Status::OK(); +} + +Status InferenceContext::output(StringPiece output_name, + std::vector* output) const { + const auto result = output_name_map_.find(output_name.ToString()); + if (result == output_name_map_.end()) { + return errors::InvalidArgument("Unknown output name: ", output_name); + } else { + output->clear(); + for (int i = result->second.first; i < result->second.second; ++i) { + output->push_back(outputs_[i]); + } + } + return Status::OK(); +} + +string InferenceContext::op() const { return node_def_->op(); } + +void InferenceContext::PreInputInit( + const OpDef& op_def, const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes) { + input_tensors_ = input_tensors; + input_tensors_as_shapes_ = input_tensors_as_shapes; + + construction_status_ = NameRangesForNode(*node_def_, op_def, &input_name_map_, + &output_name_map_); + if (!construction_status_.ok()) return; + + int num_outputs = 0; + for (const auto& e : output_name_map_) { + num_outputs = std::max(num_outputs, e.second.second); + } + for (int i = 0; i < num_outputs; ++i) { + outputs_.push_back(nullptr); + } + output_handle_shapes_and_types_.resize(num_outputs); +} + +void InferenceContext::PostInputInit( + std::vector>> input_handle_data) { + int num_inputs_from_node_def = 0; + for (const auto& e : input_name_map_) { + num_inputs_from_node_def = + std::max(num_inputs_from_node_def, e.second.second); + } + + // Allow passing empty shapes/dtypes to avoid changing every single test. + if (input_handle_data.empty()) { + input_handle_shapes_and_types_.resize(inputs_.size()); + } else { + if (input_handle_data.size() != inputs_.size()) { + construction_status_ = errors::InvalidArgument( + "Wrong number of handle shapes passed; expected ", inputs_.size(), + " got ", input_handle_data.size()); + return; + } + input_handle_shapes_and_types_ = std::move(input_handle_data); + } + + if (inputs_.size() != num_inputs_from_node_def) { + construction_status_ = errors::InvalidArgument( + "Wrong number of inputs passed: ", inputs_.size(), " while ", + num_inputs_from_node_def, " expected based on NodeDef"); + return; + } + + CHECK_LE(input_tensors_.size(), inputs_.size()); + input_tensors_.resize(inputs_.size()); + requested_input_tensor_.resize(inputs_.size()); + requested_input_tensor_as_partial_shape_.resize(inputs_.size()); +} + +void InferenceContext::ShapeHandleToProto(ShapeHandle handle, + TensorShapeProto* proto) { + if (!RankKnown(handle)) { + proto->set_unknown_rank(true); + return; + } + + for (int32 i = 0; i < Rank(handle); ++i) { + DimensionHandle dim = Dim(handle, i); + auto* dim_shape = proto->add_dim(); + if (ValueKnown(dim)) { + dim_shape->set_size(Value(dim)); + } else { + dim_shape->set_size(-1); + } + } +} + +bool InferenceContext::FullyDefined(ShapeHandle s) { + if (!RankKnown(s)) return false; + for (int i = 0; i < Rank(s); ++i) { + if (!ValueKnown(Dim(s, i))) return false; + } + return true; +} + +DimensionHandle InferenceContext::NumElements(ShapeHandle s) { + const auto rank = Rank(s); + if (rank == kUnknownRank) return UnknownDim(); + int64 size = 1; + for (int i = 0; i < rank; ++i) { + int64 dim_val = Value(Dim(s, i)); + if (dim_val == kUnknownDim) return UnknownDim(); + size *= dim_val; + } + return MakeDim(size); +} + +string InferenceContext::DebugString(ShapeHandle s) { + if (RankKnown(s)) { + std::vector vals; + for (auto d : s->dims_) vals.push_back(DebugString(d)); + return strings::StrCat("[", str_util::Join(vals, ","), "]"); + } else { + return "?"; + } +} + +string InferenceContext::DebugString(DimensionHandle d) { + return ValueKnown(d) ? strings::StrCat(Value(d)) : "?"; +} + +string InferenceContext::DebugString() const { + return strings::StrCat("InferenceContext for node: ", + ProtoDebugString(*node_def_)); +} + +Status InferenceContext::WithRank(ShapeHandle shape, int64 rank, + ShapeHandle* out) { + if (rank > kint32max) { + return errors::InvalidArgument("Rank cannot exceed kint32max"); + } + const int32 existing = Rank(shape); + if (existing == rank) { + *out = shape; + return Status::OK(); + } + if (existing == kUnknownRank) { + std::vector dims; + dims.reserve(rank); + for (int i = 0; i < rank; ++i) { + dims.push_back(UnknownDim()); + } + ShapeHandle shp = shape_manager_.MakeShape(dims); + return Merge(shape, shp, out); + } + *out = nullptr; + + return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ", + existing); +} + +Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank, + ShapeHandle* out) { + if (rank > kint32max) { + return errors::InvalidArgument("Rank cannot exceed kint32max"); + } + const int32 existing = Rank(shape); + if (existing >= rank || existing == kUnknownRank) { + *out = shape; + return Status::OK(); + } + *out = nullptr; + return errors::InvalidArgument("Shape must be at least rank ", rank, + " but is rank ", existing); +} + +Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank, + ShapeHandle* out) { + if (rank > kint32max) { + return errors::InvalidArgument("Rank cannot exceed kint32max"); + } + const int32 existing = Rank(shape); + if (existing <= rank || existing == kUnknownRank) { + *out = shape; + return Status::OK(); + } + *out = nullptr; + return errors::InvalidArgument("Shape must be at most rank ", rank, + " but is rank ", existing); +} + +Status InferenceContext::WithValue(DimensionHandle dim, int64 value, + DimensionHandle* out) { + const int64 existing = Value(dim); + if (existing == value) { + *out = dim; + return Status::OK(); + } + if (existing == kUnknownDim) { + DimensionHandle d = MakeDim(value); + return Merge(dim, d, out); + } + *out = nullptr; + return errors::InvalidArgument("Dimension must be ", value, " but is ", + existing); +} + +void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new, + DimensionHandle* out) { + if (d_old.SameHandle(d_new)) { + *out = d_old; + } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) { + // The node will be fed by the dimension d_new instead of d_old: any + // equality assertion between d_old and other input dimension on this node + // may not be true anymore, so forget them all. + ForgetMerges(); + // Return the new shape handle to force the relaxation to propagate to the + // fanout of the context. + *out = d_new; + } else if (!ValueKnown(d_new)) { + ForgetMerges(); + *out = d_new; + } else if (Value(d_old) == Value(d_new)) { + // Return the old shape handle. This will stop the relaxation in the fanout + // of the context. + *out = d_old; + } else { + // Return a new handle that encodes a different unknown dim. + ForgetMerges(); + *out = UnknownDim(); + } +} + +Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1, + DimensionHandle* out) { + if (d0.SameHandle(d1)) { + *out = d0; + return Status::OK(); + } else if (!ValueKnown(d1)) { + *out = d0; + merged_dims_.emplace_back(d0, d1); + return Status::OK(); + } else if (!ValueKnown(d0)) { + *out = d1; + merged_dims_.emplace_back(d0, d1); + return Status::OK(); + } else if (Value(d0) == Value(d1)) { + *out = d0; + return Status::OK(); + } else { + *out = nullptr; + return errors::InvalidArgument("Dimensions must be equal, but are ", + Value(d0), " and ", Value(d1)); + } +} + +Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix, + ShapeHandle* s_out, + ShapeHandle* prefix_out) { + *s_out = *prefix_out = nullptr; + if (!RankKnown(prefix) || !RankKnown(s)) { + *s_out = s; + *prefix_out = prefix; + return Status::OK(); + } + const int32 rank = Rank(prefix); + TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s)); + + // Merge the prefix dims and create the new output shapes. + std::vector dims; + dims.resize(rank); + for (int i = 0; i < rank; ++i) { + TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i])); + } + *prefix_out = MakeShape(dims); + for (int i = rank; i < Rank(s); ++i) dims.push_back(Dim(s, i)); + *s_out = MakeShape(dims); + return Status::OK(); +} + +void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new, + ShapeHandle* out) { + if (s_old.SameHandle(s_new)) { + *out = s_old; + return; + } else if (!RankKnown(s_new) || !s_old.IsSet()) { + ForgetMerges(); + *out = s_new; + return; + } + + const int32 rank = Rank(s_old); + if (rank != Rank(s_new)) { + ForgetMerges(); + *out = UnknownShape(); + return; + } + + bool return_s_old = true; + for (int i = 0; i < rank; ++i) { + auto d0 = Dim(s_old, i); + auto d1 = Dim(s_new, i); + if (d0.SameHandle(d1)) continue; + + auto v0 = Value(d0); + auto v1 = Value(d1); + if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) { + return_s_old = false; + break; + } + } + if (return_s_old) { + *out = s_old; + return; + } + + // Relax dims. + std::vector dims(rank); + for (int i = 0; i < rank; ++i) { + Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]); + } + ForgetMerges(); + *out = MakeShape(dims); +} + +Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, + ShapeHandle* out) { + if (s0.SameHandle(s1)) { + *out = s0; + return Status::OK(); + } else if (!RankKnown(s1)) { + *out = s0; + merged_shapes_.emplace_back(s0, s1); + return Status::OK(); + } else if (!RankKnown(s0)) { + *out = s1; + merged_shapes_.emplace_back(s0, s1); + return Status::OK(); + } + + const int32 rank = Rank(s0); + if (rank != Rank(s1)) { + *out = nullptr; + return errors::InvalidArgument("Shapes must be equal rank, but are ", rank, + " and ", Rank(s1)); + } + + bool return_s0 = true; + bool return_s1 = true; + for (int i = 0; i < rank; ++i) { + auto d0 = Dim(s0, i); + auto d1 = Dim(s1, i); + if (d0.SameHandle(d1)) continue; + + auto v0 = Value(d0); + auto v1 = Value(d1); + if (v0 == kUnknownDim) { + if (v1 != kUnknownDim) { + return_s0 = false; + } + } else if (v1 == kUnknownDim) { + return_s1 = false; + } else if (v0 != v1) { + *out = nullptr; + return errors::InvalidArgument( + "Dimension ", i, " in both shapes must be equal, but are ", Value(d0), + " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ", + DebugString(s1), "."); + } + } + + merged_shapes_.emplace_back(s0, s1); + + if (return_s0 || return_s1) { + *out = return_s0 ? s0 : s1; + return Status::OK(); + } + + // Merge dims. + std::vector dims(rank, nullptr); + for (int i = 0; i < rank; ++i) { + // Invariant for merge was checked earlier, so CHECK is ok. + TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i])); + } + + Status s = ReturnCreatedShape(dims, out); + if (s.ok()) { + // Merge the new shape with s0. Since s0 and s1 are merged, this implies + // that s1 and out are also merged. + merged_shapes_.emplace_back(s0, *out); + } + return s; +} + +Status InferenceContext::Subshape(ShapeHandle s, int64 start, + ShapeHandle* out) { + return Subshape(s, start, std::numeric_limits::max() /* end */, out); +} + +Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in, + ShapeHandle* out) { + int64 start = start_in; + int64 end = end_in; + const int32 rank = Rank(s); + if (start == 0 && ((RankKnown(s) && end >= rank) || + end == std::numeric_limits::max())) { + *out = s; + return Status::OK(); + } + if (!RankKnown(s)) { + return ReturnUnknownShape(out); + } + + if (start > rank) start = rank; + if (end > rank) end = rank; + if (start < 0) { + start = rank + start; + if (start < 0) { + *out = nullptr; + return errors::InvalidArgument("Subshape start out of bounds: ", start_in, + ", for shape with rank ", rank); + } + } + + if (end < 0) { + end = rank + end; + if (end < 0) { + *out = nullptr; + return errors::InvalidArgument("Subshape end out of bounds: ", end_in, + ", for shape with rank ", rank); + } + } + if (start > end) { + *out = nullptr; + return errors::InvalidArgument( + "Subshape must have computed start <= end, but is ", start, " and ", + end, " (computed from start ", start_in, " and end ", end_in, + " over shape with rank ", rank, ")"); + } + std::vector dims; + dims.reserve(end - start); + for (int i = start; i < end; ++i) { + dims.push_back(Dim(s, i)); + } + return ReturnCreatedShape(dims, out); +} + +Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2, + ShapeHandle* out) { + if (!RankKnown(s1) || !RankKnown(s2)) { + return ReturnUnknownShape(out); + } + const int32 s1_rank = Rank(s1); + const int32 s2_rank = Rank(s2); + const int32 rank = s1_rank + s2_rank; + std::vector dims; + dims.reserve(rank); + for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i)); + for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i)); + return ReturnCreatedShape(dims, out); +} + +Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in, + DimensionHandle new_dim, ShapeHandle* out) { + if (!RankKnown(s)) { + return ReturnUnknownShape(out); + } + int64 dim_index = dim_index_in; + if (dim_index < 0) { + dim_index = s->dims_.size() + dim_index; + } + if (!FastBoundsCheck(dim_index, s->dims_.size())) { + *out = nullptr; + return errors::InvalidArgument("Out of range dim_index ", dim_index_in, + " for shape with ", s->dims_.size(), + " dimensions"); + } + std::vector dims(s->dims_); + dims[dim_index] = new_dim; + return ReturnCreatedShape(dims, out); +} + +ShapeHandle InferenceContext::MakeShape( + const std::vector& dims) { + return shape_manager_.MakeShape(dims); +} + +ShapeHandle InferenceContext::MakeShape( + std::initializer_list dims) { + std::vector dims_actual; + dims_actual.reserve(dims.size()); + for (const DimensionOrConstant& d : dims) { + dims_actual.push_back(MakeDim(d)); + } + + return shape_manager_.MakeShape(dims_actual); +} + +ShapeHandle InferenceContext::UnknownShape() { + return shape_manager_.UnknownShape(); +} + +ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) { + CHECK_LE(rank, kint32max) << "rank must be less than kint32max"; + if (rank == kUnknownRank) { + return UnknownShape(); + } + CHECK_GE(rank, 0) << "rank must not be negative"; + std::vector dims(rank); + for (int32 i = 0; i < rank; ++i) { + dims[i] = UnknownDim(); + } + return MakeShape(dims); +} + +ShapeHandle InferenceContext::Scalar() { return MakeShape({}); } + +ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) { + return MakeShape({dim}); +} + +ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1, + DimensionOrConstant dim2) { + return MakeShape({dim1, dim2}); +} + +Status InferenceContext::MakeShapeFromShapeTensor(int input_idx, + ShapeHandle* out) { + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape)); + + requested_input_tensor_as_partial_shape_[input_idx] = true; + if (input_idx < input_tensors_as_shapes_.size() && + input_tensors_as_shapes_[input_idx].IsSet() && + RankKnown(input_tensors_as_shapes_[input_idx])) { + *out = input_tensors_as_shapes_[input_idx]; + return Status::OK(); + } + + return MakeShapeFromTensor(input_tensor(input_idx), input_shape, out); +} + +Status InferenceContext::MakeShapeFromTensor(const Tensor* t, + ShapeHandle tensor_shape, + ShapeHandle* out) { + if (t == nullptr) { + // Shape tensor is not known, but if the shape of the shape tensor is then + // the right number of unknown dims can be created. + DimensionHandle shape_dim = Dim(tensor_shape, 0); + if (!ValueKnown(shape_dim)) { + return ReturnUnknownShape(out); + } + const auto num_dims = Value(shape_dim); + std::vector dims; + dims.reserve(num_dims); + for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim()); + return ReturnCreatedShape(dims, out); + } + + if (t->shape().dims() != 1) { + *out = nullptr; + return errors::InvalidArgument("Input tensor must be rank 1, but was rank ", + t->shape().dims()); + } + std::vector dims; + if (t->dtype() == DataType::DT_INT32) { + auto flat_t = t->flat(); + for (int i = 0; i < flat_t.size(); ++i) { + const int32 val = flat_t(i); + if (val < -1) { + return errors::InvalidArgument( + "Invalid value in tensor used for shape: ", val); + } + // -1 will become an unknown dim. + dims.push_back(MakeDim(val)); + } + } else if (t->dtype() == DataType::DT_INT64) { + auto flat_t = t->flat(); + for (int i = 0; i < flat_t.size(); ++i) { + const int64 val = flat_t(i); + if (val < -1) { + return errors::InvalidArgument( + "Invalid value in tensor used for shape: ", val); + } + // -1 will become an unknown dim. + dims.push_back(MakeDim(val)); + } + } else { + *out = nullptr; + return errors::InvalidArgument( + "Input tensor must be int32 or int64, but was ", + DataTypeString(t->dtype())); + } + + return ReturnCreatedShape(dims, out); +} + +Status InferenceContext::MakeShapeFromPartialTensorShape( + const PartialTensorShape& partial_shape, ShapeHandle* out) { + *out = nullptr; + if (partial_shape.dims() == -1) { + return ReturnUnknownShape(out); + } + const int num_dims = partial_shape.dims(); + std::vector dims(num_dims); + for (int i = 0; i < num_dims; ++i) { + // -1 is unknown in PartialTensorShape and in InferenceContext, so this size + // can be passed directly to MakeDim. + dims[i] = MakeDim(partial_shape.dim_size(i)); + } + return ReturnCreatedShape(dims, out); +} + +Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape, + ShapeHandle* out) { + return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()), + out); +} + +Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, + ShapeHandle* out) { + *out = nullptr; + TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto)); + PartialTensorShape partial_shape(proto); + return MakeShapeFromPartialTensorShape(partial_shape, out); +} + +Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) { + // Caller must ensure that is not NULL. + const int rank = t->dims(); + if (rank != 0) { + return errors::InvalidArgument("Input must be scalar but has rank ", rank); + } + + if (t->dtype() == DT_INT32) { + *val = t->scalar()(); + return Status::OK(); + } else if (t->dtype() == DT_INT64) { + *val = t->scalar()(); + return Status::OK(); + } else { + return errors::InvalidArgument( + "Scalar input for dim size must be int32 or int64"); + } +} + +// Returns a new dimension whose value is given by a scalar input tensor. +Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { + int64 val; + const Tensor* t = input_tensor(idx); + if (t == nullptr) { + *out = UnknownDim(); + return Status::OK(); + } + TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val)); + if (val < 0) { + return errors::InvalidArgument("Dimension size, given by scalar input ", + idx, ", must be non-negative but is ", val); + } + *out = MakeDim(val); + return Status::OK(); +} + +Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing( + int idx, int input_rank, DimensionHandle* out) { + int64 val; + const Tensor* t = input_tensor(idx); + if (t == nullptr) { + *out = UnknownDim(); + return Status::OK(); + } + TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val)); + if (val < 0) { + if (input_rank < 0) { + *out = UnknownDim(); + return Status::OK(); + } else if (val + input_rank < 0) { + return errors::InvalidArgument("Dimension size, given by scalar input ", + val, " must be in range [-", input_rank, + ", ", input_rank, ")"); + } else { + val += input_rank; + } + } else if (input_rank >= 0 && val >= input_rank) { + return errors::InvalidArgument("Dimension size, given by scalar input ", + val, " must be in range [-", input_rank, + ", ", input_rank, ")"); + } + *out = MakeDim(val); + return Status::OK(); +} + +Status InferenceContext::Divide(DimensionHandle dividend, + DimensionOrConstant divisor, + bool evenly_divisible, DimensionHandle* out) { + const int64 divisor_value = Value(divisor); + if (divisor_value == 1) { + *out = dividend; + } else if (!ValueKnown(dividend) || + (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) { + *out = UnknownDim(); + } else { + const int64 v = Value(dividend); + if (divisor_value <= 0) { + return errors::InvalidArgument("Divisor must be positive but is ", + divisor_value); + } + if (evenly_divisible && (v % divisor_value) != 0) { + return errors::InvalidArgument( + "Dimension size must be evenly divisible by ", divisor_value, + " but is ", v); + } + *out = MakeDim(v / divisor_value); + } + return Status::OK(); +} + +Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out) { + const int64 first_value = Value(first); + const int64 second_value = Value(second); + // Special cases. + if (first_value == 0) { + *out = MakeDim(second); + } else if (second_value == 0) { + *out = first; + } else if (first_value == kUnknownDim || second_value == kUnknownDim) { + *out = UnknownDim(); + } else { + // Invariant: Both values are known and positive. Still in run-time we can + // get pair of values which cannot be store in output. Check below will + // report error. We still need to avoid undefined behavior of signed + // overflow and use unsigned addition. + const int64 sum = static_cast(first_value) + second_value; + if (sum < 0) { + return errors::InvalidArgument("Dimension size overflow from adding ", + first_value, " and ", second_value); + } + *out = MakeDim(sum); + } + return Status::OK(); +} + +Status InferenceContext::Subtract(DimensionHandle first, + DimensionOrConstant second, + DimensionHandle* out) { + const int64 first_value = Value(first); + const int64 second_value = Value(second); + // Special cases. + if (second_value == 0) { + *out = first; + } else if (first_value == kUnknownDim || second_value == kUnknownDim) { + *out = UnknownDim(); + } else { + // Invariant: Both values are known, first_value is non-negative, and + // second_value is positive. + if (first_value < second_value) { + return errors::InvalidArgument( + "Negative dimension size caused by subtracting ", second_value, + " from ", first_value); + } + *out = MakeDim(first_value - second_value); + } + return Status::OK(); +} + +Status InferenceContext::Multiply(DimensionHandle first, + DimensionOrConstant second, + DimensionHandle* out) { + const int64 first_value = Value(first); + const int64 second_value = Value(second); + // Special cases. + if (first_value == 0) { + *out = first; + } else if (second_value == 0) { + *out = MakeDim(second); + } else if (first_value == 1) { + *out = MakeDim(second); + } else if (second_value == 1) { + *out = first; + } else if (first_value == kUnknownDim || second_value == kUnknownDim) { + *out = UnknownDim(); + } else { + // Invariant: Both values are known and greater than 1. + const int64 product = first_value * second_value; + if (product < 0) { + return errors::InvalidArgument( + "Negative dimension size caused by overflow when multiplying ", + first_value, " and ", second_value); + } + *out = MakeDim(product); + } + return Status::OK(); +} + +Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out) { + const int64 first_value = Value(first); + const int64 second_value = Value(second); + if (first_value == 0) { + *out = first; + } else if (second_value == 0) { + *out = MakeDim(second); + } else if (first_value == kUnknownDim || second_value == kUnknownDim) { + *out = UnknownDim(); + } else { + if (first_value <= second_value) { + *out = first; + } else { + *out = MakeDim(second); + } + } + return Status::OK(); +} + +Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out) { + const int64 first_value = Value(first); + const int64 second_value = Value(second); + if (first_value == kUnknownDim || second_value == kUnknownDim) { + *out = UnknownDim(); + } else { + if (first_value >= second_value) { + *out = first; + } else { + *out = MakeDim(second); + } + } + return Status::OK(); +} + +Status InferenceContext::AttachContext(const Status& status) { + std::vector input_shapes; + for (const ShapeHandle& input_shape : inputs_) { + input_shapes.emplace_back(DebugString(input_shape)); + } + + // Add information about the input tensors and partial tensor shapes used. + std::vector input_from_tensors_str; + std::vector input_from_tensors_as_shape_str; + for (int i = 0; i < inputs_.size(); ++i) { + if (requested_input_tensor_as_partial_shape_[i] && + i < input_tensors_as_shapes_.size() && + input_tensors_as_shapes_[i].IsSet() && + RankKnown(input_tensors_as_shapes_[i])) { + input_from_tensors_as_shape_str.push_back(strings::StrCat( + "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i]))); + } else if (requested_input_tensor_[i] && i < input_tensors_.size() && + input_tensors_[i] != nullptr) { + input_from_tensors_str.push_back(strings::StrCat( + "input[", i, "] = <", + input_tensors_[i]->SummarizeValue(256 /* max_values */), ">")); + } + } + + string error_context = strings::StrCat( + " for '", node_def_->name(), "' (op: '", node_def_->op(), + "') with input shapes: ", str_util::Join(input_shapes, ", ")); + if (!input_from_tensors_str.empty()) { + strings::StrAppend(&error_context, " and with computed input tensors: ", + str_util::Join(input_from_tensors_str, ", ")); + } + if (!input_from_tensors_as_shape_str.empty()) { + strings::StrAppend(&error_context, + " and with input tensors computed as partial shapes: ", + str_util::Join(input_from_tensors_as_shape_str, ",")); + } + + strings::StrAppend(&error_context, "."); + return Status(status.code(), + strings::StrCat(status.error_message(), error_context)); +} + +bool InferenceContext::MergeHandleShapesAndTypes( + const std::vector& shapes_and_types, + std::vector* to_update) { + if (shapes_and_types.size() != to_update->size()) { + return false; + } + std::vector new_values(shapes_and_types.size()); + bool refined = false; + for (int i = 0; i < shapes_and_types.size(); ++i) { + const ShapeAndType& existing = (*to_update)[i]; + if (shapes_and_types[i].dtype == existing.dtype) { + new_values[i].dtype = existing.dtype; + } else { + if (existing.dtype != DT_INVALID) { + return false; + } else { + new_values[i].dtype = shapes_and_types[i].dtype; + refined = true; + } + } + if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape) + .ok()) { + // merge failed, ignore the new value. + new_values[i].shape = existing.shape; + } + if (!existing.shape.SameHandle(new_values[i].shape)) { + refined = true; + } + } + if (!refined) { + return false; + } + for (int i = 0; i < new_values.size(); ++i) { + (*to_update)[i] = new_values[i]; + } + return true; +} + +bool InferenceContext::MergeOutputHandleShapesAndTypes( + int idx, const std::vector& shapes_and_types) { + if (output_handle_shapes_and_types_[idx] == nullptr) { + output_handle_shapes_and_types_[idx].reset( + new std::vector(shapes_and_types)); + return true; + } + return MergeHandleShapesAndTypes(shapes_and_types, + output_handle_shapes_and_types_[idx].get()); +} + +bool InferenceContext::MergeInputHandleShapesAndTypes( + int idx, const std::vector& shapes_and_types) { + if (input_handle_shapes_and_types_[idx] == nullptr) { + input_handle_shapes_and_types_[idx].reset( + new std::vector(shapes_and_types)); + return true; + } + return MergeHandleShapesAndTypes(shapes_and_types, + input_handle_shapes_and_types_[idx].get()); +} + +bool InferenceContext::RelaxHandleShapesAndMergeTypes( + const std::vector& shapes_and_types, + std::vector* to_update) { + if (shapes_and_types.size() != to_update->size()) { + return false; + } + std::vector new_values(shapes_and_types.size()); + bool refined = false; + for (int i = 0; i < shapes_and_types.size(); ++i) { + const ShapeAndType& existing = (*to_update)[i]; + if (shapes_and_types[i].dtype == existing.dtype) { + new_values[i].dtype = existing.dtype; + } else { + if (existing.dtype != DT_INVALID) { + return false; + } else { + new_values[i].dtype = shapes_and_types[i].dtype; + refined = true; + } + } + Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape); + if (!existing.shape.SameHandle(new_values[i].shape)) { + refined = true; + } + } + if (!refined) { + return false; + } + for (int i = 0; i < new_values.size(); ++i) { + (*to_update)[i] = new_values[i]; + } + return true; +} + +bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes( + int idx, const std::vector& shapes_and_types) { + if (output_handle_shapes_and_types_[idx] == nullptr) { + output_handle_shapes_and_types_[idx].reset( + new std::vector(shapes_and_types)); + return true; + } + return RelaxHandleShapesAndMergeTypes( + shapes_and_types, output_handle_shapes_and_types_[idx].get()); +} + +bool InferenceContext::RelaxInputHandleShapesAndMergeTypes( + int idx, const std::vector& shapes_and_types) { + if (input_handle_shapes_and_types_[idx] == nullptr) { + input_handle_shapes_and_types_[idx].reset( + new std::vector(shapes_and_types)); + return true; + } + return RelaxHandleShapesAndMergeTypes( + shapes_and_types, input_handle_shapes_and_types_[idx].get()); +} + +// ----------------------------------------------------------------------------- +// ShapeManager +// ----------------------------------------------------------------------------- +InferenceContext::ShapeManager::ShapeManager() {} +InferenceContext::ShapeManager::~ShapeManager() { + for (auto* s : all_shapes_) delete s; + for (auto* d : all_dims_) delete d; +} + +ShapeHandle InferenceContext::ShapeManager::MakeShape( + const std::vector& dims) { + all_shapes_.push_back(new Shape(dims)); + return all_shapes_.back(); +} + +ShapeHandle InferenceContext::ShapeManager::UnknownShape() { + all_shapes_.push_back(new Shape()); + return all_shapes_.back(); +} + +} // namespace shape_inference +} // namespace tensorflow diff --git a/shape_inference.h b/shape_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..4a4ef12635f867fccb594d50a2c9e8f3059ce337 --- /dev/null +++ b/shape_inference.h @@ -0,0 +1,790 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ + +#include + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +class ShapeRefiner; +class ShapeRefinerTest; + +namespace grappler { +class GraphProperties; +class SymbolicShapeManager; +} + +namespace shape_inference { + +struct DimensionOrConstant; +class InferenceContext; + +// Dimension values are accessed through InferenceContext. +class Dimension { + private: + Dimension(); + Dimension(int64 value); + ~Dimension() {} + + const int64 value_; + + friend class InferenceContext; + friend class ShapeManager; + TF_DISALLOW_COPY_AND_ASSIGN(Dimension); +}; + +class DimensionHandle { + public: + DimensionHandle() {} + bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; } + std::size_t Handle() const { return reinterpret_cast(ptr_); } + + private: + DimensionHandle(const Dimension* dim) { ptr_ = dim; } + + const Dimension* operator->() const { return ptr_; } + bool IsSet() const { return ptr_ != nullptr; } + + const Dimension* ptr_ = nullptr; + + friend struct DimensionOrConstant; + friend class InferenceContext; + friend class ShapeInferenceTest; + friend class ShapeInferenceTestutil; + friend class ::tensorflow::ShapeRefinerTest; + friend class ShapeManager; + friend class ::tensorflow::grappler::GraphProperties; + friend class ::tensorflow::grappler::SymbolicShapeManager; + + // Intentionally copyable. +}; + +// Shape rank and dimensions are accessed through InferenceContext. +class Shape { + private: + Shape(); + Shape(const std::vector& dims); + ~Shape() {} + + const int32 rank_; + const std::vector dims_; + + friend class InferenceContext; + friend class ShapeManager; + friend class ::tensorflow::grappler::SymbolicShapeManager; + + TF_DISALLOW_COPY_AND_ASSIGN(Shape); +}; + +class ShapeHandle { + public: + ShapeHandle() {} + bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; } + std::size_t Handle() const { return reinterpret_cast(ptr_); } + + private: + ShapeHandle(const Shape* shape) { ptr_ = shape; } + const Shape* operator->() const { return ptr_; } + bool IsSet() const { return ptr_ != nullptr; } + + const Shape* ptr_ = nullptr; + + friend class InferenceContext; + friend class ShapeInferenceTest; + friend class ShapeInferenceTestutil; + friend class ::tensorflow::ShapeRefinerTest; + friend class ShapeManager; + friend class ::tensorflow::grappler::SymbolicShapeManager; + + // Intentionally copyable. +}; + +// Struct used to allow functions to take DimensionHandle or a dimension value. +// Not meant to be constructed directly. +struct DimensionOrConstant { + public: + // Intentionally not explicit. + DimensionOrConstant(DimensionHandle dim); + + // val must be non-negative or InferenceContext::kUnknownDim. + DimensionOrConstant(int64 val); + + // dim takes precedence. If dim != nullptr, val is ignored. + DimensionHandle dim; + int64 val; + + private: + DimensionOrConstant(); +}; + +struct ShapeAndType { + ShapeAndType() {} + ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {} + + ShapeHandle shape; + DataType dtype = DT_INVALID; +}; + +// Shape inference functions registered on ops in REGISTER_OP implement +// their shape functions in terms of this InferenceContext. An InferenceContext +// is created by the framework and passed to a shape inference function. The +// shape inference function calls functions on the context, and should call +// set_output() to set the shape on all outputs. +// +// To infer shapes for user-defined functions see ShapeRefiner. +// +// All Shape* and Dimension* returned by functions of InferenceContext are owned +// by the InferenceContext. +class InferenceContext { + public: + static constexpr int64 kUnknownDim = -1; + static constexpr int32 kUnknownRank = -1; + + // is NULL-padded to be the same size as . + // + // Elements of are used for when a shape function + // makes a call to MakeShapeFromShapeTensor; in particular, when the + // input_tensors[i] is nullptr but the shape represented by it is partially + // known from analysis of the graph. + // can have fewer elements than . + // Values of do not need to outlive the context. + // + // REQUIRES: is not NULL, and must outlive the InferenceContext. + InferenceContext(int graph_def_version, const NodeDef* node_def, + const OpDef& op_def, + const std::vector& input_shapes, + const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes, + std::vector>> + input_handle_shapes_and_types); + + // is NULL-padded to be the same size as . + // + // Elements of are used for when a shape + // function makes a call to MakeShapeFromShapeTensor; in particular, when + // the input_tensors[i] is nullptr but the shape represented by it is + // partially known from analysis of the graph. + // can have fewer elements than . Values of + // do not need to outlive the context. + // + // REQUIRES: is not NULL, and must outlive the + // InferenceContext. + InferenceContext( + int graph_def_version, const NodeDef* node_def, const OpDef& op_def, + const std::vector& input_shapes, + const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes, + const std::vector< + std::unique_ptr>>>& + input_handle_shapes_and_types); + + // is NULL-padded to be the same size as . + // + // Elements of are used for when a shape + // function makes a call to MakeShapeFromShapeTensor; in particular, when + // the input_tensors[i] is nullptr but the shape represented by it is + // partially known from analysis of the graph. + // can have fewer elements than . Values of + // do not need to outlive the context. + // + // REQUIRES: is not NULL, and must outlive the + // InferenceContext. + InferenceContext( + int graph_def_version, const NodeDef* node_def, const OpDef& op_def, + const std::vector& input_shapes, + const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes, + const std::vector>>>& + input_handle_shapes_and_types); + + ~InferenceContext(); + + // Runs the shape inference function 'fn' with 'this' as the + // argument, returns the status of the inference. + // + // On error, additional context is provided in the error message. + Status Run( + const std::function& fn); + + // Merge the stored shape of the input in position idx with according + // to the following rules: + // + // - If the ShapeHandles are the same or is unknown, there will be no + // change. Otherwise if the stored shape is unknown, the new shape will be + // . + // - If both shapes are known, then they must have the same rank. + // - For any one dimension, if the values for that dimension in both shapes + // are known, then the values must match. + // - If one shape has equal or more information than the other shape in every + // dimension, the new shape will become the shape with more information. + // - Example: merging [2,?] and [?,2] results in [2,2] + // - Example: [2,2] cannot be merged with [1,2] + // + // This requires idx to be in the [0, num_inputs) range. If the merge is + // successful, return true. Return false otherwise. + bool MergeInput(int idx, ShapeHandle shape) { + ShapeHandle new_shape; + if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false; + inputs_[idx] = new_shape; + return true; + } + + // Relax the stored shape of the input in position idx with according + // to the following rules: + // + // - If the ShapeHandles are the same then the stored shape will be returned. + // - If either of the ShapeHandles are unknown, then a new UnknownShape will + // be returned. A new shape must be returned because we cannot claim that + // the resulting shape is necessarily the same as either of the input + // shapes. + // - If the shapes both have known ranks but their ranks are different, a new + // UnknownShape will be returned. + // - For any one dimension, if the value for that dimension in either of the + // shapes is unknown, a new shape will be returned with a new UnknownDim in + // that dimension. + // - For any one dimension, if the values for that dimension in both shapes + // are known but do not match, a new shape will be returned with a new + // UnknownDim in that dimension. + // - If both shapes have the same known rank and match in every dimension, + // the stored shape will be returned. + // - Example: relaxing [2,?] and [?,2] results in [?,?] + // - Example: relaxing [2,2] and [3,2] results in [?,2] + // - Example: relaxing [2,2] with [1,2,3] results in ? + // + // This requires idx to be in the [0, num_inputs) range. If the relax is + // successful and the new shape differs from the old one, store the new + // shape and return true. Return false otherwise. + bool RelaxInput(int idx, ShapeHandle shape) { + ShapeHandle new_shape; + Relax(inputs_[idx], shape, &new_shape); + if (inputs_[idx].SameHandle(new_shape)) { + return false; + } + inputs_[idx] = new_shape; + return true; + } + + ShapeHandle input(int64 idx) const { return inputs_[idx]; } + Status input(StringPiece input_name, std::vector* output) const; + int num_inputs() const { return inputs_.size(); } + + // Returns the input tensor at index , or nullptr if the input tensor is + // not available at the time of shape inference. + const Tensor* input_tensor(int idx) { + // Mark that this idx was requested. + requested_input_tensor_[idx] = true; + return input_tensors_[idx]; + } + + // Returns true iff input_tensor(idx) was called by the shape function. + bool requested_input_tensor(int idx) const { + return requested_input_tensor_[idx]; + } + + // Returns true if MakeShapeFromInputTensor was called but the constant + // input_tensor was not present. + bool requested_input_tensor_as_partial_shape(int idx) const { + return requested_input_tensor_as_partial_shape_[idx]; + } + + void set_input_tensors(const std::vector& input_tensors) { + input_tensors_ = input_tensors; + } + + void set_input_tensors_as_shapes( + const std::vector& input_tensors_as_shapes) { + input_tensors_as_shapes_ = input_tensors_as_shapes; + } + + void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; } + Status set_output(StringPiece output_name, + const std::vector& shapes); + + int num_outputs() const { return outputs_.size(); } + ShapeHandle output(int idx) const { return outputs_[idx]; } + Status output(StringPiece output_name, + std::vector* output) const; + + AttrSlice attrs() const { return AttrSlice(*node_def_); } + + string op() const; + + // idx can be negative for an offset from end of dimensions. + // idx must be in the range [-1 * s.rank, s.rank). + DimensionHandle Dim(ShapeHandle s, int64 idx) { + if (s->rank_ == kUnknownRank) { + return UnknownDim(); + } + return DimKnownRank(s, idx); + } + // As above, but asserts that the rank of the shape is known. + static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) { + CHECK_NE(s->rank_, kUnknownRank); + if (idx < 0) { + return s->dims_[s->dims_.size() + idx]; + } + return s->dims_[idx]; + } + + static int32 Rank(ShapeHandle s) { + DCHECK(s.IsSet()); + return s.IsSet() ? s->rank_ : kUnknownRank; + } + static bool RankKnown(ShapeHandle s) { + return (s.IsSet() && (Rank(s) != kUnknownRank)); + } + static inline int64 Value(DimensionOrConstant d) { + return d.dim.IsSet() ? d.dim->value_ : d.val; + } + static inline bool ValueKnown(DimensionOrConstant d) { + return Value(d) != kUnknownDim; + } + + // Fills the output proto with the shape defined by the handle. + // "proto" is expected to be empty prior to the call. + void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto); + + // Returns true if the rank and all dimensions of the Shape are known. + bool FullyDefined(ShapeHandle s); + + // Returns the total number of elements, or an unknown dimension for an + // incomplete shape. + DimensionHandle NumElements(ShapeHandle s); + + string DebugString(ShapeHandle s); + string DebugString(DimensionHandle d); + + // Describes the whole context, for debugging purposes. + string DebugString() const; + + // If has rank , or its rank is unknown, return OK and return + // the shape with asserted rank in <*out>. Otherwise return an error. + // + // Note that <*out> may be set to . + Status WithRank(ShapeHandle shape, int64 rank, + ShapeHandle* out) TF_MUST_USE_RESULT; + Status WithRankAtLeast(ShapeHandle shape, int64 rank, + ShapeHandle* out) TF_MUST_USE_RESULT; + Status WithRankAtMost(ShapeHandle shape, int64 rank, + ShapeHandle* out) TF_MUST_USE_RESULT; + + // If has value , or its value is unknown, returns OK and returns + // the dimension with asserted value in <*out>. Otherwise returns an error. + // + // Note that <*out> may be set to . + Status WithValue(DimensionHandle dim, int64 value, + DimensionHandle* out) TF_MUST_USE_RESULT; + + // Merges and and returns the merged shape in <*out>. See + // 'MergeInput' function for full details and examples. + Status Merge(ShapeHandle s0, ShapeHandle s1, + ShapeHandle* out) TF_MUST_USE_RESULT; + + // Asserts that 's rank >= 's rank, and the first + // dimensions of are compatible with the dimensions of + // . + // Returns the merged results in <*s_out> and <*prefix_out>. + Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out, + ShapeHandle* prefix_out) TF_MUST_USE_RESULT; + + // Merges and and returns the merged dimension in <*out>. If + // and have incompatible values, returns an error. + // + // Note that <*out> may be set to or . + Status Merge(DimensionHandle d0, DimensionHandle d1, + DimensionHandle* out) TF_MUST_USE_RESULT; + + // Returns in <*out> a sub-shape of with dimensions [start:]. + // can be negative to index from the end of the shape. If > + // rank of , then an empty subshape is returned. + Status Subshape(ShapeHandle s, int64 start, + ShapeHandle* out) TF_MUST_USE_RESULT; + + // Returns in <*out> a sub-shape of , with dimensions [start:end]. + // and can be negative, to index from the end of the shape. + // and are set to the rank of if > rank of . + Status Subshape(ShapeHandle s, int64 start, int64 end, + ShapeHandle* out) TF_MUST_USE_RESULT; + + // Returns in <*out> the result of appending the dimensions of to those + // of . + Status Concatenate(ShapeHandle s1, ShapeHandle s2, + ShapeHandle* out) TF_MUST_USE_RESULT; + + // Returns in the shape from replacing with + // . + Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim, + ShapeHandle* out) TF_MUST_USE_RESULT; + + // Returns a new shape with the given dims. The returned value is owned by + // this context. + ShapeHandle MakeShape(const std::vector& dims); + ShapeHandle MakeShape(std::initializer_list dims); + + // Returns a new unknown shape. + ShapeHandle UnknownShape(); + + // Returns a shape with specified rank but unknown dims. + ShapeHandle UnknownShapeOfRank(int64 rank); + + // Returns a new shape of zero dimensions. + ShapeHandle Scalar(); + + // Returns a new shape of one dimension. + ShapeHandle Vector(DimensionOrConstant dim); + + // Returns a new shape of two dimensions. + ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2); + + // Returns in a new shape whose dimension sizes come from input tensor + // . The tensor must be a 1-dimensional int32 or int64 tensor. If + // the input tensor is NULL, then an unknown shape is returned. + Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out); + + // Returns in a new shape corresponding to . + Status MakeShapeFromShapeProto(const TensorShapeProto& proto, + ShapeHandle* out); + + // Returns in a new shape corresponding to . + Status MakeShapeFromPartialTensorShape( + const PartialTensorShape& partial_shape, ShapeHandle* out); + + // Returns in a new shape corresponding to . + Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out); + + // Returns a new dimension of the given size. The returned value is owned by + // this context. + inline DimensionHandle MakeDim(DimensionOrConstant d) { + return shape_manager_.MakeDim(d); + } + + inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); } + + // Returns in a scalar value from an input tensor . The input tensor + // must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the + // input tensor is not NULL. + Status GetScalarFromTensor(const Tensor* t, int64* val); + + // Returns a new dimension whose value is given by a scalar input tensor. + // The input tensor must be in host memory, since it is dereferenced to get + // the value. + Status MakeDimForScalarInput(int idx, DimensionHandle* out); + + // Returns a new dimension whose value is given by a scalar input tensor. + // This allows for a negative input dimension given the rank of a separate + // tensor. This rank can be negative if unknown. + // The input tensor must be in host memory, since it is dereferenced to get + // the value. + Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank, + DimensionHandle* out); + + // Look up the attr for the NodeDef being evaluated with name attr_name and + // set *value to its value. If no attr with attr_name is found in def(), or + // the attr does not have a matching type, a non-ok status will be returned. + template + Status GetAttr(StringPiece attr_name, T* value) const; + + // Returns in the result of dividing by . + // Returns an error if is not positive or if + // and does not evenly divide . + Status Divide(DimensionHandle dividend, DimensionOrConstant divisor, + bool evenly_divisible, DimensionHandle* out); + + // Returns in the sum of and . + Status Add(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); + + // Returns in the dimension that is minus . + Status Subtract(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); + + // Returns in the product of and . + Status Multiply(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); + + // Returns in the minimum of and . If either or + // is zero the results is zero. Otherwise, if either or + // is unknown the results is unknown. + Status Min(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); + + // Returns in the maximum of and . If either or + // is unknown the results is unknown. + Status Max(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); + + Status construction_status() const { return construction_status_; } + + // Methods to propagate shape and dtype on edges of handles. Handles are the + // dtype DT_RESOURCE which can be used to access state stored in a + // ResourceManager. When ops (such as variables) consume these handles to + // produce tensors they might need to know side-information about the shapes + // and dtypes of tensors which can be accessed via the handle. These methods + // propagate that information. Output handle dtypes and shapes are ignored if + // the output tensor is not of type DT_RESOURCE. + + // Merge the stored shapes and types corresponding to the input handle in + // position idx with the specified shapes and types. This requires idx to be + // in the [0, num_inputs) range. + // + // If the merge is successful and any of the new shapes differs from the old + // one, or any of the old dtypes was DT_INVALID, store the new shapes and + // return true. Return false otherwise. + // + // See 'MergeInput' function for full details and examples. + bool MergeInputHandleShapesAndTypes( + int idx, + const std::vector& shapes_and_types) TF_MUST_USE_RESULT; + + // As MergeInputHandleShapesAndTypes, but for an output. + bool MergeOutputHandleShapesAndTypes( + int idx, + const std::vector& shapes_and_types) TF_MUST_USE_RESULT; + + // Relaxes the stored shapes and types corresponding to the input handle in + // position idx with the specified shapes and types. This requires idx to be + // in the [0, num_inputs) range. + // + // If the relax is successful and any of the new shapes differs from the old + // one, or any of the old dtypes was DT_INVALID, store the new shapes and + // return true. Return false otherwise. + // + // See 'RelaxInput' function for full details and examples. + bool RelaxInputHandleShapesAndMergeTypes( + int idx, + const std::vector& shapes_and_types) TF_MUST_USE_RESULT; + + // As RelaxInputHandleShapesAndTypes, but for an output. + bool RelaxOutputHandleShapesAndMergeTypes( + int idx, + const std::vector& shapes_and_types) TF_MUST_USE_RESULT; + + // Returns the output handle shapes and types, for the resource tensor output + // at index . Returns NULL if the shape and types were never set. + const std::vector* output_handle_shapes_and_types(int idx) { + return output_handle_shapes_and_types_[idx].get(); + } + + // Returns the inputs handle shapes and types, for the resource tensor output + // at index . Returns NULL if the shape and types were not available. + const std::vector* input_handle_shapes_and_types(int idx) { + return input_handle_shapes_and_types_[idx].get(); + } + + void set_output_handle_shapes_and_types( + int idx, const std::vector& shapes_and_types) { + output_handle_shapes_and_types_[idx].reset( + new std::vector(shapes_and_types)); + } + + // Note that shape functions should usually call MakeShapeFromShapeTensor, + // as it does more analysis to provide partial shapes. + // + // Returns in a new shape whose dimension sizes come from tensor . + // The tensor must be a 1-dimensional int32 or int64 tensor. If is NULL, + // then an unknown shape is returned. + Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape, + ShapeHandle* out); + + int graph_def_version() const { return graph_def_version_; } + + const std::vector>& MergedShapes() const { + return merged_shapes_; + } + const std::vector>& MergedDims() + const { + return merged_dims_; + } + + private: + // Creates and stores shapes for use in InferenceContext. + class ShapeManager { + public: + ShapeManager(); + ~ShapeManager(); + + // Returns a new shape with the given dims. The returned value is owned by + // this class. + ShapeHandle MakeShape(const std::vector& dims); + + // Returns a new unknown shape. + ShapeHandle UnknownShape(); + + // Returns a new dimension of the given size. The returned value + // is owned by this class. + inline DimensionHandle MakeDim(DimensionOrConstant d) { + if (d.dim.IsSet()) { + return d.dim; + } else { + all_dims_.push_back(new Dimension(d.val)); + return all_dims_.back(); + } + } + + private: + std::vector all_shapes_; // values are owned. + std::vector all_dims_; // values are owned. + }; + + friend class ::tensorflow::grappler::GraphProperties; + + // Friend for user-defined function shape inference purposes. + friend class ::tensorflow::ShapeRefiner; + + friend class ShapeInferenceTest; // For testing Relax functions. + friend class ShapeInferenceTestutil; // For testing shapes. + + // Shared initialization across the two constructors. Remove + // once we get rid of one of them. + void PreInputInit(const OpDef& op_def, + const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes); + void PostInputInit(std::vector>> + input_handle_data); + + DimensionHandle GetDimension(const DimensionOrConstant& d); + + Status ReturnUnknownShape(ShapeHandle* out) { + *out = UnknownShape(); + return Status::OK(); + } + Status ReturnCreatedShape(const std::vector& dims, + ShapeHandle* out) { + *out = MakeShape(dims); + return Status::OK(); + } + + // Adds additional context to the given status. + Status AttachContext(const Status& status); + + // Relaxes an existing value with a new value and returns the + // relaxed dimension in <*out>. If and have incompatible + // values, returns an error. + // + // Note that <*out> may be set to or . + void Relax(DimensionHandle d_old, DimensionHandle d_new, + DimensionHandle* out); + // Relaxes an existing shape with a new shape and returns the + // relaxed shape in <*out>. See 'RelaxInput' function for full details and + // examples. + void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out); + + // Used to implement MergeInputHandleShapesAndTypes and + // MergeOutputHandleShapesAndTypes. + bool MergeHandleShapesAndTypes( + const std::vector& shapes_and_types, + std::vector* to_update) TF_MUST_USE_RESULT; + // Used to implement RelaxInputHandleShapesAndMergeTypes and + // RelaxOutputHandleShapesAndMergeTypes. + bool RelaxHandleShapesAndMergeTypes( + const std::vector& shapes_and_types, + std::vector* to_update) TF_MUST_USE_RESULT; + + // Forget all the previous merged shapes and dims. + void ForgetMerges() { + merged_shapes_.clear(); + merged_dims_.clear(); + } + + ShapeManager shape_manager_; + + // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from + // `shape_manager_`. + std::vector inputs_; + std::vector input_tensors_; + std::vector requested_input_tensor_; + std::vector outputs_; + // Can have fewer elements than inputs_. + std::vector input_tensors_as_shapes_; + std::vector requested_input_tensor_as_partial_shape_; + + // input_handle_shapes_and_types_[i] is the list of shape/type pairs available + // through the resource handle passed along input i of the node. + // + // Values may be NULL. + std::vector>> + input_handle_shapes_and_types_; + + // output_handle_shapes_and_types_[i] is the list of shape/type pairs + // available through the resource handle passed along output i of the node. + // + // Values may be NULL. + std::vector>> + output_handle_shapes_and_types_; + + const int graph_def_version_; + const NodeDef* node_def_; + NameRangeMap input_name_map_; + NameRangeMap output_name_map_; + + // An error set during construction. TODO(cwhipkey): remove when test + // constructor is removed. + Status construction_status_; + + // Pair of shape or dim handles that are equivalent, ie that represent the + // same underlying shape of dimension. Note that for each pair at least one of + // the handles must contain an unknown shape, since we don't keep track of + // known shapes or dims here. + std::vector> merged_shapes_; + std::vector> merged_dims_; + + TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext); +}; + +// ----------------------------------------------------------------------------- +// Template and inline method implementations, please ignore + +inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {} +inline Dimension::Dimension(int64 value) : value_(value) { + DCHECK(value >= 0 || value == InferenceContext::kUnknownDim) + << "Dimension must be non-negative or equal to " + "InferenceContext::kUnknownDim but got " + << value; +} + +inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {} +inline Shape::Shape(const std::vector& dims) + : rank_(dims.size()), dims_(dims) {} + +inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim) + : dim(dim) { + DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension."; +} + +inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) { + DCHECK(val >= 0 || val == InferenceContext::kUnknownDim) + << "Dimension must be non-negative or equal to " + "InferenceContext::kUnknownDim but got " + << val; +} + +template +Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const { + return GetNodeAttr(*node_def_, attr_name, value); +} + +} // namespace shape_inference +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ diff --git a/shape_inference_test.cc b/shape_inference_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9b63ca60e4574bb0d59c4b939ac157e62f317e8 --- /dev/null +++ b/shape_inference_test.cc @@ -0,0 +1,1819 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/shape_inference.h" + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace shape_inference { +namespace { + +OpDef MakeOpDefWithLists() { + OpRegistrationData op_reg_data; + OpDefBuilder b("dummy"); + b.Input(strings::StrCat("input: N * float")); + b.Output(strings::StrCat("output: N * float")); + CHECK(b.Attr("N:int >= 1").Finalize(&op_reg_data).ok()); + return op_reg_data.op_def; +} + +PartialTensorShape S(std::initializer_list dims) { + return PartialTensorShape(dims); +} + +PartialTensorShape Unknown() { return PartialTensorShape(); } + +} // namespace + +class ShapeInferenceTest : public ::testing::Test { + protected: + // These give access to private functions of DimensionHandle and ShapeHandle. + bool SameHandle(DimensionHandle a, DimensionHandle b) { + return a.SameHandle(b); + } + bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); } + bool IsSet(DimensionHandle d) { return d.IsSet(); } + bool IsSet(ShapeHandle s) { return s.IsSet(); } + void Relax(InferenceContext* c, DimensionHandle d0, DimensionHandle d1, + DimensionHandle* out) { + c->Relax(d0, d1, out); + } + void Relax(InferenceContext* c, ShapeHandle s0, ShapeHandle s1, + ShapeHandle* out) { + c->Relax(s0, s1, out); + } + void TestMergeHandles(bool input_not_output); + void TestRelaxHandles(bool input_not_output); + + static const int kVersion = 0; // used for graph-def version. +}; + +TEST_F(ShapeInferenceTest, InputOutputByName) { + // Setup test to contain an input tensor list of size 3. + OpDef op_def = MakeOpDefWithLists(); + NodeDef def; + auto s = NodeDefBuilder("dummy", &op_def) + .Attr("N", 3) + .Input(FakeInput(DT_FLOAT)) + .Finalize(&def); + InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, + {}, {}, {}); + + EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0)))); + EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1)))); + EXPECT_EQ("3", c.DebugString(c.NumElements(c.input(2)))); + // Test getters. + std::vector shapes; + EXPECT_FALSE(c.input("nonexistent", &shapes).ok()); + TF_EXPECT_OK(c.input("input", &shapes)); + EXPECT_EQ("[1,5]", c.DebugString(shapes[0])); + EXPECT_EQ("[2,5]", c.DebugString(shapes[1])); + EXPECT_EQ("[1,3]", c.DebugString(shapes[2])); + + // Test setters. + EXPECT_FALSE(c.set_output("nonexistent", shapes).ok()); + TF_EXPECT_OK(c.set_output("output", shapes)); + EXPECT_EQ("5", c.DebugString(c.NumElements(c.output(0)))); + EXPECT_EQ("10", c.DebugString(c.NumElements(c.output(1)))); + EXPECT_EQ("3", c.DebugString(c.NumElements(c.output(2)))); +} + +static OpDef MakeOpDef(int num_inputs, int num_outputs) { + OpRegistrationData op_reg_data; + OpDefBuilder b("dummy"); + for (int i = 0; i < num_inputs; ++i) { + b.Input(strings::StrCat("i", i, ": float")); + } + for (int i = 0; i < num_outputs; ++i) { + b.Output(strings::StrCat("o", i, ": float")); + } + CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok()); + return op_reg_data.op_def; +} + +TEST_F(ShapeInferenceTest, DimensionOrConstant) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}); + EXPECT_EQ(InferenceContext::kUnknownDim, + c.Value(InferenceContext::kUnknownDim)); + EXPECT_EQ(1, c.Value(1)); + +#ifndef NDEBUG + // Only run death test if DCHECKS are enabled. + EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to"); +#endif +} + +TEST_F(ShapeInferenceTest, Run) { + NodeDef def; + def.set_name("foo"); + def.set_op("foo_op"); + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}); + TF_ASSERT_OK(c.construction_status()); + + { + auto fn = [](InferenceContext* c) { + ShapeHandle h; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 6, &h)); + c->set_output(0, c->input(0)); + c->set_output(1, c->input(0)); + return Status::OK(); + }; + TF_ASSERT_OK(c.Run(fn)); + } + + { + auto fn = [](InferenceContext* c) { + ShapeHandle h; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); + c->set_output(0, c->input(0)); + c->set_output(1, c->input(0)); + return Status::OK(); + }; + Status s = c.Run(fn); + // Extra error message is attached when Run fails. + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("Shape must be at most rank 0 but " + "is rank 1 for 'foo' (op: " + "'foo_op')")) + << s; + } +} + +// Tests different context data added when Run returns error. +TEST_F(ShapeInferenceTest, AttachContext) { + NodeDef def; + def.set_name("foo"); + def.set_op("foo_op"); + // Error when no constant tensors were requested. + { + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {}, + {}); + TF_ASSERT_OK(c.construction_status()); + auto fn = [](InferenceContext* c) { + ShapeHandle h; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); + c->set_output(0, c->input(0)); + return Status::OK(); + }; + EXPECT_EQ( + "Invalid argument: Shape must be at most rank 0 but is rank 3 for " + "'foo' (op: 'foo_op') with input shapes: [1,2,3].", + c.Run(fn).ToString()); + } + + // Error when a constant tensor value was requested. + { + Tensor input_t = + ::tensorflow::test::AsTensor({1.1, 2.2, 3.3, 4.4, 5.5}); + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), + {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {}); + TF_ASSERT_OK(c.construction_status()); + auto fn = [](InferenceContext* c) { + c->input_tensor(0); // get this one, but it's null - won't be in error. + c->input_tensor(1); // get this one, will now be in error. + ShapeHandle h; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); + c->set_output(0, c->input(0)); + return Status::OK(); + }; + EXPECT_EQ( + "Invalid argument: Shape must be at most rank 0 but is rank 3 for " + "'foo' (op: 'foo_op') with input shapes: [1,2,3], [4,5] and with " + "computed input tensors: input[1] = <1.1 2.2 3.3 4.4 5.5>.", + c.Run(fn).ToString()); + } + + // Error when a constant tensor value as shape was requested, but no partial + // shapes provided. + { + Tensor input_t = ::tensorflow::test::AsTensor({1, 2, 3, 4, 5}); + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})}, + {nullptr, &input_t}, {}, {}); + TF_ASSERT_OK(c.construction_status()); + auto fn = [](InferenceContext* c) { + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); + ShapeHandle h; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); + c->set_output(0, c->input(0)); + return Status::OK(); + }; + EXPECT_EQ( + "Invalid argument: Shape must be at most rank 0 but is rank 1 for " + "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed " + "input tensors: input[1] = <1 2 3 4 5>.", + c.Run(fn).ToString()); + } + + // Error when a constant tensor value as shape was requested, and a partial + // shape was provided. + { + Tensor input_t = ::tensorflow::test::AsTensor({1, 2, 3, 4, 5}); + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})}, + {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {}); + TF_ASSERT_OK(c.construction_status()); + auto fn = [](InferenceContext* c) { + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); + ShapeHandle h; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); + c->set_output(0, c->input(0)); + return Status::OK(); + }; + EXPECT_EQ( + "Invalid argument: Shape must be at most rank 0 but is rank 1 for " + "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed " + "input tensors: input[1] = <1 2 3 4 5> and with input tensors computed " + "as partial shapes: input[0] = [10,?,5].", + c.Run(fn).ToString()); + } +} + +TEST_F(ShapeInferenceTest, RankAndDimInspection) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(3, 2), + {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(2, c.num_outputs()); + + auto in0 = c.input(0); + EXPECT_EQ("?", c.DebugString(in0)); + EXPECT_FALSE(c.RankKnown(in0)); + EXPECT_EQ(InferenceContext::kUnknownRank, c.Rank(in0)); + EXPECT_EQ("?", c.DebugString(c.Dim(in0, 0))); + EXPECT_EQ("?", c.DebugString(c.Dim(in0, -1))); + EXPECT_EQ("?", c.DebugString(c.Dim(in0, 1000))); + + auto in1 = c.input(1); + EXPECT_EQ("[1,?,3]", c.DebugString(in1)); + EXPECT_TRUE(c.RankKnown(in1)); + EXPECT_EQ(3, c.Rank(in1)); + auto d = c.Dim(in1, 0); + EXPECT_EQ(1, c.Value(d)); + EXPECT_TRUE(SameHandle(d, c.Dim(in1, -3))); + EXPECT_TRUE(c.ValueKnown(d)); + EXPECT_EQ("1", c.DebugString(d)); + d = c.Dim(in1, 1); + EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(d)); + EXPECT_FALSE(c.ValueKnown(d)); + EXPECT_TRUE(SameHandle(d, c.Dim(in1, -2))); + EXPECT_EQ("?", c.DebugString(d)); + d = c.Dim(in1, 2); + EXPECT_EQ(3, c.Value(d)); + EXPECT_TRUE(SameHandle(d, c.Dim(in1, -1))); + EXPECT_TRUE(c.ValueKnown(d)); + EXPECT_EQ("3", c.DebugString(d)); + + auto in2 = c.input(2); + EXPECT_EQ("[]", c.DebugString(in2)); + EXPECT_TRUE(c.RankKnown(in2)); + EXPECT_EQ(0, c.Rank(in2)); +} + +TEST_F(ShapeInferenceTest, NumElements) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(3, 2), + {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {}); + + EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0)))); + EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1)))); + + // Different handles (not the same unknown value). + EXPECT_FALSE(SameHandle(c.Dim(c.input(1), 1), c.NumElements(c.input(1)))); + + EXPECT_EQ("120", c.DebugString(c.NumElements(c.input(2)))); +} + +TEST_F(ShapeInferenceTest, WithRank) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), + {Unknown(), S({1, -1, 3})}, {}, {}, {}); + + auto in0 = c.input(0); + auto in1 = c.input(1); + ShapeHandle s1; + ShapeHandle s2; + + // WithRank on a shape with unknown dimensionality always succeeds. + EXPECT_TRUE(c.WithRank(in0, 1, &s1).ok()); + EXPECT_EQ("[?]", c.DebugString(s1)); + + EXPECT_TRUE(c.WithRank(in0, 2, &s2).ok()); + EXPECT_EQ("[?,?]", c.DebugString(s2)); + EXPECT_FALSE(SameHandle(s1, s2)); + EXPECT_FALSE(SameHandle(c.Dim(s2, 0), c.Dim(s2, 1))); + + EXPECT_TRUE(c.WithRank(in0, 1, &s2).ok()); + EXPECT_EQ("[?]", c.DebugString(s2)); + EXPECT_FALSE(SameHandle(s1, s2)); + + EXPECT_TRUE(c.WithRank(in0, 0, &s1).ok()); + EXPECT_EQ("[]", c.DebugString(s1)); + + // WithRank on shape with known dimensionality. + s1 = in1; + EXPECT_EQ("Invalid argument: Shape must be rank 2 but is rank 3", + c.WithRank(in1, 2, &s1).ToString()); + EXPECT_FALSE(IsSet(s1)); + EXPECT_TRUE(c.WithRank(in1, 3, &s1).ok()); + EXPECT_TRUE(SameHandle(s1, in1)); + + // Inputs are unchanged. + EXPECT_EQ("?", c.DebugString(in0)); + EXPECT_EQ("[1,?,3]", c.DebugString(in1)); +} + +TEST_F(ShapeInferenceTest, WithRankAtMost) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), + {Unknown(), S({1, -1, 3})}, {}, {}, {}); + + auto in0 = c.input(0); + auto in1 = c.input(1); + ShapeHandle s1; + ShapeHandle s2; + + // WithRankAtMost on a shape with unknown dimensionality always succeeds. + EXPECT_TRUE(c.WithRankAtMost(in0, 1, &s1).ok()); + EXPECT_EQ("?", c.DebugString(s1)); + EXPECT_TRUE(SameHandle(in0, s1)); + + EXPECT_TRUE(c.WithRankAtMost(in0, 2, &s2).ok()); + EXPECT_EQ("?", c.DebugString(s2)); + EXPECT_TRUE(SameHandle(s1, s2)); + + // WithRankAtMost on shape with known dimensionality. + s1 = in1; + EXPECT_TRUE( + StringPiece(c.WithRankAtMost(in1, 2, &s1).ToString()) + .contains( + "Invalid argument: Shape must be at most rank 2 but is rank 3")); + + EXPECT_FALSE(IsSet(s1)); + EXPECT_TRUE(c.WithRankAtMost(in1, 3, &s1).ok()); + EXPECT_TRUE(SameHandle(s1, in1)); + EXPECT_TRUE(c.WithRankAtMost(in1, 4, &s1).ok()); + EXPECT_TRUE(SameHandle(s1, in1)); + EXPECT_TRUE(c.WithRankAtMost(in1, 5, &s1).ok()); + EXPECT_TRUE(SameHandle(s1, in1)); + + // Inputs are unchanged. + EXPECT_EQ("?", c.DebugString(in0)); + EXPECT_EQ("[1,?,3]", c.DebugString(in1)); +} + +TEST_F(ShapeInferenceTest, WithRankAtLeast) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), + {Unknown(), S({1, -1, 3})}, {}, {}, {}); + + auto in0 = c.input(0); + auto in1 = c.input(1); + ShapeHandle s1; + ShapeHandle s2; + + // WithRankAtLeast on a shape with unknown dimensionality always succeeds. + EXPECT_TRUE(c.WithRankAtLeast(in0, 1, &s1).ok()); + EXPECT_EQ("?", c.DebugString(s1)); + EXPECT_TRUE(SameHandle(in0, s1)); + + EXPECT_TRUE(c.WithRankAtLeast(in0, 2, &s2).ok()); + EXPECT_EQ("?", c.DebugString(s2)); + EXPECT_TRUE(SameHandle(s1, s2)); + + // WithRankAtLeast on shape with known dimensionality. + s1 = in1; + EXPECT_TRUE( + StringPiece(c.WithRankAtLeast(in1, 4, &s1).ToString()) + .contains( + "Invalid argument: Shape must be at least rank 4 but is rank 3")); + + EXPECT_FALSE(IsSet(s1)); + EXPECT_TRUE(c.WithRankAtLeast(in1, 3, &s1).ok()); + EXPECT_TRUE(SameHandle(s1, in1)); + EXPECT_TRUE(c.WithRankAtLeast(in1, 2, &s1).ok()); + EXPECT_TRUE(SameHandle(s1, in1)); + EXPECT_TRUE(c.WithRankAtLeast(in1, 0, &s1).ok()); + EXPECT_TRUE(SameHandle(s1, in1)); + + // Inputs are unchanged. + EXPECT_EQ("?", c.DebugString(in0)); + EXPECT_EQ("[1,?,3]", c.DebugString(in1)); +} + +TEST_F(ShapeInferenceTest, WithValue) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}); + + auto d0 = c.Dim(c.input(0), 0); + auto d1 = c.Dim(c.input(0), 1); + DimensionHandle out1; + DimensionHandle out2; + + // WithValue on a dimension with unknown value always succeeds. + EXPECT_TRUE(c.WithValue(d1, 1, &out1).ok()); + EXPECT_EQ(1, c.Value(out1)); + + EXPECT_TRUE(c.WithValue(d1, 2, &out2).ok()); + EXPECT_EQ(2, c.Value(out2)); + EXPECT_FALSE(SameHandle(out1, out2)); + EXPECT_FALSE(SameHandle(out1, d1)); + + EXPECT_TRUE(c.WithValue(d1, 1, &out2).ok()); + EXPECT_EQ(1, c.Value(out2)); + EXPECT_FALSE(SameHandle(out1, out2)); + + // WithValue on dimension with known size. + out1 = d0; + + EXPECT_TRUE(StringPiece(c.WithValue(d0, 0, &out1).ToString()) + .contains("Invalid argument: Dimension must be 0 but is 1")); + EXPECT_FALSE(IsSet(out1)); + out1 = d0; + EXPECT_TRUE(StringPiece(c.WithValue(d0, 2, &out1).ToString()) + .contains("Invalid argument: Dimension must be 2 but is 1")); + + EXPECT_FALSE(IsSet(out1)); + EXPECT_TRUE(c.WithValue(d0, 1, &out1).ok()); + EXPECT_TRUE(SameHandle(d0, out1)); + + // Inputs are unchanged. + EXPECT_EQ("1", c.DebugString(d0)); + EXPECT_EQ("?", c.DebugString(d1)); +} + +TEST_F(ShapeInferenceTest, MergeDim) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, + {}, {}, {}); + + auto d2 = c.Dim(c.input(0), 0); + auto d_unknown = c.Dim(c.input(0), 1); + auto d2_b = c.Dim(c.input(0), 2); + auto d1 = c.Dim(c.input(0), 3); + auto d_unknown_b = c.Dim(c.input(0), 4); + DimensionHandle out; + + // Merging anything with unknown returns the same pointer. + EXPECT_TRUE(c.Merge(d2, d_unknown, &out).ok()); + EXPECT_TRUE(SameHandle(d2, out)); + EXPECT_TRUE(c.Merge(d_unknown, d2, &out).ok()); + EXPECT_TRUE(SameHandle(d2, out)); + EXPECT_TRUE(c.Merge(d_unknown, d_unknown_b, &out).ok()); + EXPECT_TRUE(SameHandle(d_unknown, out)); + + auto merged_dims = c.MergedDims(); + ASSERT_EQ(3, merged_dims.size()); + EXPECT_TRUE(merged_dims[0].first.SameHandle(d2)); + EXPECT_TRUE(merged_dims[0].second.SameHandle(d_unknown)); + EXPECT_TRUE(merged_dims[1].first.SameHandle(d_unknown)); + EXPECT_TRUE(merged_dims[1].second.SameHandle(d2)); + EXPECT_TRUE(merged_dims[2].first.SameHandle(d_unknown)); + EXPECT_TRUE(merged_dims[2].second.SameHandle(d_unknown_b)); + + // Merging with self is a no-op and returns self. + EXPECT_TRUE(c.Merge(d2, d2, &out).ok()); + EXPECT_TRUE(SameHandle(d2, out)); + EXPECT_TRUE(c.Merge(d_unknown, d_unknown, &out).ok()); + EXPECT_TRUE(SameHandle(d_unknown, out)); + + merged_dims = c.MergedDims(); + EXPECT_EQ(3, merged_dims.size()); + + // Merging equal values is a no op and returns first one. + EXPECT_TRUE(c.Merge(d2, d2_b, &out).ok()); + EXPECT_TRUE(SameHandle(d2, out)); + EXPECT_TRUE(c.Merge(d2_b, d2, &out).ok()); + EXPECT_TRUE(SameHandle(d2_b, out)); + + merged_dims = c.MergedDims(); + EXPECT_EQ(3, merged_dims.size()); + + // Merging unequal values is an error. + EXPECT_TRUE( + StringPiece(c.Merge(d2, d1, &out).ToString()) + .contains( + "Invalid argument: Dimensions must be equal, but are 2 and 1")); + + EXPECT_FALSE(IsSet(out)); + EXPECT_TRUE( + StringPiece(c.Merge(d1, d2, &out).ToString()) + .contains( + "Invalid argument: Dimensions must be equal, but are 1 and 2")); + + EXPECT_FALSE(IsSet(out)); + + merged_dims = c.MergedDims(); + EXPECT_EQ(3, merged_dims.size()); +} + +TEST_F(ShapeInferenceTest, RelaxDim) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), + {S({2, InferenceContext::kUnknownDim, 2, 1, + InferenceContext::kUnknownDim})}, + {}, {}, {}); + + auto d2 = c.Dim(c.input(0), 0); + auto d_unknown = c.Dim(c.input(0), 1); + auto d2_b = c.Dim(c.input(0), 2); + auto d1 = c.Dim(c.input(0), 3); + auto d_unknown_b = c.Dim(c.input(0), 4); + DimensionHandle out; + + // Relaxing anything with unknown returns a new unknown or the existing + // unknown. + Relax(&c, d2, d_unknown, &out); + EXPECT_TRUE(SameHandle(d_unknown, out)); + EXPECT_FALSE(SameHandle(d_unknown_b, out)); + EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); + Relax(&c, d_unknown, d2, &out); + EXPECT_FALSE(SameHandle(d_unknown, out)); + EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); + Relax(&c, d_unknown, d_unknown_b, &out); + EXPECT_FALSE(SameHandle(d_unknown, out)); + EXPECT_TRUE(SameHandle(d_unknown_b, out)); + EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); + + // Relaxing with self returns self. + Relax(&c, d2, d2, &out); + EXPECT_TRUE(SameHandle(d2, out)); + Relax(&c, d_unknown, d_unknown, &out); + EXPECT_TRUE(SameHandle(d_unknown, out)); + + // Relaxing equal values returns first one. + Relax(&c, d2, d2_b, &out); + EXPECT_TRUE(SameHandle(d2, out)); + Relax(&c, d2_b, d2, &out); + EXPECT_TRUE(SameHandle(d2_b, out)); + + // Relaxing unequal values returns a new unknown. + Relax(&c, d2, d1, &out); + EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); + Relax(&c, d1, d2, &out); + EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); +} + +TEST_F(ShapeInferenceTest, RelaxShape) { + NodeDef def; + InferenceContext c( + kVersion, &def, MakeOpDef(7, 2), + {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}), + S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})}, + {}, {}, {}); + + auto s_unknown = c.input(0); + auto s_1_2 = c.input(1); + auto s_u_2 = c.input(2); + auto s_1_u = c.input(3); + auto s_1_3 = c.input(4); + auto s_unknown_b = c.input(5); + auto s_1 = c.input(6); + ShapeHandle out; + + // Relaxing any shape with unknown returns a new unknown. + Relax(&c, s_unknown, s_1_2, &out); + EXPECT_FALSE(SameHandle(s_u_2, s_unknown)); + EXPECT_EQ("?", c.DebugString(out)); + Relax(&c, s_u_2, s_unknown, &out); + EXPECT_FALSE(SameHandle(s_u_2, out)); + EXPECT_EQ("?", c.DebugString(out)); + Relax(&c, s_unknown, s_unknown_b, &out); + EXPECT_FALSE(SameHandle(s_unknown, out)); + EXPECT_TRUE(SameHandle(s_unknown_b, out)); + EXPECT_EQ("?", c.DebugString(out)); + + // Relaxing with self returns self. + Relax(&c, s_1_2, s_1_2, &out); + EXPECT_TRUE(SameHandle(out, s_1_2)); + + // Relaxing where one of the inputs has less information. + out = ShapeHandle(); + Relax(&c, s_1_2, s_u_2, &out); + EXPECT_FALSE(SameHandle(s_u_2, out)); + EXPECT_EQ("[?,2]", c.DebugString(out)); + out = ShapeHandle(); + Relax(&c, s_u_2, s_1_2, &out); + EXPECT_FALSE(SameHandle(s_u_2, out)); + EXPECT_EQ("[?,2]", c.DebugString(out)); + + // Relaxing where each input has one distinct unknown dimension. + Relax(&c, s_u_2, s_1_u, &out); + EXPECT_EQ("[?,?]", c.DebugString(out)); + EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0))); + EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 1), c.Dim(out, 1))); + auto s_u1 = c.UnknownShapeOfRank(1); + auto s_u2 = c.UnknownShapeOfRank(1); + Relax(&c, s_u1, s_u2, &out); + EXPECT_FALSE(SameHandle(s_u1, out)); + + // Relaxing with mismatched values in a dimension returns a shape with that + // dimension unknown. + out = s_unknown; + Relax(&c, s_u_2, s_1_3, &out); + EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0))); + EXPECT_EQ("[?,?]", c.DebugString(out)); + out = s_unknown; + Relax(&c, s_1_3, s_u_2, &out); + EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0))); + EXPECT_EQ("[?,?]", c.DebugString(out)); + out = s_unknown; + + // Relaxing with mismatched ranks returns a new unknown. + Relax(&c, s_1, s_1_2, &out); + EXPECT_EQ("?", c.DebugString(out)); +} + +TEST_F(ShapeInferenceTest, MergeShape) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(7, 2), + {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}), + Unknown(), S({1})}, + {}, {}, {}); + + auto s_unknown = c.input(0); + auto s_1_2 = c.input(1); + auto s_u_2 = c.input(2); + auto s_1_u = c.input(3); + auto s_1_3 = c.input(4); + auto s_unknown_b = c.input(5); + auto s_1 = c.input(6); + ShapeHandle out; + + // Merging any shape with unknown returns the shape. + EXPECT_TRUE(c.Merge(s_unknown, s_1_2, &out).ok()); + EXPECT_TRUE(SameHandle(s_1_2, out)); + EXPECT_TRUE(c.Merge(s_u_2, s_unknown, &out).ok()); + EXPECT_TRUE(SameHandle(s_u_2, out)); + EXPECT_TRUE(c.Merge(s_unknown, s_unknown_b, &out).ok()); + EXPECT_TRUE(SameHandle(s_unknown, out)); + + auto merged_shapes = c.MergedShapes(); + ASSERT_EQ(3, merged_shapes.size()); + EXPECT_TRUE(merged_shapes[0].first.SameHandle(s_unknown)); + EXPECT_TRUE(merged_shapes[0].second.SameHandle(s_1_2)); + EXPECT_TRUE(merged_shapes[1].first.SameHandle(s_u_2)); + EXPECT_TRUE(merged_shapes[1].second.SameHandle(s_unknown)); + EXPECT_TRUE(merged_shapes[2].first.SameHandle(s_unknown)); + EXPECT_TRUE(merged_shapes[2].second.SameHandle(s_unknown_b)); + + // Merging with self returns self. + EXPECT_TRUE(c.Merge(s_1_2, s_1_2, &out).ok()); + EXPECT_TRUE(SameHandle(out, s_1_2)); + + merged_shapes = c.MergedShapes(); + EXPECT_EQ(3, merged_shapes.size()); + + // Merging where one of the inputs is the right answer - return that input. + out = ShapeHandle(); + EXPECT_TRUE(c.Merge(s_1_2, s_u_2, &out).ok()); + EXPECT_TRUE(SameHandle(s_1_2, out)); + out = ShapeHandle(); + EXPECT_TRUE(c.Merge(s_u_2, s_1_2, &out).ok()); + EXPECT_TRUE(SameHandle(s_1_2, out)); + + merged_shapes = c.MergedShapes(); + ASSERT_EQ(5, merged_shapes.size()); + EXPECT_TRUE(merged_shapes[3].first.SameHandle(s_1_2)); + EXPECT_TRUE(merged_shapes[3].second.SameHandle(s_u_2)); + EXPECT_TRUE(merged_shapes[4].first.SameHandle(s_u_2)); + EXPECT_TRUE(merged_shapes[4].second.SameHandle(s_1_2)); + + // Merging where neither input is the right answer. + EXPECT_TRUE(c.Merge(s_u_2, s_1_u, &out).ok()); + EXPECT_FALSE(SameHandle(out, s_u_2)); + EXPECT_FALSE(SameHandle(out, s_1_u)); + EXPECT_EQ("[1,2]", c.DebugString(out)); + EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 0), c.Dim(out, 0))); + EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 1), c.Dim(out, 1))); + + merged_shapes = c.MergedShapes(); + ASSERT_EQ(7, merged_shapes.size()); + EXPECT_TRUE(merged_shapes[5].first.SameHandle(s_u_2)); + EXPECT_TRUE(merged_shapes[5].second.SameHandle(s_1_u)); + EXPECT_TRUE(merged_shapes[6].first.SameHandle(s_u_2)); + EXPECT_TRUE(merged_shapes[6].second.SameHandle(out)); + + auto s_u1 = c.UnknownShapeOfRank(1); + auto s_u2 = c.UnknownShapeOfRank(1); + TF_EXPECT_OK(c.Merge(s_u1, s_u2, &out)); + EXPECT_TRUE(SameHandle(s_u1, out)); + + merged_shapes = c.MergedShapes(); + ASSERT_EQ(8, merged_shapes.size()); + EXPECT_TRUE(merged_shapes[7].first.SameHandle(s_u1)); + EXPECT_TRUE(merged_shapes[7].second.SameHandle(s_u2)); + + // Incompatible merges give errors and set out to nullptr. + out = s_unknown; + EXPECT_TRUE( + StringPiece(c.Merge(s_u_2, s_1_3, &out).ToString()) + .contains( + "Invalid argument: Dimension 1 in both shapes must be equal, but " + "are 2 and 3")); + + EXPECT_FALSE(IsSet(out)); + out = s_unknown; + EXPECT_TRUE( + StringPiece(c.Merge(s_1_3, s_u_2, &out).ToString()) + .contains( + "Invalid argument: Dimension 1 in both shapes must be equal, but " + "are 3 and 2")); + + EXPECT_FALSE(IsSet(out)); + out = s_unknown; + EXPECT_TRUE( + StringPiece(c.Merge(s_1, s_1_2, &out).ToString()) + .contains( + "Invalid argument: Shapes must be equal rank, but are 1 and 2")); + + EXPECT_FALSE(IsSet(out)); + + merged_shapes = c.MergedShapes(); + EXPECT_EQ(8, merged_shapes.size()); +} + +TEST_F(ShapeInferenceTest, MergePrefix) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(4, 2), + { + Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}), + }, + {}, {}, {}); + + auto s_unknown = c.input(0); + auto s_u_2 = c.input(1); + auto s_1_u_3 = c.input(2); + auto s_2_4 = c.input(3); + + ShapeHandle s_out; + ShapeHandle s_prefix_out; + + // Merging with unknown returns the inputs. + EXPECT_TRUE(c.MergePrefix(s_unknown, s_u_2, &s_out, &s_prefix_out).ok()); + EXPECT_TRUE(SameHandle(s_out, s_unknown)); + EXPECT_TRUE(SameHandle(s_prefix_out, s_u_2)); + EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_unknown, &s_out, &s_prefix_out).ok()); + EXPECT_TRUE(SameHandle(s_out, s_1_u_3)); + EXPECT_TRUE(SameHandle(s_prefix_out, s_unknown)); + + EXPECT_TRUE(c.MergePrefix(s_1_u_3, s_u_2, &s_out, &s_prefix_out).ok()); + EXPECT_FALSE(SameHandle(s_out, s_1_u_3)); + EXPECT_EQ("[1,2]", c.DebugString(s_prefix_out)); + EXPECT_EQ("[1,2,3]", c.DebugString(s_out)); + EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 0), c.Dim(s_out, 0))); + EXPECT_TRUE(SameHandle(c.Dim(s_out, 0), c.Dim(s_1_u_3, 0))); + EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_out, 1))); + EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_u_2, 1))); + + // Incompatible merges give errors and set outs to nullptr. + s_out = s_unknown; + s_prefix_out = s_unknown; + EXPECT_TRUE( + StringPiece( + c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out).ToString()) + .contains( + "Invalid argument: Dimensions must be equal, but are 1 and 2")); + + EXPECT_FALSE(IsSet(s_out)); + EXPECT_FALSE(IsSet(s_prefix_out)); + + s_out = s_unknown; + s_prefix_out = s_unknown; + EXPECT_TRUE( + StringPiece( + c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out).ToString()) + .contains( + "Invalid argument: Shape must be at least rank 3 but is rank 2")); + EXPECT_FALSE(IsSet(s_out)); + EXPECT_FALSE(IsSet(s_prefix_out)); +} + +TEST_F(ShapeInferenceTest, Subshape) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), + {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {}); + + ShapeHandle unknown = c.input(1); + ShapeHandle out; + EXPECT_TRUE(c.Subshape(unknown, 0, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(SameHandle(out, unknown)); + EXPECT_TRUE(c.Subshape(unknown, 1, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_FALSE(SameHandle(out, unknown)); + EXPECT_TRUE(c.Subshape(unknown, 200, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_FALSE(SameHandle(out, unknown)); + + const int kFullRank = 5; + ShapeHandle out_arr[4]; + auto in0 = c.input(0); + EXPECT_TRUE(c.Subshape(in0, 0, &out).ok()); + EXPECT_EQ("[1,2,3,?,5]", c.DebugString(out)); + EXPECT_TRUE(SameHandle(out, in0)); + EXPECT_EQ(kFullRank, c.Rank(out)); + for (int start = 0; start <= kFullRank + 1; ++start) { + for (int end = start; end <= kFullRank + 1; ++end) { + // Get subshapes using different start and end values that give the same + // range. + const int neg_start = + start >= kFullRank ? kFullRank : (start - kFullRank); + const int neg_end = end >= kFullRank ? kFullRank : (end - kFullRank); + ASSERT_TRUE(c.Subshape(in0, start, end, &out_arr[0]).ok()); + ASSERT_TRUE(c.Subshape(in0, neg_start, end, &out_arr[1]).ok()); + ASSERT_TRUE(c.Subshape(in0, start, neg_end, &out_arr[2]).ok()); + ASSERT_TRUE(c.Subshape(in0, neg_start, neg_end, &out_arr[3]).ok()); + + // Verify all computed subshapes. + for (int arr_idx = 0; arr_idx < 4; ++arr_idx) { + out = out_arr[arr_idx]; + ASSERT_EQ(std::min(kFullRank, end) - std::min(kFullRank, start), + c.Rank(out)) + << "start: " << start << " end: " << end << " arr_idx: " << arr_idx + << " in0: " << c.DebugString(in0) << " out: " << c.DebugString(out); + for (int d = 0; d < c.Rank(out); ++d) { + EXPECT_TRUE(SameHandle(c.Dim(in0, start + d), c.Dim(out, d))) + << "arr_idx: " << arr_idx; + } + } + } + } + + // Errors. + out = unknown; + EXPECT_TRUE(StringPiece(c.Subshape(in0, 6, -3, &out).ToString()) + .contains("Invalid argument: Subshape must have computed " + "start <= end, but is 5 " + "and 2 (computed from start 6 and end -3 over " + "shape with rank 5)")); + EXPECT_FALSE(IsSet(out)); + out = unknown; + EXPECT_TRUE(StringPiece(c.Subshape(in0, -50, 100, &out).ToString()) + .contains("Invalid argument: Subshape start out of " + "bounds: -50, for shape with " + "rank 5")); + + EXPECT_FALSE(IsSet(out)); + out = unknown; + EXPECT_TRUE(StringPiece(c.Subshape(in0, 0, -50, &out).ToString()) + .contains("Invalid argument: Subshape end out of bounds: " + "-50, for shape with rank " + "5")); + + EXPECT_FALSE(IsSet(out)); +} + +TEST_F(ShapeInferenceTest, Concatenate) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(3, 2), + {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}); + + auto in0 = c.input(0); + auto in1 = c.input(1); + ShapeHandle unknown = c.input(2); + ShapeHandle out; + EXPECT_TRUE(c.Concatenate(unknown, unknown, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_FALSE(SameHandle(out, unknown)); + EXPECT_TRUE(c.Concatenate(unknown, in0, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_FALSE(SameHandle(out, unknown)); + + EXPECT_TRUE(c.Concatenate(in0, in1, &out).ok()); + EXPECT_EQ("[1,?,3,4,5]", c.DebugString(out)); + int out_i = 0; + for (int i = 0; i < c.Rank(in0); ++i, ++out_i) { + EXPECT_TRUE(SameHandle(c.Dim(in0, i), c.Dim(out, out_i))); + } + for (int i = 0; i < c.Rank(in1); ++i, ++out_i) { + EXPECT_TRUE(SameHandle(c.Dim(in1, i), c.Dim(out, out_i))); + } +} + +TEST_F(ShapeInferenceTest, ReplaceDim) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, + {}, {}, {}); + + auto in = c.input(0); + auto unknown = c.input(1); + + ShapeHandle replaced; + EXPECT_TRUE(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced).ok()); + EXPECT_EQ("[2,2,3]", c.DebugString(replaced)); + EXPECT_TRUE(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced).ok()); + EXPECT_EQ("[1,2,2]", c.DebugString(replaced)); + EXPECT_TRUE(c.ReplaceDim(in, 1, c.Dim(in, 2), &replaced).ok()); + EXPECT_EQ("[1,3,3]", c.DebugString(replaced)); + EXPECT_TRUE(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced).ok()); + EXPECT_EQ("?", c.DebugString(replaced)); + + // Negative indexing. + EXPECT_TRUE(c.ReplaceDim(in, -1, c.Dim(in, 1), &replaced).ok()); + EXPECT_EQ("[1,2,2]", c.DebugString(replaced)); + EXPECT_TRUE(c.ReplaceDim(unknown, -1, c.Dim(in, 1), &replaced).ok()); + EXPECT_EQ("?", c.DebugString(replaced)); + + // out of range indexing. + EXPECT_FALSE(c.ReplaceDim(in, 3, c.Dim(in, 1), &replaced).ok()); + EXPECT_FALSE(IsSet(replaced)); + replaced = in; + EXPECT_FALSE(c.ReplaceDim(in, -4, c.Dim(in, 1), &replaced).ok()); + EXPECT_FALSE(IsSet(replaced)); +} + +TEST_F(ShapeInferenceTest, MakeShape) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, + {}, {}); + + std::vector dims; + auto in0 = c.input(0); + const int rank = c.Rank(in0); + dims.reserve(rank); + for (int i = 0; i < rank; ++i) { + dims.push_back(c.Dim(in0, rank - i - 1)); + } + + auto s = c.MakeShape(dims); + EXPECT_EQ("[5,?,3,2,1]", c.DebugString(s)); + EXPECT_TRUE(SameHandle(c.Dim(s, 0), c.Dim(in0, rank - 1))); + + auto s2 = c.MakeShape(dims); + EXPECT_FALSE(SameHandle(s, s2)); + EXPECT_TRUE(SameHandle(c.Dim(s2, 0), c.Dim(in0, rank - 1))); + + auto s3 = c.MakeShape({1, 2, dims[2]}); + EXPECT_FALSE(SameHandle(s, s3)); + EXPECT_EQ("[1,2,3]", c.DebugString(s3)); +} + +TEST_F(ShapeInferenceTest, UnknownShape) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + + auto u0 = c.UnknownShape(); + auto u1 = c.UnknownShape(); + EXPECT_EQ("?", c.DebugString(u0)); + EXPECT_EQ("?", c.DebugString(u1)); + EXPECT_FALSE(SameHandle(u0, u1)); +} + +TEST_F(ShapeInferenceTest, KnownShapeToProto) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + + auto s = c.MakeShape({1, 2, 3}); + TensorShapeProto proto; + c.ShapeHandleToProto(s, &proto); + + EXPECT_FALSE(proto.unknown_rank()); + EXPECT_EQ(3, proto.dim_size()); + EXPECT_EQ(1, proto.dim(0).size()); +} + +TEST_F(ShapeInferenceTest, UnknownShapeToProto) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + + auto u0 = c.UnknownShape(); + TensorShapeProto proto; + c.ShapeHandleToProto(u0, &proto); + + EXPECT_TRUE(proto.unknown_rank()); + EXPECT_EQ(0, proto.dim_size()); +} + +TEST_F(ShapeInferenceTest, Scalar) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + + auto s0 = c.Scalar(); + EXPECT_EQ("[]", c.DebugString(s0)); + auto s1 = c.Scalar(); + EXPECT_EQ("[]", c.DebugString(s1)); +} + +TEST_F(ShapeInferenceTest, Vector) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + + auto s0 = c.Vector(1); + EXPECT_EQ("[1]", c.DebugString(s0)); + auto s1 = c.Vector(InferenceContext::kUnknownDim); + EXPECT_EQ("[?]", c.DebugString(s1)); + + auto d1 = c.UnknownDim(); + auto s2 = c.Vector(d1); + EXPECT_EQ("[?]", c.DebugString(s2)); + EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0))); +} + +TEST_F(ShapeInferenceTest, Matrix) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + + auto s0 = c.Matrix(1, 2); + EXPECT_EQ("[1,2]", c.DebugString(s0)); + auto s1 = c.Matrix(0, InferenceContext::kUnknownDim); + EXPECT_EQ("[0,?]", c.DebugString(s1)); + + auto d1 = c.UnknownDim(); + auto d2 = c.UnknownDim(); + auto s2 = c.Matrix(d1, d2); + EXPECT_EQ("[?,?]", c.DebugString(s2)); + EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0))); + EXPECT_TRUE(SameHandle(d2, c.Dim(s2, 1))); + + auto s3 = c.Matrix(d1, 100); + EXPECT_EQ("[?,100]", c.DebugString(s3)); + EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0))); +} + +TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { + auto create = [&](Tensor* t) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, + {}); + ShapeHandle out; + Status s = c.MakeShapeFromShapeTensor(0, &out); + if (s.ok()) { + return c.DebugString(out); + } else { + EXPECT_FALSE(IsSet(out)); + return s.error_message(); + } + }; + + Tensor t; + EXPECT_EQ("?", create(nullptr)); + + t = ::tensorflow::test::AsTensor({1, 2, 3}); + EXPECT_EQ("[1,2,3]", create(&t)); + + t = ::tensorflow::test::AsTensor({3, 2, 1}); + EXPECT_EQ("[3,2,1]", create(&t)); + + t = ::tensorflow::test::AsTensor({3, -1, 1}); + EXPECT_EQ("[3,?,1]", create(&t)); + + t = ::tensorflow::test::AsTensor({}); + EXPECT_EQ("[]", create(&t)); + + t = ::tensorflow::test::AsTensor({1, 2, 3}); + EXPECT_TRUE( + StringPiece(create(&t)) + .contains("Input tensor must be int32 or int64, but was float")); + + t = ::tensorflow::test::AsScalar(1); + EXPECT_TRUE(StringPiece(create(&t)) + .contains("Input tensor must be rank 1, but was rank 0")); + + t = ::tensorflow::test::AsTensor({1, 2}, TensorShape{2, 1}); + EXPECT_TRUE(StringPiece(create(&t)) + .contains("Input tensor must be rank 1, but was rank 2")); + + // Test negative values for the dims. + t = ::tensorflow::test::AsTensor({3, -2, 1}); + EXPECT_TRUE(StringPiece(create(&t)) + .contains("Invalid value in tensor used for shape: -2")); + + // Test negative values for the dims. + t = ::tensorflow::test::AsTensor({3, -2, 1}); + EXPECT_TRUE(StringPiece(create(&t)) + .contains("Invalid value in tensor used for shape: -2")); + + // Test when the input shape is wrong. + { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, + {}, {}); + ShapeHandle out; + EXPECT_EQ("Shape must be rank 1 but is rank 2", + c.MakeShapeFromShapeTensor(0, &out).error_message()); + } +} + +TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + + // With an unknown rank. + ShapeHandle out; + TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(PartialTensorShape(), &out)); + EXPECT_EQ("?", c.DebugString(out)); + + // With a known rank. + TF_ASSERT_OK( + c.MakeShapeFromPartialTensorShape(PartialTensorShape({0}), &out)); + EXPECT_EQ("[0]", c.DebugString(out)); + TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape( + PartialTensorShape({0, -1, 1000}), &out)); + EXPECT_EQ("[0,?,1000]", c.DebugString(out)); +} + +TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + + ShapeHandle out; + TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out)); + EXPECT_EQ("[]", c.DebugString(out)); + TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0}), &out)); + EXPECT_EQ("[0]", c.DebugString(out)); + TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0, 7, 1000}), &out)); + EXPECT_EQ("[0,7,1000]", c.DebugString(out)); +} + +TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + TensorShapeProto proto; + + // With a set unknown rank. + ShapeHandle out; + proto.set_unknown_rank(true); + EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + proto.add_dim()->set_size(0); + EXPECT_TRUE( + StringPiece(c.MakeShapeFromShapeProto(proto, &out).error_message()) + .contains("An unknown shape must not have any dimensions set.")); + EXPECT_FALSE(IsSet(out)); + + // With known rank. + proto.set_unknown_rank(false); + EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok()); + EXPECT_EQ("[0]", c.DebugString(out)); + proto.add_dim()->set_size(-1); + proto.add_dim()->set_size(1000); + EXPECT_TRUE(c.MakeShapeFromShapeProto(proto, &out).ok()); + EXPECT_EQ("[0,?,1000]", c.DebugString(out)); + + // With invalid dimension value. + proto.add_dim()->set_size(-2); + EXPECT_TRUE( + StringPiece(c.MakeShapeFromShapeProto(proto, &out).error_message()) + .contains("Shape [0,?,1000,-2] has dimensions with values below -1 " + "(where -1 means unknown)")); + + EXPECT_FALSE(IsSet(out)); +} + +TEST_F(ShapeInferenceTest, MakeDim) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + + auto d0 = c.MakeDim(1); + auto d1 = c.MakeDim(1); + auto d2 = c.MakeDim(2); + EXPECT_EQ("1", c.DebugString(d0)); + EXPECT_EQ("1", c.DebugString(d1)); + EXPECT_FALSE(SameHandle(d0, d1)); + EXPECT_EQ("2", c.DebugString(d2)); +} + +TEST_F(ShapeInferenceTest, UnknownDim) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + + auto d0 = c.UnknownDim(); + auto d1 = c.UnknownDim(); + EXPECT_EQ("?", c.DebugString(d0)); + EXPECT_EQ("?", c.DebugString(d1)); + EXPECT_FALSE(SameHandle(d0, d1)); +} + +TEST_F(ShapeInferenceTest, UnknownShapeOfRank) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + + auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3); + EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3)); + + auto unknown_shape_of_rank_0 = c.UnknownShapeOfRank(0); + EXPECT_EQ("[]", c.DebugString(unknown_shape_of_rank_0)); +} + +TEST_F(ShapeInferenceTest, InputTensors) { + const Tensor t1 = tensorflow::test::AsTensor({10}); + const Tensor t2 = tensorflow::test::AsTensor({20, 30}); + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})}, + {&t1, &t2}, {}, {}); + + EXPECT_TRUE(c.input_tensor(0) == &t1); + EXPECT_TRUE(c.input_tensor(1) == &t2); + EXPECT_TRUE(c.input_tensor(2) == nullptr); +} + +TEST_F(ShapeInferenceTest, MakeDimForScalarInput) { + Tensor t1 = tensorflow::test::AsScalar(20); + Tensor t2 = tensorflow::test::AsScalar(-1); + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, + {&t1, &t2}, {}, {}); + + DimensionHandle d; + EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok()); + EXPECT_EQ("20", c.DebugString(d)); + + EXPECT_TRUE(StringPiece(c.MakeDimForScalarInput(1, &d).error_message()) + .contains("Dimension size, given by scalar input 1, must " + "be non-negative but is -1")); + + // Same tests, with int64 values. + t1 = tensorflow::test::AsScalar(20); + t2 = tensorflow::test::AsScalar(-1); + EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok()); + EXPECT_EQ("20", c.DebugString(d)); + + EXPECT_TRUE(StringPiece(c.MakeDimForScalarInput(1, &d).error_message()) + .contains("Dimension size, given by scalar input 1, must " + "be non-negative but is -1")); +} + +TEST_F(ShapeInferenceTest, GetAttr) { + OpRegistrationData op_reg_data; + op_reg_data.op_def = MakeOpDef(0, 2); + NodeDef def; + CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def) + .Attr("foo", "bar") + .Finalize(&def) + .ok()); + + std::vector empty; + InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {}); + string value; + EXPECT_TRUE(c.GetAttr("foo", &value).ok()); + EXPECT_EQ("bar", value); +} + +TEST_F(ShapeInferenceTest, Divide) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, + {}, {}); + + auto s = c.input(0); + auto d_6 = c.Dim(s, 0); + auto d_unknown = c.Dim(s, 1); + auto d_1 = c.Dim(s, 2); + auto d_2 = c.Dim(s, 3); + auto d_0 = c.Dim(s, 4); + bool evenly_divisible = true; + + // Dividing unknown by non-1 gives new unknown. + DimensionHandle out; + EXPECT_TRUE(c.Divide(d_unknown, 2, evenly_divisible, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_FALSE(SameHandle(out, d_unknown)); + + // Dividing anything by 1 returns the input. + EXPECT_TRUE(c.Divide(d_unknown, 1, evenly_divisible, &out).ok()); + EXPECT_TRUE(SameHandle(out, d_unknown)); + EXPECT_TRUE(c.Divide(d_6, 1, evenly_divisible, &out).ok()); + EXPECT_TRUE(SameHandle(out, d_6)); + EXPECT_TRUE(c.Divide(d_unknown, d_1, evenly_divisible, &out).ok()); + EXPECT_TRUE(SameHandle(out, d_unknown)); + EXPECT_TRUE(c.Divide(d_6, d_1, evenly_divisible, &out).ok()); + EXPECT_TRUE(SameHandle(out, d_6)); + + EXPECT_TRUE(c.Divide(d_6, 2, evenly_divisible, &out).ok()); + EXPECT_EQ("3", c.DebugString(out)); + EXPECT_TRUE(c.Divide(d_6, d_2, evenly_divisible, &out).ok()); + EXPECT_EQ("3", c.DebugString(out)); + + EXPECT_TRUE( + StringPiece(c.Divide(d_6, 5, evenly_divisible, &out).error_message()) + .contains("Dimension size must be evenly divisible by 5 but is 6")); + + EXPECT_TRUE( + StringPiece(c.Divide(d_6, 0, evenly_divisible, &out).error_message()) + .contains("Divisor must be positive but is 0")); + EXPECT_TRUE( + StringPiece(c.Divide(d_6, d_0, evenly_divisible, &out).error_message()) + .contains("Divisor must be positive but is 0")); + + EXPECT_TRUE( + StringPiece(c.Divide(d_6, -1, evenly_divisible, &out).error_message()) + .contains("Divisor must be positive but is -1")); + + // Repeat error cases above with evenly_divisible=false. + evenly_divisible = false; + EXPECT_TRUE(c.Divide(d_6, 5, evenly_divisible, &out).ok()); + EXPECT_EQ("1", c.DebugString(out)); + + EXPECT_TRUE( + StringPiece(c.Divide(d_6, 0, evenly_divisible, &out).error_message()) + .contains("Divisor must be positive but is 0")); + + EXPECT_TRUE( + StringPiece(c.Divide(d_6, -1, evenly_divisible, &out).error_message()) + .contains("Divisor must be positive but is -1")); +} + +TEST_F(ShapeInferenceTest, Add) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, + {}); + + auto s = c.input(0); + auto d_6 = c.Dim(s, 0); + auto d_unknown = c.Dim(s, 1); + auto d_0 = c.Dim(s, 2); + + // Adding non-zero to unknown gives new unknown. + DimensionHandle out; + EXPECT_TRUE(c.Add(d_unknown, 1, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_FALSE(SameHandle(out, d_unknown)); + + // Adding 0 to anything gives input. + EXPECT_TRUE(c.Add(d_unknown, 0, &out).ok()); + EXPECT_TRUE(SameHandle(out, d_unknown)); + EXPECT_TRUE(c.Add(d_6, 0, &out).ok()); + EXPECT_TRUE(SameHandle(out, d_6)); + + // Adding dimension with value 0 to anything gives input. + EXPECT_TRUE(c.Add(d_unknown, c.MakeDim(0ll), &out).ok()); + EXPECT_TRUE(SameHandle(out, d_unknown)); + EXPECT_TRUE(c.Add(d_6, c.MakeDim(0ll), &out).ok()); + EXPECT_TRUE(SameHandle(out, d_6)); + + // Test addition. + EXPECT_TRUE(c.Add(d_6, 2, &out).ok()); + EXPECT_EQ("8", c.DebugString(out)); + EXPECT_TRUE(c.Add(d_6, std::numeric_limits::max() - 6, &out).ok()); + EXPECT_EQ(std::numeric_limits::max(), c.Value(out)); + + // Test addition using dimension as second value. + EXPECT_TRUE(c.Add(d_6, c.MakeDim(2), &out).ok()); + EXPECT_EQ("8", c.DebugString(out)); + EXPECT_TRUE( + c.Add(d_6, c.MakeDim(std::numeric_limits::max() - 6), &out).ok()); + EXPECT_EQ(std::numeric_limits::max(), c.Value(out)); + EXPECT_TRUE(c.Add(d_6, c.UnknownDim(), &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Add(d_0, d_6, &out).ok()); + EXPECT_TRUE(SameHandle(out, d_6)); + + EXPECT_TRUE( + StringPiece(c.Add(d_6, std::numeric_limits::max() - 5, &out) + .error_message()) + .contains( + "Dimension size overflow from adding 6 and 9223372036854775802")); +} + +TEST_F(ShapeInferenceTest, Subtract) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, + {}, {}); + + auto s = c.input(0); + auto d_6 = c.Dim(s, 0); + auto d_unknown = c.Dim(s, 1); + auto d_0 = c.Dim(s, 2); + auto d_5 = c.Dim(s, 3); + + // Subtracting non-zero from unknown gives new unknown. + DimensionHandle out; + EXPECT_TRUE(c.Subtract(d_unknown, 1, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_FALSE(SameHandle(out, d_unknown)); + + // Subtracting 0 from anything gives input. + EXPECT_TRUE(c.Subtract(d_unknown, 0ll, &out).ok()); + EXPECT_TRUE(SameHandle(out, d_unknown)); + EXPECT_TRUE(c.Subtract(d_6, 0ll, &out).ok()); + EXPECT_TRUE(SameHandle(out, d_6)); + + // Subtracting dimension with value 0 from anything gives input. + EXPECT_TRUE(c.Subtract(d_unknown, c.MakeDim(0ll), &out).ok()); + EXPECT_TRUE(SameHandle(out, d_unknown)); + EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(0ll), &out).ok()); + EXPECT_TRUE(SameHandle(out, d_6)); + + // Test subtraction. + EXPECT_TRUE(c.Subtract(d_6, 2, &out).ok()); + EXPECT_EQ("4", c.DebugString(out)); + EXPECT_TRUE(c.Subtract(d_6, 6, &out).ok()); + EXPECT_EQ("0", c.DebugString(out)); + + // Test subtraction using dimension as second value. + EXPECT_TRUE(c.Subtract(d_6, c.MakeDim(2), &out).ok()); + EXPECT_EQ("4", c.DebugString(out)); + EXPECT_TRUE(c.Subtract(d_6, d_5, &out).ok()); + EXPECT_EQ("1", c.DebugString(out)); + EXPECT_TRUE(c.Subtract(d_6, c.UnknownDim(), &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Subtract(d_6, d_0, &out).ok()); + EXPECT_TRUE(SameHandle(out, d_6)); + + EXPECT_TRUE( + StringPiece(c.Subtract(d_5, d_6, &out).error_message()) + .contains("Negative dimension size caused by subtracting 6 from 5")); +} + +TEST_F(ShapeInferenceTest, Multiply) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, + {}, {}); + + auto s = c.input(0); + auto d_6 = c.Dim(s, 0); + auto d_unknown = c.Dim(s, 1); + auto d_0 = c.Dim(s, 2); + auto d_1 = c.Dim(s, 3); + + // Multiplying non-zero to unknown gives new unknown. + DimensionHandle out; + EXPECT_TRUE(c.Multiply(d_unknown, 2, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + + // Multiplying 0 to anything gives 0. + EXPECT_TRUE(c.Multiply(d_unknown, 0, &out).ok()); + EXPECT_EQ("0", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_unknown, d_0, &out).ok()); + EXPECT_EQ("0", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_0, d_unknown, &out).ok()); + EXPECT_EQ("0", c.DebugString(out)); + + // Multiplying 1 to anything gives the original. + // (unknown -> unknown) + EXPECT_TRUE(c.Multiply(d_unknown, 1, &out).ok()); + EXPECT_TRUE(SameHandle(d_unknown, out)); + EXPECT_TRUE(c.Multiply(d_unknown, d_1, &out).ok()); + EXPECT_TRUE(SameHandle(d_unknown, out)); + EXPECT_TRUE(c.Multiply(d_1, d_unknown, &out).ok()); + EXPECT_TRUE(SameHandle(d_unknown, out)); + // (known -> known) + EXPECT_TRUE(c.Multiply(d_6, 1, &out).ok()); + EXPECT_TRUE(SameHandle(d_6, out)); + EXPECT_TRUE(c.Multiply(d_6, d_1, &out).ok()); + EXPECT_TRUE(SameHandle(d_6, out)); + EXPECT_TRUE(c.Multiply(d_1, d_6, &out).ok()); + EXPECT_TRUE(SameHandle(d_6, out)); + + // Test multiplication. + EXPECT_TRUE(c.Multiply(d_6, 2, &out).ok()); + EXPECT_EQ("12", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_6, 6, &out).ok()); + EXPECT_EQ("36", c.DebugString(out)); + + // Test multiplication using dimension as second value. + EXPECT_TRUE(c.Multiply(d_6, c.MakeDim(2), &out).ok()); + EXPECT_EQ("12", c.DebugString(out)); + EXPECT_TRUE(c.Multiply(d_6, c.UnknownDim(), &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); +} + +TEST_F(ShapeInferenceTest, FullyDefined) { + NodeDef def; + std::vector empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}); + + // No rank or missing dimension information should return false. + EXPECT_FALSE(c.FullyDefined(c.UnknownShape())); + EXPECT_FALSE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.UnknownDim()))); + + // Return true if all information exists. + EXPECT_TRUE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.MakeDim(2)))); + EXPECT_TRUE(c.FullyDefined(c.Scalar())); +} + +TEST_F(ShapeInferenceTest, Min) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, + {}, {}); + + auto s = c.input(0); + auto d_1 = c.Dim(s, 0); + auto d_2 = c.Dim(s, 1); + auto d_unknown = c.Dim(s, 2); + auto d_0 = c.Dim(s, 3); + + // Minimum involving zero and unknown returns zero. + DimensionHandle out; + EXPECT_TRUE(c.Min(d_0, d_unknown, &out).ok()); + EXPECT_TRUE(SameHandle(d_0, out)); + EXPECT_TRUE(c.Min(d_unknown, d_0, &out).ok()); + EXPECT_TRUE(SameHandle(d_0, out)); + EXPECT_TRUE(c.Min(c.MakeDim(0ll), d_unknown, &out).ok()); + EXPECT_EQ("0", c.DebugString(out)); + EXPECT_TRUE(c.Min(d_unknown, 0ll, &out).ok()); + EXPECT_EQ("0", c.DebugString(out)); + + // Minimum involving unknowns and non-zeros gives new unknown. + EXPECT_TRUE(c.Min(d_unknown, d_unknown, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Min(d_unknown, 1, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Min(d_1, d_unknown, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + + // Minimum with constant second arg. + EXPECT_TRUE(c.Min(d_1, 1, &out).ok()); + EXPECT_TRUE(SameHandle(d_1, out)); + EXPECT_TRUE(c.Min(d_1, 3, &out).ok()); + EXPECT_TRUE(SameHandle(d_1, out)); + EXPECT_TRUE(c.Min(d_2, 1, &out).ok()); + EXPECT_EQ("1", c.DebugString(out)); + + // Minimum with two dimensions. + EXPECT_TRUE(c.Min(d_1, d_1, &out).ok()); + EXPECT_TRUE(SameHandle(d_1, out)); + EXPECT_TRUE(c.Min(d_1, d_2, &out).ok()); + EXPECT_TRUE(SameHandle(d_1, out)); + EXPECT_TRUE(c.Min(d_2, d_1, &out).ok()); + EXPECT_TRUE(SameHandle(d_1, out)); + EXPECT_TRUE(c.Min(d_2, d_2, &out).ok()); + EXPECT_TRUE(SameHandle(d_2, out)); +} + +TEST_F(ShapeInferenceTest, Max) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, + {}); + + auto s = c.input(0); + auto d_1 = c.Dim(s, 0); + auto d_2 = c.Dim(s, 1); + auto d_unknown = c.Dim(s, 2); + + // Maximum involving unknowns gives new unknown. + DimensionHandle out; + EXPECT_TRUE(c.Max(d_unknown, d_unknown, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Max(d_unknown, 1, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + EXPECT_TRUE(c.Max(d_1, d_unknown, &out).ok()); + EXPECT_EQ("?", c.DebugString(out)); + + // Maximum with constant second arg. + EXPECT_TRUE(c.Max(d_1, 1, &out).ok()); + EXPECT_TRUE(SameHandle(d_1, out)); + EXPECT_TRUE(c.Max(d_2, 1, &out).ok()); + EXPECT_TRUE(SameHandle(d_2, out)); + EXPECT_TRUE(c.Max(d_2, 3, &out).ok()); + EXPECT_EQ("3", c.DebugString(out)); + + // Maximum with two dimensions. + EXPECT_TRUE(c.Max(d_1, d_1, &out).ok()); + EXPECT_TRUE(SameHandle(d_1, out)); + EXPECT_TRUE(c.Max(d_1, d_2, &out).ok()); + EXPECT_TRUE(SameHandle(d_2, out)); + EXPECT_TRUE(c.Max(d_2, d_1, &out).ok()); + EXPECT_TRUE(SameHandle(d_2, out)); + EXPECT_TRUE(c.Max(d_2, d_2, &out).ok()); + EXPECT_TRUE(SameHandle(d_2, out)); +} + +void ShapeInferenceTest::TestMergeHandles(bool input_not_output) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {}, + {}); + auto make_shape = [&c](std::initializer_list dim_sizes) { + ShapeHandle s; + TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s)); + return s; + }; + auto get_shapes_and_types_from_context = [&](int idx) { + if (input_not_output) { + return c.input_handle_shapes_and_types(idx); + } else { + return c.output_handle_shapes_and_types(idx); + } + }; + auto merge_shapes_and_types_to_context = + [&](int idx, const std::vector& shapes_and_types) { + if (input_not_output) { + return c.MergeInputHandleShapesAndTypes(idx, shapes_and_types); + } else { + return c.MergeOutputHandleShapesAndTypes(idx, shapes_and_types); + } + }; + + EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr); + EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr); + + // First merge will take the input completely. + std::vector t{{make_shape({1, 2, 3}), DT_FLOAT}, + {c.UnknownShape(), DT_INVALID}, + {make_shape({4, 3, 2, 1}), DT_INT32}}; + ASSERT_TRUE(merge_shapes_and_types_to_context(0, t)); + ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr); + std::vector v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + for (int i = 0; i < v.size(); ++i) { + EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Merge that fails because wrong number of values passed. + // Fails, and no changes made. + ASSERT_FALSE(merge_shapes_and_types_to_context( + 0, std::vector{{make_shape({1, 2, 3}), DT_FLOAT}})); + v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + for (int i = 0; i < v.size(); ++i) { + EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Only difference is in a mismatched shape. That is ignored, + // and there are no other changes, so nothing is done. + // + // TODO(cwhipkey): in mismatch cases, change Merge*HandleShapesAndTypes to + // return an error (separate error from 'refined' output)? + auto t2 = t; + t2[2].shape = make_shape({4, 3, 4, 1}); + ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2)); + v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + for (int i = 0; i < v.size(); ++i) { + EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Only difference is in a mismatched dtype, but that cannot be + // updated unless original dtype is DT_INVALID. + t2 = t; + t2[2].dtype = DT_FLOAT; + ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2)); + v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + for (int i = 0; i < v.size(); ++i) { + EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Difference is mergeable (new shape). + t[1].shape = make_shape({1, 10}); + ASSERT_TRUE(merge_shapes_and_types_to_context(0, t)); + v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + for (int i = 0; i < v.size(); ++i) { + EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Difference is mergeable (new type). + t[1].dtype = DT_DOUBLE; + ASSERT_TRUE(merge_shapes_and_types_to_context(0, t)); + v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + for (int i = 0; i < v.size(); ++i) { + EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // No difference. + ASSERT_FALSE(merge_shapes_and_types_to_context(0, t)); +} + +TEST_F(ShapeInferenceTest, MergeInputHandleShapesAndTypes) { + TestMergeHandles(true /* input_not_output */); +} + +TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) { + TestMergeHandles(false /* input_not_output */); +} + +void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {}, + {}); + auto make_shape = [&c](std::initializer_list dim_sizes) { + ShapeHandle s; + TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s)); + return s; + }; + auto get_shapes_and_types_from_context = [&](int idx) { + if (input_not_output) { + return c.input_handle_shapes_and_types(idx); + } else { + return c.output_handle_shapes_and_types(idx); + } + }; + auto relax_shapes_and_types_to_context = + [&](int idx, const std::vector& shapes_and_types) { + if (input_not_output) { + return c.RelaxInputHandleShapesAndMergeTypes(idx, shapes_and_types); + } else { + return c.RelaxOutputHandleShapesAndMergeTypes(idx, shapes_and_types); + } + }; + + EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr); + EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr); + + // First relax will take the input completely. + std::vector t{{make_shape({1, 2, 3}), DT_FLOAT}, + {c.UnknownShape(), DT_INVALID}, + {make_shape({4, 3, 2, 1}), DT_INT32}}; + ASSERT_TRUE(relax_shapes_and_types_to_context(0, t)); + ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr); + std::vector v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + for (int i = 0; i < v.size(); ++i) { + EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Relax that fails because wrong number of values passed. + // Fails, and no changes made. + ASSERT_FALSE(relax_shapes_and_types_to_context( + 0, std::vector{{make_shape({1, 2, 3}), DT_FLOAT}})); + v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + for (int i = 0; i < v.size(); ++i) { + EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Only difference is in a mismatched shape. This should replace + // the mismatched dimension with an UnknownDim. + auto t2 = t; + t2[2].shape = make_shape({4, 3, 4, 1}); + ASSERT_TRUE(relax_shapes_and_types_to_context(0, t2)); + v = *get_shapes_and_types_from_context(0); + EXPECT_EQ("[4,3,?,1]", c.DebugString(v[2].shape)); + for (int i = 0; i < v.size(); ++i) { + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Only difference is in a mismatched dtype, but that cannot be + // updated unless original dtype is DT_INVALID. + t2 = t; + t2[2].dtype = DT_FLOAT; + ASSERT_FALSE(relax_shapes_and_types_to_context(0, t2)); + v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + for (int i = 0; i < v.size(); ++i) { + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Difference is a new shape, which will result in a new UnknownShape. + t[1].shape = make_shape({1, 10}); + ASSERT_TRUE(relax_shapes_and_types_to_context(0, t)); + v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + EXPECT_FALSE(SameHandle(t[1].shape, v[1].shape)); + EXPECT_EQ("?", c.DebugString(v[1].shape)); + for (int i = 0; i < v.size(); ++i) { + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Difference is relaxable (new type). + t[1].dtype = DT_DOUBLE; + ASSERT_TRUE(relax_shapes_and_types_to_context(0, t)); + v = *get_shapes_and_types_from_context(0); + EXPECT_EQ(t[1].dtype, v[1].dtype); +} + +TEST_F(ShapeInferenceTest, RelaxInputHandleShapesAndTypes) { + TestRelaxHandles(true /* input_not_output */); +} + +TEST_F(ShapeInferenceTest, RelaxOutputHandleShapesAndTypes) { + TestRelaxHandles(false /* input_not_output */); +} + +} // namespace shape_inference +} // namespace tensorflow diff --git a/shape_inference_testutil.cc b/shape_inference_testutil.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4765ab0b2c41a1b510364d755984b6ae68dd07a --- /dev/null +++ b/shape_inference_testutil.cc @@ -0,0 +1,273 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/shape_inference_testutil.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { +namespace shape_inference { + +using errors::Unknown; + +Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, + const string& ins, + const string& expected_outs) { + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op.name, &op_reg_data)); + + std::vector ins_v = str_util::Split(ins, ';'); + std::unique_ptr new_node_def; + + InferenceContext::ShapeManager manager; + std::vector in_shapes; + for (const string& spec : ins_v) { + ShapeHandle shape; + TF_RETURN_IF_ERROR(MakeShapeFromString(&manager, spec, &shape)); + in_shapes.push_back(shape); + } + + std::vector>> + input_resource_handle_shapes_and_types; + for (const auto p : op.input_resource_handle_shapes_and_types) { + if (p == nullptr) { + input_resource_handle_shapes_and_types.push_back(nullptr); + } else { + std::unique_ptr> v( + new std::vector()); + for (const auto& shape_and_type : *p) { + ShapeHandle shape; + TF_RETURN_IF_ERROR( + MakeShapeFromString(&manager, shape_and_type.first, &shape)); + v->emplace_back(shape, shape_and_type.second); + } + input_resource_handle_shapes_and_types.emplace_back(v.release()); + } + } + shape_inference::InferenceContext c( + op.graph_def_version, &op.node_def, op_reg_data->op_def, in_shapes, + op.input_tensors, {}, std::move(input_resource_handle_shapes_and_types)); + TF_RETURN_IF_ERROR(c.construction_status()); + if (op_reg_data->shape_inference_fn == nullptr) { + return errors::InvalidArgument( + "No shape inference function exists for op '", op.name, + "', did you forget to define it?"); + } + + TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn)); + + const int num_outputs = c.num_outputs(); + + if (expected_outs == "e") { + return Unknown("Shape inference should have returned error"); + } + + // Verify the output shape. + std::vector expected_outs_v = str_util::Split(expected_outs, ';'); + if (num_outputs != expected_outs_v.size()) { + return Unknown("The expected output string lists the wrong number of ", + "outputs. It lists ", expected_outs_v.size(), + " but should list ", num_outputs); + } + for (int i = 0; i < num_outputs; ++i) { + StringPiece expected(expected_outs_v[i]); + shape_inference::ShapeHandle out = c.output(i); + + string err_prefix = strings::StrCat("Output ", i); + string err_suffix = + strings::StrCat(". Output shape was ", c.DebugString(out)); + + int in_index = -1; + for (int i = 0; i < c.num_inputs(); ++i) { + if (c.input(i).SameHandle(out)) { + in_index = i; + } + } + + if (expected.starts_with("in")) { + if (in_index == -1) { + return Unknown(err_prefix, + " should have matched an input shape by " + "handle, but matched no input shape. This means the ", + "shape function was expected to pass an input " + "ShapeHandle through for this output, but did not", + err_suffix); + } + auto v = str_util::Split(expected, '|'); + if (std::find(v.begin(), v.end(), strings::StrCat("in", in_index)) == + v.end()) { + return Unknown( + err_prefix, " matched input ", in_index, + " by handle, but should have matched one of (", expected, + ") instead. This means the shape function passed the ShapeHandle ", + "for input ", in_index, + " to the output, but should have passed a different input ", + "ShapeHandle through", err_suffix); + } + continue; + } + if (in_index != -1) { + return Unknown(err_prefix, " matched input ", in_index, + " by ShapeHandle, but was expected to not match an input ", + "shape by handle", err_suffix); + } + if (expected == "?") { + if (c.RankKnown(out)) { + return Unknown(err_prefix, " expected to be unknown", err_suffix); + } + continue; + } + + // Verify the dimensions. + CHECK(expected.starts_with("[") && expected.ends_with("]")) << expected; + expected.remove_prefix(1); + expected.remove_suffix(1); + + // Split expected as a dimension. + auto expected_dims = str_util::Split(expected, ','); + if (!c.RankKnown(out)) { + return Unknown(err_prefix, " expected rank ", expected_dims.size(), + " but was ?", err_suffix); + } + if (c.Rank(out) != expected_dims.size()) { + return Unknown(err_prefix, " expected rank ", expected_dims.size(), + " but was ", c.Rank(out), err_suffix); + } + for (int j = 0; j < expected_dims.size(); ++j) { + err_prefix = strings::StrCat("Output dim ", i, ",", j); + StringPiece expected_dim(expected_dims[j]); + DimensionHandle out_dim = c.Dim(out, j); + + std::pair in_dim_idx(-1, -1); + for (int i = 0; i < c.num_inputs(); ++i) { + auto in = c.input(i); + for (int j = 0; j < c.Rank(in); ++j) { + if (c.Dim(in, j).SameHandle(out_dim)) { + in_dim_idx = std::make_pair(i, j); + } + } + } + + if (expected_dim == "?") { + if (in_dim_idx.first != -1) { + return Unknown(err_prefix, + " expected to be an unknown but matched input d", + in_dim_idx.first, "_", in_dim_idx.second, + ". The shape function passed through ", + "a DimensionHandle from an input instead of making ", + "a new unknown dimension", err_suffix); + } else if (c.ValueKnown(out_dim)) { + return Unknown(err_prefix, " expected to be unknown but was ", + c.Value(out_dim), err_suffix); + } + } else if (expected_dim.starts_with("d")) { + // Compare the dimension values. + auto v = str_util::Split(expected_dim, '|'); + if (in_dim_idx.first == -1) { + return Unknown( + err_prefix, " was expected to match the dimension of an input, ", + "but did not match any input dimension. The shape ", + "function was expected to pass through a ", + "DimensionHandle for an input, but did not", err_suffix); + } + if (std::find(v.begin(), v.end(), + strings::StrCat("d", in_dim_idx.first, "_", + in_dim_idx.second)) == v.end()) { + return Unknown(err_prefix, " matched input d", in_dim_idx.first, "_", + in_dim_idx.second, + ", but should have matched one of (", expected_dim, + "). The shape function passed through " + "the DimensionHandle for an input, but ", + "was expected to pass a different one", err_suffix); + } + } else { + // Parse it as a value. + int64 value = -1; + if (!strings::safe_strto64(expected_dim, &value)) { + return Unknown(err_prefix, ": the expected dimension value '", + expected_dim, "' failed to parse as int64", + err_suffix); + } + if (in_dim_idx.first != -1) { + return Unknown( // + err_prefix, " expected to be ", value, " but matched input d", + in_dim_idx.first, "_", in_dim_idx.second, + ". The shape function was not expected to pass a DimensionHandle " + "from the input to the output, but did. Note that even if the " + "passed through output has the same dimension value as the " + "expected value, this is considered a failure for the test; " + "switch to using d#_# syntax if passing through the " + "DimensionHandle should be the expected behavior", + err_suffix); + } else if (value != c.Value(out_dim)) { + return Unknown(err_prefix, " expected to be ", value, " but was ", + c.DebugString(out_dim), err_suffix); + } + } + } + } + return Status::OK(); +} + +// static +Status ShapeInferenceTestutil::MakeShapeFromString( + InferenceContext::ShapeManager* manager, const string& spec, + ShapeHandle* output) { + if (spec == "?") { + *output = manager->UnknownShape(); + return Status::OK(); + } + + std::vector dims; + strings::Scanner scanner(spec); + scanner.OneLiteral("["); + while (scanner.Peek() != ']') { + if (scanner.Peek() == '?') { + scanner.OneLiteral("?"); + dims.push_back(manager->MakeDim(InferenceContext::kUnknownDim)); + } else { + scanner.RestartCapture().Many(strings::Scanner::DIGIT); + StringPiece match; + int64 dim_size = 0; + + if (!scanner.GetResult(nullptr, &match) || + !strings::safe_strto64(match, &dim_size)) { + return errors::InvalidArgument("Could not parse number in ", spec); + } + + dims.push_back(manager->MakeDim(dim_size)); + } + + if (scanner.Peek() == ',') { + scanner.OneLiteral(","); + } else if (scanner.Peek() != ']') { + return errors::InvalidArgument( + "Invalid input spec (] not found in dim shape): ", spec); + } + } + if (!scanner.OneLiteral("]").Eos().GetResult()) { + return errors::InvalidArgument("Malformed shape spec: did not end in ']'."); + } + *output = manager->MakeShape(dims); + + return Status::OK(); +} + +} // namespace shape_inference +} // namespace tensorflow diff --git a/shape_inference_testutil.h b/shape_inference_testutil.h new file mode 100644 index 0000000000000000000000000000000000000000..fbfd24538bc7a5b1f3ee3805d4a803a0e7239fca --- /dev/null +++ b/shape_inference_testutil.h @@ -0,0 +1,101 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ + +#include +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/version.h" + +// Contains utilities for writing tests for shape inference functions. + +namespace tensorflow { + +class Tensor; + +struct ShapeInferenceTestOp { + typedef std::pair ShapeAndType; + explicit ShapeInferenceTestOp(StringPiece name) : name(name.ToString()) {} + string name; + NodeDef node_def; + std::vector input_tensors; + std::vector*> + input_resource_handle_shapes_and_types; + int graph_def_version = TF_GRAPH_DEF_VERSION; +}; + +namespace shape_inference { + +class ShapeInferenceTestutil { + public: + // Run shape inference for , given inputs specified by + // and returns an error if the inferred shape does not match expected_outs. + // + // is a semicolon separated list of shapes. Each shape is formatted + // according to the formatting per + // shape_inference::InferenceContext::InferenceContext. + // + // is a semicolon separated list of shapes. Each shape is + // formatted as one of: + // * ? - an unknown shape, but not matching an input shape + // * in0|in2|... - output shape must be the same as one of these input shapes. + // * [1,?,d0_0|d0_1] - output shape is of known rank, with comma-separated + // dimension values. + // Each dimension value is one of: + // * a constant, which means that constant not equal to a specific input + // * ?, which means an unknown dim size not equal to a specific input + // * d0_0|d1_2, indicating that the dim size must be equal to one of + // the given input dimensions; the first number is the input # and + // the second is which dimension in that input it corresponds to. + // can be "e"; this is used to indicate that shape inference + // should have failed. + static Status InferShapes(ShapeInferenceTestOp op, const string& ins, + const string& expected_outs); + + private: + ShapeInferenceTestutil() {} + + // Makes a shape out of 'spec'. + static Status MakeShapeFromString(InferenceContext::ShapeManager* manager, + const string& spec, ShapeHandle* output); +}; + +} // namespace shape_inference + +#define INFER_OK(op, i, o) \ + EXPECT_EQ( \ + "", ::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \ + op, i, o) \ + .error_message()) +#define INFER_ERROR(error_substring, op, i) \ + { \ + string error_message = \ + ::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \ + op, i, "e") \ + .error_message(); \ + const string& substring = error_substring; \ + EXPECT_NE("", error_message); \ + EXPECT_TRUE(StringPiece(error_message).contains(substring)) \ + << "Expected to see '" << substring << "' in '" << error_message \ + << "'"; \ + } + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ diff --git a/shape_inference_testutil_test.cc b/shape_inference_testutil_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..20a6807064bea96f41cbd6035327d7a6db2f73b8 --- /dev/null +++ b/shape_inference_testutil_test.cc @@ -0,0 +1,171 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/shape_inference_testutil.h" + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace shape_inference { + +namespace { + +#define EXPECT_CONTAINS(str, substr) \ + do { \ + string s = (str); \ + EXPECT_TRUE(StringPiece(s).contains(substr)) << "String: " << s; \ + } while (false) + +static OpShapeInferenceFn* global_fn_ptr = nullptr; +REGISTER_OP("OpOneOut") + .Input("inputs: N * T") + .Output("o1: T") + .Attr("N: int >= 1") + .Attr("T: numbertype") + .SetShapeFn([](InferenceContext* c) { return (*global_fn_ptr)(c); }); +REGISTER_OP("OpTwoOut") + .Input("inputs: N * T") + .Output("o1: T") + .Output("o2: T") + .Attr("N: int >= 1") + .Attr("T: numbertype") + .SetShapeFn([](InferenceContext* c) { return (*global_fn_ptr)(c); }); + +string RunInferShapes(const string& op_name, const string& ins, + const string& expected_outs, OpShapeInferenceFn fn) { + ShapeInferenceTestOp op(op_name); + const int num_inputs = 1 + std::count(ins.begin(), ins.end(), ';'); + std::vector src_list; + src_list.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) src_list.emplace_back("a", 0, DT_FLOAT); + NodeDef node_def; + TF_CHECK_OK(NodeDefBuilder("dummy", op_name) + .Input(src_list) + .Attr("N", num_inputs) + .Finalize(&op.node_def)); + global_fn_ptr = &fn; + return ShapeInferenceTestutil::InferShapes(op, ins, expected_outs) + .error_message(); +} + +} // namespace + +TEST(ShapeInferenceTestutilTest, Failures) { + auto fn_copy_input_0 = [](InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }; + auto fn_copy_input_2 = [](InferenceContext* c) { + c->set_output(0, c->input(2)); + return Status::OK(); + }; + auto fn_output_unknown_shapes = [](InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->UnknownShape()); + } + return Status::OK(); + }; + auto fn_output_1_2 = [](InferenceContext* c) { + c->set_output(0, c->Matrix(1, 2)); + return Status::OK(); + }; + auto fn_output_u_2 = [](InferenceContext* c) { + c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2)); + return Status::OK(); + }; + const string& op = "OpOneOut"; + + EXPECT_EQ("Shape inference should have returned error", + RunInferShapes(op, "[1];[2];[1]", "e", fn_copy_input_0)); + EXPECT_CONTAINS(RunInferShapes(op, "[1];[2];[1]", "[1];[2]", fn_copy_input_0), + "wrong number of outputs"); + auto error_message = ShapeInferenceTestutil::InferShapes( + ShapeInferenceTestOp("NoSuchOp"), "", "") + .error_message(); + EXPECT_TRUE(StringPiece(error_message) + .starts_with("Op type not registered 'NoSuchOp'")); + + // Wrong shape error messages. + EXPECT_CONTAINS(RunInferShapes(op, "[1];[2];[1]", "?", fn_copy_input_0), + "expected to not match"); + EXPECT_CONTAINS(RunInferShapes(op, "[1];[2];[1]", "in2", fn_copy_input_0), + "should have matched one of (in2)"); + EXPECT_CONTAINS(RunInferShapes(op, "[1];[2];[1]", "in1|in2", fn_copy_input_0), + "should have matched one of (in1|in2)"); + EXPECT_CONTAINS(RunInferShapes(op, "[1];[2];[1]", "[1]", fn_copy_input_2), + "but was expected to not match"); + EXPECT_CONTAINS(RunInferShapes(op, "[1];[2];[1]", "in0|in1", fn_output_1_2), + "Output 0 should have matched an input shape"); + EXPECT_EQ("Output 0 expected to be unknown. Output shape was [1,2]", + RunInferShapes(op, "[1];[2];[1]", "?", fn_output_1_2)); + EXPECT_EQ("Output 0 expected rank 3 but was 2. Output shape was [1,2]", + RunInferShapes(op, "[1];[2];[1]", "[1,2,3]", fn_output_1_2)); + EXPECT_EQ( + "Output 0 expected rank 2 but was ?. Output shape was ?", + RunInferShapes(op, "[1];[2];[1]", "[1,2]", fn_output_unknown_shapes)); + + // Wrong shape error messages on the second output. + EXPECT_EQ("Output 1 expected rank 3 but was ?. Output shape was ?", + RunInferShapes("OpTwoOut", "[1];[2];[1]", "?;[1,2,3]", + fn_output_unknown_shapes)); + + // Wrong dimension error messages. + EXPECT_EQ("Output dim 0,1 expected to be 3 but was 2. Output shape was [1,2]", + RunInferShapes(op, "[1];[2];[1]", "[1,3]", fn_output_1_2)); + EXPECT_EQ("Output dim 0,0 expected to be 2 but was 1. Output shape was [1,2]", + RunInferShapes(op, "[1];[2];[1]", "[2,2]", fn_output_1_2)); + EXPECT_EQ( + "Output dim 0,0 expected to be unknown but was 1. Output shape was [1,2]", + RunInferShapes(op, "[1];[2];[1]", "[?,2]", fn_output_1_2)); + EXPECT_EQ("Output dim 0,1 expected to be 1 but was 2. Output shape was [?,2]", + RunInferShapes(op, "[1];[2];[1]", "[?,1]", fn_output_u_2)); + EXPECT_EQ("Output dim 0,0 expected to be 1 but was ?. Output shape was [?,2]", + RunInferShapes(op, "[0,1,?];[2];[1]", "[1,2]", fn_output_u_2)); + auto fn = [](InferenceContext* c) { + c->set_output(0, c->MakeShape({c->Dim(c->input(0), 1), c->MakeDim(2), + c->UnknownDim(), c->Dim(c->input(2), 0)})); + return Status::OK(); + }; + const string ins = "[0,1,?];[2];[1]"; + EXPECT_CONTAINS(RunInferShapes(op, ins, "[?,2,?,d2_0]", fn), + "Output dim 0,0 expected to be an unknown"); + EXPECT_CONTAINS(RunInferShapes(op, ins, "[0,2,?,d2_0]", fn), + "Output dim 0,0 expected to be 0 but matched input d0_1."); + EXPECT_CONTAINS( + RunInferShapes(op, ins, "[d0_0,2,?,d2_0]", fn), + "dim 0,0 matched input d0_1, but should have matched one of (d0_0)."); + EXPECT_CONTAINS(RunInferShapes(op, ins, "[x,2,?,d2_0]", fn), + "Output dim 0,0: the expected dimension value 'x' failed to " + "parse as int64."); + EXPECT_CONTAINS(RunInferShapes(op, ins, "[d0_0|d0_2,2,?,d2_0]", fn), + "dim 0,0 matched input d0_1, but should have matched one of " + "(d0_0|d0_2)."); + EXPECT_CONTAINS(RunInferShapes(op, ins, "[d0_1,?,?,d0_0|d2_0]", fn), + ("Output dim 0,1 expected to be unknown but was 2. " + "Output shape was [1,2,?,1]")); + EXPECT_EQ( + "Output dim 0,2 expected to be 8 but was ?. Output shape was [1,2,?,1]", + RunInferShapes(op, ins, "[d0_1,2,8,d0_0|d2_0]", fn)); + EXPECT_CONTAINS(RunInferShapes(op, ins, "[d0_1,2,d0_1|d2_0,d0_0|d2_0]", fn), + "expected to match"); + EXPECT_EQ("", // OK, no error. + RunInferShapes(op, ins, "[d0_1,2,?,d0_0|d2_0]", fn)); +} + +} // namespace shape_inference +} // namespace tensorflow diff --git a/step_stats.proto b/step_stats.proto new file mode 100644 index 0000000000000000000000000000000000000000..99dee2257e0a4ccab4098f5ee49feda9ed21d2cf --- /dev/null +++ b/step_stats.proto @@ -0,0 +1,78 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "StepStatsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/allocation_description.proto"; +import "tensorflow/core/framework/tensor_description.proto"; + +// An allocation/de-allocation operation performed by the allocator. +message AllocationRecord { + // The timestamp of the operation. + int64 alloc_micros = 1; + // Number of bytes allocated, or de-allocated if negative. + int64 alloc_bytes = 2; +} + +message AllocatorMemoryUsed { + string allocator_name = 1; + // These are per-node allocator memory stats. + int64 total_bytes = 2; + int64 peak_bytes = 3; + // The bytes that are not deallocated. + int64 live_bytes = 4; + // The allocation and deallocation timeline. + repeated AllocationRecord allocation_records = 6; + + // These are snapshots of the overall allocator memory stats. + // The number of live bytes currently allocated by the allocator. + int64 allocator_bytes_in_use = 5; +} + +// Output sizes recorded for a single execution of a graph node. +message NodeOutput { + int32 slot = 1; + TensorDescription tensor_description = 3; +}; + +// For memory tracking. +message MemoryStats { + int64 host_temp_memory_size = 1; + int64 device_temp_memory_size = 2; + int64 host_persistent_memory_size = 3; + int64 device_persistent_memory_size = 4; + repeated int64 host_persistent_tensor_alloc_ids = 5; + repeated int64 device_persistent_tensor_alloc_ids = 6; +} + +// Time/size stats recorded for a single execution of a graph node. +message NodeExecStats { + // TODO(tucker): Use some more compact form of node identity than + // the full string name. Either all processes should agree on a + // global id (cost_id?) for each node, or we should use a hash of + // the name. + string node_name = 1; + int64 all_start_micros = 2; + int64 op_start_rel_micros = 3; + int64 op_end_rel_micros = 4; + int64 all_end_rel_micros = 5; + repeated AllocatorMemoryUsed memory = 6; + repeated NodeOutput output = 7; + string timeline_label = 8; + int64 scheduled_micros = 9; + uint32 thread_id = 10; + repeated AllocationDescription referenced_tensor = 11; + MemoryStats memory_stats = 12; +}; + +message DeviceStepStats { + string device = 1; + repeated NodeExecStats node_stats = 2; +} + +message StepStats { + repeated DeviceStepStats dev_stats = 1; +}; diff --git a/summary.proto b/summary.proto new file mode 100644 index 0000000000000000000000000000000000000000..55879f87831eb968ee900e01697fbb99ba4cfe99 --- /dev/null +++ b/summary.proto @@ -0,0 +1,124 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "SummaryProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/tensor.proto"; + +// Metadata associated with a series of Summary data +message SummaryDescription { + // Hint on how plugins should process the data in this series. + // Supported values include "scalar", "histogram", "image", "audio" + string type_hint = 1; +} + +// Serialization format for histogram module in +// core/lib/histogram/histogram.h +message HistogramProto { + double min = 1; + double max = 2; + double num = 3; + double sum = 4; + double sum_squares = 5; + + // Parallel arrays encoding the bucket boundaries and the bucket values. + // bucket(i) is the count for the bucket i. The range for + // a bucket is: + // i == 0: -DBL_MAX .. bucket_limit(0) + // i != 0: bucket_limit(i-1) .. bucket_limit(i) + repeated double bucket_limit = 6 [packed = true]; + repeated double bucket = 7 [packed = true]; +}; + +// A SummaryMetadata encapsulates information on which plugins are able to make +// use of a certain summary value. +message SummaryMetadata { + message PluginData { + // The name of the plugin this data pertains to. + string plugin_name = 1; + + // The content to store for the plugin. The best practice is for this to be + // a binary serialized protocol buffer. + bytes content = 2; + } + + // Data that associates a summary with a certain plugin. + PluginData plugin_data = 1; + + // Display name for viewing in TensorBoard. + string display_name = 2; + + // Longform readable description of the summary sequence. Markdown supported. + string summary_description = 3; +}; + +// A Summary is a set of named values to be displayed by the +// visualizer. +// +// Summaries are produced regularly during training, as controlled by +// the "summary_interval_secs" attribute of the training operation. +// Summaries are also produced at the end of an evaluation. +message Summary { + message Image { + // Dimensions of the image. + int32 height = 1; + int32 width = 2; + // Valid colorspace values are + // 1 - grayscale + // 2 - grayscale + alpha + // 3 - RGB + // 4 - RGBA + // 5 - DIGITAL_YUV + // 6 - BGRA + int32 colorspace = 3; + // Image data in encoded format. All image formats supported by + // image_codec::CoderUtil can be stored here. + bytes encoded_image_string = 4; + } + + message Audio { + // Sample rate of the audio in Hz. + float sample_rate = 1; + // Number of channels of audio. + int64 num_channels = 2; + // Length of the audio in frames (samples per channel). + int64 length_frames = 3; + // Encoded audio data and its associated RFC 2045 content type (e.g. + // "audio/wav"). + bytes encoded_audio_string = 4; + string content_type = 5; + } + + message Value { + // This field is deprecated and will not be set. + string node_name = 7; + + // Tag name for the data. Used by TensorBoard plugins to organize data. Tags + // are often organized by scope (which contains slashes to convey + // hierarchy). For example: foo/bar/0 + string tag = 1; + + // Contains metadata on the summary value such as which plugins may use it. + // Take note that many summary values may lack a metadata field. This is + // because the FileWriter only keeps a metadata object on the first summary + // value with a certain tag for each tag. TensorBoard then remembers which + // tags are associated with which plugins. This saves space. + SummaryMetadata metadata = 9; + + // Value associated with the tag. + oneof value { + float simple_value = 2; + bytes obsolete_old_style_histogram = 3; + Image image = 4; + HistogramProto histo = 5; + Audio audio = 6; + TensorProto tensor = 8; + } + } + + // Set of values for the summary. + repeated Value value = 1; +} diff --git a/tensor.cc b/tensor.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f08cdc1d7c130bd351de7b5f7574ea199977804 --- /dev/null +++ b/tensor.cc @@ -0,0 +1,1073 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation notes: +// +// Tensor.cc uses a few templated classes and structs to facilitate +// implementation of the Tensor class. +// +// * Buffer: provides the implementation for a typed array T[n]. +// The array is allocated by the given allocator. It runs T's +// default constructors and destructors when T is not a simple type +// (e.g., string.), and skips them otherwise. +// +// * Helper: provides various routines given type T. The routines +// includes running the constructor and destructor of T[], encoding +// an decoding T[] into/from a Cord, etc. + +#include "tensorflow/core/framework/tensor.h" + +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/log_memory.h" +#include "tensorflow/core/framework/resource_handle.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_description.pb.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/tensor_coding.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/platform/variant_coding.h" + +namespace tensorflow { + +// Allow Tensors to be stored inside Variants with automatic +// encoding/decoding when those Variants are themselves being decoded +// in a Tensor's FromProto. +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor"); + +namespace { + +// An un-templated base class for Buffer. +class BufferBase : public TensorBuffer { + public: + explicit BufferBase(Allocator* alloc) : alloc_(alloc) {} + + TensorBuffer* root_buffer() override { return this; } + void FillAllocationDescription(AllocationDescription* proto) const override { + void* data_ptr = data(); + int64 rb = size(); + proto->set_requested_bytes(rb); + proto->set_allocator_name(alloc_->Name()); + proto->set_ptr(reinterpret_cast(data_ptr)); + if (alloc_->TracksAllocationSizes()) { + int64 ab = alloc_->AllocatedSize(data_ptr); + proto->set_allocated_bytes(ab); + int64 id = alloc_->AllocationId(data_ptr); + if (id > 0) { + proto->set_allocation_id(id); + } + if (RefCountIsOne()) { + proto->set_has_single_reference(true); + } + } + } + + protected: + void RecordDeallocation() { + LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()), + alloc_->Name()); + } + + Allocator* const alloc_; +}; + +// Typed ref-counted buffer: T[n]. +template +class Buffer : public BufferBase { + public: + Buffer(Allocator* a, int64 n); + Buffer(Allocator* a, int64 n, const AllocationAttributes& allocation_attr); + + void* data() const override { return data_; } + size_t size() const override { return sizeof(T) * elem_; } + + private: + T* data_; + int64 elem_; + + ~Buffer() override; + + TF_DISALLOW_COPY_AND_ASSIGN(Buffer); +}; + +void LogUnexpectedSize(int64 actual, int64 expected) { + LOG(ERROR) << "Input size was " << actual << " and expected " << expected; +} + +// A set of helper functions depending on T. +template +struct Helper { + // By default, we assume T is a simple type (float, int32, etc.) + static_assert(is_simple_type::value, "T is not a simple type."); + typedef protobuf::RepeatedField RepeatedFieldType; + + // Encoder of simple type T to a string. We do a copy. + template + static void Encode(TensorBuffer* in, int64 n, Destination* out) { + DCHECK_EQ(in->size(), sizeof(T) * n); + port::AssignRefCounted(StringPiece(in->base(), in->size()), in, + out); + } + + // Decoder of simple type T. Copy the bytes from "in" into the + // tensor buffer. + template + static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) { + if (in.size() != sizeof(T) * n) { + LogUnexpectedSize(in.size(), sizeof(T) * n); + return nullptr; + } + Buffer* buf = new Buffer(a, n); + char* data = buf->template base(); + if (data == nullptr) { + buf->Unref(); + return nullptr; + } + port::CopyToArray(in, data); + return buf; + } + + // Memory usage. + static int64 TotalBytes(TensorBuffer* in, int64 n) { + DCHECK_EQ(in->size(), sizeof(T) * n); + return in->size(); + } +}; + +// Helper specialization for string (the only non-simple type we +// support). +template <> +struct Helper { + // Proto message uses RepeatedFieldType to hold repeated T. + typedef protobuf::RepeatedPtrField RepeatedFieldType; + + // Encodes "n" elements of type string stored in "in" into Cord + // "out", which is usually the TensorProto::tensor_content. + template + static void Encode(TensorBuffer* in, int64 n, Destination* out) { + port::EncodeStringList(in->base(), n, out); + } + + // Decodes "n" elements of type string from "in" and constructs a + // buffer out of it. Returns nullptr if the decoding fails. "in" is + // usually the TensorProto::tensor_content. + template + static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) { + Buffer* buf = new Buffer(a, n); + string* strings = buf->template base(); + if (strings == nullptr || !port::DecodeStringList(in, strings, n)) { + buf->Unref(); + return nullptr; + } + return buf; + } + + // Returns the estimated memory usage of "n" elements of type T + // stored in buffer "in". + static int64 TotalBytes(TensorBuffer* in, int n) { + int64 tot = in->size(); + DCHECK_EQ(tot, sizeof(string) * n); + const string* p = in->base(); + for (int i = 0; i < n; ++i, ++p) tot += p->size(); + return tot; + } +}; + +template <> +struct Helper { + // Proto message uses RepeatedFieldType to hold repeated T. + typedef protobuf::RepeatedPtrField RepeatedFieldType; + + // Encodes "n" elements of type ResourceHandle stored in "in" into destination + // "out", which is usually the TensorProto::tensor_content. + template + static void Encode(TensorBuffer* in, int64 n, Destination* out) { + port::EncodeResourceHandleList(in->base(), n, out); + } + + // Decodes "n" elements of type string from "in" and constructs a + // buffer out of it. Returns nullptr if the decoding fails. "in" is + // usually the TensorProto::tensor_content. + template + static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) { + auto* buf = new Buffer(a, n); + ResourceHandle* ps = buf->template base(); + if (ps == nullptr || !port::DecodeResourceHandleList(in, ps, n)) { + buf->Unref(); + return nullptr; + } + return buf; + } + + // Returns the estimated memory usage of "n" elements of type T + // stored in buffer "in". + static int64 TotalBytes(TensorBuffer* in, int n) { + return n * sizeof(ResourceHandle); + } +}; + +template <> +struct Helper { + // Encodes "n" elements of type Variant stored in "in" into destination + // "out", which is usually the TensorProto::tensor_content. + template + static void Encode(TensorBuffer* in, int64 n, Destination* out) { + port::EncodeVariantList(in->base(), n, out); + } + + // Decodes "n" elements of type Variant from "in" and constructs a + // buffer out of it. Returns nullptr if the decoding fails. "in" is + // usually the TensorProto::tensor_content. + template + static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) { + auto* buf = new Buffer(a, n); + Variant* ps = buf->template base(); + if (ps == nullptr || !port::DecodeVariantList(in, ps, n)) { + buf->Unref(); + return nullptr; + } + return buf; + } + + // Returns the estimated memory usage of "n" elements of type T + // stored in buffer "in". + static int64 TotalBytes(TensorBuffer* in, int n) { + return n * sizeof(Variant); + } +}; + +template +struct ProtoHelper {}; + +// For a C++ type "T" (float, double, int32, etc.), the repeated field +// "N"_val (float_val, int_val, label_val, etc.) of type "F" (float, +// int32, string, etc) in the TensorProto is used for serializing the +// tensor of type "T". +#define PROTO_TRAITS(T, F, N) \ + template <> \ + struct ProtoHelper { \ + typedef Helper::RepeatedFieldType FieldType; \ + static FieldType::const_iterator Begin(const TensorProto& proto) { \ + return proto.N##_val().begin(); \ + } \ + static size_t NumElements(const TensorProto& proto) { \ + return proto.N##_val().size(); \ + } \ + static void Fill(const T* data, size_t n, TensorProto* proto) { \ + typename ProtoHelper::FieldType copy(data, data + n); \ + proto->mutable_##N##_val()->Swap(©); \ + } \ + }; +PROTO_TRAITS(float, float, float); +PROTO_TRAITS(double, double, double); +PROTO_TRAITS(int32, int32, int); +PROTO_TRAITS(uint8, int32, int); +PROTO_TRAITS(uint16, int32, int); +PROTO_TRAITS(uint32, uint32, uint32); +PROTO_TRAITS(int16, int32, int); +PROTO_TRAITS(int8, int32, int); +PROTO_TRAITS(bool, bool, bool); +PROTO_TRAITS(string, string, string); +PROTO_TRAITS(qint8, int32, int); +PROTO_TRAITS(quint8, int32, int); +PROTO_TRAITS(qint16, int32, int); +PROTO_TRAITS(quint16, int32, int); +#undef PROTO_TRAITS + +template <> +struct ProtoHelper { + static const int64* Begin(const TensorProto& proto) { + return reinterpret_cast(proto.int64_val().begin()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.int64_val().size(); + } + static void Fill(const int64* data, size_t n, TensorProto* proto) { + protobuf::RepeatedField copy(data, data + n); + proto->mutable_int64_val()->Swap(©); + } +}; + +template <> +struct ProtoHelper { + static const uint64* Begin(const TensorProto& proto) { + return reinterpret_cast(proto.uint64_val().begin()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.uint64_val().size(); + } + static void Fill(const uint64* data, size_t n, TensorProto* proto) { + protobuf::RepeatedField copy(data, data + n); + proto->mutable_uint64_val()->Swap(©); + } +}; + +template <> +struct ProtoHelper { + static protobuf::RepeatedPtrField::const_iterator Begin( + const TensorProto& proto) { + return proto.resource_handle_val().begin(); + } + static size_t NumElements(const TensorProto& proto) { + return proto.resource_handle_val().size(); + } + static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) { + auto* handles = proto->mutable_resource_handle_val(); + handles->Clear(); + for (size_t i = 0; i < n; i++) { + data[i].AsProto(handles->Add()); + } + } +}; + +template <> +struct ProtoHelper { + static protobuf::RepeatedPtrField::const_iterator + Begin(const TensorProto& proto) { + return proto.variant_val().begin(); + } + static size_t NumElements(const TensorProto& proto) { + return proto.variant_val().size(); + } + static void Fill(const Variant* data, size_t n, TensorProto* proto) { + auto* variant_values = proto->mutable_variant_val(); + variant_values->Clear(); + for (size_t i = 0; i < n; ++i) { + VariantTensorData tmp; + data[i].Encode(&tmp); + tmp.ToProto(variant_values->Add()); + } + } +}; + +template <> +struct ProtoHelper { + typedef Helper::RepeatedFieldType FieldType; + static const complex64* Begin(const TensorProto& proto) { + return reinterpret_cast(proto.scomplex_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.scomplex_val().size() / 2; + } + static void Fill(const complex64* data, size_t n, TensorProto* proto) { + const float* p = reinterpret_cast(data); + FieldType copy(p, p + n * 2); + proto->mutable_scomplex_val()->Swap(©); + } +}; + +template <> +struct ProtoHelper { + typedef Helper::RepeatedFieldType FieldType; + static const complex128* Begin(const TensorProto& proto) { + return reinterpret_cast(proto.dcomplex_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.dcomplex_val().size() / 2; + } + static void Fill(const complex128* data, size_t n, TensorProto* proto) { + const double* p = reinterpret_cast(data); + FieldType copy(p, p + n * 2); + proto->mutable_dcomplex_val()->Swap(©); + } +}; + +template <> +struct ProtoHelper { + typedef Helper::RepeatedFieldType FieldType; + static const qint32* Begin(const TensorProto& proto) { + return reinterpret_cast(proto.int_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.int_val().size(); + } + static void Fill(const qint32* data, size_t n, TensorProto* proto) { + const int32* p = reinterpret_cast(data); + FieldType copy(p, p + n); + proto->mutable_int_val()->Swap(©); + } +}; + +template <> +struct ProtoHelper { + static void Fill(const bfloat16* data, size_t n, TensorProto* proto) { + proto->mutable_half_val()->Reserve(n); + for (size_t i = 0; i < n; ++i) { + proto->mutable_half_val()->AddAlreadyReserved(data[i].value); + } + } +}; + +template <> +struct ProtoHelper { + static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) { + proto->mutable_half_val()->Reserve(n); + for (size_t i = 0; i < n; ++i) { + proto->mutable_half_val()->AddAlreadyReserved(data[i].x); + } + } +}; + +template +Buffer::Buffer(Allocator* a, int64 n) + : BufferBase(a), data_(a->Allocate(n)), elem_(n) {} + +template +Buffer::Buffer(Allocator* a, int64 n, + const AllocationAttributes& allocation_attr) + : BufferBase(a), data_(a->Allocate(n, allocation_attr)), elem_(n) {} + +template +Buffer::~Buffer() { + if (data_) { + if (LogMemory::IsEnabled()) { + RecordDeallocation(); + } + alloc_->Deallocate(data_, elem_); + } +} + +// Allocates a T[n] buffer. Fills in the buffer with repeated values +// in "in". If "in" has less values than "n", fills the rest of T[n] +// with the last value. If "in" has no values, fills T[n] with the +// default value for T. +// +// This routine is using the typed fields (float_val, etc.) in the +// tensor proto as opposed to the untyped binary representation +// (tensor_content). This is used when we expect the TensorProto is +// used by a client program which may not know how to encode a tensor +// in the compact binary representation. +template +TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) { + CHECK_GT(n, 0); + Buffer* buf = new Buffer(a, n); + T* data = buf->template base(); + if (data == nullptr) { + buf->Unref(); + return nullptr; + } + + const int64 in_n = ProtoHelper::NumElements(in); + if (in_n <= 0) { + std::fill_n(data, n, T()); + } else { + auto begin = ProtoHelper::Begin(in); + if (n <= in_n) { + std::copy_n(begin, n, data); + } else { + std::copy_n(begin, in_n, data); + const T& last = *(data + in_n - 1); + std::fill_n(data + in_n, n - in_n, last); + } + } + + return buf; +} + +template <> +TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, + int64 n) { + CHECK_GT(n, 0); + Buffer* buf = new Buffer(a, n); + Variant* data = buf->template base(); + if (data == nullptr) { + buf->Unref(); + return nullptr; + } + const int64 in_n = ProtoHelper::NumElements(in); + if (in_n <= 0) { + std::fill_n(data, n, Variant()); + } else { + for (int64 i = 0; i < in_n; ++i) { + data[i] = in.variant_val(i); + if (!DecodeUnaryVariant(&data[i])) { + LOG(ERROR) << "Could not decode variant with type_name: \"" + << data[i].TypeName() + << "\". Perhaps you forgot to register a " + "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?"; + buf->Unref(); + return nullptr; + } + } + for (int64 i = in_n; i < n; ++i) { + data[i] = Variant(); + } + } + return buf; +} + +// fp16 and bfloat16 are opaque to the protobuf, so we deserialize these +// identical to uint16 but with data stored in half_val instead of int_val (ie., +// we don't use ProtoHelper). +template <> +TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, + int64 n) { + CHECK_GT(n, 0); + Buffer* buf = new Buffer(a, n); + uint16* data = buf->template base(); + if (data == nullptr) { + buf->Unref(); + return nullptr; + } + const int64 in_n = in.half_val().size(); + auto begin = in.half_val().begin(); + if (n <= in_n) { + std::copy_n(begin, n, data); + } else if (in_n > 0) { + std::copy_n(begin, in_n, data); + const uint16 last = *(data + in_n - 1); + std::fill_n(data + in_n, n - in_n, last); + } else { + std::fill_n(data, n, 0); + } + return buf; +} + +template <> +TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, + int64 n) { + CHECK_GT(n, 0); + Buffer* buf = new Buffer(a, n); + uint16* data = buf->template base(); + if (data == nullptr) { + buf->Unref(); + return nullptr; + } + const int64 in_n = in.half_val().size(); + auto begin = in.half_val().begin(); + if (n <= in_n) { + std::copy_n(begin, n, data); + } else if (in_n > 0) { + std::copy_n(begin, in_n, data); + const uint16 last = *(data + in_n - 1); + std::fill_n(data + in_n, n - in_n, last); + } else { + std::fill_n(data, n, 0); + } + return buf; +} + +// Copies T[n] stored in the buffer "in" into the repeated field in +// "out" corresponding to type T. +template +void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) { + const T* data = in.base(); + // NOTE: T may not the same as + // ProtoHelper::FieldType::value_type. E.g., T==int16, + // ProtoHelper::FieldType::value_type==int32. If performance is + // critical, we can specialize T=float and do memcpy directly. + ProtoHelper::Fill(data, n, out); +} + +void RefIfNonNull(core::RefCounted* buf) { + if (buf) buf->Ref(); +} + +void UnrefIfNonNull(core::RefCounted* buf) { + if (buf) buf->Unref(); +} + +} // end namespace + +Tensor::Tensor() : Tensor(DT_FLOAT) {} + +Tensor::Tensor(DataType type) : shape_({0}), buf_(nullptr) { set_dtype(type); } + +Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf) + : shape_(shape), buf_(buf) { + set_dtype(type); + RefIfNonNull(buf); +} + +bool Tensor::IsInitialized() const { + return (buf_ != nullptr && buf_->data() != nullptr) || + shape_.num_elements() == 0; +} + +void Tensor::CheckType(DataType expected_dtype) const { + CHECK_EQ(dtype(), expected_dtype); +} + +void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const { + CHECK_EQ(dtype(), expected_dtype); + CHECK(IsAligned()); +} + +void Tensor::CheckIsAlignedAndSingleElement() const { + CHECK(IsAligned()); + CHECK_EQ(1, NumElements()) << "Must have a one element tensor"; +} + +Tensor::~Tensor() { UnrefIfNonNull(buf_); } + +void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) { + CHECK_EQ(shape.num_elements(), other.NumElements()); + // Data type will be overwritten if this == &other, since dtype is part of + // shape. + DataType other_dtype = other.dtype(); + shape_ = shape; + set_dtype(other_dtype); + if (buf_ != other.buf_) { + UnrefIfNonNull(buf_); + buf_ = other.buf_; + RefIfNonNull(buf_); + } +} + +void Tensor::UnsafeCopyFromInternal(const Tensor& other, DataType dtype, + const TensorShape& shape) { + int in_size = DataTypeSize(other.dtype()); + int out_size = DataTypeSize(dtype); + CHECK_NE(in_size, 0); + CHECK_NE(out_size, 0); + CHECK_EQ(shape.num_elements() * out_size, + other.shape().num_elements() * in_size); + shape_ = shape; + shape_.set_data_type(dtype); + if (buf_ != other.buf_) { + UnrefIfNonNull(buf_); + buf_ = other.buf_; + RefIfNonNull(buf_); + } +} + +// Notice that buf_ either points to a regular TensorBuffer or a SubBuffer. +// For the latter case, we have to make sure that the refcount is +// one both for the SubBuffer _and_ the underlying TensorBuffer. +bool Tensor::RefCountIsOne() const { + return buf_ != nullptr && buf_->RefCountIsOne() && + buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory(); +} + +// The macro CASES() expands to a switch statement conditioned on +// TYPE_ENUM. Each case expands the STMTS after a typedef for T. +#define SINGLE_ARG(...) __VA_ARGS__ +#define CASE(TYPE, STMTS) \ + case DataTypeToEnum::value: { \ + typedef TYPE T; \ + STMTS; \ + break; \ + } +#define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \ + switch (TYPE_ENUM) { \ + CASE(float, SINGLE_ARG(STMTS)) \ + CASE(double, SINGLE_ARG(STMTS)) \ + CASE(int32, SINGLE_ARG(STMTS)) \ + CASE(uint8, SINGLE_ARG(STMTS)) \ + CASE(uint16, SINGLE_ARG(STMTS)) \ + CASE(uint32, SINGLE_ARG(STMTS)) \ + CASE(uint64, SINGLE_ARG(STMTS)) \ + CASE(int16, SINGLE_ARG(STMTS)) \ + CASE(int8, SINGLE_ARG(STMTS)) \ + CASE(string, SINGLE_ARG(STMTS)) \ + CASE(complex64, SINGLE_ARG(STMTS)) \ + CASE(complex128, SINGLE_ARG(STMTS)) \ + CASE(int64, SINGLE_ARG(STMTS)) \ + CASE(bool, SINGLE_ARG(STMTS)) \ + CASE(qint32, SINGLE_ARG(STMTS)) \ + CASE(quint8, SINGLE_ARG(STMTS)) \ + CASE(qint8, SINGLE_ARG(STMTS)) \ + CASE(quint16, SINGLE_ARG(STMTS)) \ + CASE(qint16, SINGLE_ARG(STMTS)) \ + CASE(bfloat16, SINGLE_ARG(STMTS)) \ + CASE(Eigen::half, SINGLE_ARG(STMTS)) \ + CASE(ResourceHandle, SINGLE_ARG(STMTS)) \ + CASE(Variant, SINGLE_ARG(STMTS)) \ + case DT_INVALID: \ + INVALID; \ + break; \ + default: \ + DEFAULT; \ + break; \ + } + +#define CASES(TYPE_ENUM, STMTS) \ + CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \ + , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;) + +Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape) + : shape_(shape), buf_(nullptr) { + set_dtype(type); + CHECK_NOTNULL(a); + if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) { + CASES(type, buf_ = new Buffer(a, shape.num_elements())); + } + if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) { + LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID, + *this); + } +} + +Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape, + const AllocationAttributes& allocation_attr) + : shape_(shape), buf_(nullptr) { + set_dtype(type); + CHECK_NOTNULL(a); + if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) { + CASES(type, buf_ = new Buffer(a, shape.num_elements(), allocation_attr)); + } + if (!allocation_attr.allocation_will_be_logged && buf_ != nullptr && + buf_->data() != nullptr && LogMemory::IsEnabled()) { + LogMemory::RecordTensorAllocation("Unknown (with attributes)", + LogMemory::UNKNOWN_STEP_ID, *this); + } +} + +Tensor::Tensor(DataType type, const TensorShape& shape) + : Tensor(cpu_allocator(), type, shape) {} + +template +class SubBuffer : public TensorBuffer { + public: + // This buffer is an alias to buf[delta, delta + n). + SubBuffer(TensorBuffer* buf, int64 delta, int64 n) + : root_(buf->root_buffer()), data_(buf->base() + delta), elem_(n) { + // Sanity check. The caller should ensure the sub buffer is valid. + CHECK_LE(root_->base(), this->base()); + T* root_limit = root_->base() + root_->size() / sizeof(T); + CHECK_LE(this->base(), root_limit); + CHECK_LE(this->base() + n, root_limit); + // Hold a ref of the underlying root buffer. + // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer. + root_->Ref(); + } + + void* data() const override { return data_; } + size_t size() const override { return sizeof(T) * elem_; } + TensorBuffer* root_buffer() override { return root_; } + void FillAllocationDescription(AllocationDescription* proto) const override { + root_->FillAllocationDescription(proto); + } + + private: + TensorBuffer* root_; + T* data_; + int64 elem_; + + ~SubBuffer() override { root_->Unref(); } + + TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer); +}; + +Tensor Tensor::Slice(int64 start, int64 limit) const { + CHECK_GE(dims(), 1); + CHECK_LE(0, start); + CHECK_LE(start, limit); + int64 dim0_size = shape_.dim_size(0); + CHECK_LE(limit, dim0_size); + if ((start == 0) && (limit == dim0_size)) { + return *this; + } + Tensor ret; + ret.shape_ = shape_; + ret.set_dtype(dtype()); + ret.buf_ = nullptr; + if (dim0_size > 0) { + const int64 elems_per_dim0 = NumElements() / dim0_size; + const int64 delta = start * elems_per_dim0; + dim0_size = limit - start; + ret.shape_.set_dim(0, dim0_size); + const int64 num_elems = dim0_size * elems_per_dim0; + if (buf_) { + DataType dt = dtype(); + CASES(dt, ret.buf_ = new SubBuffer(buf_, delta, num_elems)); + } + } + return ret; +} + +bool Tensor::FromProto(const TensorProto& proto) { + return FromProto(cpu_allocator(), proto); +} + +bool Tensor::FromProto(Allocator* a, const TensorProto& proto) { + CHECK_NOTNULL(a); + TensorBuffer* p = nullptr; + if (!TensorShape::IsValid(proto.tensor_shape())) return false; + if (proto.dtype() == DT_INVALID) return false; + TensorShape shape(proto.tensor_shape()); + const int64 N = shape.num_elements(); + if (N > 0 && proto.dtype()) { + bool dtype_error = false; + if (!proto.tensor_content().empty()) { + const auto& content = proto.tensor_content(); + CASES_WITH_DEFAULT(proto.dtype(), p = Helper::Decode(a, content, N), + dtype_error = true, dtype_error = true); + } else { + CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField(a, proto, N), + dtype_error = true, dtype_error = true); + } + if (dtype_error || p == nullptr) return false; + } + shape_ = shape; + set_dtype(proto.dtype()); + UnrefIfNonNull(buf_); + buf_ = p; + // TODO(misard) add tracking of which kernels and steps are calling + // FromProto. + if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) { + LogMemory::RecordTensorAllocation("Unknown (from Proto)", + LogMemory::UNKNOWN_STEP_ID, *this); + } + return true; +} + +void Tensor::AsProtoField(TensorProto* proto) const { + proto->Clear(); + shape_.AsProto(proto->mutable_tensor_shape()); + proto->set_dtype(dtype()); + if (buf_) { + CASES(dtype(), ToProtoField(*buf_, shape_.num_elements(), proto)); + } +} + +void Tensor::AsProtoTensorContent(TensorProto* proto) const { + proto->Clear(); + proto->set_dtype(dtype()); + shape_.AsProto(proto->mutable_tensor_shape()); + if (buf_) { + CASES(dtype(), Helper::Encode(buf_, shape_.num_elements(), + proto->mutable_tensor_content())); + } +} + +size_t Tensor::TotalBytes() const { + if (shape_.num_elements() == 0) return 0; + CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements(); + CASES(dtype(), return Helper::TotalBytes(buf_, shape_.num_elements())); + return 0; // Makes compiler happy. +} + +size_t Tensor::AllocatedBytes() const { + TensorDescription tensor_description; + FillDescription(&tensor_description); + if (tensor_description.has_allocation_description() && + tensor_description.allocation_description().allocated_bytes() > 0) { + return tensor_description.allocation_description().allocated_bytes(); + } else { + // Fall back to TotalBytes() if the allocator doesn't have its size. + return TotalBytes(); + } +} + +bool Tensor::CanUseDMA() const { + CASES(dtype(), return is_simple_type::value); + return false; // Makes compiler happy. +} + +#undef CASES +#undef CASE + +namespace { +// Print from left dim to right dim recursively. +template +void PrintOneDim(int dim_index, gtl::InlinedVector shape, int64 limit, + int shape_size, T* data, int64* data_index, string* result) { + if (*data_index >= limit) return; + int64 element_count = shape[dim_index]; + // We have reached the right-most dimension of the tensor. + if (dim_index == shape_size - 1) { + for (int64 i = 0; i < element_count; i++) { + if (*data_index >= limit) return; + if (i > 0) strings::StrAppend(result, " "); + strings::StrAppend(result, data[(*data_index)++]); + } + return; + } + // Loop every element of one dim. + for (int64 i = 0; i < element_count; i++) { + bool flag = false; + if (*data_index < limit) { + strings::StrAppend(result, "["); + flag = true; + } + // As for each element, print the sub-dim. + PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index, + result); + if (*data_index < limit || flag) { + strings::StrAppend(result, "]"); + flag = false; + } + } +} + +template +string SummarizeArray(int64 limit, int64 num_elts, + const TensorShape& tensor_shape, const char* data) { + string ret; + const T* array = reinterpret_cast(data); + + const gtl::InlinedVector shape = tensor_shape.dim_sizes(); + if (shape.empty()) { + for (int64 i = 0; i < limit; ++i) { + if (i > 0) strings::StrAppend(&ret, " "); + strings::StrAppend(&ret, array[i]); + } + if (num_elts > limit) strings::StrAppend(&ret, "..."); + return ret; + } + int64 data_index = 0; + const int shape_size = tensor_shape.dims(); + PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret); + + if (num_elts > limit) strings::StrAppend(&ret, "..."); + return ret; +} +} // namespace + +string Tensor::SummarizeValue(int64 max_entries) const { + const int64 num_elts = NumElements(); + size_t limit = std::min(max_entries, num_elts); + if ((limit > 0) && (buf_ == nullptr)) { + return strings::StrCat("uninitialized Tensor of ", num_elts, + " elements of type ", dtype()); + } + const char* data = limit > 0 ? tensor_data().data() : nullptr; + switch (dtype()) { + case DT_HALF: + return SummarizeArray(limit, num_elts, shape_, data); + break; + case DT_FLOAT: + return SummarizeArray(limit, num_elts, shape_, data); + break; + case DT_DOUBLE: + return SummarizeArray(limit, num_elts, shape_, data); + break; + case DT_UINT32: + return SummarizeArray(limit, num_elts, shape_, data); + break; + case DT_INT32: + return SummarizeArray(limit, num_elts, shape_, data); + break; + case DT_UINT8: + case DT_QUINT8: + return SummarizeArray(limit, num_elts, shape_, data); + break; + case DT_UINT16: + case DT_QUINT16: + return SummarizeArray(limit, num_elts, shape_, data); + break; + case DT_INT16: + case DT_QINT16: + return SummarizeArray(limit, num_elts, shape_, data); + break; + case DT_INT8: + case DT_QINT8: + return SummarizeArray(limit, num_elts, shape_, data); + break; + case DT_UINT64: + return SummarizeArray(limit, num_elts, shape_, data); + break; + case DT_INT64: + return SummarizeArray(limit, num_elts, shape_, data); + break; + case DT_BOOL: + // TODO(tucker): Is it better to emit "True False..."? This + // will emit "1 0..." which is more compact. + return SummarizeArray(limit, num_elts, shape_, data); + break; + default: { + // All irregular cases + string ret; + // TODO(irving): Don't call flat every time around this + // loop. + for (size_t i = 0; i < limit; ++i) { + if (i > 0) strings::StrAppend(&ret, " "); + switch (dtype()) { + case DT_STRING: + strings::StrAppend(&ret, str_util::CEscape(flat()(i))); + break; + case DT_VARIANT: { + const Variant& v = flat()(i); + strings::StrAppend(&ret, v.DebugString()); + } break; + default: + // TODO(zhifengc, josh11b): Pretty-print other types (bool, + // complex64, quantized). + strings::StrAppend(&ret, "?"); + } + } + if (max_entries < num_elts) strings::StrAppend(&ret, "..."); + return ret; + } + } +} + +StringPiece Tensor::tensor_data() const { + if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors + return StringPiece(static_cast(buf_->data()), TotalBytes()); +} + +bool Tensor::SharesBufferWith(const Tensor& b) const { + CHECK_NE(nullptr, buf_); + CHECK_NE(nullptr, b.buf_); + return buf_->root_buffer() == b.buf_->root_buffer(); +} + +string Tensor::DebugString() const { + return strings::StrCat("Tensor"); +} + +void Tensor::FillDescription(TensorDescription* description) const { + description->set_dtype(dtype()); + shape().AsProto(description->mutable_shape()); + if (buf_ != nullptr && buf_->data() != nullptr) { + buf_->FillAllocationDescription( + description->mutable_allocation_description()); + } +} + +gtl::InlinedVector Tensor::ComputeFlatInnerDims( + gtl::ArraySlice orig, int64 num_out_dims) { + gtl::InlinedVector out_dims(num_out_dims, 0); + int64 offset = orig.size() - num_out_dims; + for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) { + const int64 in_dim = out_dim + offset; + out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim]; + } + for (int64 in_dim = 0; in_dim < offset; ++in_dim) { + out_dims[0] *= orig[in_dim]; + } + return out_dims; +} + +gtl::InlinedVector Tensor::ComputeFlatOuterDims( + gtl::ArraySlice orig, int64 num_out_dims) { + gtl::InlinedVector out_dims(num_out_dims, 0); + for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) { + out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim]; + } + for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) { + out_dims[num_out_dims - 1] *= orig[in_dim]; + } + return out_dims; +} + +} // namespace tensorflow diff --git a/tensor.h b/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..92d10f0d8cf452264885917bc0c897e03527a782 --- /dev/null +++ b/tensor.h @@ -0,0 +1,786 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Forward declarations. In particular, we forward declare protos so that their +// symbols can be removed from .so exports. +class AllocationDescription; +class Allocator; +class OpKernelContext; +class TensorBuffer; +class TensorCApi; +class TensorDescription; +class TensorProto; +class VariantTensorData; +namespace batch_util { +Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index); +} // namespace batch_util + +/// @ingroup core +/// Represents an n-dimensional array of values. +class Tensor { + public: + /// \brief Creates a 1-dimensional, 0-element float tensor. + /// + /// The returned Tensor is not a scalar (shape {}), but is instead + /// an empty one-dimensional Tensor (shape {0}, NumElements() == + /// 0). Since it has no elements, it does not need to be assigned a + /// value and is initialized by default (IsInitialized() is + /// true). If this is undesirable, consider creating a one-element + /// scalar which does require initialization: + /// + /// ```c++ + /// + /// Tensor(DT_FLOAT, TensorShape({})) + /// + /// ``` + Tensor(); + + /// \brief Creates a Tensor of the given `type` and `shape`. If + /// LogMemory::IsEnabled() the allocation is logged as coming from + /// an unknown kernel and step. Calling the Tensor constructor + /// directly from within an Op is deprecated: use the + /// OpKernelConstruction/OpKernelContext allocate_* methods to + /// allocate a new tensor, which record the kernel and step. + /// + /// The underlying buffer is allocated using a `CPUAllocator`. + Tensor(DataType type, const TensorShape& shape); + + /// \brief Creates a tensor with the input `type` and `shape`, using + /// the allocator `a` to allocate the underlying buffer. If + /// LogMemory::IsEnabled() the allocation is logged as coming from + /// an unknown kernel and step. Calling the Tensor constructor + /// directly from within an Op is deprecated: use the + /// OpKernelConstruction/OpKernelContext allocate_* methods to + /// allocate a new tensor, which record the kernel and step. + /// + /// `a` must outlive the lifetime of this Tensor. + Tensor(Allocator* a, DataType type, const TensorShape& shape); + + /// \brief Creates a tensor with the input `type` and `shape`, using + /// the allocator `a` and the specified "allocation_attr" to + /// allocate the underlying buffer. If the kernel and step are known + /// allocation_attr.allocation_will_be_logged should be set to true + /// and LogMemory::RecordTensorAllocation should be called after the + /// tensor is constructed. Calling the Tensor constructor directly + /// from within an Op is deprecated: use the + /// OpKernelConstruction/OpKernelContext allocate_* methods to + /// allocate a new tensor, which record the kernel and step. + /// + /// `a` must outlive the lifetime of this Tensor. + Tensor(Allocator* a, DataType type, const TensorShape& shape, + const AllocationAttributes& allocation_attr); + + /// \brief Creates an empty Tensor of the given data type. + /// + /// Like Tensor(), returns a 1-dimensional, 0-element Tensor with + /// IsInitialized() returning True. See the Tensor() documentation + /// for details. + explicit Tensor(DataType type); + + /// Copy constructor. + Tensor(const Tensor& other); + + /// \brief Move constructor. After this call, is safely destructible + /// and can be assigned to, but other calls on it (e.g. shape manipulation) + /// are not valid. + Tensor(Tensor&& other); + + ~Tensor(); + + /// Returns the data type. + DataType dtype() const { return shape_.data_type(); } + + /// Returns the shape of the tensor. + const TensorShape& shape() const { return shape_; } + + /// \brief Convenience accessor for the tensor shape. + /// + /// For all shape accessors, see comments for relevant methods of + /// `TensorShape` in `tensor_shape.h`. + int dims() const { return shape().dims(); } + + /// Convenience accessor for the tensor shape. + int64 dim_size(int d) const { return shape().dim_size(d); } + + /// Convenience accessor for the tensor shape. + int64 NumElements() const { return shape().num_elements(); } + + bool IsSameSize(const Tensor& b) const { + return shape().IsSameSize(b.shape()); + } + + // True iff the two tensors use the same underlying refcounted storage + bool SharesBufferWith(const Tensor& b) const; + + /// \brief If necessary, has this Tensor been initialized? + /// + /// Zero-element Tensors are always considered initialized, even if they + /// have never been assigned to and do not have any memory allocated. + bool IsInitialized() const; + + /// Returns the estimated memory usage of this tensor. + size_t TotalBytes() const; + + // Returns the size of sallocated memory for this tensor. + size_t AllocatedBytes() const; + + /// Returns true iff this tensor is aligned. + bool IsAligned() const { +#if EIGEN_MAX_ALIGN_BYTES == 0 + return true; +#else + void* ptr = base(); + return reinterpret_cast(ptr) % EIGEN_MAX_ALIGN_BYTES == 0; +#endif + } + + /// Assign operator. This tensor shares other's underlying storage. + Tensor& operator=(const Tensor& other) { + CopyFromInternal(other, other.shape()); + return *this; + } + + /// Move operator. See move constructor for details. + Tensor& operator=(Tensor&& other); + + /// \brief Copy the other tensor into this tensor and reshape it. + /// + /// This tensor shares other's underlying storage. Returns `true` + /// iff `other.shape()` has the same number of elements of the given + /// `shape`. + bool CopyFrom(const Tensor& other, + const TensorShape& shape) TF_MUST_USE_RESULT { + if (other.NumElements() != shape.num_elements()) return false; + CopyFromInternal(other, shape); + return true; + } + + /// \brief Slice this tensor along the 1st dimension. + + /// I.e., the returned tensor satisfies + /// returned[i, ...] == this[dim0_start + i, ...]. + /// The returned tensor shares the underlying tensor buffer with this + /// tensor. + /// + /// NOTE: The returned tensor may not satisfy the same alignment + /// requirement as this tensor depending on the shape. The caller + /// must check the returned tensor's alignment before calling certain + /// methods that have alignment requirement (e.g., `flat()`, `tensor()`). + /// + /// REQUIRES: `dims()` >= 1 + /// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)` + Tensor Slice(int64 dim0_start, int64 dim0_limit) const; + + /// \brief Parse `other` and construct the tensor. + + /// Returns `true` iff the parsing succeeds. If the parsing fails, + /// the state of `*this` is unchanged. + bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT; + bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT; + + /// \brief Fills in `proto` with `*this` tensor's content. + /// + /// `AsProtoField()` fills in the repeated field for `proto.dtype()`, while + /// `AsProtoTensorContent()` encodes the content in `proto.tensor_content()` + /// in a compact form. + void AsProtoField(TensorProto* proto) const; + void AsProtoTensorContent(TensorProto* proto) const; + + /// \brief Return the tensor data as an `Eigen::Tensor` with the type and + /// sizes of this `Tensor`. + /// + /// Use these methods when you know the data type and the number of + /// dimensions of the Tensor and you want an `Eigen::Tensor` + /// automatically sized to the `Tensor` sizes. The implementation check + /// fails if either type or sizes mismatch. + /// + /// Example: + /// + /// ```c++ + /// + /// typedef float T; + /// Tensor my_mat(...built with Shape{rows: 3, cols: 5}...); + /// auto mat = my_mat.matrix(); // 2D Eigen::Tensor, 3 x 5. + /// auto mat = my_mat.tensor(); // 2D Eigen::Tensor, 3 x 5. + /// auto vec = my_mat.vec(); // CHECK fails as my_mat is 2D. + /// auto vec = my_mat.tensor(); // CHECK fails as my_mat is 2D. + /// auto mat = my_mat.matrix();// CHECK fails as type mismatch. + /// + /// ``` + template + typename TTypes::Vec vec() { + return tensor(); + } + + template + typename TTypes::Matrix matrix() { + return tensor(); + } + + template + typename TTypes::Tensor tensor(); + + /// \brief Return the tensor data to an `Eigen::Tensor` with the + /// same size but a bitwise cast to the specified dtype `T`. + /// + /// Using a bitcast is useful for move and copy operations. + /// NOTE: this is the same as `tensor()` except a bitcast is allowed. + template + typename TTypes::Tensor bit_casted_tensor(); + + /// \brief Return the tensor data to an `Eigen::Tensor` with the + /// last dimension elements converted into single elements of a larger type. + /// + /// For example, this is useful for kernels that can treat NCHW_VECT_C int8 + /// tensors as NCHW int32 tensors. The sizeof(T) should equal the size of + /// the original element type * num elements in the original last dimension. + /// NDIMS should be 1 less than the original number of dimensions. + template + typename TTypes::Tensor reinterpret_last_dimension(); + + /// \brief Return the tensor data as an `Eigen::Tensor` of the data type and a + /// specified shape. + /// + /// These methods allow you to access the data with the dimensions + /// and sizes of your choice. You do not need to know the number of + /// dimensions of the Tensor to call them. However, they `CHECK` that + /// the type matches and the dimensions requested creates an + /// `Eigen::Tensor` with the same number of elements as the tensor. + /// + /// Example: + /// + /// ```c++ + /// + /// typedef float T; + /// Tensor my_ten(...built with Shape{planes: 4, rows: 3, cols: 5}...); + /// // 1D Eigen::Tensor, size 60: + /// auto flat = my_ten.flat(); + /// // 2D Eigen::Tensor 12 x 5: + /// auto inner = my_ten.flat_inner_dims(); + /// // 2D Eigen::Tensor 4 x 15: + /// auto outer = my_ten.shaped({4, 15}); + /// // CHECK fails, bad num elements: + /// auto outer = my_ten.shaped({4, 8}); + /// // 3D Eigen::Tensor 6 x 5 x 2: + /// auto weird = my_ten.shaped({6, 5, 2}); + /// // CHECK fails, type mismatch: + /// auto bad = my_ten.flat(); + /// + /// ``` + template + typename TTypes::Flat flat() { + return shaped({NumElements()}); + } + + template + typename TTypes::UnalignedFlat unaligned_flat() { + return unaligned_shaped({NumElements()}); + } + + /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all + /// Tensor dimensions but the last NDIMS-1 into the first dimension of the + /// result. If NDIMS > dims() then leading dimensions of size 1 will be + /// added to make the output rank NDIMS. + template + typename TTypes::Tensor flat_inner_dims(); + + /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all + /// Tensor dimensions but the first NDIMS-1 into the last dimension of the + /// result. If NDIMS > dims() then trailing dimensions of size 1 will be + /// added to make the output rank NDIMS. + template + typename TTypes::Tensor flat_outer_dims(); + + /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing the + /// first 'begin' Tensor dimensions into the first dimension of the result and + /// the Tensor dimensions of the last dims() - 'begin' - NDIMS into the last + /// dimension of the result. If 'begin' < 0 then the |'begin'| leading + /// dimensions of size 1 will be added. If 'begin' + NDIMS > dims() then + /// 'begin' + NDIMS - dims() trailing dimensions of size 1 will be added. + template + typename TTypes::Tensor flat_inner_outer_dims(int64 begin); + + template + typename TTypes::Tensor shaped(gtl::ArraySlice new_sizes); + + /// \brief Return the tensor data to an `Eigen::Tensor` with the new + /// shape specified in `new_sizes` and cast to a new dtype `T`. + /// + /// Using a bitcast is useful for move and copy operations. + /// The allowed bitcast is the only difference from `shaped()`. + template + typename TTypes::Tensor bit_casted_shaped( + gtl::ArraySlice new_sizes); + + template + typename TTypes::UnalignedTensor unaligned_shaped( + gtl::ArraySlice new_sizes); + + /// \brief Return the Tensor data as a `TensorMap` of fixed size 1: + /// `TensorMap>`. + + /// Using `scalar()` allows the compiler to perform optimizations as + /// the size of the tensor is known at compile time. + template + typename TTypes::Scalar scalar(); + + /// Const versions of all the methods above. + template + typename TTypes::ConstVec vec() const { + return tensor(); + } + + template + typename TTypes::ConstMatrix matrix() const { + return tensor(); + } + + template + typename TTypes::ConstTensor tensor() const; + + /// \brief Return the tensor data to an `Eigen::Tensor` with the + /// same size but a bitwise cast to the specified dtype `T`. + /// + /// Using a bitcast is useful for move and copy operations. + /// NOTE: this is the same as `tensor()` except a bitcast is allowed. + template + typename TTypes::ConstTensor bit_casted_tensor() const; + + /// \brief Return the tensor data to an `Eigen::Tensor` with the + /// last dimension elements converted into single elements of a larger type. + /// + /// For example, this is useful for kernels that can treat NCHW_VECT_C int8 + /// tensors as NCHW int32 tensors. The sizeof(T) should equal the size of + /// the original element type * num elements in the original last dimension. + /// NDIMS should be 1 less than the original number of dimensions. + template + typename TTypes::ConstTensor reinterpret_last_dimension() const; + + template + typename TTypes::ConstFlat flat() const { + return shaped({NumElements()}); + } + + template + typename TTypes::UnalignedConstFlat unaligned_flat() const { + return unaligned_shaped({NumElements()}); + } + + template + typename TTypes::ConstTensor shaped( + gtl::ArraySlice new_sizes) const; + + /// \brief Return the tensor data to an `Eigen::Tensor` with the new + /// shape specified in `new_sizes` and cast to a new dtype `T`. + /// + /// Using a bitcast is useful for move and copy operations. + /// The allowed bitcast is the only difference from `shaped()`. + template + typename TTypes::ConstTensor bit_casted_shaped( + gtl::ArraySlice new_sizes) const; + + template + typename TTypes::UnalignedConstTensor unaligned_shaped( + gtl::ArraySlice new_sizes) const; + + template + typename TTypes::ConstScalar scalar() const; + + template + typename TTypes::ConstTensor flat_inner_dims() const; + + template + typename TTypes::ConstTensor flat_outer_dims() const; + + template + typename TTypes::ConstTensor flat_inner_outer_dims( + int64 begin) const; + + /// Render the first `max_entries` values in `*this` into a string. + string SummarizeValue(int64 max_entries) const; + + /// A human-readable summary of the tensor suitable for debugging. + string DebugString() const; + + /// Fill in the `TensorDescription` proto with metadata about the + /// tensor that is useful for monitoring and debugging. + void FillDescription(TensorDescription* description) const; + + /// \brief Returns a `StringPiece` mapping the current tensor's buffer. + /// + /// The returned `StringPiece` may point to memory location on devices + /// that the CPU cannot address directly. + /// + /// NOTE: The underlying tensor buffer is refcounted, so the lifetime + /// of the contents mapped by the `StringPiece` matches the lifetime of + /// the buffer; callers should arrange to make sure the buffer does + /// not get destroyed while the `StringPiece` is still used. + /// + /// REQUIRES: `DataTypeCanUseMemcpy(dtype())`. + StringPiece tensor_data() const; + + /// Copy the other tensor into this tensor and reshape it and reinterpret the + /// buffer's datatype. + /// + /// This tensor shares other's underlying storage. + void UnsafeCopyFromInternal(const Tensor&, DataType dtype, + const TensorShape&); + + private: + // Returns true if the refcount on buf_ and any possible underlying root + // buffer is one. + bool RefCountIsOne() const; + void CheckType(DataType expected_dtype) const; + void CheckTypeAndIsAligned(DataType expected_dtype) const; + void CheckIsAlignedAndSingleElement() const; + void set_dtype(DataType t) { shape_.set_data_type(t); } + + // TensorShape's InlineVector. + static gtl::InlinedVector ComputeFlatInnerDims( + gtl::ArraySlice orig, int64 num_out_dims); + static gtl::InlinedVector ComputeFlatOuterDims( + gtl::ArraySlice orig, int64 num_out_dims); + + TensorShape shape_; + TensorBuffer* buf_; + + friend class DMAHelper; + friend class TensorCApi; + friend class TensorReference; // For access to buf_ + friend class VariableOp; // For access to set_shape + friend class AutoReloadVariableOp; // For access to set_shape + friend class TensorTestHelper; // For access to set_shape + friend class OpKernelContext; // For access to RefCountIsOne(). + template + friend class AssignVariableOp; // For access to RefCountIsOne(). + template + friend Status PrepareToUpdateVariable( + OpKernelContext* ctx, Tensor* tensor); // For access to RefCountIsOne(). + friend Status batch_util::CopyElementToSlice( + Tensor element, Tensor* parent, + int64 index); // For access to RefCountIsOne(). + friend class NumpyTensorBuffer; // For access to the private constructor + // taking the buffer. + + // Creates a tensor with the input datatype, shape and buf. + // + // Acquires a ref on buf that belongs to this Tensor. + Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf); + + bool CanUseDMA() const; + + // Only needed by variable op to set the shape of an uninitialized + // Tensor. + // TODO: Remove this when we have a better story for detecting + // uninitialized tensors. + void set_shape(const TensorShape& shape) { + DataType dt = dtype(); + shape_ = shape; + set_dtype(dt); + } + + void CopyFromInternal(const Tensor& other, const TensorShape& shape); + + template + T* base() const; + + template + void FillDimsAndValidateCompatibleShape( + gtl::ArraySlice new_sizes, + Eigen::array* dims) const; + + template + void FillDimsAndValidateCompatibleShape( + gtl::ArraySlice new_sizes, + Eigen::array* dims) const; +}; + +// Implementation details + +// START_SKIP_DOXYGEN + +// Interface to access the raw ref-counted data buffer. +class TensorBuffer : public core::RefCounted { + public: + ~TensorBuffer() override {} + + // data() points to a memory region of size() bytes. + virtual void* data() const = 0; + virtual size_t size() const = 0; + + // If this TensorBuffer is sub-buffer of another TensorBuffer, + // returns that TensorBuffer. Otherwise, returns this. + virtual TensorBuffer* root_buffer() = 0; + + // Fill metadata about the allocation into the proto. + virtual void FillAllocationDescription( + AllocationDescription* proto) const = 0; + + template + T* base() const { + return reinterpret_cast(data()); + } + + // Whether this TensorBuffer owns the underlying memory. + virtual bool OwnsMemory() const { return true; } +}; + +template +T* Tensor::base() const { + return buf_ == nullptr ? nullptr : buf_->base(); +} + +template +typename TTypes::Tensor Tensor::tensor() { + CheckTypeAndIsAligned(DataTypeToEnum::v()); + return typename TTypes::Tensor(base(), + shape().AsEigenDSizes()); +} + +template +typename TTypes::ConstTensor Tensor::tensor() const { + CheckTypeAndIsAligned(DataTypeToEnum::v()); + return typename TTypes::ConstTensor(base(), + shape().AsEigenDSizes()); +} + +template +typename TTypes::Tensor Tensor::bit_casted_tensor() { + CHECK(IsAligned()); + return typename TTypes::Tensor(base(), + shape().AsEigenDSizes()); +} + +template +typename TTypes::ConstTensor Tensor::bit_casted_tensor() const { + CHECK(IsAligned()); + return typename TTypes::ConstTensor(base(), + shape().AsEigenDSizes()); +} + +template +typename TTypes::Tensor Tensor::reinterpret_last_dimension() { + if (NDIMS == dims()) { + return tensor(); + } + CHECK(IsAligned()); + CHECK_EQ(NDIMS, dims() - 1); + CHECK_EQ(sizeof(T), shape_.dim_sizes()[NDIMS] * DataTypeSize(dtype())); + Eigen::array dims; + for (int d = 0; d < NDIMS; ++d) { + dims[d] = shape_.dim_sizes()[d]; + } + return typename TTypes::Tensor(base(), dims); +} + +template +typename TTypes::ConstTensor Tensor::reinterpret_last_dimension() + const { + if (NDIMS == dims()) { + return tensor(); + } + CHECK(IsAligned()); + CHECK_EQ(NDIMS, dims() - 1); + CHECK_EQ(sizeof(T), shape_.dim_sizes()[NDIMS] * DataTypeSize(dtype())); + Eigen::array dims; + for (int d = 0; d < NDIMS; ++d) { + dims[d] = shape_.dim_sizes()[d]; + } + return typename TTypes::ConstTensor(base(), dims); +} + +template +void Tensor::FillDimsAndValidateCompatibleShape( + gtl::ArraySlice new_sizes, + Eigen::array* dims) const { + CHECK_EQ(NDIMS, new_sizes.size()); + int64 new_num_elements = 1; + for (size_t d = 0; d < NDIMS; d++) { + new_num_elements *= new_sizes[d]; + (*dims)[d] = new_sizes[d]; + } + CHECK_EQ(new_num_elements, NumElements()); +} + +template +void Tensor::FillDimsAndValidateCompatibleShape( + gtl::ArraySlice new_sizes, + Eigen::array* dims) const { + CHECK_EQ(NDIMS, new_sizes.size()); + int64 new_num_elements = 1; + for (size_t d = 0; d < NDIMS; d++) { + new_num_elements *= new_sizes[d]; + (*dims)[d] = new_sizes[d]; + } + const int element_size = DataTypeSize(BaseType(dtype())); + if (element_size > 0) { + CHECK_EQ(new_num_elements * sizeof(T), NumElements() * element_size); + } else { + // DataTypeSize() returns 0 for some data types. In this case, assume that T + // has the same size as the buffer type. + // NOTE: If we can be sure that DataTypeSize() does not return 0 for all POD + // types, then we should check DataTypeToEnum::v() == dtype(). Or simply + // check if `element_size > 0` to err when bit cast is attempted on Tensor + // of unknown data type size. + CHECK_EQ(new_num_elements, NumElements()); + } +} + +template +typename TTypes::Tensor Tensor::shaped( + gtl::ArraySlice new_sizes) { + CheckTypeAndIsAligned(DataTypeToEnum::v()); + Eigen::array dims; + FillDimsAndValidateCompatibleShape(new_sizes, &dims); + return typename TTypes::Tensor(base(), dims); +} + +template +typename TTypes::Tensor Tensor::bit_casted_shaped( + gtl::ArraySlice new_sizes) { + CHECK(IsAligned()); + Eigen::array dims; + FillDimsAndValidateCompatibleShape(new_sizes, &dims); + return typename TTypes::Tensor(base(), dims); +} + +template +typename TTypes::UnalignedTensor Tensor::unaligned_shaped( + gtl::ArraySlice new_sizes) { + CheckType(DataTypeToEnum::v()); + Eigen::array dims; + FillDimsAndValidateCompatibleShape(new_sizes, &dims); + return typename TTypes::UnalignedTensor(base(), dims); +} + +template +typename TTypes::ConstTensor Tensor::shaped( + gtl::ArraySlice new_sizes) const { + CheckTypeAndIsAligned(DataTypeToEnum::v()); + Eigen::array dims; + FillDimsAndValidateCompatibleShape(new_sizes, &dims); + return typename TTypes::ConstTensor(base(), dims); +} + +template +typename TTypes::ConstTensor Tensor::bit_casted_shaped( + gtl::ArraySlice new_sizes) const { + CHECK(IsAligned()); + Eigen::array dims; + FillDimsAndValidateCompatibleShape(new_sizes, &dims); + return typename TTypes::ConstTensor(base(), dims); +} + +template +typename TTypes::UnalignedConstTensor Tensor::unaligned_shaped( + gtl::ArraySlice new_sizes) const { + CheckType(DataTypeToEnum::v()); + Eigen::array dims; + FillDimsAndValidateCompatibleShape(new_sizes, &dims); + return typename TTypes::UnalignedConstTensor(base(), dims); +} + +template +typename TTypes::Scalar Tensor::scalar() { + CheckIsAlignedAndSingleElement(); + return typename TTypes::Scalar(base()); +} + +template +typename TTypes::ConstScalar Tensor::scalar() const { + CheckIsAlignedAndSingleElement(); + return typename TTypes::ConstScalar(base()); +} + +template +typename TTypes::Tensor Tensor::flat_inner_dims() { + return shaped(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS)); +} + +template +typename TTypes::Tensor Tensor::flat_outer_dims() { + return shaped(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS)); +} + +template +typename TTypes::Tensor Tensor::flat_inner_outer_dims(int64 begin) { + gtl::InlinedVector flat_outer = + ComputeFlatOuterDims(shape_.dim_sizes(), begin + NDIMS); + return shaped(ComputeFlatInnerDims(flat_outer, NDIMS)); +} + +template +typename TTypes::ConstTensor Tensor::flat_inner_dims() const { + return shaped(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS)); +} + +template +typename TTypes::ConstTensor Tensor::flat_outer_dims() const { + return shaped(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS)); +} + +template +typename TTypes::ConstTensor Tensor::flat_inner_outer_dims( + int64 begin) const { + gtl::InlinedVector flat_outer = + ComputeFlatOuterDims(shape_.dim_sizes(), begin + NDIMS); + return shaped(ComputeFlatInnerDims(flat_outer, NDIMS)); +} + +inline Tensor::Tensor(const Tensor& other) + : shape_(other.shape()), buf_(other.buf_) { + if (buf_) buf_->Ref(); +} + +inline Tensor::Tensor(Tensor&& other) + : shape_(std::move(other.shape())), buf_(other.buf_) { + other.buf_ = nullptr; +} + +inline Tensor& Tensor::operator=(Tensor&& other) { + // Avoid self-assignment, since we might destroy our underlying buffer. + if (&other != this) { + shape_ = std::move(other.shape_); + if (buf_) buf_->Unref(); + buf_ = other.buf_; + other.buf_ = nullptr; + } + return *this; +} + +// END_SKIP_DOXYGEN + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_ diff --git a/tensor.proto b/tensor.proto new file mode 100644 index 0000000000000000000000000000000000000000..abbf16e8103326011525feb0017922474ff8d2cf --- /dev/null +++ b/tensor.proto @@ -0,0 +1,94 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/resource_handle.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. TODO(touts): sort out the 0-rank issues. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; +}; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/tensor_description.proto b/tensor_description.proto new file mode 100644 index 0000000000000000000000000000000000000000..6ac3c1b881087892893fe22d5e5ea383a6f616db --- /dev/null +++ b/tensor_description.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorDescriptionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/types.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/allocation_description.proto"; + +message TensorDescription { + // Data type of tensor elements + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto shape = 2; + + // Information about the size and allocator used for the data + AllocationDescription allocation_description = 4; +}; diff --git a/tensor_reference.cc b/tensor_reference.cc new file mode 100644 index 0000000000000000000000000000000000000000..8429bd421ee2c8ed8e4a3e590b0b53f75167c0bb --- /dev/null +++ b/tensor_reference.cc @@ -0,0 +1,25 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/tensor_reference.h" + +namespace tensorflow { + +TensorReference::TensorReference(const Tensor& tensor) + : buf_(tensor.buf_ ? tensor.buf_->root_buffer() : nullptr) { + if (buf_) buf_->Ref(); +} + +} // namespace tensorflow diff --git a/tensor_reference.h b/tensor_reference.h new file mode 100644 index 0000000000000000000000000000000000000000..37e588d4f108987f3f03ed503e9c6b66dfd7e5c7 --- /dev/null +++ b/tensor_reference.h @@ -0,0 +1,77 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { + +// An opaque class that holds a reference to an underlying TensorBuffer. +// Unlike Tensor, it does not have any shape or type information, so +// it is cheaper to construct/move, but the only thing you can really do +// with it is Unref it, which releases one of the references to the underlying +// TensorBuffer. +// IMPORTANT: If you do not call Unref(), you will likely leak tensor memory. +class TensorReference { + public: + // Take the reference of the root buffer so the size will be more accurate + explicit TensorReference(const Tensor& tensor); + + ~TensorReference() {} + + void Unref() const { + if (buf_) buf_->Unref(); + } + + // Return an estimate of the total bytes being kept alive by this reference. + size_t TotalBytes() const { + // We add 128 as a baseline to account for per-Tensor metadata + return 128 + (buf_ ? buf_->size() : 0); + } + + void FillDescription(AllocationDescription* description) const { + if (buf_) buf_->FillAllocationDescription(description); + } + + // Convenience function for de-duplicating tensor references. + bool SharesBufferWith(const TensorReference& t) const { + return buf_ == t.buf_; + } + + // Convenience function for de-duplicating tensor references. + bool SharesBufferWith(const Tensor& t) const { + return buf_ == (t.buf_ ? t.buf_->root_buffer() : nullptr); + } + + // Convenience function for de-duplicating tensor references. + size_t BufferHash() const { return std::hash()(buf_); } + + // A constructor used only for tests + explicit TensorReference(TensorBuffer* test_buffer) : buf_(test_buffer) { + if (buf_) buf_->Ref(); + } + + private: + TensorBuffer* buf_; +}; + +typedef gtl::InlinedVector TensorReferenceVector; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_ diff --git a/tensor_shape.cc b/tensor_shape.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e0b976e1736dff6b8a18c7b801cb6d1ef500f11 --- /dev/null +++ b/tensor_shape.cc @@ -0,0 +1,690 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/tensor_shape.h" + +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/overflow.h" + +namespace tensorflow { + +// TensorShape and PartialTensorShape should have no fields beyond +// TensorShapeRep. In particular, their sizes should be the same. +static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape), + "TensorShape must have no fields beyond TensorShapeRep"); +static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape), + "PartialTensorShape must have no fields beyond TensorShapeRep"); + +template +static void AppendTo(const TensorShapeBase& s, + gtl::InlinedVector* vals) { + for (auto dim : s) { + vals->push_back(dim.size); + } +} + +void TensorShape::CheckDimsEqual(int NDIMS) const { + CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions" + << " from a tensor of " << dims() << " dimensions"; +} + +void TensorShape::CheckDimsAtLeast(int NDIMS) const { + CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS + << " dimensions from a tensor of " << dims() + << " dimensions"; +} + +template +bool TensorShapeBase::IsValid(const TensorShapeProto& proto) { + // NOTE(irving): Unfortunately, TensorShape allows parsing protos with + // unknown_shape() set, and it seems hard to remove this without backwards + // compatibility issues. + if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0; + int64 num_elements = 1; + if (proto.dim().size() > MaxDimensions()) return false; + for (const auto& d : proto.dim()) { + if (d.size() < (kIsPartial ? -1 : 0)) return false; + if (d.size() == -1) { + num_elements = -1; + } else if (!kIsPartial || num_elements >= 0) { + num_elements = MultiplyWithoutOverflow(num_elements, d.size()); + if (num_elements < 0) return false; + } + } + return true; +} + +template +Status TensorShapeBase::IsValidShape(const TensorShapeProto& proto) { + // NOTE(irving): Unfortunately, TensorShape allows parsing protos with + // unknown_shape() set, and it seems hard to remove this without backwards + // compatibility issues. + if (kIsPartial && proto.unknown_rank()) { + if (proto.dim_size() > 0) { + return errors::InvalidArgument( + "An unknown shape must not have any dimensions set."); + } + return Status::OK(); + } + int64 num_elements = 1; + if (proto.dim().size() > MaxDimensions()) { + return errors::InvalidArgument("Shape ", DebugString(proto), + " has too many dimensions"); + } + for (const auto& d : proto.dim()) { + if (d.size() < (kIsPartial ? -1 : 0)) { + if (kIsPartial) { + return errors::InvalidArgument( + "Shape ", DebugString(proto), + " has dimensions with values below -1 (where -1 means unknown)"); + } else { + return errors::InvalidArgument("Shape ", DebugString(proto), + " is not fully defined"); + } + } + if (d.size() == -1) { + num_elements = -1; + } else if (!kIsPartial || num_elements >= 0) { + num_elements = MultiplyWithoutOverflow(num_elements, d.size()); + if (num_elements < 0) { + return errors::InvalidArgument( + "Shape ", DebugString(proto), + " is too large (more than 2**63 - 1 entries)"); + } + } + } + return Status::OK(); +} + +template +TensorShapeBase::TensorShapeBase(const TensorShapeProto& proto) { + set_tag(REP16); + set_data_type(DT_INVALID); + // NOTE(irving): Unfortunately, TensorShape allows parsing protos with + // unknown_shape() set, and it seems hard to remove this without backwards + // compatibility issues. + if (kIsPartial && proto.unknown_rank()) { + set_ndims_byte(kUnknownRank); + set_num_elements(-1); + } else { + set_ndims_byte(0); + set_num_elements(1); + for (const auto& d : proto.dim()) { + AddDim(d.size()); + } + } +} + +template +TensorShapeBase::TensorShapeBase(gtl::ArraySlice dim_sizes) { + set_tag(REP16); + set_data_type(DT_INVALID); + set_ndims_byte(0); + set_num_elements(1); + for (int64 s : dim_sizes) { + AddDim(internal::SubtleMustCopy(s)); + } +} + +template +TensorShapeBase::TensorShapeBase() { + set_tag(REP16); + set_data_type(DT_INVALID); + if (kIsPartial) { + set_ndims_byte(kUnknownRank); + set_num_elements(-1); + } else { + set_ndims_byte(0); + set_num_elements(1); + } +} + +void TensorShapeRep::DestructorOutOfLine() { + DCHECK(tag() == REP_OUT_OF_LINE); + delete as64()->dims_; +} + +void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) { + if (b.tag() != REP_OUT_OF_LINE) { + if (tag() == REP_OUT_OF_LINE) { + delete as64()->dims_; + } + memcpy(buf(), b.buf(), sizeof(u_.buf)); + // memcpy above implicitly also does: + // set_tag(b.tag()); + // set_ndims_byte(b.ndims_byte()); + // set_data_type(b.data_type()); + } else { + DCHECK_EQ(b.tag(), REP_OUT_OF_LINE); + set_ndims_byte(b.ndims_byte()); + set_data_type(b.data_type()); + if (tag() == REP_OUT_OF_LINE) { + // vector already allocated + *(as64()->dims_) = *(b.as64()->dims_); + } else { + set_tag(REP_OUT_OF_LINE); + as64()->dims_ = new gtl::InlinedVector(*(b.as64()->dims_)); + } + } +} + +template +int64 TensorShapeBase::dim_size(int d) const { + if (unknown_rank()) return -1; + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + if (tag() == REP16) { + uint16 dim = as16()->dims_[d]; + if (kIsPartial && dim == kUnknownRep16) return -1; + return dim; + } else if (tag() == REP32) { + uint32 dim = as32()->dims_[d]; + if (kIsPartial && dim == kUnknownRep32) return -1; + return dim; + } else { + return (*as64()->dims_)[d]; + } +} + +void TensorShapeRep::Clear() { + ClearAllButDataType(); + set_data_type(DT_INVALID); +} + +void TensorShapeRep::ClearAllButDataType() { + if (tag() == REP_OUT_OF_LINE) { + delete as64()->dims_; + } + set_tag(REP16); + set_ndims_byte(0); + // Leaves data_type alone + set_num_elements(1); +} + +template +void TensorShapeBase::RecomputeNumElements() { + if (unknown_rank()) { + set_num_elements(-1); + return; + } + int64 n = 1; + for (auto dim : *this) { + if (kIsPartial && dim.size < 0) { + n = -1; + break; + } + n = MultiplyWithoutOverflow(n, dim.size); + CHECK_LE(0, n); + } + set_num_elements(n); +} + +template +void TensorShapeBase::AddDim(int64 size) { + if (!kIsPartial) CHECK_GE(size, 0); + if (unknown_rank()) return; + CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor"; + int64 new_num_elements; + if (kIsPartial && (num_elements() < 0 || size < 0)) { + new_num_elements = -1; + } else { + new_num_elements = MultiplyWithoutOverflow(num_elements(), size); + CHECK_LE(0, new_num_elements); + } + UnsafeAddDim(size, new_num_elements); +} + +template +void TensorShapeBase::UnsafeAddDim(int64 size, int64 new_num_elements) { + const int nd = ndims_byte(); + if (tag() == REP16 && nd < 6 && size < kMaxRep16) { + as16()->dims_[nd] = + kIsPartial && size < 0 ? kUnknownRep16 : static_cast(size); + } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) { + as32()->dims_[nd] = + kIsPartial && size < 0 ? kUnknownRep32 : static_cast(size); + } else if (tag() == REP_OUT_OF_LINE) { + as64()->dims_->push_back(size); + } else { + // Need to change representation + gtl::InlinedVector vals; + AppendTo(*this, &vals); + vals.push_back(size); + // We know we can't be REP16. See if we have a small enough + // number of dimensions and each dimension's size is small enough + // to allow REP32. + bool can_be_rep32 = (vals.size() <= 3); + if (can_be_rep32) { + for (size_t i = 0; i < vals.size(); i++) { + if (vals[i] >= kMaxRep32) { + can_be_rep32 = false; + break; + } + } + } + if (can_be_rep32) { + set_tag(REP32); + for (size_t d = 0; d < vals.size(); d++) { + as32()->dims_[d] = kIsPartial && vals[d] < 0 + ? kUnknownRep32 + : static_cast(vals[d]); + } + } else { + set_tag(REP_OUT_OF_LINE); + as64()->dims_ = + new gtl::InlinedVector(vals.begin(), vals.end()); + } + } + set_ndims_byte(nd + 1); + set_num_elements(new_num_elements); +} + +template +void TensorShapeBase::AppendShape(const TensorShapeBase& shape) { + for (auto d : shape) AddDim(d.size); +} + +template +void TensorShapeBase::InsertDim(int d, int64 size) { + CHECK_GE(d, 0); + CHECK_LE(d, dims()); + if (!kIsPartial) CHECK_GE(size, 0); + CHECK_LT(dims(), MaxDimensions()); + gtl::InlinedVector vals; + AppendTo(*this, &vals); + vals.insert(vals.begin() + d, size); + ClearAllButDataType(); + for (auto dval : vals) { + AddDim(dval); + } +} + +template +gtl::InlinedVector TensorShapeBase::dim_sizes() const { + gtl::InlinedVector result; + for (auto dim : *this) { + result.push_back(dim.size); + } + return result; +} + +template +void TensorShapeBase::set_dim(int d, int64 size) { + CHECK_GE(d, 0); + CHECK_LT(d, dims()); + CHECK_GE(size, 0); + if (tag() == REP16 && size < kMaxRep16) { + as16()->dims_[d] = + kIsPartial && size < 0 ? kUnknownRep16 : static_cast(size); + } else if (tag() == REP32 && size < kMaxRep32) { + as32()->dims_[d] = + kIsPartial && size < 0 ? kUnknownRep32 : static_cast(size); + } else if (tag() == REP_OUT_OF_LINE) { + (*as64()->dims_)[d] = size; + } else { + // Must upgrade + gtl::InlinedVector vals; + AppendTo(*this, &vals); + vals[d] = size; + ClearAllButDataType(); + for (auto dval : vals) { + AddDim(dval); + } + } + RecomputeNumElements(); +} + +template +void TensorShapeBase::RemoveDimRange(int begin, int end) { + if (unknown_rank()) return; + begin = begin < 0 ? dims() + begin + 1 : begin; + end = end < 0 ? dims() + end + 1 : end; + CHECK_GE(begin, 0); + CHECK_LE(begin, dims()); + CHECK_GE(end, 0); + CHECK_LE(end, dims()); + if (begin >= end) return; + gtl::InlinedVector vals; + AppendTo(*this, &vals); + vals.erase(vals.begin() + begin, vals.begin() + end); + ClearAllButDataType(); + for (auto dval : vals) { + AddDim(dval); + } + RecomputeNumElements(); +} + +bool TensorShape::IsSameSize(const TensorShape& b) const { + if (b.dims() != dims()) return false; + for (int d = 0; d < dims(); d++) { + if (dim_size(d) != b.dim_size(d)) return false; + } + return true; +} + +template +void TensorShapeBase::AsProto(TensorShapeProto* proto) const { + proto->Clear(); + if (unknown_rank()) { + proto->set_unknown_rank(true); + } else { + for (int i = 0; i < dims(); i++) { + proto->add_dim()->set_size(dim_size(i)); + } + } +} + +void TensorShapeRep::DumpRep() const { +#if 0 + fprintf(stderr, "Rep: %d %d dims\n", tag(), dims()); + if (tag() == REP16) { + fprintf(stderr, "REP16 NDIMS: %d\n", ndims_byte()); + for (int i = 0; i < ndims_byte(); i++) { + fprintf(stderr, "dim %d: %d\n", i, as16()->dims_[i]); + } + } else if (tag_ == REP32) { + fprintf(stderr, "REP32 NDIMS: %d\n", ndims_); + for (int i = 0; i < ndims_byte(); i++) { + fprintf(stderr, "dim %d: %d\n", i, as32()->dims_[i]); + } + } else if (tag_ == REP_OUT_OF_LINE) { + fprintf(stderr, "REP_OUT_OF_LINE NDIMS: %d %p\n", ndims_, as16()->dims_); + for (int i = 0; i < ndims_byte(); i++) { + fprintf(stderr, "dim %d: %lld\n", i, (*as64()->dims_)[i]); + } + } +#endif +} + +template +TensorShapeIter TensorShapeBase::begin() const { + return TensorShapeIter(static_cast(this), 0); +} + +template +TensorShapeIter TensorShapeBase::end() const { + CHECK(!unknown_rank()); + return TensorShapeIter(static_cast(this), dims()); +} + +string TensorShapeRep::DebugString() const { + const auto& shape = *static_cast(this); + if (shape.unknown_rank()) return ""; + string s = "["; + for (int i = 0; i < shape.dims(); i++) { + if (i > 0) strings::StrAppend(&s, ","); + int64 dim = shape.dim_size(i); + if (dim < 0) { + strings::StrAppend(&s, "?"); + } else { + strings::StrAppend(&s, dim); + } + } + strings::StrAppend(&s, "]"); + return s; +} + +string TensorShapeRep::DebugString(const TensorShapeProto& proto) { + string s; + if (proto.unknown_rank()) { + strings::StrAppend(&s, ""); + if (proto.dim_size() == 0) return s; + } + strings::StrAppend(&s, "["); + bool first = true; + for (const auto& d : proto.dim()) { + if (!first) strings::StrAppend(&s, ","); + if (d.size() == -1) { + strings::StrAppend(&s, "?"); + } else { + strings::StrAppend(&s, d.size()); + } + first = false; + } + strings::StrAppend(&s, "]"); + return s; +} + +bool TensorShapeUtils::StartsWith(const TensorShape& shape, + const TensorShape& prefix) { + if (shape.dims() < prefix.dims()) return false; + for (int i = 0; i < prefix.dims(); ++i) { + if (shape.dim_size(i) != prefix.dim_size(i)) return false; + } + return true; +} + +bool TensorShapeUtils::EndsWith(const TensorShape& shape, + const TensorShape& suffix) { + const int suffix_size = suffix.dims(); + if (shape.dims() < suffix_size) return false; + for (int i = 0; i < suffix_size; ++i) { + if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) { + return false; + } + } + return true; +} + +template +Status MakeShapeHelper(const T* dims, int64 n, Shape* out) { + out->Clear(); + if (n > TensorShape::MaxDimensions()) { + return errors::InvalidArgument("Too many dimensions"); + } + if (n < 0) { + return errors::InvalidArgument("Negative number of dimensions ", n); + } + for (int64 i = 0; i < n; ++i) { + T dim = internal::SubtleMustCopy(dims[i]); + int64 new_num_elements; + if (dim < 0) { + if (!out->kIsPartial) { + return errors::InvalidArgument("Dimension ", dim, " must be >= 0"); + } + if (dim < -1) { + return errors::InvalidArgument("Dimension ", dim, " must be >= -1"); + } + dim = -1; + new_num_elements = -1; + } else if (out->num_elements() < 0) { + new_num_elements = -1; + } else { + new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim); + if (TF_PREDICT_FALSE(new_num_elements < 0)) { + TensorShapeProto proto; + for (int64 j = 0; j < n; ++j) { + proto.add_dim()->set_size(dim); + } + return errors::InvalidArgument( + "Shape ", TensorShape::DebugString(proto), + " would have more than 2**63 - 1 elements"); + } + } + out->UnsafeAddDim(dim, new_num_elements); + } + return Status::OK(); +} + +#define MAKE_SHAPE(T, Shape) \ + Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) { \ + return MakeShapeHelper(dims, n, out); \ + } \ + Status TensorShapeUtils::MakeShape(gtl::ArraySlice shape, Shape* out) { \ + return MakeShapeHelper(shape.data(), shape.size(), out); \ + } +MAKE_SHAPE(int32, TensorShape) +MAKE_SHAPE(int64, TensorShape) +MAKE_SHAPE(int32, PartialTensorShape) +MAKE_SHAPE(int64, PartialTensorShape) +#undef MAKE_SHAPE + +string TensorShapeUtils::ShapeListString( + const gtl::ArraySlice& shapes) { + string result = "["; + bool first = true; + for (const TensorShape& shape : shapes) { + strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString()); + first = false; + } + strings::StrAppend(&result, "]"); + return result; +} + +PartialTensorShape PartialTensorShape::Concatenate(int64 size) const { + PartialTensorShape out = *this; + out.AddDim(size); + return out; +} + +PartialTensorShape PartialTensorShape::Concatenate( + const PartialTensorShape& shape) const { + if (unknown_rank() || shape.unknown_rank()) { + return PartialTensorShape(); + } + PartialTensorShape out = *this; + for (auto dim : shape) out.AddDim(dim.size); + return out; +} + +Status PartialTensorShape::MergeWith(const PartialTensorShape& shape, + PartialTensorShape* result) const { + if (unknown_rank()) { + *result = shape; + return Status::OK(); + } + if (shape.unknown_rank()) { + *result = *this; + return Status::OK(); + } + const int dims_ = dims(); + if (dims_ != shape.dims()) { + return errors::InvalidArgument( + "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ", + shape.dims()); + } + CHECK(result != this); + result->Clear(); + for (int i = 0; i < dims_; ++i) { + const int64 dim0 = dim_size(i); + const int64 dim1 = shape.dim_size(i); + if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) { + return errors::InvalidArgument( + "PartialTensorShape: Incompatible shapes during merge: ", + DebugString(), " vs. ", shape.DebugString()); + } + result->AddDim(dim0 >= 0 ? dim0 : dim1); + } + return Status::OK(); +} + +bool PartialTensorShape::AsTensorShape(TensorShape* shape) const { + if (IsFullyDefined()) { + const TensorShapeRep* rep = this; + *shape = *static_cast(rep); + return true; + } + return false; +} + +bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const { + if (unknown_rank() || shape.unknown_rank()) { + return unknown_rank() == shape.unknown_rank(); + } + if (dims() != shape.dims()) return false; + for (int i = 0; i < dims(); i++) { + if (dim_size(i) != shape.dim_size(i)) return false; + } + return true; +} + +bool PartialTensorShape::IsCompatibleWith( + const PartialTensorShape& shape) const { + if (unknown_rank() || shape.unknown_rank()) return true; + if (dims() != shape.dims()) return false; + for (int i = 0; i < dims(); i++) { + const int64 dim0 = dim_size(i); + const int64 dim1 = shape.dim_size(i); + if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false; + } + return true; +} + +string PartialTensorShapeUtils::PartialShapeListString( + const gtl::ArraySlice& shapes) { + string result = "["; + bool first = true; + for (const PartialTensorShape& shape : shapes) { + strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString()); + first = false; + } + strings::StrAppend(&result, "]"); + return result; +} + +bool PartialTensorShapeUtils::AreCompatible( + const gtl::ArraySlice& shapes0, + const gtl::ArraySlice& shapes1) { + if (shapes0.size() == shapes1.size()) { + for (size_t i = 0; i < shapes0.size(); ++i) { + if (!shapes0[i].IsCompatibleWith(shapes1[i])) { + return false; + } + } + return true; + } else { + return false; + } +} + +bool PartialTensorShapeUtils::AreIdentical( + const gtl::ArraySlice& shapes0, + const gtl::ArraySlice& shapes1) { + if (shapes0.size() == shapes1.size()) { + for (size_t i = 0; i < shapes0.size(); ++i) { + if (!shapes0[i].IsIdenticalTo(shapes1[i])) { + return false; + } + } + return true; + } else { + return false; + } +} + +Status TensorShapeUtils::NumElements(gtl::ArraySlice shape, + int64* num_elements) { + int64 n = 1; + for (auto dim : shape) { + n = MultiplyWithoutOverflow(n, dim); + if (n < 0) { + return errors::InvalidArgument("Can't compute total size of shape [", + str_util::Join(shape, ","), + "]; product would overflow int64"); + } + } + *num_elements = n; + return Status::OK(); +} + +template class TensorShapeBase; +template class TensorShapeBase; + +} // namespace tensorflow diff --git a/tensor_shape.h b/tensor_shape.h new file mode 100644 index 0000000000000000000000000000000000000000..adb41b81c6ec019ce51a3871ca329c82f8a1f4b7 --- /dev/null +++ b/tensor_shape.h @@ -0,0 +1,547 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ + +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// START_SKIP_DOXYGEN +template +class TensorShapeIter; +class TensorShape; +class TensorShapeProto; +class PartialTensorShape; +// END_SKIP_DOXYGEN + +/// Internal representation for both TensorShape and PartialTensorShape. +class TensorShapeRep { + public: + ~TensorShapeRep(); + + /// Copy the specified shape + TensorShapeRep(const TensorShapeRep& b); + void operator=(const TensorShapeRep& b); + + /// Move the specified shape. After moving, is safe for destruction and + // can be reassigned into, but its dimensions and number of elements can be + // nonsensical (e.g., negative dimension sizes, or number of elements not + // properly recomputed). + TensorShapeRep(TensorShapeRep&& b); + void operator=(TensorShapeRep&& b); + + /// Clear a tensor shape, producing the scalar shape. + void Clear(); + + // Maximum number of dimensions in a tensor. + // It's 254 because 255 = kUnknownRank is used to represent unknown rank. + static constexpr int MaxDimensions() { return 254; } + + /// \brief Returns the number of elements in the tensor. + /// + /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor` + /// which uses `ptrdiff_t`. For PartialTensorShape, -1 means not fully + /// defined. + int64 num_elements() const { return num_elements_; } + + /// For error messages. + string DebugString() const; + static string DebugString(const TensorShapeProto& proto); + + void DumpRep() const; // XXX + + protected: + // Constructable only via TensorShapeBase + TensorShapeRep() = default; + + void ClearAllButDataType(); + + // We use 16 bytes to represent a TensorShape. Because we need to + // be able to support full 64-bit dimension sizes and an arbitrary + // number of dimensions for a Tensor, but most tensor dimensions are + // significantly smaller than 64 bits and most tensors are 1, 2, or 3 + // dimensions, we have several representations. + // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1 + // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1 + // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using + // an out of line vector. + // For PartialTensorShape, a dimension of static_cast(-1) is unknown. + // This value is not allowed in TensorShape either for format compatibility. + struct Rep16 { + uint16 dims_[6]; + }; + struct Rep32 { + uint32 dims_[3]; + }; + struct Rep64 { + gtl::InlinedVector* dims_; + }; + + // We use the max value of uint16 or uint32 to represent unknown shapes, so + // the maximum representable valid shape in these representations is one less. + static const int64 kMaxRep16 = std::numeric_limits::max() - 1; + static const int64 kMaxRep32 = std::numeric_limits::max() - 1; + static const uint16 kUnknownRep16 = std::numeric_limits::max(); + static const uint32 kUnknownRep32 = std::numeric_limits::max(); + + Rep16* as16() { return reinterpret_cast(buf()); } + Rep32* as32() { return reinterpret_cast(buf()); } + Rep64* as64() { return reinterpret_cast(buf()); } + + const Rep16* as16() const { return reinterpret_cast(buf()); } + const Rep32* as32() const { return reinterpret_cast(buf()); } + const Rep64* as64() const { return reinterpret_cast(buf()); } + + enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 }; + + // Since we have a convenient extra byte available, we allow the + // Tensor class to store an 8-bit value in this extra storage. This + // allows it to store the Tensor's datatype enum value here and avoid + // an extra word of storage. + friend class Tensor; + friend class TensorShapeTestHelper; + DataType data_type() const { return static_cast(buf()[13]); } + void set_data_type(DataType dt) { + // We only have 8 bits available to store DataType, so make sure it fits + DCHECK_LT(static_cast(dt), 256u); + buf()[13] = static_cast(dt); + } + + // We store the number of dimensions in byte 14, and the RepTag in byte 15. + // Bytes [0..13] vary depending on the representation. + // A value of 255 indicates unknown rank in the PartialTensorShape case. + static const uint8 kUnknownRank = 255; + uint8 ndims_byte() const { return buf()[14]; } + void set_ndims_byte(uint8 nd) { buf()[14] = nd; } + + RepTag tag() const { return static_cast(buf()[15]); } + void set_tag(RepTag tag) { buf()[15] = static_cast(tag); } + + void set_num_elements(int64 n) { num_elements_ = n; } + + private: + void DestructorOutOfLine(); + void SlowCopyFrom(const TensorShapeRep& b); + + uint8* buf() { return &u_.buf[0]; } + const uint8* buf() const { return &u_.buf[0]; } + + union { + uint8 buf[16]; + // Force data to be aligned enough for a pointer. + Rep64* unused_aligner; + } u_; + int64 num_elements_; +}; + +/// Base class for TensorShape and PartialTensorShape. +/// The class is templatized by either TensorShape or PartialTensorShape to +/// allow skipping known/unknown checks in the TensorShape case, but the +/// representation is shared exactly for fast conversion. +template +class TensorShapeBase : public TensorShapeRep { + public: + /// \brief Construct a `TensorShapeBase` from the provided sizes. + /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape) + explicit TensorShapeBase(gtl::ArraySlice dim_sizes); + TensorShapeBase(std::initializer_list dim_sizes) + : TensorShapeBase(gtl::ArraySlice(dim_sizes)) {} + + /// Construct an empty TensorShape, or an unknown rank PartialTensorShape + TensorShapeBase(); + + TensorShapeBase(const TensorShapeProto& proto); + + /// Returns `true` iff `proto` is a valid tensor shape. + // For TensorShape, the proto shape must be fully defined. + static bool IsValid(const TensorShapeProto& proto); + + /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error + /// status otherwise. + static Status IsValidShape(const TensorShapeProto& proto); + + /// \brief Add a dimension to the end ("inner-most"). + /// REQUIRES: `size >= 0` + void AddDim(int64 size); + + /// Appends all the dimensions from `shape`. + void AppendShape(const TensorShapeBase& shape); + + // Maximum number of dimensions in a tensor. + static constexpr int MaxDimensions() { return 254; } + + /// \brief Insert a dimension somewhere in the `TensorShape`. + /// REQUIRES: `0 <= d <= dims()` + /// REQUIRES: `size >= 0` + void InsertDim(int d, int64 size); + + /// \brief Modifies the size of the dimension `d` to be `size` + /// REQUIRES: `0 <= d < dims()` + /// REQUIRES: `size >= 0` + void set_dim(int d, int64 size); + + /// \brief Removes dimension `d` from the `TensorShape`. + /// REQUIRES: `0 <= d < dims()` + void RemoveDim(int d) { + CHECK_GE(d, 0); + RemoveDimRange(d, d + 1); + } + + /// \brief Removes last `n` dimensions from the `TensorShape`. + /// REQUIRES: `0 <= n <= dims()` + void RemoveLastDims(int n) { + CHECK_LE(n, dims()); + RemoveDimRange(dims() - n, dims()); + } + + /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`. + /// Negative values of `end` are interpreted as `dims() + end + 1` (as in + /// Python). The same is true for negative values of `begin`. REQUIRES: + /// `-(dims()+1) <= begin <= dims()` REQUIRES: `-(dims()+1) <= end <= dims()` + void RemoveDimRange(int begin, int end); + + /// Return whether the rank is unknown + bool unknown_rank() const { + return kIsPartial && ndims_byte() == kUnknownRank; + } + + /// Return the number of dimensions in the tensor. + /// Can be -1 meaning unknown rank for PartialTensorShape. + int dims() const { + uint8 dims = ndims_byte(); + return kIsPartial && dims == kUnknownRank ? -1 : dims; + } + + /// \brief Returns the number of elements in dimension `d`. + /// REQUIRES: `0 <= d < dims()` + // TODO(touts): Rename to `dimension()` to match + // `Eigen::Tensor::dimension()`? + int64 dim_size(int d) const; + + /// Returns sizes of all dimensions. + // Returns an empty list for unknown rank PartialTensorShape. + gtl::InlinedVector dim_sizes() const; + + /// Return true iff the rank and all of the dimensions are well defined + // TODO(irving): Rename to is_fully_defined now that it's fast. + bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; } + + /// Fill `*proto` from `*this`. + void AsProto(TensorShapeProto* proto) const; + + /// For iterating through the dimensions. + TensorShapeIter begin() const; + TensorShapeIter end() const; + + private: + void RecomputeNumElements(); + + // True for PartialTensorShape, false for TensorShape + static constexpr bool kIsPartial = + std::is_same::value; + static_assert(kIsPartial || std::is_same::value, + "Shape is neither TensorShape nor PartialTensorShape"); + + // Used by AddDim and MakeShapeHelper. Does no error checking. + void UnsafeAddDim(int64 size, int64 new_num_elements); + + // For use by TensorShapeUtils::MakeShape + template + friend Status MakeShapeHelper(const T*, int64, S*); +}; + +/// Represents the shape of a Tensor. +/// +/// A tensor's shape is denoted by its number of dimensions and a size for each +/// dimension. For example, a Tensor represented by a 3 x 4 matrix would have +/// a shape of 2-D, [3,4]. +/// +/// If you know the exact shape of your Tensor when you create the TensorShape +/// object, you can specify it then, or you can create a TensorShape with +/// zero dimensions and one element, and call AddDim() to add dimensions later. +class TensorShape : public TensorShapeBase { + public: + using TensorShapeBase::TensorShapeBase; + + /// Allow a TensorShape to be used as a PartialTensorShape without copying + operator const PartialTensorShape&() const; // NOLINT(runtime/explicit) + + /// Returns true if `*this` and `b` have the same sizes. Ignores + /// dimension names. + bool IsSameSize(const TensorShape& b) const; + bool operator==(const TensorShape& b) const { return IsSameSize(b); } + bool operator!=(const TensorShape& b) const { return !IsSameSize(b); } + + /// Fill `*dsizes` from `*this`. + template + Eigen::DSizes AsEigenDSizes() const; + + /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in + /// which case we pad the rest of the sizes with 1. + template + Eigen::DSizes AsEigenDSizesWithPadding() const; + + private: + // These CHECK fail to ease debugging. + // REQUIRES: dims() == NDIMS + void CheckDimsEqual(int NDIMS) const; + // REQUIRES: dims() >= NDIMS + void CheckDimsAtLeast(int NDIMS) const; +}; + +/// Represents the value of one dimension in a TensorShape. +struct TensorShapeDim { + explicit TensorShapeDim(int64 s) : size(s) {} + int64 size; +}; + +// START_SKIP_DOXYGEN +template +class TensorShapeIter { + public: + TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {} + bool operator==(const TensorShapeIter& rhs) { + DCHECK(shape_ == rhs.shape_); + return d_ == rhs.d_; + } + bool operator!=(const TensorShapeIter& rhs) { + DCHECK(shape_ == rhs.shape_); + return d_ != rhs.d_; + } + void operator++() { ++d_; } + TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); } + + private: + const Shape* shape_; + int d_; +}; +// END_SKIP_DOXYGEN + +/// \brief Static helper routines for `TensorShape`. Includes a few common +/// predicates on a tensor shape. +class TensorShapeUtils { + public: + static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; } + + static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; } + + static bool IsVectorOrHigher(const TensorShape& shape) { + return shape.dims() >= 1; + } + + static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; } + + static bool IsSquareMatrix(const TensorShape& shape) { + return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1); + } + + static bool IsMatrixOrHigher(const TensorShape& shape) { + return shape.dims() >= 2; + } + + /// \brief Returns a `TensorShape` whose dimensions are + /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. + static Status MakeShape(const int32* dims, int64 n, TensorShape* out); + static Status MakeShape(const int64* dims, int64 n, TensorShape* out); + static Status MakeShape(gtl::ArraySlice shape, TensorShape* out); + static Status MakeShape(gtl::ArraySlice shape, TensorShape* out); + static Status MakeShape(const int32* dims, int64 n, PartialTensorShape* out); + static Status MakeShape(const int64* dims, int64 n, PartialTensorShape* out); + static Status MakeShape(gtl::ArraySlice shape, + PartialTensorShape* out); + static Status MakeShape(gtl::ArraySlice shape, + PartialTensorShape* out); + + static string ShapeListString(const gtl::ArraySlice& shapes); + + /// \brief Returns true iff `shape` starts with `prefix`. + static bool StartsWith(const TensorShape& shape, const TensorShape& prefix); + + /// \brief Returns true iff `shape` ends with `suffix`. + static bool EndsWith(const TensorShape& shape, const TensorShape& suffix); + + /// \brief Returns the product of values in an int64 array, + /// or a failing Status if the array represents a value larger than + /// a `TensorShape` can hold. + static Status NumElements(gtl::ArraySlice shape, int64* num_elements); +}; + +/// Manages the partially known dimensions of a Tensor and their sizes. +class PartialTensorShape : public TensorShapeBase { + public: + PartialTensorShape() {} + using TensorShapeBase::TensorShapeBase; + + /// Add a dimension to the end ("inner-most"), returns a new + /// PartialTensorShape. + /// REQUIRES: `size >= -1`, where -1 means unknown + PartialTensorShape Concatenate(int64 size) const; + + /// Appends all the dimensions from `shape`. Returns a new + /// PartialTensorShape. + PartialTensorShape Concatenate(const PartialTensorShape& shape) const; + + /// Merges all the dimensions from `shape`. Returns + /// `InvalidArgument` error if either `shape` has a different rank + /// or if any of the dimensions are incompatible. + Status MergeWith(const PartialTensorShape& shape, + PartialTensorShape* result) const; + + /// Exact equality test. Returns true iff the ranks match (i.e., both are + /// unknown, or both are known and equal), and all dimensions are equal (i.e., + /// both dimensions are known, or both are known and equal). This is a + /// stronger condition that IsCompatibleWith. + bool IsIdenticalTo(const PartialTensorShape& shape) const; + + /// Return true iff the ranks match, and if the + /// dimensions all either match or one is unknown. + bool IsCompatibleWith(const PartialTensorShape& shape) const; + + // Fill `*shape` from `*this`. + // If `*this` is not fully defined, returns false and + // `*shape` is left in an intermediate state. Otherwise + // returns true. + bool AsTensorShape(TensorShape* shape) const; + + /// \brief Returns a `PartialTensorShape` whose dimensions are + /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are + /// considered "unknown". + template + static Status MakePartialShape(const T* dims, int n, + PartialTensorShape* out) { + return TensorShapeUtils::MakeShape(dims, n, out); + } +}; + +/// \brief Static helper routines for `PartialTensorShape`. Includes a few +/// common predicates on a partially known tensor shape. +class PartialTensorShapeUtils { + public: + static string PartialShapeListString( + const gtl::ArraySlice& shapes); + + static bool AreIdentical(const gtl::ArraySlice& shapes0, + const gtl::ArraySlice& shapes1); + + static bool AreCompatible(const gtl::ArraySlice& shapes0, + const gtl::ArraySlice& shapes1); +}; + +// ---------------------------------------------------------------------------- +// Template method implementation details below +// ---------------------------------------------------------------------------- + +template +Eigen::DSizes TensorShape::AsEigenDSizes() const { + CheckDimsEqual(NDIMS); + return AsEigenDSizesWithPadding(); +} + +template +Eigen::DSizes TensorShape::AsEigenDSizesWithPadding() + const { + CheckDimsAtLeast(NDIMS); + static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions"); + Eigen::DSizes dsizes; + for (int d = 0; d < dims(); d++) { + dsizes[d] = dim_size(d); + } + for (int d = dims(); d < NDIMS; d++) { + dsizes[d] = 1; + } + return dsizes; +} + +// ---------------------------------------------------------------------------- +// Inlining of some performance critical routines +// ---------------------------------------------------------------------------- + +inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) { + num_elements_ = b.num_elements_; + if (b.tag() != REP_OUT_OF_LINE) { + memcpy(buf(), b.buf(), sizeof(u_.buf)); + // memcpy above Implicitly does: + // set_ndims_byte(b.ndims_byte()); + // set_tag(b.tag()); + } else { + set_tag(REP16); // So that SlowCopyFrom does not try to deallocate + SlowCopyFrom(b); + } +} + +inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) { + num_elements_ = b.num_elements_; + memcpy(buf(), b.buf(), sizeof(u_.buf)); + // memcpy above Implicitly does: + // set_ndims_byte(b.ndims_byte()); + // set_tag(b.tag()); + b.set_tag(REP16); // other shape no longer owns out-of-line data, if any. +} + +inline TensorShapeRep::~TensorShapeRep() { + if (tag() == REP_OUT_OF_LINE) { + DestructorOutOfLine(); + } +} + +inline void TensorShapeRep::operator=(const TensorShapeRep& b) { + num_elements_ = b.num_elements_; + if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) { + memcpy(buf(), b.buf(), sizeof(u_.buf)); + // memcpy above implicitly also does: + // set_tag(b.tag()); + // set_ndims_byte(b.ndims_byte()); + } else { + SlowCopyFrom(b); + } +} + +inline void TensorShapeRep::operator=(TensorShapeRep&& b) { + if (tag() == REP_OUT_OF_LINE) { + DestructorOutOfLine(); + } + num_elements_ = b.num_elements_; + memcpy(buf(), b.buf(), sizeof(u_.buf)); + // memcpy above Implicitly does: + // set_ndims_byte(b.ndims_byte()); + // set_tag(b.tag()); + b.set_tag(REP16); // other shape no longer owns out-of-line data, if any. +} + +inline TensorShape::operator const PartialTensorShape&() const { + // Downcast to the shared representation and upcast to PartialTensorShape + const TensorShapeRep* rep = this; + return *static_cast(rep); +} + +// Declare explicit instantiations in .cc file +extern template class TensorShapeBase; +extern template class TensorShapeBase; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ diff --git a/tensor_shape.proto b/tensor_shape.proto new file mode 100644 index 0000000000000000000000000000000000000000..1ec3c5323c2c7306131ac8c278247a841f0e1b64 --- /dev/null +++ b/tensor_shape.proto @@ -0,0 +1,45 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package tensorflow; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/tensor_shape_test.cc b/tensor_shape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8a9c0bac5b950157044dae07771b6733481ac9e --- /dev/null +++ b/tensor_shape_test.cc @@ -0,0 +1,688 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/tensor_shape.h" + +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +class TensorShapeTestHelper { + public: + static void set_data_type(TensorShape* s, DataType t) { s->set_data_type(t); } + static uint8 data_type(const TensorShape* s) { return s->data_type(); } +}; + +namespace { + +TEST(TensorShapeTest, Default) { + // The default TensorShape constructor constructs a shape of 0-dim + // and 1-element. + TensorShape s; + EXPECT_EQ(s.dims(), 0); + EXPECT_EQ(s.num_elements(), 1); +} + +TEST(TensorShapeTest, set_dim) { + TensorShape s({10, 5}); + + s.set_dim(0, 20); + ASSERT_EQ(2, s.dims()); + EXPECT_EQ(20, s.dim_size(0)); + EXPECT_EQ(100, s.num_elements()); + + s.set_dim(1, 2); + ASSERT_EQ(2, s.dims()); + EXPECT_EQ(2, s.dim_size(1)); + EXPECT_EQ(40, s.num_elements()); +} + +TEST(TensorShapeTest, RemoveDim) { + TensorShape s({10, 5}); + s.RemoveDim(0); + EXPECT_EQ(5, s.num_elements()); + ASSERT_EQ(1, s.dims()); +} + +TEST(TensorShapeTest, RemoveAndAddDim) { + TensorShape s({10, 5, 20}); + s.RemoveDim(1); + s.AddDim(100); + + EXPECT_EQ(20000, s.num_elements()); + ASSERT_EQ(3, s.dims()); +} + +TEST(TensorShapeTest, RemoveLastDims) { + TensorShape s({2, 3, 5, 7}); + s.RemoveLastDims(1); + + ASSERT_EQ(3, s.dims()); + EXPECT_EQ(30, s.num_elements()); + + s.RemoveLastDims(2); + ASSERT_EQ(1, s.dims()); + EXPECT_EQ(2, s.dim_size(0)); +} + +TEST(TensorShapeTest, RemoveDimRange) { + TensorShape s0({2, 3, 5, 7}); + // Empty interval => noop. + for (int i = -4; i <= 4; ++i) { + s0.RemoveDimRange(i, i); + ASSERT_EQ(4, s0.dims()); + ASSERT_EQ(210, s0.num_elements()); + } + + // Positive begin and end. + s0.RemoveDimRange(3, 1); // Empty interval. + ASSERT_EQ(4, s0.dims()); + ASSERT_EQ(210, s0.num_elements()); + s0.RemoveDimRange(0, 3); + ASSERT_EQ(1, s0.dims()); + EXPECT_EQ(7, s0.dim_size(0)); + TensorShape s1({2, 3, 5, 7}); + s1.RemoveDimRange(2, 3); + ASSERT_EQ(3, s1.dims()); + ASSERT_EQ(42, s1.num_elements()); + + // Negative begin or end. + TensorShape s2({2, 3, 5, 7}); + s2.RemoveDimRange(-2, -3); // Empty interval. + ASSERT_EQ(4, s2.dims()); + ASSERT_EQ(210, s2.num_elements()); + s2.RemoveDimRange(0, -2); + ASSERT_EQ(1, s2.dims()); + ASSERT_EQ(7, s2.dim_size(0)); + TensorShape s3({2, 3, 5, 7}); + s3.RemoveDimRange(-3, -2); + ASSERT_EQ(3, s3.dims()); + ASSERT_EQ(42, s3.num_elements()); +} + +TEST(TensorShapeTest, InvalidShapeProto) { + TensorShapeProto proto; + EXPECT_TRUE(TensorShape::IsValid(proto)); + + proto.add_dim()->set_size(357); + proto.add_dim()->set_size(982); + EXPECT_TRUE(TensorShape::IsValid(proto)); + + proto.Clear(); + proto.add_dim()->set_size(-357); + proto.add_dim()->set_size(-982); + EXPECT_FALSE(TensorShape::IsValid(proto)); + + proto.Clear(); + proto.add_dim()->set_size(1LL << 35); + proto.add_dim()->set_size((1LL << 35) + 1); + EXPECT_FALSE(TensorShape::IsValid(proto)); +} + +TEST(TensorShapeTest, TooManyDimsProto) { + TensorShapeProto proto; + // Deliberate redundancy to ensure that both paths work. + EXPECT_TRUE(TensorShape::IsValid(proto)); + TF_EXPECT_OK(TensorShape::IsValidShape(proto)); + for (int i = 0; i < TensorShape::MaxDimensions(); i++) { + proto.add_dim()->set_size(1); + } + EXPECT_TRUE(TensorShape::IsValid(proto)); + TF_EXPECT_OK(TensorShape::IsValidShape(proto)); + proto.add_dim()->set_size(1); + EXPECT_FALSE(TensorShape::IsValid(proto)); + EXPECT_FALSE(TensorShape::IsValidShape(proto).ok()); +} + +TEST(TensorShapeTest, SetDimForEmptyTensor) { + TensorShape s({10, 5, 20}); + EXPECT_EQ(1000, s.num_elements()); + s.set_dim(1, 0); + EXPECT_EQ(0, s.num_elements()); + s.set_dim(1, 7); + EXPECT_EQ(1400, s.num_elements()); +} + +TEST(TensorShapeTest, AppendShape64BitIndices) { + TensorShape s({10, 2147483648}); + + EXPECT_EQ(10, s.dim_size(0)); + EXPECT_EQ(2147483648, s.dim_size(1)); + + TensorShape s2; + s2.AppendShape(s); + EXPECT_EQ(10, s2.dim_size(0)); + EXPECT_EQ(2147483648, s2.dim_size(1)); +} + +TEST(TensorShapeTest, DataType) { + TensorShape s({}); + EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INVALID); + TensorShapeTestHelper::set_data_type(&s, DT_INT32); + s.AddDim(1); + EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INT32); + s.AddDim(100000); + EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INT32); + TensorShapeTestHelper::set_data_type(&s, DT_UINT16_REF); + s.AddDim(2); + EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF); + s.AddDim(4); + EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF); + s.AddDim(3); + EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF); + + TensorShape s2 = s; + EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_UINT16_REF); + s2.RemoveDim(2); + EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_UINT16_REF); + TensorShapeTestHelper::set_data_type(&s2, DT_FLOAT); + EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_FLOAT); + s2.Clear(); + EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_INVALID); +} + +// ----------------------------------------------------------------------- +// An old implementation of TensorShape using a different representation, +// preserved here in the unittest to allow us to have a randomized unittest +// that makes sure the behavior of TensorShape and TensorShapeOld are +// the same. +class TensorShapeIterOld; // Declared below + +/// Manages the dimensions of a Tensor and their sizes. +class TensorShapeOld { + public: + /// \brief Construct a `TensorShape` from the provided sizes. + /// REQUIRES: `dim_sizes[i] >= 0` + explicit TensorShapeOld(gtl::ArraySlice dim_sizes); + TensorShapeOld(std::initializer_list dim_sizes) + : TensorShapeOld(gtl::ArraySlice(dim_sizes)) {} + + /// REQUIRES: `IsValid(proto)` + explicit TensorShapeOld(const TensorShapeProto& proto); + + /// Create a tensor shape with no dimensions and one element, which you can + /// then call `AddDim()` on. + TensorShapeOld(); + + /// Returns `true` iff `proto` is a valid tensor shape. + static bool IsValid(const TensorShapeProto& proto); + + /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error + /// status otherwise. + static Status IsValidShape(const TensorShapeProto& proto); + + /// Clear a tensor shape + void Clear(); + + /// \brief Add a dimension to the end ("inner-most"). + /// REQUIRES: `size >= 0` + void AddDim(int64 size); + + /// Appends all the dimensions from `shape`. + void AppendShape(const TensorShapeOld& shape); + + /// \brief Insert a dimension somewhere in the `TensorShape`. + /// REQUIRES: `0 <= d <= dims()` + /// REQUIRES: `size >= 0` + void InsertDim(int d, int64 size); + + /// \brief Modifies the size of the dimension `d` to be `size` + /// REQUIRES: `0 <= d < dims()` + /// REQUIRES: `size >= 0` + void set_dim(int d, int64 size); + + /// \brief Removes dimension `d` from the `TensorShape`. + /// REQUIRES: `0 <= d < dims()` + void RemoveDim(int d); + + /// Return the number of dimensions in the tensor. + int dims() const { return dim_sizes_.size(); } + + /// \brief Returns the number of elements in dimension `d`. + /// REQUIRES: `0 <= d < dims()` + // TODO(touts): Rename to `dimension()` to match + // `Eigen::Tensor::dimension()`? + int64 dim_size(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return dim_sizes_[d]; + } + + /// Returns sizes of all dimensions. + gtl::ArraySlice dim_sizes() const { return dim_sizes_; } + + /// \brief Returns the number of elements in the tensor. + /// + /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor` + /// which uses `ptrdiff_t`. + int64 num_elements() const { return num_elements_; } + + /// Returns true if `*this` and `b` have the same sizes. Ignores + /// dimension names. + bool IsSameSize(const TensorShapeOld& b) const; + bool operator==(const TensorShapeOld& b) const { return IsSameSize(b); } + + /// Fill `*proto` from `*this`. + void AsProto(TensorShapeProto* proto) const; + + /// Fill `*dsizes` from `*this`. + template + Eigen::DSizes AsEigenDSizes() const; + + /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in + /// which case we pad the rest of the sizes with 1. + template + Eigen::DSizes AsEigenDSizesWithPadding() const; + + /// For iterating through the dimensions. + TensorShapeIterOld begin() const; + TensorShapeIterOld end() const; + + /// For error messages. + string DebugString() const; + + /// Same as `TensorShape(proto).DebugString()` but doesn't crash for + /// invalid protos. + static string DebugString(const TensorShapeProto& proto); + + private: + // Recalculates the dimensions of this tensor after they are modified. + void recompute_dims(); + + // TODO(josh11b): Maybe use something from the Eigen Tensor library + // for the sizes. + gtl::InlinedVector dim_sizes_; + + // total number of elements (avoids recomputing it each time). + int64 num_elements_; +}; + +struct TensorShapeDimOld { + explicit TensorShapeDimOld(int64 s) : size(s) {} + int64 size; +}; + +class TensorShapeIterOld { + public: + TensorShapeIterOld(const TensorShapeOld* shape, int d) + : shape_(shape), d_(d) {} + bool operator==(const TensorShapeIterOld& rhs) { + DCHECK(shape_ == rhs.shape_); + return d_ == rhs.d_; + } + bool operator!=(const TensorShapeIterOld& rhs) { + DCHECK(shape_ == rhs.shape_); + return d_ != rhs.d_; + } + void operator++() { ++d_; } + TensorShapeDimOld operator*() { + return TensorShapeDimOld(shape_->dim_size(d_)); + } + + private: + const TensorShapeOld* shape_; + int d_; +}; + +// An upper limit of the total number of elements in a tensor. +static const int64 kMaxElements = (1LL << 40); + +bool TensorShapeOld::IsValid(const TensorShapeProto& proto) { + int64 num_elements = 1; + for (const auto& d : proto.dim()) { + if (d.size() < 0) return false; + num_elements *= d.size(); + if (num_elements > kMaxElements) return false; + } + return true; +} + +Status TensorShapeOld::IsValidShape(const TensorShapeProto& proto) { + int64 num_elements = 1; + for (const auto& d : proto.dim()) { + if (d.size() < 0) { + return errors::InvalidArgument("Shape ", DebugString(proto), + " has negative dimensions; ", + "perhaps an un-fed placeholder?"); + } + num_elements *= d.size(); + if (num_elements > kMaxElements) { + return errors::InvalidArgument("Shape ", DebugString(proto), + " is too large (more than ", kMaxElements, + " entries)"); + } + } + return Status::OK(); +} + +TensorShapeOld::TensorShapeOld(const TensorShapeProto& proto) { + dim_sizes_.reserve(proto.dim_size()); + num_elements_ = 1; + for (const auto& d : proto.dim()) { + AddDim(d.size()); + } +} + +TensorShapeOld::TensorShapeOld(gtl::ArraySlice dim_sizes) { + dim_sizes_.reserve(dim_sizes.size()); + num_elements_ = 1; + for (auto s : dim_sizes) { + AddDim(s); + } +} + +TensorShapeOld::TensorShapeOld() : num_elements_(1) {} + +void TensorShapeOld::Clear() { + dim_sizes_.clear(); + num_elements_ = 1; +} + +void TensorShapeOld::AddDim(int64 size) { + CHECK_GE(size, 0); + dim_sizes_.push_back(size); + num_elements_ *= size; + CHECK_LE(0, num_elements_); + CHECK_LE(num_elements_, kMaxElements); +} + +void TensorShapeOld::AppendShape(const TensorShapeOld& shape) { + for (auto d : shape) AddDim(d.size); +} + +void TensorShapeOld::InsertDim(int d, int64 size) { + CHECK_GE(d, 0); + CHECK_LE(d, dims()); + CHECK_GE(size, 0); + dim_sizes_.insert(dim_sizes_.begin() + d, size); + num_elements_ *= size; + CHECK_LE(0, num_elements_); + CHECK_LE(num_elements_, kMaxElements); +} + +void TensorShapeOld::set_dim(int d, int64 size) { + CHECK_GE(d, 0); + CHECK_LT(d, dims()); + CHECK_GE(size, 0); + + // Update the number of elements. num_elements_ is int64. + dim_sizes_[d] = size; + recompute_dims(); +} + +void TensorShapeOld::RemoveDim(int d) { + CHECK_GE(d, 0); + CHECK_LT(d, dims()); + + // Update the number of elements and remove the dimension from the + // sizes. + dim_sizes_.erase(dim_sizes_.begin() + d); + recompute_dims(); +} + +void TensorShapeOld::recompute_dims() { + num_elements_ = 1; + for (auto s : dim_sizes_) { + num_elements_ *= s; + CHECK_LE(0, num_elements_); + CHECK_LE(num_elements_, kMaxElements); + } +} + +bool TensorShapeOld::IsSameSize(const TensorShapeOld& b) const { + if (b.dims() != dims()) return false; + for (int d = 0; d < dims(); d++) { + if (dim_size(d) != b.dim_size(d)) return false; + } + return true; +} + +void TensorShapeOld::AsProto(TensorShapeProto* proto) const { + proto->Clear(); + for (size_t d = 0; d < dim_sizes_.size(); ++d) { + auto* dim = proto->add_dim(); + dim->set_size(dim_sizes_[d]); + } +} + +TensorShapeIterOld TensorShapeOld::begin() const { + return TensorShapeIterOld(this, 0); +} + +TensorShapeIterOld TensorShapeOld::end() const { + return TensorShapeIterOld(this, dims()); +} + +string TensorShapeOld::DebugString() const { + return strings::StrCat( + "[", str_util::Join(gtl::ArraySlice(dim_sizes_), ","), "]"); +} + +string TensorShapeOld::DebugString(const TensorShapeProto& proto) { + string s = "["; + bool first = true; + for (const auto& d : proto.dim()) { + strings::StrAppend(&s, first ? "" : ",", d.size()); + first = false; + } + strings::StrAppend(&s, "]"); + return s; +} +// End of old implementation +// ------------------------------------------------------------------------ + +static int64 SkewedSize(random::SimplePhilox* gen, int64 current_elements) { + int64 result = 0; + do { + if (current_elements < 100) { + result = gen->Uniform(100000); + } else { + result = gen->Uniform(2); + } + } while ((result * current_elements >= 1LL << 34) || + (result * current_elements < 0)); + return result; +} + +TEST(TensorShapeTest, Randomized) { + // We do a randomized test to verify that the behavior of the + // TensorShape implementation (which changes representations depending + // on the values) is identical to our older, more straightforward (but + // more memory hungry) implementation (TensorShapeOld). + random::PhiloxRandom philox(7, 7); + random::SimplePhilox gen(&philox); + TensorShape s; + TensorShapeOld sold; + TensorShapeProto sp; + TensorShapeProto spold; + LOG(INFO) << "Sizes: " << sizeof(TensorShape) << " vs " + << sizeof(TensorShapeOld); + for (int i = 0; i < 100000; i++) { + s.AsProto(&sp); + sold.AsProto(&spold); + EXPECT_EQ(sp.DebugString(), spold.DebugString()); + if ((i % 1000) == 0) { + fprintf(stderr, "ITERATION %d: %s\n", i, sp.DebugString().c_str()); + } + EXPECT_EQ(s.num_elements(), sold.num_elements()); + + // Test moves. + TensorShape copy = s; + TensorShape moved(std::move(copy)); + EXPECT_EQ(s, moved); + copy = s; + moved = std::move(copy); + EXPECT_EQ(s, moved); + + int64 ne = sold.num_elements(); + int r = gen.Uniform(100); + if (r < 10) { + int64 sz = SkewedSize(&gen, sold.num_elements()); + s.AddDim(sz); + sold.AddDim(sz); + } else if (r < 15) { + s.Clear(); + sold.Clear(); + } else if (r < 35 && s.dims() > 0 && ne > 0 && ne < 100000000) { + int dim = gen.Uniform(s.dims()); + s.RemoveDim(dim); + sold.RemoveDim(dim); + } else if (r < 50 && ne > 0 && ne < 100000000) { + int dim = gen.Uniform(s.dims() + 1); + int64 sz = SkewedSize(&gen, sold.num_elements()); + s.InsertDim(dim, sz); + sold.InsertDim(dim, sz); + } else { + std::vector sizes; + const int N = (gen.Uniform(4) == 0) ? gen.Uniform(10) : gen.Uniform(3); + int64 num_elements = 1; + for (int i = 0; i < N; i++) { + int64 sz = SkewedSize(&gen, num_elements); + sizes.push_back(sz); + num_elements *= std::max(1, sz); + } + + s = TensorShape(sizes); + sold = TensorShapeOld(sizes); + } + } +} + +TEST(TensorShapeTest, Large) { + // We used to cap shapes at 2**40 elements. Ensure the + // bound is now higher. + int64 one = 1; + int64 max = std::numeric_limits::max(); + EXPECT_EQ(TensorShape({max}).num_elements(), max); + EXPECT_EQ(TensorShape({1, max}).num_elements(), max); + EXPECT_EQ(TensorShape({max, 1}).num_elements(), max); + EXPECT_EQ(TensorShape({one << 62}).num_elements(), one << 62); + EXPECT_EQ(TensorShape({one << 20, one << 41}).num_elements(), one << 61); + EXPECT_EQ(TensorShape({1000, 1000, 1000, 1000, 1000, 1000}).num_elements(), + 1e18); +} + +TEST(TensorShapeTest, Overflow) { + int64 one = 1; + std::vector> overflows = { + {1 << 30, 1 << 30, 1 << 30}, {1 << 5, (one << 60) + 1}, + }; + for (const auto& overflow : overflows) { + TensorShapeProto proto; + for (auto dim : overflow) { + proto.add_dim()->set_size(dim); + } + EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, + TensorShape::IsValidShape(proto).code()); + TensorShape shape; + EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, + TensorShapeUtils::MakeShape(overflow, &shape).code()); + } +} + +TEST(TensorShapeTest, UnknownRank) { + // NOTE(irving): Unfortunately, for historical reasons we have to allow an + // TensorShapeProto with unknown_rank() set to be parsed as a TensorShape. + // Would be nice to tighten this, but it's tricky given backwards + // compatibility requirements. + TensorShapeProto proto; + proto.set_unknown_rank(true); + EXPECT_TRUE(TensorShape::IsValid(proto)); + TF_EXPECT_OK(TensorShape::IsValidShape(proto)); + EXPECT_EQ(TensorShape(), TensorShape(proto)); + + proto.add_dim()->set_size(7); + EXPECT_TRUE(TensorShape::IsValid(proto)); + TF_EXPECT_OK(TensorShape::IsValidShape(proto)); + EXPECT_EQ(TensorShape({7}), TensorShape(proto)); +} + +TEST(TensorShapeUtilsTest, StartsWith) { + EXPECT_TRUE(TensorShapeUtils::StartsWith(TensorShape({}), TensorShape({}))); + EXPECT_TRUE( + TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({}))); + EXPECT_TRUE( + TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2}))); + EXPECT_TRUE( + TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2, 3}))); + EXPECT_TRUE(TensorShapeUtils::StartsWith(TensorShape({2, 3, 4}), + TensorShape({2, 3}))); + EXPECT_FALSE( + TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({3}))); + EXPECT_FALSE( + TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2, 4}))); + EXPECT_FALSE(TensorShapeUtils::StartsWith(TensorShape({2, 3}), + TensorShape({2, 3, 4}))); + EXPECT_FALSE(TensorShapeUtils::StartsWith(TensorShape({2, 3, 4}), + TensorShape({3, 4}))); +} + +TEST(TensorShapeUtilsTest, EndsWith) { + EXPECT_TRUE(TensorShapeUtils::EndsWith(TensorShape({}), TensorShape({}))); + EXPECT_TRUE(TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({}))); + EXPECT_TRUE( + TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({3}))); + EXPECT_TRUE( + TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 3}))); + EXPECT_TRUE( + TensorShapeUtils::EndsWith(TensorShape({2, 3, 4}), TensorShape({3, 4}))); + EXPECT_FALSE( + TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2}))); + EXPECT_FALSE( + TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 4}))); + EXPECT_FALSE( + TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 3, 4}))); + EXPECT_FALSE( + TensorShapeUtils::EndsWith(TensorShape({2, 3, 4}), TensorShape({2, 3}))); +} + +// A few different test cases for tensor sizes for benchmarks +static std::vector MakeSizes(int arg) { + std::vector sizes; + switch (arg) { + case 0: + sizes = {100}; + break; + case 1: + sizes = {100, 1000}; + break; + case 2: + sizes = {100, 1000000}; + break; + case 3: + sizes = {100, 256, 192, 3}; + break; + case 4: + sizes = {1, 2, 1ll << 34, 1, 1, 1}; + break; + } + return sizes; +} + +static void BM_TensorShape_Assign(int iters, int arg) { + TensorShape s(MakeSizes(arg)); + while (--iters > 0) { + TensorShape s2 = s; + } +} +BENCHMARK(BM_TensorShape_Assign)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4); + +} // namespace +} // namespace tensorflow diff --git a/tensor_slice.cc b/tensor_slice.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb3a7f52c2ba5f9622242ff424abeca3457b8ec4 --- /dev/null +++ b/tensor_slice.cc @@ -0,0 +1,273 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/tensor_slice.h" +#include +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); } + +TensorSlice::TensorSlice(const TensorSliceProto& proto) { + starts_.reserve(proto.extent_size()); + lengths_.reserve(proto.extent_size()); + for (const auto& e : proto.extent()) { + starts_.push_back(e.start()); + lengths_.push_back(GetExtentLength(e)); + } +} + +TensorSlice::TensorSlice( + std::initializer_list> extents) { + starts_.reserve(extents.size()); + lengths_.reserve(extents.size()); + for (const auto& e : extents) { + starts_.push_back(e.first); + lengths_.push_back(e.second); + } +} + +Status TensorSlice::Parse(const string& str, TensorSlice* slice) { + std::vector items = str_util::Split(str, ':', str_util::SkipEmpty()); + slice->starts_.reserve(items.size()); + slice->lengths_.reserve(items.size()); + for (const string& x : items) { + int64 s, l; + if (x == "-") { + // "everything" + s = 0; + l = kFullExtent; + } else { + std::vector sl = str_util::Split(x, ',', str_util::SkipEmpty()); + if (sl.size() != 2 || !strings::safe_strto64(sl[0], &s) || + !strings::safe_strto64(sl[1], &l)) { + return errors::InvalidArgument( + "Expected a pair of numbers or '-' " + "but got '", + x, "': string = ", str); + } + if (s < 0 || l <= 0) { + return errors::InvalidArgument( + "Expected non-negative start and " + "positive length but got start = ", + s, ", length = ", l, ": string = ", str); + } + } + slice->starts_.push_back(s); + slice->lengths_.push_back(l); + } + + return Status::OK(); +} + +void TensorSlice::Clear() { + starts_.clear(); + lengths_.clear(); +} + +bool TensorSlice::IsFull() const { + for (int d = 0; d < dims(); ++d) { + if (!IsFullAt(d)) return false; + } + return true; +} + +void TensorSlice::SetFullSlice(int dim) { + Clear(); + starts_.reserve(dim); + lengths_.reserve(dim); + for (int d = 0; d < dim; ++d) { + starts_.push_back(0); + lengths_.push_back(kFullExtent); + } +} + +void TensorSlice::Extend(int dim) { + int old_dim = dims(); + DCHECK_LE(old_dim, dim); + starts_.resize(dim); + lengths_.resize(dim); + for (int d = old_dim; d < dim; ++d) { + starts_[d] = 0; + lengths_[d] = kFullExtent; + } +} + +void TensorSlice::AsProto(TensorSliceProto* proto) const { + for (int d = 0; d < dims(); ++d) { + TensorSliceProto::Extent* e = proto->add_extent(); + // We only need to record the explicit slice for non-full slices + if (!IsFullAt(d)) { + e->set_start(starts_[d]); + e->set_length(lengths_[d]); + } + } +} + +string TensorSlice::DebugString() const { + string buffer; + bool first = true; + for (int d = 0; d < dims(); ++d) { + if (!first) { + buffer.append(":"); + } + string s; + if (IsFullAt(d)) { + buffer.append("-"); + } else { + strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]); + } + first = false; + } + return buffer; +} + +bool TensorSlice::Intersect(const TensorSlice& other, + TensorSlice* result) const { + // First, if two slices have different ranks, they obviously don't overlap + // -- in fact they are not compatible. + if (dims() != other.dims()) { + return false; + } + + // Setting the result to the right dimension + if (result) { + result->SetFullSlice(dims()); + } + // The two slices overlap if they overlap in all dimensions. + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + if (result) { + result->set_start(d, other.start(d)); + result->set_length(d, other.length(d)); + } + } else if (other.IsFullAt(d)) { + if (result) { + result->set_start(d, start(d)); + result->set_length(d, length(d)); + } + } else { + // If we have an intersection here, it should have a start that is the + // max of the two starts and an end that is the min of the two ends. + int64 s = std::max(start(d), other.start(d)); + int64 l = std::min(end(d), other.end(d)) - s; + if (l > 0) { + // We have a real intersection + if (result) { + result->set_start(d, s); + result->set_length(d, l); + } + } else { + // We don't have an intersection for this dimension -- thus we don't + // have any intersection at all. + if (result) { + result->Clear(); + } + return false; + } + } + } + // If we are here, we know there is overlap in every dimension. + return true; +} + +bool TensorSlice::operator==(const TensorSlice& other) const { + return dims() == other.dims() && starts_ == other.starts_ && + lengths_ == other.lengths_; +} + +void TensorSlice::ComputeRelative(const TensorSlice& sub, + TensorSlice* relative) const { + DCHECK_EQ(dims(), sub.dims()); + relative->SetFullSlice(dims()); + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + relative->set_start(d, sub.start(d)); + relative->set_length(d, sub.length(d)); + } else { + // Otherwise the relative start is the difference between the start of + // sub and the start of base + relative->set_start(d, sub.start(d) - start(d)); + relative->set_length(d, sub.length(d)); + } + } +} + +void TensorSlice::UpdateToCover(const TensorSlice& other) { + DCHECK_EQ(dims(), other.dims()); + for (int d = 0; d < dims(); ++d) { + if (!IsFullAt(d)) { + if (other.IsFullAt(d)) { + starts_[d] = 0; + lengths_[d] = kFullExtent; + } else { + const auto new_end = std::max(end(d), other.end(d)); + set_start(d, std::min(start(d), other.start(d))); + set_length(d, new_end - start(d)); + } + } + } +} + +// static +bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) { + return extent.has_length_case() == TensorSliceProto::Extent::kLength; +} + +// static +int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) { + if (!HasExtentLength(extent)) return -1; + return extent.length(); +} + +Status TensorSlice::SliceTensorShape(const TensorShape& shape, + TensorShape* result_shape) const { + result_shape->Clear(); + // Mismatching ranks: we can't apply the slice at all. + if (shape.dims() != dims()) { + return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(), + ", slice = ", DebugString()); + } + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + result_shape->AddDim(shape.dim_size(d)); + } else { + // Check if the extent applies to the dimension + if (end(d) <= shape.dim_size(d)) { + // Yes: the end is within the range of the dim -- we adjust the result + // shape so that its size along this dimension is the length of the + // slice. + result_shape->AddDim(length(d)); + } else { + // The extent doesn't apply to the dimension + result_shape->Clear(); + return errors::Internal("Extent in dimension ", d, + " out of bounds: shape = ", shape.DebugString(), + ", slice = ", DebugString()); + } + } + } + // If we are here, we have successfully applied the shape. + return Status::OK(); +} + +const int64 TensorSlice::kFullExtent = -1; + +} // namespace tensorflow diff --git a/tensor_slice.h b/tensor_slice.h new file mode 100644 index 0000000000000000000000000000000000000000..6019737342a1d5033411a1080d849585ec8544bf --- /dev/null +++ b/tensor_slice.h @@ -0,0 +1,224 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_ + +#include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// A tensor slice represents a slice of a given tensor. It is represented by a +// list of (start, length) pairs, where the size of the list is the rank of the +// tensor. + +class TensorSlice { + public: + // Construct a tensor slice: you have a number of ways: + // -- creating an empty slice + // -- from just a dimension (in this case it will create a full slice) + // -- from an array of pairs of integers. + // -- from a TensorSliceProto protocol buffer + // -- from a string format of "start,length:start,length..." where each + // "start,length" pair represents the slice on one dimension. We allow a + // special "-" that means "everything for this dimension". One such example + // is: 0,10:-:14,1:-:- + TensorSlice() {} + explicit TensorSlice(int dim); + explicit TensorSlice(const TensorSliceProto& proto); + explicit TensorSlice(std::initializer_list> extents); + + static Status Parse(const string& str, TensorSlice* output); + static TensorSlice ParseOrDie(const string& str) { + TensorSlice ret; + Status s = Parse(str, &ret); + if (!s.ok()) { + LOG(FATAL) << "Could not parse TensorSlice"; + } + return ret; + } + + void Clear(); + + // Accessors + int dims() const { return starts_.size(); } + + int64 start(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return starts_[d]; + } + + int64 length(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return lengths_[d]; + } + + int64 end(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return start(d) + length(d); + } + + void set_start(int d, int64 x) { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + DCHECK_GE(x, 0); + starts_[d] = x; + } + + void set_length(int d, int64 x) { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + lengths_[d] = x; + } + + // If we have a full slice along dimension "d". + bool IsFullAt(int d) const { + return lengths_[d] == kFullExtent && starts_[d] == 0; + } + + // If this is a full slice, i.e. IsFullAt(d) for every d. + bool IsFull() const; + + // Set the slice to be a full slice of "dim" dimensions + void SetFullSlice(int dim); + + // Extend a slice to "dim" dimensions: all the added dimensions are full. + // Requires: dim >= dims(). + void Extend(int dim); + + // Conversion of a TensorSlice to other formats + void AsProto(TensorSliceProto* proto) const; + string DebugString() const; + + // Fill *indices and *sizes from *this (so that we can use the slice() + // function in eigen tensor). We need a tensor shape in case some of the + // slices are full slices. + // We allow NDIMS to be greater than dims(), in which case we will pad the + // higher dimensions with trivial dimensions. + template + void FillIndicesAndSizes( + const TensorShape& shape, + Eigen::DSizes* indices, + Eigen::DSizes* sizes) const; + + // Interaction with other TensorSlices. + + // Compute the intersection with another slice and if "result" is not + // nullptr, store the results in *result; returns true if there is any real + // intersection. + bool Intersect(const TensorSlice& other, TensorSlice* result) const; + // A short hand. + bool Overlaps(const TensorSlice& other) const { + return Intersect(other, nullptr); + } + + // Equals iff "*this" and "other" are logically equivalent. + bool operator==(const TensorSlice& other) const; + bool operator!=(const TensorSlice& other) const { return !(*this == other); } + + // Interaction with TensorShape. + + // Slices a shape and stores the result into *result_shape. + // Requires that the shape and *this have the same rank. + // For example, given a tensor shape of {3, 4, 5}, and a slice of + // 1,2:-:0,2, the result shape is {2, 4, 2}. + Status SliceTensorShape(const TensorShape& shape, + TensorShape* result_shape) const; + + // Given slice "sub" where "sub" is fully contained in *this, + // (meaning that the intersection of "sub" and *this equals "sub"), computes + // the "relative" slice of "sub" with respect to *this. + // + // In other words, if we use A>S to denote slicing a shape S with a slice A, + // then the function is computing a slice X such that: + // X > (this > S) = sub > S + // for any shape S. + // + // In general, along every dimension, the start of the relative slice is the + // start of the "sub" slice minus the start of *this; the length of the + // relative slice is the length of the "sub" slice. + // + // For example, say we have a shape of {3, 4, 5}, "this" is 0,2:-:1,2, and + // "sub" is 1,1:2:2,1,2, then the related slice is 1,1:2,2:0,2. + // + // The caller needs to make sure that "sub" is indeed a sub-slice of *this; + // otherwise the result is undefined. + void ComputeRelative(const TensorSlice& sub, TensorSlice* relative) const; + + // Updates the slice in such a way that it fully covers "other" slice. + // Note, "other" slice should refer to the same tensor shape. + // Example: + // given a slice [2:4, :, 3:] and "other" slice [:, 1:4, 2:4] the + // updated slice would be [:, :, 2:]. Here is why: + // dim 0: "2:4" U ":" -> ":" + // dim 1: ":" U "1-4" -> ":" + // dim 2: "3:" U "2:4" -> "2:" + void UpdateToCover(const TensorSlice& other); + + // Returns true if the length field was specified in an Extent. + static bool HasExtentLength(const TensorSliceProto::Extent& extent); + + // Returns the value of the length field in an Extent, or -1 if it + // is not present. + static int64 GetExtentLength(const TensorSliceProto::Extent& extent); + + private: + // a length value of kFullExtent (-1) means we have a full slice at this + // dimension. It's defined in tensor_slice.cc. + static const int64 kFullExtent; + + // TODO(yangke): switch to Eigen once it supports variable size arrays. + // A value of + gtl::InlinedVector starts_; + gtl::InlinedVector lengths_; +}; + +template +void TensorSlice::FillIndicesAndSizes( + const TensorShape& shape, Eigen::DSizes* indices, + Eigen::DSizes* sizes) const { + CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape " + << "slices: shape = " << shape.DebugString() + << ", slice = " << DebugString(); + CHECK_GE(NDIMS, dims()) << "Asking for a " << NDIMS << "-dim slice from " + << "a slice of dimension " << dims(); + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + (*indices)[d] = 0; + (*sizes)[d] = shape.dim_size(d); + } else { + (*indices)[d] = starts_[d]; + (*sizes)[d] = lengths_[d]; + } + } + for (int d = dims(); d < NDIMS; ++d) { + (*indices)[d] = 0; + (*sizes)[d] = 1; + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_ diff --git a/tensor_slice.proto b/tensor_slice.proto new file mode 100644 index 0000000000000000000000000000000000000000..24b01661dc4691207da08babbc5e04a51b840396 --- /dev/null +++ b/tensor_slice.proto @@ -0,0 +1,37 @@ +// Protocol buffer representing slices of a tensor + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorSliceProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package tensorflow; + +// Can only be interpreted if you know the corresponding TensorShape. +message TensorSliceProto { + // Extent of the slice in one dimension. + message Extent { + // Either both or no attributes must be set. When no attribute is set + // means: All data in that dimension. + + // Start index of the slice, starting at 0. + int64 start = 1; + + // Length of the slice: if the length is missing or -1 we will + // interpret this as "everything in this dimension". We use + // "oneof" to preserve information about whether the length is + // present without changing the serialization format from the + // prior proto2 version of this proto. + oneof has_length { + int64 length = 2; + } + }; + + // Extent of the slice in all tensor dimensions. + // + // Must have one entry for each of the dimension of the tensor that this + // slice belongs to. The order of sizes is the same as the order of + // dimensions in the TensorShape. + repeated Extent extent = 1; +}; diff --git a/tensor_slice_test.cc b/tensor_slice_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..54e680484e228b2256a0f6e8689525aa2b8300bc --- /dev/null +++ b/tensor_slice_test.cc @@ -0,0 +1,324 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/tensor_slice.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// Basic tests +TEST(TensorSliceTest, Basic) { + { + // Repeatedly setting FullSlice should work. + TensorSlice s(3); + EXPECT_EQ("-:-:-", s.DebugString()); + EXPECT_TRUE(s.IsFull()); + + s.SetFullSlice(4); + EXPECT_EQ("-:-:-:-", s.DebugString()); + EXPECT_TRUE(s.IsFull()); + } +} + +// Testing for serialization and parsing for the string format of slices. +TEST(TensorSliceTest, Serialization) { + // Serialization + { + TensorSlice s({{0, -1}, {0, 10}, {14, 1}, {0, -1}}); + EXPECT_EQ("-:0,10:14,1:-", s.DebugString()); + EXPECT_TRUE(!s.IsFull()); + } + + { + TensorSliceProto proto; + // Define ptxt outside ASSERT_TRUE call to work around bug in some + // versions of gcc that breaks when you use raw string literals + // inside macro expansions. + // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971 + const char* ptxt = R"PROTO( + extent { } + extent { start: 0 length: 10 } + extent { start: 14 length: 1 } + extent { } + )PROTO"; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto)); + TensorSlice s(proto); + EXPECT_EQ("-:0,10:14,1:-", s.DebugString()); + EXPECT_TRUE(!s.IsFull()); + } + + // Parsing + { + TensorSlice s = TensorSlice::ParseOrDie("-:-:1,3:4,5"); + TensorSliceProto proto; + s.AsProto(&proto); + EXPECT_EQ( + "extent { } " + "extent { } " + "extent { start: 1 length: 3 } " + "extent { start: 4 length: 5 }", + proto.ShortDebugString()); + EXPECT_TRUE(!s.IsFull()); + } + + // Failed parsing + { + TensorSlice slice; + Status s = TensorSlice::Parse("-:-:1,3:4:5", &slice); + EXPECT_EQ( + "Invalid argument: " + "Expected a pair of numbers or '-' but got '4': " + "string = -:-:1,3:4:5", + s.ToString()); + } + { + TensorSlice slice; + Status s = TensorSlice::Parse("-:-1,3", &slice); + EXPECT_EQ( + "Invalid argument: " + "Expected non-negative start and positive length but got " + "start = -1, length = 3: string = -:-1,3", + s.ToString()); + } + + // int64 parsing + { + TensorSlice s = + TensorSlice::ParseOrDie("9223372036854775807,9223372036854775807"); + TensorSliceProto proto; + s.AsProto(&proto); + EXPECT_EQ( + "extent { start: 9223372036854775807 length: 9223372036854775807 }", + proto.ShortDebugString()); + EXPECT_TRUE(!s.IsFull()); + } + + // int64 parsing failure + { + TensorSlice slice; + Status s = + TensorSlice::Parse("19223372036854775808,19223372036854775808", &slice); + EXPECT_EQ( + "Invalid argument: Expected a pair of numbers or '-' but got " + "'19223372036854775808,19223372036854775808': string = " + "19223372036854775808,19223372036854775808", + s.ToString()); + } +} + +// Testing the slice intersection +TEST(TensorSliceTest, Intersection) { + // "EVERYTHING" intersects with everything + { + TensorSlice a = TensorSlice::ParseOrDie("-:-"); + TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4"); + TensorSlice c; + EXPECT_TRUE(a.Intersect(b, &c)); + EXPECT_EQ("1,2:3,4", c.DebugString()); + } + + { + TensorSlice a = TensorSlice::ParseOrDie("-:-"); + TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4"); + TensorSlice c; + EXPECT_TRUE(b.Intersect(a, &c)); + EXPECT_EQ("1,2:3,4", c.DebugString()); + } + + // Overlap at all dimensions + { + TensorSlice a = TensorSlice::ParseOrDie("1,5:2,6:3,7:5,10"); + TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4:9,10:12,1"); + TensorSlice c; + EXPECT_TRUE(a.Intersect(b, &c)); + EXPECT_EQ("1,2:3,4:9,1:12,1", c.DebugString()); + } + + // A mixture of everything and non-trivial slices + { + TensorSlice a = TensorSlice::ParseOrDie("-:1,1"); + TensorSlice b = TensorSlice::ParseOrDie("-:0,2"); + TensorSlice c; + EXPECT_TRUE(a.Intersect(b, &c)); + EXPECT_EQ("-:1,1", c.DebugString()); + } + + // No overlap on dimension 3: "3,1" and "4,5" don't intersect + { + TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:5,6"); + TensorSlice b = TensorSlice::ParseOrDie("1,3:4,5:1,6"); + TensorSlice c; + EXPECT_FALSE(a.Intersect(b, &c)); + EXPECT_EQ("", c.DebugString()); + } + // No intersection when there are different dimensions + { + TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:-"); + TensorSlice b = TensorSlice::ParseOrDie("-:-"); + TensorSlice c; + EXPECT_FALSE(a.Intersect(b, &c)); + EXPECT_EQ("", c.DebugString()); + } +} + +// Testing applying a slice to a tensor shape +TEST(TensorSliceTest, SliceTensorShape) { + // A proper application + { + TensorSlice a = TensorSlice::ParseOrDie("1,1:-:4,1:2,6"); + TensorShape x({2, 4, 5, 8}); + TensorShape y; + TF_EXPECT_OK(a.SliceTensorShape(x, &y)); + EXPECT_EQ("[1,4,1,6]", y.DebugString()); + } + + // An invalid application -- dimension 2 is out of bounds + { + TensorSlice a = TensorSlice::ParseOrDie("1,1:1,4:-:-"); + TensorShape x({2, 4, 5, 8}); + TensorShape y; + EXPECT_EQ( + "Internal: " + "Extent in dimension 1 out of bounds: " + "shape = [2,4,5,8], slice = 1,1:1,4:-:-", + a.SliceTensorShape(x, &y).ToString()); + EXPECT_EQ("[]", y.DebugString()); + } +} + +// Testing the computation of relative slices. +TEST(TensorSliceTest, ComputeRelative) { + // Easy case: base is "everything" + { + TensorSlice base = TensorSlice::ParseOrDie("-:-:-:-"); + TensorSlice sub = TensorSlice::ParseOrDie("-:1,2:-:3,4"); + TensorSlice relative; + base.ComputeRelative(sub, &relative); + EXPECT_EQ("-:1,2:-:3,4", relative.DebugString()); + } + + // A slightly more complicated case + { + TensorSlice base = TensorSlice::ParseOrDie("1,2:3,4:-:5,1"); + TensorSlice sub = TensorSlice::ParseOrDie("1,1:4,2:3,3:5,1"); + TensorSlice relative; + base.ComputeRelative(sub, &relative); + EXPECT_EQ("0,1:1,2:3,3:0,1", relative.DebugString()); + } +} + +TEST(TensorSliceTest, ExtentLength) { + TensorSliceProto proto; + // Define ptxt outside ASSERT_TRUE call to work around bug in some + // versions of gcc that breaks when you use raw string literals + // inside macro expansions. + // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971 + const char* ptxt = R"PROTO( + extent { } + extent { start: 0 length: 10 } + extent { start: 14 length: 1 } + extent { } + )PROTO"; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto)); + EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(0))); + EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(1))); + EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(2))); + EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(3))); + EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(0))); + EXPECT_EQ(10, TensorSlice::GetExtentLength(proto.extent(1))); + EXPECT_EQ(1, TensorSlice::GetExtentLength(proto.extent(2))); + EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(3))); +} + +TEST(TensorSliceTest, Deserialization) { + // Serialization of + // extent { length: 5 } + // extent { start: 0 length: 10 } + // extent { start: 14 length: 1 } + // extent { start: 1 } + // extent { } + // in proto2 and proto3: + const char pb2[] = + "\x0A\x02\x10\x05\x0A\x04\x08\x00" + "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00"; + const char pb3[] = + "\x0A\x02\x10\x05\x0A\x02" + "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00"; + // (The difference is that in the proto3 version, "start: 0" isn't included + // since 0 is start's default value.) + + TensorSliceProto proto2; + ASSERT_TRUE(proto2.ParseFromArray(pb2, sizeof(pb2) - 1)); + TensorSlice ts2(proto2); + + TensorSliceProto proto3; + ASSERT_TRUE(proto3.ParseFromArray(pb3, sizeof(pb3) - 1)); + TensorSlice ts3(proto3); + + // Both serializations should be interpreted the same. + EXPECT_EQ("0,5:0,10:14,1:1,-1:-", ts2.DebugString()); + EXPECT_EQ("0,5:0,10:14,1:1,-1:-", ts3.DebugString()); +} + +TEST(TensorSliceTest, UpdateToCover) { + // [2:4, :, 3:] + TensorSlice s({{2, 2}, {0, -1}, {3, 7}}); + // [:, 1:4, 2:4] + TensorSlice other({{0, -1}, {1, 3}, {2, 2}}); + + s.UpdateToCover(other); + // [:, :, 2:] + EXPECT_EQ("-:-:2,8", s.DebugString()); +} + +TEST(TensorSliceTest, IsFull) { + TensorSlice slice(3); + EXPECT_TRUE(slice.IsFull()); + + TensorSlice slice2({{0, -1}}); + EXPECT_TRUE(slice2.IsFull()); + + TensorSlice slice3({{0, -1}, {0, -1}, {14, 1}}); + EXPECT_TRUE(!slice3.IsFull()); +} + +TEST(TensorSliceTest, Equality) { + { // Dims are different. + TensorSlice slice1(3); + TensorSlice slice2(2); + EXPECT_TRUE(slice1 != slice2); + EXPECT_TRUE(slice2 != slice1); + } + { // Both are 3-dim full slices. + TensorSlice slice1(3); + TensorSlice slice2({{0, -1}, {0, -1}, {0, -1}}); + EXPECT_TRUE(slice1 == slice2); + EXPECT_TRUE(slice2 == slice1); + } + { // Differs in one dimension. + TensorSlice slice1(3); + TensorSlice slice2({{0, -1}, {0, 1}, {0, -1}}); + EXPECT_TRUE(slice1 != slice2); + EXPECT_TRUE(slice2 != slice1); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensor_test.cc b/tensor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..14828804285a8115dd49f596b4aea38f7f6af1ff --- /dev/null +++ b/tensor_test.cc @@ -0,0 +1,1337 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/tensor.h" + +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +class TensorTestHelper { + public: + // This is an operation that can be done by VariableOp. + static void set_shape(Tensor* t, const TensorShape& s) { t->set_shape(s); } +}; + +// To make TestCopies do the right thing. +inline bool operator==(const ResourceHandle& a, const ResourceHandle& b) { + return a.device() == b.device() && a.container() == b.container() && + a.name() == b.name() && a.hash_code() == b.hash_code() && + a.maybe_type_name() == b.maybe_type_name(); +} + +inline bool operator==(const Variant& a, const Variant& b) { + if (a.is_empty()) { + return b.is_empty(); + } + + if (a.TypeId() != b.TypeId()) return false; + if (a.TypeName() != b.TypeName()) return false; + + VariantTensorData a_data, b_data; + a.Encode(&a_data); + b.Encode(&b_data); + + string a_metadata; + string b_metadata; + a_data.get_metadata(&a_metadata); + b_data.get_metadata(&b_metadata); + if (a_metadata != b_metadata) return false; + + if (a_data.tensors_size() != b_data.tensors_size()) return false; + + for (int i = 0; i < a_data.tensors_size(); ++i) { + TensorProto a_proto, b_proto; + a_data.tensors(i).AsProtoTensorContent(&a_proto); + b_data.tensors(i).AsProtoTensorContent(&b_proto); + string a_str, b_str; + a_proto.SerializeToString(&a_str); + b_proto.SerializeToString(&b_str); + if (a_str != b_str) return false; + } + + return true; +} + +TEST(TensorTest, Default) { + Tensor t; + EXPECT_EQ(t.dtype(), DT_FLOAT); + EXPECT_EQ(t.dims(), 1); + EXPECT_EQ(t.NumElements(), 0); +} + +TEST(TensorTest, DataType_Traits) { + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_FALSE(std::is_trivial::value); + + EXPECT_EQ(sizeof(bool), 1); + + // Unfortunately. std::complex::complex() initializes (0, 0). + EXPECT_FALSE(std::is_trivial::value); + EXPECT_FALSE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + EXPECT_TRUE(std::is_trivial::value); + struct MyComplex64 { + float re, im; + }; + EXPECT_TRUE(std::is_trivial::value); + struct MyComplex128 { + double re, im; + }; + EXPECT_TRUE(std::is_trivial::value); +} + +template +void TestCopies(const Tensor& t) { + { + LOG(INFO) << "CopyFrom()"; + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.CopyFrom(t, t.shape())); + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "operator=()"; + Tensor t2(t.dtype()); + t2 = t; + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "deep copy"; + Tensor t2(t.dtype(), t.shape()); + t2.flat() = t.flat(); + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "AsProtoField()"; + TensorProto proto; + t.AsProtoField(&proto); + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.FromProto(proto)); + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "AsProtoTensorContent()"; + TensorProto proto; + t.AsProtoTensorContent(&proto); + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.FromProto(proto)); + test::ExpectTensorEqual(t, t2); + // Make another copy via tensor_content field. + *proto.mutable_tensor_content() = proto.tensor_content(); + Tensor t3(t.dtype()); + EXPECT_TRUE(t3.FromProto(proto)); + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "AsTensor"; + gtl::ArraySlice values(t.flat().data(), t.NumElements()); + Tensor t2 = test::AsTensor(values, t.shape()); + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "Move constructor"; + Tensor t2 = t; + Tensor t3(std::move(t2)); + test::ExpectTensorEqual(t, t3); + EXPECT_TRUE(t3.IsInitialized()); + EXPECT_FALSE(t2.IsInitialized()); + } + { + LOG(INFO) << "Move assignment"; + Tensor t2 = t; + Tensor t3 = std::move(t2); + Tensor* t4 = &t3; + *t4 = std::move(t3); + test::ExpectTensorEqual(t, t3); + EXPECT_TRUE(t3.IsInitialized()); + EXPECT_FALSE(t2.IsInitialized()); + } +} + +TEST(Tensor_Half, Simple) { + Tensor t(DT_HALF, TensorShape({5, 7})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({5, 7}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = static_cast(a * b); + } + } + TestCopies(t); +} + +TEST(Tensor_Bfloat16, Simple) { + Tensor t(DT_BFLOAT16, TensorShape({5, 7})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({5, 7}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = static_cast(a * b); + } + } + TestCopies(t); +} + +TEST(Tensor_Float, Simple) { + Tensor t(DT_FLOAT, TensorShape({10, 20})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({10, 20}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = static_cast(a * b); + } + } + TestCopies(t); +} + +TEST(Tensor_ResourceHandle, Simple) { + Tensor t(DT_RESOURCE, TensorShape({})); + ResourceHandle tmp; + tmp.set_name("a"); + t.flat()(0) = tmp; + TestCopies(t); +} + +TEST(Tensor_Variant, Simple) { + Tensor t(DT_VARIANT, TensorShape({})); + Tensor value(DT_FLOAT, TensorShape({})); + value.flat()(0) = 42.0f; + t.flat()(0) = value; + // All the tests in TestCopies except the ones that serialize and deserialize + // the tensor. The consumer of a serialized Variant Tensor should know what + // type is stored in the Tensor, so not testing the generic + // serialize/deserialize case here. + { + LOG(INFO) << "CopyFrom()"; + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.CopyFrom(t, t.shape())); + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "operator=()"; + Tensor t2(t.dtype()); + t2 = t; + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "deep copy"; + Tensor t2(t.dtype(), t.shape()); + t2.flat() = t.flat(); + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "AsTensor"; + gtl::ArraySlice values(t.flat().data(), t.NumElements()); + Tensor t2 = test::AsTensor(values, t.shape()); + test::ExpectTensorEqual(t, t2); + } + { + LOG(INFO) << "Move constructor"; + Tensor t2 = t; + Tensor t3(std::move(t2)); + test::ExpectTensorEqual(t, t3); + EXPECT_TRUE(t3.IsInitialized()); + EXPECT_FALSE(t2.IsInitialized()); + } + { + LOG(INFO) << "Move assignment"; + Tensor t2 = t; + Tensor t3 = std::move(t2); + Tensor* t4 = &t3; + *t4 = std::move(t3); + test::ExpectTensorEqual(t, t3); + EXPECT_TRUE(t3.IsInitialized()); + EXPECT_FALSE(t2.IsInitialized()); + } +} + +TEST(Tensor_Variant, Marshal) { + Tensor t(DT_VARIANT, TensorShape({})); + + Tensor internal(DT_FLOAT, TensorShape({})); + internal.flat()(0) = 42.0f; + t.flat()(0) = internal; + + LOG(INFO) << "AsProtoField()"; + TensorProto proto; + t.AsProtoField(&proto); + + // This performs a decode operation. + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.FromProto(proto)); + + Tensor* out = t2.flat()(0).get(); + EXPECT_NE(out, nullptr); + EXPECT_FLOAT_EQ(out->scalar()(), 42.0f); +} + +TEST(Tensor_UInt16, Simple) { + Tensor t(DT_UINT16, TensorShape({2, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = uint16(a * b); + } + } + TestCopies(t); +} + +TEST(Tensor_QInt8, Simple) { + Tensor t(DT_QINT8, TensorShape({2, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = qint8(a * b); + } + } + TestCopies(t); +} + +TEST(Tensor_QUInt8, Simple) { + Tensor t(DT_QUINT8, TensorShape({2, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = Eigen::QUInt8(a * b); + } + } + TestCopies(t); +} + +TEST(Tensor_QInt32, Simple) { + Tensor t(DT_QINT32, TensorShape({2, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix()(a, b) = qint32(static_cast(a * b)); + } + } + TestCopies(t); +} + +class TensorReshapeTest : public ::testing::Test { + protected: + Tensor t; + Tensor zero_t; + + TensorReshapeTest() + : t(DT_FLOAT, TensorShape({2, 3, 4, 5})), + zero_t(DT_FLOAT, TensorShape({3, 0, 2, 0, 5})) {} + + void SetUp() override { + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 3, 4, 5}))); + EXPECT_TRUE(zero_t.shape().IsSameSize(TensorShape({3, 0, 2, 0, 5}))); + + auto tensor = t.tensor(); + EXPECT_EQ(2, tensor.dimension(0)); + EXPECT_EQ(3, tensor.dimension(1)); + EXPECT_EQ(4, tensor.dimension(2)); + EXPECT_EQ(5, tensor.dimension(3)); + + // Set first and last elements. + tensor(0, 0, 0, 0) = 0.01f; + tensor(1, 2, 3, 4) = 0.02f; + } + + template + using ReshapeFunc = T (Tensor::*)(gtl::ArraySlice); + template + using ConstReshapeFunc = T (Tensor::*)(gtl::ArraySlice) const; + + template Func> + void TestReshape(std::initializer_list sizes) { + T shaped = (t.*Func)(sizes); + TestReshapeImpl(shaped, sizes); + } + + template Func> + void TestReshape(std::initializer_list sizes) { + T shaped = (static_cast(t).*Func)(sizes); + TestReshapeImpl(shaped, sizes); + } + + template + void TestReshapeImpl(T shaped, std::initializer_list sizes) { + auto iter = sizes.begin(); + for (int i = 0; i < shaped.rank(); ++i, ++iter) { + EXPECT_EQ(*iter, shaped.dimension(i)); + } + + using Index = typename T::Index; + using Scalar = typename T::Scalar; + constexpr int N = T::NumIndices; + + // To handle the cast when `shaped` is bit casted into a different type. + const float expected_first = 0.01f; + Eigen::DSizes coord; + EXPECT_EQ(shaped(coord), *reinterpret_cast(&expected_first)); + + for (int i = 0; i < N; ++i) { + coord[i] = shaped.dimension(i) - 1; + } + const float expected_last = 0.02f; + constexpr int kNumScalarPerFloat = + sizeof(float) / sizeof(Scalar); // Assuming even divide. + EXPECT_EQ(shaped(coord), reinterpret_cast( + &expected_last)[kNumScalarPerFloat - 1]); + } +}; + +TEST_F(TensorReshapeTest, Reshape) { + LOG(INFO) << "shaped"; + +#define TEST_RESHAPE(...) \ + { \ + constexpr int N = (sizeof((int[]){__VA_ARGS__}) / sizeof(int)); \ + TestReshape::Tensor, &Tensor::shaped>( \ + {__VA_ARGS__}); \ + TestReshape::ConstTensor, &Tensor::shaped>( \ + {__VA_ARGS__}); \ + TestReshape::UnalignedTensor, \ + &Tensor::unaligned_shaped>({__VA_ARGS__}); \ + TestReshape::UnalignedConstTensor, \ + &Tensor::unaligned_shaped>({__VA_ARGS__}); \ + TestReshape::Tensor, \ + &Tensor::bit_casted_shaped>({__VA_ARGS__}); \ + TestReshape::ConstTensor, \ + &Tensor::bit_casted_shaped>({__VA_ARGS__}); \ + TestReshape::Tensor, \ + &Tensor::bit_casted_shaped>({__VA_ARGS__}); \ + TestReshape::ConstTensor, \ + &Tensor::bit_casted_shaped>({__VA_ARGS__}); \ + } + + TEST_RESHAPE(120); + TEST_RESHAPE(6, 20); + TEST_RESHAPE(6, 4, 5); + TEST_RESHAPE(2, 3, 4, 5); +#undef TEST_RESHAPE +} + +TEST_F(TensorReshapeTest, BitcastReshapeDifferentSize) { +#define TEST_BITCAST8_RESHAPE(...) \ + { \ + constexpr int N = (sizeof((int[]){__VA_ARGS__}) / sizeof(int)); \ + TestReshape::Tensor, \ + &Tensor::bit_casted_shaped>({__VA_ARGS__}); \ + } + + TEST_BITCAST8_RESHAPE(480); + TEST_BITCAST8_RESHAPE(24, 20); + TEST_BITCAST8_RESHAPE(6, 16, 5); + TEST_BITCAST8_RESHAPE(2, 3, 4, 20); +#undef TEST_BITCAST8_RESHAPE +#define TEST_BITCAST16_RESHAPE(...) \ + { \ + constexpr int N = (sizeof((int[]){__VA_ARGS__}) / sizeof(int)); \ + TestReshape::Tensor, \ + &Tensor::bit_casted_shaped>({__VA_ARGS__}); \ + } + + TEST_BITCAST16_RESHAPE(240); + TEST_BITCAST16_RESHAPE(6, 40); + TEST_BITCAST16_RESHAPE(12, 4, 5); + TEST_BITCAST16_RESHAPE(2, 3, 8, 5); + TEST_BITCAST16_RESHAPE(2, 3, 4, 1, 10); +#undef TEST_BITCAST16_RESHAPE +} + +TEST_F(TensorReshapeTest, ReshapeError) { + EXPECT_DEATH((t.shaped({})), "1 vs. 120"); + EXPECT_DEATH((t.shaped({119})), "119 vs. 120"); + EXPECT_DEATH((t.shaped({2, 3, 4, 6})), "144 vs. 120"); + + EXPECT_DEATH((t.unaligned_shaped({})), "1 vs. 120"); + EXPECT_DEATH((t.unaligned_shaped({119})), "119 vs. 120"); + EXPECT_DEATH((t.unaligned_shaped({2, 3, 4, 6})), "144 vs. 120"); + + EXPECT_DEATH((t.bit_casted_shaped({})), "4 vs. 480"); + EXPECT_DEATH((t.bit_casted_shaped({119})), "476 vs. 480"); + EXPECT_DEATH((t.bit_casted_shaped({2, 3, 4, 6})), "576 vs. 480"); + + Tensor string_tensor{DT_STRING, {10}}; + // Note that the error message compare # of elements, not # of bytes. + EXPECT_DEATH((string_tensor.bit_casted_shaped({9})), "9 vs. 10"); +} + +TEST_F(TensorReshapeTest, Flat) { + LOG(INFO) << "flat"; + { + auto flat = t.flat(); + EXPECT_EQ(flat(0), 0.01f); + EXPECT_EQ(120, flat.dimension(0)); + EXPECT_EQ(flat(0), 0.01f); + EXPECT_EQ(flat(119), 0.02f); + } +} + +TEST_F(TensorReshapeTest, FlatInnerDims) { + LOG(INFO) << "flat_inner_dims"; + { + auto flat_inner_dims = t.flat_inner_dims(); + EXPECT_EQ(24, flat_inner_dims.dimension(0)); + EXPECT_EQ(5, flat_inner_dims.dimension(1)); + EXPECT_EQ(flat_inner_dims(0, 0), 0.01f); + EXPECT_EQ(flat_inner_dims(23, 4), 0.02f); + } + { + auto flat_inner_dims = t.flat_inner_dims(); + EXPECT_EQ(6, flat_inner_dims.dimension(0)); + EXPECT_EQ(4, flat_inner_dims.dimension(1)); + EXPECT_EQ(5, flat_inner_dims.dimension(2)); + EXPECT_EQ(flat_inner_dims(0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_dims(5, 3, 4), 0.02f); + } + { + auto flat_inner_dims = t.flat_inner_dims(); + EXPECT_EQ(1, flat_inner_dims.dimension(0)); + EXPECT_EQ(2, flat_inner_dims.dimension(1)); + EXPECT_EQ(3, flat_inner_dims.dimension(2)); + EXPECT_EQ(4, flat_inner_dims.dimension(3)); + EXPECT_EQ(5, flat_inner_dims.dimension(4)); + EXPECT_EQ(flat_inner_dims(0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_dims(0, 1, 2, 3, 4), 0.02f); + } + { + auto flat_inner_dims = zero_t.flat_inner_dims(); + EXPECT_EQ(0, flat_inner_dims.dimension(0)); + EXPECT_EQ(5, flat_inner_dims.dimension(1)); + } + { + auto flat_inner_dims = zero_t.flat_inner_dims(); + EXPECT_EQ(0, flat_inner_dims.dimension(0)); + EXPECT_EQ(0, flat_inner_dims.dimension(1)); + EXPECT_EQ(5, flat_inner_dims.dimension(2)); + } + { + auto flat_inner_dims = zero_t.flat_inner_dims(); + EXPECT_EQ(3, flat_inner_dims.dimension(0)); + EXPECT_EQ(0, flat_inner_dims.dimension(1)); + EXPECT_EQ(2, flat_inner_dims.dimension(2)); + EXPECT_EQ(0, flat_inner_dims.dimension(3)); + EXPECT_EQ(5, flat_inner_dims.dimension(4)); + } +} + +TEST_F(TensorReshapeTest, FlatOuterDims) { + LOG(INFO) << "flat_outer_dims"; + { + auto flat_outer_dims = t.flat_outer_dims(); + EXPECT_EQ(2, flat_outer_dims.dimension(0)); + EXPECT_EQ(60, flat_outer_dims.dimension(1)); + EXPECT_EQ(flat_outer_dims(0, 0), 0.01f); + EXPECT_EQ(flat_outer_dims(1, 59), 0.02f); + } + { + auto flat_outer_dims = t.flat_outer_dims(); + EXPECT_EQ(2, flat_outer_dims.dimension(0)); + EXPECT_EQ(3, flat_outer_dims.dimension(1)); + EXPECT_EQ(20, flat_outer_dims.dimension(2)); + EXPECT_EQ(flat_outer_dims(0, 0, 0), 0.01f); + EXPECT_EQ(flat_outer_dims(1, 2, 19), 0.02f); + } + { + auto flat_outer_dims = t.flat_outer_dims(); + EXPECT_EQ(2, flat_outer_dims.dimension(0)); + EXPECT_EQ(3, flat_outer_dims.dimension(1)); + EXPECT_EQ(4, flat_outer_dims.dimension(2)); + EXPECT_EQ(5, flat_outer_dims.dimension(3)); + EXPECT_EQ(1, flat_outer_dims.dimension(4)); + EXPECT_EQ(flat_outer_dims(0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_outer_dims(1, 2, 3, 4, 0), 0.02f); + } + { + auto flat_outer_dims = zero_t.flat_outer_dims(); + EXPECT_EQ(3, flat_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_outer_dims.dimension(1)); + } + { + auto flat_outer_dims = zero_t.flat_outer_dims(); + EXPECT_EQ(3, flat_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_outer_dims.dimension(1)); + EXPECT_EQ(0, flat_outer_dims.dimension(2)); + } + { + auto flat_outer_dims = zero_t.flat_outer_dims(); + EXPECT_EQ(3, flat_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_outer_dims.dimension(1)); + EXPECT_EQ(2, flat_outer_dims.dimension(2)); + EXPECT_EQ(0, flat_outer_dims.dimension(3)); + EXPECT_EQ(5, flat_outer_dims.dimension(4)); + } +} + +TEST_F(TensorReshapeTest, FlatInnerOuterDims) { + LOG(INFO) << "flat_inner_outer_dims"; + { + auto flat_inner_outer_dims = t.flat_inner_outer_dims(0); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(4, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(1, 2, 3, 4), 0.02f); + } + { + auto flat_inner_outer_dims = t.flat_inner_outer_dims(-2); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(4, flat_inner_outer_dims.dimension(4)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(5)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 1, 2, 3, 4), 0.02f); + } + { + auto flat_inner_outer_dims = t.flat_inner_outer_dims(0); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(4, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(4)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(5)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(1, 2, 3, 4, 0, 0), 0.02f); + } + { + auto flat_inner_outer_dims = t.flat_inner_outer_dims(-2); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(4, flat_inner_outer_dims.dimension(4)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(5)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(6)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(7)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0, 0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 1, 2, 3, 4, 0, 0), 0.02f); + } + { + auto flat_inner_outer_dims = t.flat_inner_outer_dims(1); + EXPECT_EQ(6, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(4, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(5, 3, 4), 0.02f); + } + { + auto flat_inner_outer_dims = t.flat_inner_outer_dims(1); + EXPECT_EQ(6, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(4, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(4)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(5, 3, 4, 0, 0), 0.02f); + } + { + auto flat_inner_outer_dims = t.flat_inner_outer_dims(0); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(20, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(1, 2, 19), 0.02f); + } + { + auto flat_inner_outer_dims = t.flat_inner_outer_dims(-2); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(20, flat_inner_outer_dims.dimension(4)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 1, 2, 19), 0.02f); + } + { + auto flat_inner_outer_dims = t.flat_inner_outer_dims(1); + EXPECT_EQ(6, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(20, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(flat_inner_outer_dims(0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(5, 19), 0.02f); + } + { + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(0); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(1)); + } + { + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(0); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(2)); + } + { + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(0); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(4)); + } + { + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(3); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(1)); + } + { + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(2); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(2)); + } + { + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(1); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(2)); + } +} + +TEST(ReinterpretLastDimension, Reinterpret_NCHW_VECT_C_as_NCHW) { + LOG(INFO) << "reinterpret_last_dimension"; + { + Tensor t_nchw_vect_c(DT_QINT8, TensorShape({2, 3, 5, 7, 4})); + auto nchw_vect_c = t_nchw_vect_c.tensor(); + Tensor t_expected_nchw(DT_INT32, TensorShape({2, 3, 5, 7})); + auto expected_nchw = t_expected_nchw.tensor(); + int8 val = 0; + for (int n = 0; n < t_nchw_vect_c.shape().dim_size(0); ++n) { + for (int c = 0; c < t_nchw_vect_c.shape().dim_size(1); ++c) { + for (int h = 0; h < t_nchw_vect_c.shape().dim_size(2); ++h, ++val) { + int8 packet[4]; + for (int w = 0; w < t_nchw_vect_c.shape().dim_size(3); ++w) { + packet[0] = nchw_vect_c(n, c, h, w, 0) = ++val; + packet[1] = nchw_vect_c(n, c, h, w, 1) = ++val; + packet[2] = nchw_vect_c(n, c, h, w, 2) = ++val; + packet[3] = nchw_vect_c(n, c, h, w, 3) = ++val; + expected_nchw(n, c, h, w) = *reinterpret_cast(&packet[0]); + } + } + } + } + auto actual_nchw = t_nchw_vect_c.reinterpret_last_dimension(); + const auto& const_t_nchw_vect_c = t_nchw_vect_c; + auto const_actual_nchw = + const_t_nchw_vect_c.reinterpret_last_dimension(); + for (int n = 0; n < t_nchw_vect_c.shape().dim_size(0); ++n) { + for (int c = 0; c < t_nchw_vect_c.shape().dim_size(1); ++c) { + for (int h = 0; h < t_nchw_vect_c.shape().dim_size(2); ++h) { + for (int w = 0; w < t_nchw_vect_c.shape().dim_size(3); ++w) { + EXPECT_EQ(expected_nchw(n, c, h, w), actual_nchw(n, c, h, w)); + EXPECT_EQ(expected_nchw(n, c, h, w), const_actual_nchw(n, c, h, w)); + } + } + } + } + } +} + +TEST(Tensor_Scalar, Basics) { + { + Tensor t(DT_BOOL, TensorShape({})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.scalar()() = true; + EXPECT_TRUE(Tt()); + } + { + Tensor t(DT_FLOAT, TensorShape({})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.scalar()() = 123.45f; + EXPECT_FLOAT_EQ(123.45f, Tt()); + } + { + Tensor t(DT_FLOAT, TensorShape({1})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.vec(); + EXPECT_EQ(1, Tt.size()); + t.vec()(0) = 123.45f; + EXPECT_FLOAT_EQ(123.45f, Tt(0)); + } + { + Tensor t(DT_FLOAT, TensorShape({1, 1, 1})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.flat()(0) = 123.45f; + EXPECT_FLOAT_EQ(123.45f, Tt()); + } + { + Tensor t(DT_STRING, TensorShape({})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.scalar()() = "foo"; + EXPECT_EQ("foo", Tt()); + } + { + Tensor t(DT_STRING, TensorShape({1})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.vec(); + EXPECT_EQ(1, Tt.size()); + t.flat()(0) = "foo"; + EXPECT_EQ("foo", Tt(0)); + } + { + Tensor t(DT_STRING, TensorShape({1, 1, 1})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.flat()(0) = "bar"; + EXPECT_EQ("bar", Tt()); + } + { + Tensor t(DT_FLOAT, TensorShape({0, 1})); + EXPECT_EQ(0, t.NumElements()); + auto Tt = t.flat(); + EXPECT_EQ(0, Tt.size()); + auto Tm = t.matrix(); + EXPECT_EQ(0, Tm.size()); + EXPECT_EQ(0, Tm.dimensions()[0]); + EXPECT_EQ(1, Tm.dimensions()[1]); + } +} + +TEST(Tensor_Float, Reshape_And_Slice_Assignment) { + // A test to experiment with a way to assign to a subset of a tensor + Tensor t(DT_FLOAT, TensorShape({10, 4, 3, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({10, 4, 3, 2}))); + + // Get the N dimensional tensor (N==4 here) + auto e_t = t.tensor(); + // Reshape to view it as a two-dimensional tensor + auto e_2d = t.shaped({10, 4 * 3 * 2}); + for (int i = 0; i < 10; i++) { + // Assign a 1 x 4*3*2 matrix (really vector) to a slice of size + // 1 x 4*3*2 in e_t. + Eigen::Tensor m(1, 4 * 3 * 2); + m.setConstant(i * 2.0); + + Eigen::DSizes indices(i, 0); + Eigen::DSizes sizes(1, 4 * 3 * 2); + e_2d.slice(indices, sizes) = m; + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 3; k++) { + for (int l = 0; l < 2; l++) { + EXPECT_EQ(e_t(i, j, k, l), i * 2.0f); + LOG(INFO) << i << "," << j << "," << k << "," << l + << " &e_t(i, j, k, l): " << &e_t(i, j, k, l) << " = " + << e_t(i, j, k, l); + } + } + } + } +} + +TEST(Tensor_String, Simple) { + Tensor t = test::AsTensor( + {"hello", "world", "machine", "learning", "new", "york"}, + TensorShape({3, 2})); + auto s = t.shape(); + ASSERT_EQ(s.dims(), 2); + ASSERT_EQ(s.dim_size(0), 3); + ASSERT_EQ(s.dim_size(1), 2); + auto m = t.matrix(); + EXPECT_EQ(t.TotalBytes(), 3 * 2 * sizeof(string) + 5 + 5 + 7 + 8 + 3 + 4); + + EXPECT_EQ(m(0, 0), "hello"); + EXPECT_EQ(m(0, 1), "world"); + EXPECT_EQ(m(1, 0), "machine"); + EXPECT_EQ(m(1, 1), "learning"); + EXPECT_EQ(m(2, 0), "new"); + EXPECT_EQ(m(2, 1), "york"); + + TestCopies(t); +} + +TEST(Tensor_Float, SimpleWithHelper) { + Tensor t1 = test::AsTensor({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() * 2.0f; + Tensor t3 = test::AsTensor({0, 2, 4, 6, 8, 10}, t1.shape()); + test::ExpectTensorEqual(t2, t3); +} + +TEST(Tensor_Int32, SimpleWithHelper) { + Tensor t1 = test::AsTensor({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() * 2; + Tensor t3 = test::AsTensor({0, 2, 4, 6, 8, 10}, t1.shape()); + test::ExpectTensorEqual(t2, t3); +} + +TEST(Tensor_UInt16, SimpleWithHelper) { + Tensor t1 = test::AsTensor({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() * uint16(2); + Tensor t3 = test::AsTensor({0, 2, 4, 6, 8, 10}, t1.shape()); + test::ExpectTensorEqual(t2, t3); +} + +TEST(Tensor_QInt8, SimpleWithHelper) { + Tensor t1 = test::AsTensor({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() + qint8(-2); + Tensor t3 = test::AsTensor({-2, -1, 0, 1, 2, 3}, {2, 3}); + test::ExpectTensorEqual(t2, t3); +} + +TEST(Tensor_QUInt8, SimpleWithHelper) { + Tensor t1 = test::AsTensor({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() + quint8(2); + Tensor t3 = test::AsTensor({2, 3, 4, 5, 6, 7}, {2, 3}); + test::ExpectTensorEqual(t2, t3); +} + +TEST(Tensor_Int64, SimpleWithHelper) { + Tensor t1 = test::AsTensor( + {0LL << 48, 1LL << 48, 2LL << 48, 3LL << 48, 4LL << 48, 5LL << 48}, + {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() * static_cast(2); + Tensor t3 = test::AsTensor( + {0LL << 48, 2LL << 48, 4LL << 48, 6LL << 48, 8LL << 48, 10LL << 48}, + {2, 3}); + test::ExpectTensorEqual(t2, t3); +} + +TEST(Tensor_String, SimpleWithHelper) { + Tensor t1 = test::AsTensor({"0", "1", "2", "3", "4", "5"}, {2, 3}); + Tensor t2(DT_STRING, {2, 3}); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + t2.matrix()(i, j) = strings::StrCat(i * 3 + j); + } + } + + // Test with helper. + test::ExpectTensorEqual(t1, t2); +} + +TEST(Tensor_Bool, SimpleWithHelper) { + Tensor t1 = + test::AsTensor({false, true, false, true, false, true}, {2, 3}); + + Tensor t2(DT_BOOL, {2, 3}); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + t2.matrix()(i, j) = (((i + j) % 2) != 0); + } + } + + // Test with helper. + test::ExpectTensorEqual(t1, t2); +} + +TEST(Tensor_Complex, Simple64) { + Tensor t(DT_COMPLEX64, {4, 5, 3, 7}); + t.flat().setRandom(); + TestCopies(t); +} + +TEST(Tensor_Complex, Simple128) { + Tensor t(DT_COMPLEX128, {4, 5, 3, 7}); + t.flat().setRandom(); + TestCopies(t); +} + +TEST(Tensor_Complex, SimpleWithHelper64) { + { + Tensor t1 = test::AsTensor({0, + {1, 1}, + complex64(2), + complex64(3, 3), + complex64(0, 4), + complex64(2, 5)}, + {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() * complex64(0, 2); + Tensor t3 = test::AsTensor( + {0, {-2, 2}, {0, 4}, {-6, 6}, {-8, 0}, {-10, 4}}, + // shape + {2, 3}); + test::ExpectTensorEqual(t2, t3); + } + + // Does some numeric operations for complex64 numbers. + { + const float PI = std::acos(-1); + const complex64 rotate_45 = std::polar(1.0f, PI / 4); + + // x contains all the 8-th root of unity. + Tensor x(DT_COMPLEX64, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + x.vec()(i) = std::pow(rotate_45, i); + } + + // Shift the roots by 45 degree. + Tensor y(DT_COMPLEX64, TensorShape({8})); + y.vec() = x.vec() * rotate_45; + Tensor y_expected(DT_COMPLEX64, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + y_expected.vec()(i) = std::pow(rotate_45, i + 1); + } + test::ExpectTensorNear(y, y_expected, 1e-5); + + // Raise roots to the power of 8. + Tensor z(DT_COMPLEX64, TensorShape({8})); + z.vec() = x.vec().pow(8); + Tensor z_expected(DT_COMPLEX64, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + z_expected.vec()(i) = 1; + } + test::ExpectTensorNear(z, z_expected, 1e-5); + } +} + +TEST(Tensor_Complex, SimpleWithHelper128) { + { + Tensor t1 = test::AsTensor({0, + {1, 1}, + complex128(2), + complex128(3, 3), + complex128(0, 4), + complex128(2, 5)}, + {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat() = t1.flat() * complex128(0, 2); + Tensor t3 = test::AsTensor( + {0, {-2, 2}, {0, 4}, {-6, 6}, {-8, 0}, {-10, 4}}, + // shape + {2, 3}); + test::ExpectTensorEqual(t2, t3); + } + + // Does some numeric operations for complex128 numbers. + { + const double PI = std::acos(-1); + const complex128 rotate_45 = std::polar(1.0, PI / 4); + + // x contains all the 8-th root of unity. + Tensor x(DT_COMPLEX128, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + x.vec()(i) = std::pow(rotate_45, i); + } + + // Shift the roots by 45 degree. + Tensor y(DT_COMPLEX128, TensorShape({8})); + y.vec() = x.vec() * rotate_45; + Tensor y_expected(DT_COMPLEX128, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + y_expected.vec()(i) = std::pow(rotate_45, i + 1); + } + test::ExpectTensorNear(y, y_expected, 1e-5); + + // Raise roots to the power of 8. + Tensor z(DT_COMPLEX128, TensorShape({8})); + z.vec() = x.vec().pow(8); + Tensor z_expected(DT_COMPLEX128, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + z_expected.vec()(i) = 1; + } + test::ExpectTensorNear(z, z_expected, 1e-5); + } +} + +namespace { + +// An allocator that always returns nullptr, for testing +// failures to allocate. +class DummyCPUAllocator : public Allocator { + public: + DummyCPUAllocator() = default; + string Name() override { return "cpu"; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + return nullptr; + } + void DeallocateRaw(void* ptr) override {} +}; + +TEST(Tensor, FailureToAllocate) { + TensorShape shape({1}); + DummyCPUAllocator allocator; + { + Tensor a(&allocator, DT_FLOAT, shape); + ASSERT_FALSE(a.IsInitialized()); + } + + // Float + { + Tensor t(DT_FLOAT, TensorShape({1})); + t.vec()(0) = 1.0; + TensorProto proto; + t.AsProtoField(&proto); + + // FromProto should fail nicely. + Tensor a(&allocator, DT_FLOAT, TensorShape({1})); + ASSERT_FALSE(a.FromProto(&allocator, proto)); + } + + // String + { + Tensor t(DT_STRING, TensorShape({1})); + t.vec()(0) = "foo"; + TensorProto proto; + t.AsProtoField(&proto); + + // FromProto should fail nicely. + Tensor a(&allocator, DT_STRING, TensorShape({1})); + ASSERT_FALSE(a.FromProto(&allocator, proto)); + } + + // Half + { + Tensor t(DT_HALF, TensorShape({1})); + t.vec()(0) = Eigen::half(1.0); + TensorProto proto; + t.AsProtoField(&proto); + + // FromProto should fail nicely. + Tensor a(&allocator, DT_HALF, TensorShape({1})); + ASSERT_FALSE(a.FromProto(&allocator, proto)); + } +} + +// On the alignment. +// +// As of 2015/8, tensorflow::Tensor allocates its buffer with 32-byte +// alignment. Tensor::tensor/flat/vec/matrix methods requires the +// buffer satisfies Eigen::Aligned (e.g., 16-bytes aligned usually, +// and 32-bytes for AVX). Tensor::Slice requires the caller to ensure +// its result is aligned if the caller intends to use those methods. +// In this test case, we simply make sure each slice is 32-byte +// aligned: sizeof(float) * 4 * 2 = 32. +TEST(Tensor, Slice_Basic) { + Tensor saved; + { // General + Tensor x(DT_FLOAT, TensorShape({10, 4, 34})); + // Fills in known values. + for (int i = 0; i < 10; ++i) { + x.Slice(i, i + 1).flat().setConstant(i * 1.f); + } + // A simple slice along dim0. + Tensor y = x.Slice(4, 8); + EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 4, 34}))); + auto tx = x.tensor(); + auto ty = y.tensor(); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 34; ++k) { + EXPECT_EQ(ty(i, j, k), 4.0 + i); + EXPECT_EQ(&tx(4 + i, j, k), &ty(i, j, k)); + } + } + } + // A simple slice equivalent to identity. + TestCopies(y); + y = x.Slice(0, 10); + test::ExpectTensorEqual(x, y); + EXPECT_EQ(x.flat().data(), y.flat().data()); + + // A slice of a slice. + auto z = x.Slice(4, 8).Slice(2, 3); + auto tz = z.tensor(); + EXPECT_EQ(1, z.dim_size(0)); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 34; ++k) { + EXPECT_EQ(tz(0, j, k), 6.0); + } + } + + // x and y will be out of scope. But 'saved' should be alive. + saved = z; + } + { + EXPECT_EQ(1, saved.dim_size(0)); + auto tsaved = saved.tensor(); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 34; ++k) { + EXPECT_EQ(tsaved(0, j, k), 6.0); + } + } + } + { // Empty + Tensor x(DT_FLOAT, TensorShape({10, 0, 34})); + x.flat().setRandom(); + Tensor y = x.Slice(4, 8); + EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 0, 34}))); + } + + { + // Test unaligned access via a Slice. + Tensor x(DT_FLOAT, TensorShape({30})); + x.flat().setConstant(0.0); + + // Take an unaligned slice. + Tensor y = x.Slice(1, 13); +#if EIGEN_MAX_ALIGN_BYTES > 0 + EXPECT_FALSE(y.IsAligned()); +#endif + y.unaligned_flat().setConstant(1.0); + for (int64 i = 0; i < y.NumElements(); ++i) { + EXPECT_EQ(1.0, y.unaligned_flat()(i)); + } + } +} + +namespace { +template +Tensor MkTensor(DataType dt, const TensorShape& shape, + std::vector init_values) { + Tensor x(dt, shape); + const int limit = x.NumElements(); + int vi = 0; + for (int i = 0; i < limit; ++i) { + x.flat()(i) = init_values[vi++]; + if (vi >= init_values.size()) vi = 0; + } + return x; +} +} // namespace + +TEST(SummarizeValue, Uninitialized) { + Tensor x(DT_INT32); + TensorTestHelper::set_shape(&x, TensorShape({4, 4})); + EXPECT_EQ( + strings::StrCat("uninitialized Tensor of 16 elements of type ", DT_INT32), + x.SummarizeValue(16)); +} + +TEST(SummarizeValue, INT32) { + Tensor x = MkTensor(DT_INT32, TensorShape({5}), {1, 2, 3, 4, 0}); + EXPECT_EQ("1 2 3 4 0", x.SummarizeValue(16)); + x = MkTensor(DT_INT32, TensorShape({2, 2}), {1, 2, 3, 4, 0}); + EXPECT_EQ("[1 2][3 4]", x.SummarizeValue(16)); + x = MkTensor(DT_INT32, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0}); + EXPECT_EQ("[[[1]][[2]]][[[3]][[4]]]", x.SummarizeValue(16)); + EXPECT_EQ("[[[1]][[2]]][[[3]]]...", x.SummarizeValue(3)); + x = MkTensor(DT_INT32, TensorShape({0}), {}); + EXPECT_EQ("", x.SummarizeValue(16)); +} + +TEST(SummarizeValue, FLOAT) { + Tensor x = MkTensor(DT_FLOAT, TensorShape({5}), {1, 2, 3, 4, 0}); + EXPECT_EQ("1 2 3 4 0", x.SummarizeValue(16)); + x = MkTensor(DT_FLOAT, TensorShape({2, 2}), {1, 2, 3, 4, 0}); + EXPECT_EQ("[1 2][3 4]", x.SummarizeValue(16)); + x = MkTensor(DT_FLOAT, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0}); + EXPECT_EQ("[[[1]][[2]]][[[3]][[4]]]", x.SummarizeValue(16)); + EXPECT_EQ("[[[1]][[2]]][[[3]]]...", x.SummarizeValue(3)); + x = MkTensor(DT_FLOAT, TensorShape({0}), {}); + EXPECT_EQ("", x.SummarizeValue(16)); +} + +TEST(SummarizeValue, BOOL) { + Tensor x = MkTensor(DT_BOOL, TensorShape({5}), {false, true, true}); + EXPECT_EQ("0 1 1 0 1", x.SummarizeValue(16)); + EXPECT_EQ("0 1 1...", x.SummarizeValue(3)); +} + +TEST(SummarizeValue, STRING) { + Tensor x = MkTensor(DT_STRING, TensorShape({5}), + {"one", "two", "three", "four", "five"}); + EXPECT_EQ("one two three four five", x.SummarizeValue(16)); + x = MkTensor(DT_STRING, TensorShape({5, 1, 5}), + {"one", "two", "three", "four", "five"}); + EXPECT_EQ("one two three four five one...", x.SummarizeValue(6)); +} + +static void BM_CreateAndDestroy(int iters) { + TensorShape shape({10, 20}); + while (--iters) { + Tensor t(DT_FLOAT, shape); + } +} +BENCHMARK(BM_CreateAndDestroy); + +static void BM_Assign(int iters) { + Tensor a(DT_FLOAT, TensorShape({10, 20})); + Tensor b(DT_FLOAT, TensorShape({10, 20})); + bool a_to_b = true; + while (--iters) { + if (a_to_b) { + b = a; + } else { + a = b; + } + a_to_b = !a_to_b; + } +} +BENCHMARK(BM_Assign); + +// Ensure tensor_data() works on empty tensors +TEST(Tensor, EmptyTensorData) { + Tensor empty; + EXPECT_EQ(empty.tensor_data().size(), 0); +} + +// Benchmark create and destroy a tensor, with an allocated buffer. +static void BM_CreateAndDestroyWithBuf(int iters) { + TensorShape shape({10, 20}); + Allocator* allocator = cpu_allocator(); + while (--iters) { + Tensor a(allocator, DT_FLOAT, shape); + } +} +BENCHMARK(BM_CreateAndDestroyWithBuf); + +// Benchmark create+copy a tensor, with an allocated buffer. +static void BM_CreateAndCopyCtrWithBuf(int iters) { + TensorShape shape({10, 20}); + Allocator* allocator = cpu_allocator(); + while (--iters) { + Tensor a(allocator, DT_FLOAT, shape); + Tensor b(a); + } +} +BENCHMARK(BM_CreateAndCopyCtrWithBuf); + +// Benchmark create+move a tensor, with an allocated buffer. +static void BM_CreateAndMoveCtrWithBuf(int iters) { + TensorShape shape({10, 20}); + Allocator* allocator = cpu_allocator(); + while (--iters) { + Tensor a(allocator, DT_FLOAT, shape); + Tensor b(std::move(a)); + } +} +BENCHMARK(BM_CreateAndMoveCtrWithBuf); + +} // namespace +} // namespace tensorflow diff --git a/tensor_testutil.cc b/tensor_testutil.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8d141230093152397c792588a716c00556df77d --- /dev/null +++ b/tensor_testutil.cc @@ -0,0 +1,65 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/core/framework/tensor_testutil.h" + +namespace tensorflow { +namespace test { + +template +bool IsClose(const T& x, const T& y, double atol, double rtol) { + // Need x == y so that infinities are close to themselves + return x == y || std::abs(x - y) < atol + rtol * std::abs(x); +} + +template +void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) { + auto Tx = x.flat(); + auto Ty = y.flat(); + for (int i = 0; i < Tx.size(); ++i) { + if (!IsClose(Tx(i), Ty(i), atol, rtol)) { + LOG(ERROR) << "x = " << x.DebugString(); + LOG(ERROR) << "y = " << y.DebugString(); + LOG(ERROR) << "atol = " << atol << " rtol = " << rtol + << " tol = " << atol + rtol * std::abs(Tx(i)); + EXPECT_TRUE(false) << i << "-th element is not close " << Tx(i) << " vs. " + << Ty(i); + } + } +} + +void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) { + internal::AssertSameTypeDims(x, y); + switch (x.dtype()) { + case DT_FLOAT: + ExpectClose(x, y, atol, rtol); + break; + case DT_DOUBLE: + ExpectClose(x, y, atol, rtol); + break; + case DT_COMPLEX64: + ExpectClose(x, y, atol, rtol); + break; + case DT_COMPLEX128: + ExpectClose(x, y, atol, rtol); + break; + default: + LOG(FATAL) << "Unexpected type : " << DataTypeString(x.dtype()); + } +} + +} // end namespace test +} // end namespace tensorflow diff --git a/tensor_testutil.h b/tensor_testutil.h new file mode 100644 index 0000000000000000000000000000000000000000..4c216a84f04389f9a2ef761aa6b6cec2c20a0be8 --- /dev/null +++ b/tensor_testutil.h @@ -0,0 +1,230 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace test { + +// Constructs a scalar tensor with 'val'. +template +Tensor AsScalar(const T& val) { + Tensor ret(DataTypeToEnum::value, {}); + ret.scalar()() = val; + return ret; +} + +// Constructs a flat tensor with 'vals'. +template +Tensor AsTensor(gtl::ArraySlice vals) { + Tensor ret(DataTypeToEnum::value, {static_cast(vals.size())}); + std::copy_n(vals.data(), vals.size(), ret.flat().data()); + return ret; +} + +// Constructs a tensor of "shape" with values "vals". +template +Tensor AsTensor(gtl::ArraySlice vals, const TensorShape& shape) { + Tensor ret; + CHECK(ret.CopyFrom(AsTensor(vals), shape)); + return ret; +} + +// Fills in '*tensor' with 'vals'. E.g., +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillValues(&x, {11, 21, 21, 22}); +template +void FillValues(Tensor* tensor, gtl::ArraySlice vals) { + auto flat = tensor->flat(); + CHECK_EQ(flat.size(), vals.size()); + if (flat.size() > 0) { + std::copy_n(vals.data(), vals.size(), flat.data()); + } +} + +// Fills in '*tensor' with 'vals', converting the types as needed. +template +void FillValues(Tensor* tensor, std::initializer_list vals) { + auto flat = tensor->flat(); + CHECK_EQ(flat.size(), vals.size()); + if (flat.size() > 0) { + size_t i = 0; + for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) { + flat(i) = T(*itr); + } + } +} + +// Fills in '*tensor' with a sequence of value of val, val+1, val+2, ... +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillIota(&x, 1.0); +template +void FillIota(Tensor* tensor, const T& val) { + auto flat = tensor->flat(); + std::iota(flat.data(), flat.data() + flat.size(), val); +} + +// Fills in '*tensor' with a sequence of value of fn(0), fn(1), ... +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillFn(&x, [](int i)->float { return i*i; }); +template +void FillFn(Tensor* tensor, std::function fn) { + auto flat = tensor->flat(); + for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i); +} + +// Expects "x" and "y" are tensors of the same type, same shape, and +// identical values. +template +void ExpectTensorEqual(const Tensor& x, const Tensor& y); + +// Expects "x" and "y" are tensors of the same type, same shape, and +// approximate equal values, each within "abs_err". +template +void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err); + +// Expects "x" and "y" are tensors of the same type (float or double), +// same shape and element-wise difference between x and y is no more +// than atol + rtol * abs(x). +void ExpectClose(const Tensor& x, const Tensor& y, double atol = 1e-6, + double rtol = 1e-6); + +// Implementation details. + +namespace internal { + +template +struct is_floating_point_type { + static const bool value = std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same >::value || + std::is_same >::value; +}; + +template +inline void ExpectEqual(const T& a, const T& b) { + EXPECT_EQ(a, b); +} + +template <> +inline void ExpectEqual(const float& a, const float& b) { + EXPECT_FLOAT_EQ(a, b); +} + +template <> +inline void ExpectEqual(const double& a, const double& b) { + EXPECT_DOUBLE_EQ(a, b); +} + +template <> +inline void ExpectEqual(const complex64& a, const complex64& b) { + EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b; + EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b; +} + +template <> +inline void ExpectEqual(const complex128& a, const complex128& b) { + EXPECT_DOUBLE_EQ(a.real(), b.real()) << a << " vs. " << b; + EXPECT_DOUBLE_EQ(a.imag(), b.imag()) << a << " vs. " << b; +} + +inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), y.dtype()); + ASSERT_TRUE(x.IsSameSize(y)) + << "x.shape [" << x.shape().DebugString() << "] vs " + << "y.shape [ " << y.shape().DebugString() << "]"; +} + +template ::value> +struct Expector; + +template +struct Expector { + static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } + + static void Equal(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + AssertSameTypeDims(x, y); + const auto size = x.NumElements(); + const T* a = x.flat().data(); + const T* b = y.flat().data(); + for (int i = 0; i < size; ++i) { + ExpectEqual(a[i], b[i]); + } + } +}; + +// Partial specialization for float and double. +template +struct Expector { + static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } + + static void Equal(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + AssertSameTypeDims(x, y); + const auto size = x.NumElements(); + const T* a = x.flat().data(); + const T* b = y.flat().data(); + for (int i = 0; i < size; ++i) { + ExpectEqual(a[i], b[i]); + } + } + + static void Near(const T& a, const T& b, const double abs_err, int index) { + if (a != b) { // Takes care of inf. + EXPECT_LE(double(Eigen::numext::abs(a - b)), abs_err) + << "a = " << a << " b = " << b << " index = " << index; + } + } + + static void Near(const Tensor& x, const Tensor& y, const double abs_err) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + AssertSameTypeDims(x, y); + const auto size = x.NumElements(); + const T* a = x.flat().data(); + const T* b = y.flat().data(); + for (int i = 0; i < size; ++i) { + Near(a[i], b[i], abs_err, i); + } + } +}; + +} // namespace internal + +template +void ExpectTensorEqual(const Tensor& x, const Tensor& y) { + internal::Expector::Equal(x, y); +} + +template +void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) { + static_assert(internal::is_floating_point_type::value, + "T is not a floating point types."); + internal::Expector::Near(x, y, abs_err); +} + +} // namespace test +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ diff --git a/tensor_types.h b/tensor_types.h new file mode 100644 index 0000000000000000000000000000000000000000..921f88dc0ba09e7904333613b728021751d5425c --- /dev/null +++ b/tensor_types.h @@ -0,0 +1,114 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +// Helper to define Tensor types given that the scalar is of type T. +template +struct TTypes { + // Rank- tensor of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> Tensor; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstTensor; + + // Unaligned Rank- tensor of scalar type T. + typedef Eigen::TensorMap > + UnalignedTensor; + typedef Eigen::TensorMap > UnalignedConstTensor; + + typedef Eigen::TensorMap, + Eigen::Aligned> Tensor32Bit; + + // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. + typedef Eigen::TensorMap< + Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>, + Eigen::Aligned> Scalar; + typedef Eigen::TensorMap, + Eigen::RowMajor, IndexType>, + Eigen::Aligned> ConstScalar; + + // Unaligned Scalar tensor of scalar type T. + typedef Eigen::TensorMap, Eigen::RowMajor, IndexType> > UnalignedScalar; + typedef Eigen::TensorMap, + Eigen::RowMajor, IndexType> > + UnalignedConstScalar; + + // Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> Flat; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstFlat; + typedef Eigen::TensorMap, + Eigen::Aligned> Vec; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstVec; + + // Unaligned Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap > + UnalignedFlat; + typedef Eigen::TensorMap > UnalignedConstFlat; + typedef Eigen::TensorMap > + UnalignedVec; + typedef Eigen::TensorMap< + Eigen::Tensor > UnalignedConstVec; + + // Rank-2 tensor (matrix) of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> Matrix; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstMatrix; + + // Unaligned Rank-2 tensor (matrix) of scalar type T. + typedef Eigen::TensorMap > + UnalignedMatrix; + typedef Eigen::TensorMap > UnalignedConstMatrix; +}; + +typedef typename TTypes::Tensor32Bit::Index Index32; + +template +Eigen::DSizes To32BitDims(const DSizes& in) { + Eigen::DSizes out; + for (int i = 0; i < DSizes::count; ++i) { + out[i] = in[i]; + } + return out; +} + +template +typename TTypes::Tensor32Bit +To32Bit(TensorType in) { + typedef typename TTypes::Tensor32Bit RetType; + return RetType(in.data(), To32BitDims(in.dimensions())); +} + +} // namespace tensorflow +#endif // TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_ diff --git a/tensor_util.cc b/tensor_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..8e3ac25512edd283152b72cb4d56a7fed8191428 --- /dev/null +++ b/tensor_util.cc @@ -0,0 +1,172 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/tensor_util.h" + +#include +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace tensor { + +Tensor DeepCopy(const Tensor& other) { + Tensor tmp = Tensor(other.dtype(), other.shape()); + if (DataTypeCanUseMemcpy(other.dtype())) { + if (other.NumElements() > 0) { + StringPiece other_data = other.tensor_data(); + + // We use StringPiece as a convenient map over the tensor buffer, + // but we cast the type to get to the underlying buffer to do the + // copy. + StringPiece tmp_data = tmp.tensor_data(); + memcpy(const_cast(tmp_data.data()), other_data.data(), + other_data.size()); + } + } else if (other.dtype() == DT_STRING) { + tmp.flat() = other.flat(); + } else { + CHECK_EQ(DT_VARIANT, other.dtype()); + tmp.flat() = other.flat(); + } + return tmp; +} + +Status Concat(const gtl::ArraySlice& tensors, Tensor* result) { + if (tensors.empty()) { + return errors::InvalidArgument("Cannot concatenate zero tensors"); + } + int64 total_dim0_size = 0; + for (const Tensor& tensor : tensors) { + if (tensor.dims() == 0) { + return errors::InvalidArgument( + "Cannot concatenate a zero-dimensional tensor"); + } + total_dim0_size += tensor.dim_size(0); + } + TensorShape shape = tensors[0].shape(); + shape.set_dim(0, total_dim0_size); + + const DataType dtype = tensors[0].dtype(); + for (int i = 1; i < tensors.size(); ++i) { + if (tensors[i].dtype() != dtype) { + return errors::InvalidArgument( + "Cannot concatenate tensors that have different data types"); + } + } + *result = Tensor(dtype, shape); + + // We use StringPiece as a convenient map over the tensor buffer, + // but we cast the type to get to the underlying buffer to do the + // copy. + StringPiece to_data = result->tensor_data(); + + if (DataTypeCanUseMemcpy(dtype)) { + int64 offset = 0; + for (const Tensor& tensor : tensors) { + StringPiece from_data = tensor.tensor_data(); + CHECK_LE(offset + from_data.size(), to_data.size()); + memcpy(const_cast(to_data.data()) + offset, from_data.data(), + from_data.size()); + + offset += from_data.size(); + } + } else { + if (dtype != DT_STRING) { + return errors::Internal("Unexpected data type"); + } + string* to_strings = + reinterpret_cast(const_cast(to_data.data())); + + int64 offset = 0; + for (const Tensor& tensor : tensors) { + auto from_strings = tensor.flat(); + CHECK_LE(offset + tensor.NumElements(), result->NumElements()); + for (int i = 0; i < tensor.NumElements(); ++i) { + to_strings[offset + i] = from_strings(i); + } + + offset += tensor.NumElements(); + } + } + + return Status::OK(); +} + +Status Split(const Tensor& tensor, const gtl::ArraySlice& sizes, + std::vector* result) { + if (tensor.dims() == 0) { + return errors::InvalidArgument("Cannot split a zero-dimensional tensor"); + } + int64 total_size = 0; + for (int64 size : sizes) { + total_size += size; + } + if (total_size != tensor.dim_size(0)) { + return errors::InvalidArgument( + "The values in 'sizes' do not sum to the zeroth-dimension size of " + "'tensor'"); + } + + StringPiece from_data = tensor.tensor_data(); + + if (DataTypeCanUseMemcpy(tensor.dtype())) { + int64 offset = 0; + for (int64 size : sizes) { + TensorShape shape = tensor.shape(); + shape.set_dim(0, size); + result->emplace_back(tensor.dtype(), shape); + Tensor* split = &(*result)[result->size() - 1]; + + // We use StringPiece as a convenient map over the tensor buffer, + // but we cast the type to get to the underlying buffer to do the + // copy. + StringPiece to_data = split->tensor_data(); + CHECK_LE(offset + to_data.size(), from_data.size()); + memcpy(const_cast(to_data.data()), from_data.data() + offset, + to_data.size()); + + offset += to_data.size(); + } + } else { + if (tensor.dtype() != DT_STRING) { + return errors::Internal("Unexpected data type"); + } + auto from_strings = tensor.flat(); + + int64 offset = 0; + for (int64 size : sizes) { + TensorShape shape = tensor.shape(); + shape.set_dim(0, size); + result->emplace_back(tensor.dtype(), shape); + Tensor& split = (*result)[result->size() - 1]; + string* to_strings = reinterpret_cast( + const_cast(split.tensor_data().data())); + + CHECK_LE(offset + split.NumElements(), tensor.NumElements()); + for (int i = 0; i < split.NumElements(); ++i) { + to_strings[i] = from_strings(offset + i); + } + + offset += split.NumElements(); + } + } + + return Status::OK(); +} + +} // namespace tensor +} // namespace tensorflow diff --git a/tensor_util.h b/tensor_util.h new file mode 100644 index 0000000000000000000000000000000000000000..6c218b69e07a0ea9b19792e2da9f2032a272e293 --- /dev/null +++ b/tensor_util.h @@ -0,0 +1,60 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_ + +#include "tensorflow/core/framework/tensor.h" + +#include +namespace tensorflow { +namespace tensor { + +// DeepCopy returns a tensor whose contents are a deep copy of the +// contents of 'other'. This function is intended only for +// convenience, not speed. +// +// REQUIRES: 'other' must point to data stored in CPU memory. +// REQUIRES: 'other' must be a Tensor of a copy-able type if +// 'other' is not appropriately memory-aligned. +Tensor DeepCopy(const Tensor& other); + +// Concatenates 'tensors' into a single tensor, along their 0th dimension. +// +// REQUIRES: All members of 'tensors' must have the same data type parameter. +// REQUIRES: Each member of 'tensors' must have at least one dimension. +// REQUIRES: Each member of 'tensors' must point to data stored in CPU memory. +// REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it +// is not appropriately memory-aligned. +Status Concat(const gtl::ArraySlice& tensors, + Tensor* result) TF_MUST_USE_RESULT; + +// Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th +// dimension. The ith output tensor has 0th-dimension size 'sizes[i]'. +// +// REQUIRES: 'tensor' must have at least one dimension. +// REQUIRES: 'tensor.dim_size(0)' must equal the sum of the elements of 'sizes'. +// REQUIRES: 'tensor' must point to data stored in CPU memory. +// REQUIRES: 'tensor' must be a Tensor of a copy-able type if it is not +// appropriately memory-aligned. +// +// Split() and Concat() are inverse operations. +Status Split(const Tensor& tensor, const gtl::ArraySlice& sizes, + std::vector* result) TF_MUST_USE_RESULT; + +} // namespace tensor +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_ diff --git a/tensor_util_test.cc b/tensor_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..69eb8363b2ceee89a58e7bc8844684dd200fbb27 --- /dev/null +++ b/tensor_util_test.cc @@ -0,0 +1,230 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/tensor_util.h" + +#include +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(TensorUtil, DeepCopy0d) { + Tensor x(DT_FLOAT, TensorShape({})); + x.scalar()() = 10.0; + + // Make y a deep copy of x and then change it. + Tensor y = tensor::DeepCopy(x); + y.scalar()() = 20.0; + + // x doesn't change + EXPECT_EQ(10.0, x.scalar()()); + + // Change x. + x.scalar()() = 30.0; + + // Y doesn't change. + EXPECT_EQ(20.0, y.scalar()()); + + Tensor z = tensor::DeepCopy(y); + + // Change y. + y.scalar()() = 40.0; + + // The final states should all be different. + EXPECT_EQ(20.0, z.scalar()()); + EXPECT_EQ(30.0, x.scalar()()); + EXPECT_EQ(40.0, y.scalar()()); + + // Should have the same shape and type. + EXPECT_EQ(TensorShape({}), x.shape()); + EXPECT_EQ(TensorShape({}), y.shape()); + EXPECT_EQ(TensorShape({}), z.shape()); + + EXPECT_EQ(DT_FLOAT, x.dtype()); + EXPECT_EQ(DT_FLOAT, y.dtype()); + EXPECT_EQ(DT_FLOAT, z.dtype()); +} + +TEST(TensorUtil, DeepCopyZeroElements) { + Tensor x; + Tensor y = tensor::DeepCopy(x); + EXPECT_EQ(TensorShape({0}), y.shape()); + EXPECT_EQ(DT_FLOAT, y.dtype()); + EXPECT_EQ(0, y.NumElements()); +} + +TEST(TensorUtil, DeepCopy) { + Tensor x(DT_FLOAT, TensorShape({1})); + x.flat()(0) = 10.0; + + // Make y a deep copy of x and then change it. + Tensor y = tensor::DeepCopy(x); + y.flat()(0) = 20.0; + + // x doesn't change + EXPECT_EQ(10.0, x.flat()(0)); + + // Change x. + x.flat()(0) = 30.0; + + // Y doesn't change. + EXPECT_EQ(20.0, y.flat()(0)); + + Tensor z = tensor::DeepCopy(y); + + // Change y. + y.flat()(0) = 40.0; + + // The final states should all be different. + EXPECT_EQ(20.0, z.flat()(0)); + EXPECT_EQ(30.0, x.flat()(0)); + EXPECT_EQ(40.0, y.flat()(0)); + + // Should have the same shape and type. + EXPECT_EQ(TensorShape({1}), x.shape()); + EXPECT_EQ(TensorShape({1}), y.shape()); + EXPECT_EQ(TensorShape({1}), z.shape()); + + EXPECT_EQ(DT_FLOAT, x.dtype()); + EXPECT_EQ(DT_FLOAT, y.dtype()); + EXPECT_EQ(DT_FLOAT, z.dtype()); + + // Test string deep copy + Tensor str1(DT_STRING, TensorShape({2})); + str1.flat()(0) = "foo1"; + str1.flat()(1) = "foo2"; + Tensor str2 = tensor::DeepCopy(str1); + str2.flat()(0) = "bar1"; + str2.flat()(1) = "bar2"; + EXPECT_NE(str2.flat()(0), str1.flat()(0)); +} + +TEST(TensorUtil, DeepCopySlice) { + Tensor x(DT_INT32, TensorShape({10})); + x.flat().setConstant(1); + + // Slice 'x' -- y still refers to the same buffer. + Tensor y = x.Slice(2, 6); + + // Do a deep copy of y, which is a slice. + Tensor z = tensor::DeepCopy(y); + + // Set x to be different. + x.flat().setConstant(2); + + EXPECT_EQ(TensorShape({10}), x.shape()); + EXPECT_EQ(TensorShape({4}), y.shape()); + EXPECT_EQ(TensorShape({4}), z.shape()); + EXPECT_EQ(DT_INT32, x.dtype()); + EXPECT_EQ(DT_INT32, y.dtype()); + EXPECT_EQ(DT_INT32, z.dtype()); + + // x and y should now all be '2', but z should be '1'. + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(2, x.flat()(i)); + } + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(2, y.unaligned_flat()(i)); + EXPECT_EQ(1, z.flat()(i)); + } +} + +TEST(TensorUtil, Concat) { + std::vector sizes = {1, 4, 5}; + std::vector to_concat; + int64 total_size = 0; + int offset = 0; + for (size_t entry = 0; entry < sizes.size(); ++entry) { + const int64 size = sizes[entry]; + Tensor tensor(DT_INT32, TensorShape({size, 2})); + for (int i = offset; i < offset + size; ++i) { + for (int j = 0; j < 2; ++j) { + tensor.matrix()(i - offset, j) = 2 * i + j; + } + } + to_concat.push_back(tensor); + total_size += size; + offset += size; + } + + Tensor concated; + TF_ASSERT_OK(tensor::Concat(to_concat, &concated)); + ASSERT_EQ(TensorShape({total_size, 2}), concated.shape()); + for (int i = 0; i < total_size; ++i) { + for (int j = 0; j < 2; ++j) { + EXPECT_EQ(2 * i + j, concated.matrix()(i, j)); + } + } +} + +TEST(TensorUtil, Split) { + Tensor to_split(DT_INT64, TensorShape({10, 2})); + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 2; ++j) { + to_split.matrix()(i, j) = 2 * i + j; + } + } + + std::vector sizes = {1, 4, 5}; + std::vector splits; + TF_ASSERT_OK(tensor::Split(to_split, sizes, &splits)); + ASSERT_EQ(sizes.size(), splits.size()); + + int offset = 0; + for (size_t entry = 0; entry < splits.size(); ++entry) { + const int64 size = sizes[entry]; + const Tensor& split = splits[entry]; + + ASSERT_EQ(TensorShape({size, 2}), split.shape()); + for (int i = offset; i < offset + size; ++i) { + for (int j = 0; j < 2; ++j) { + EXPECT_EQ(2 * i + j, split.matrix()(i - offset, j)); + } + } + + offset += size; + } +} + +TEST(TensorUtil, ConcatSplitStrings) { + Tensor x(DT_STRING, TensorShape({4, 3})); + for (int i = 0; i < 4 * 3; ++i) { + x.flat()(i) = strings::StrCat("foo_", i); + } + + std::vector split; + TF_ASSERT_OK(tensor::Split(x, {2, 1, 1}, &split)); + Tensor x_round_tripped; + TF_ASSERT_OK(tensor::Concat(split, &x_round_tripped)); + ASSERT_EQ(x.shape(), x_round_tripped.shape()); + for (int i = 0; i < 4 * 3; ++i) { + EXPECT_EQ(x.flat()(i), x_round_tripped.flat()(i)); + } + + // Ensure that no memory is being shared between 'x' and 'x_round_tripped'. + for (int i = 0; i < 4 * 3; ++i) { + x_round_tripped.flat()(i) = strings::StrCat("bar_", i); + } + for (int i = 0; i < 4 * 3; ++i) { + EXPECT_NE(x.flat()(i), x_round_tripped.flat()(i)); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tracking_allocator.cc b/tracking_allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..239dfd13ec2e45acb0a65700f2a8882c61fc03b3 --- /dev/null +++ b/tracking_allocator.cc @@ -0,0 +1,203 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/tracking_allocator.h" + +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +TrackingAllocator::TrackingAllocator(Allocator* allocator, bool track_sizes) + : allocator_(allocator), + ref_(1), + allocated_(0), + high_watermark_(0), + total_bytes_(0), + track_sizes_locally_(track_sizes && !allocator_->TracksAllocationSizes()), + next_allocation_id_(0) {} + +void* TrackingAllocator::AllocateRaw( + size_t alignment, size_t num_bytes, + const AllocationAttributes& allocation_attr) { + void* ptr = allocator_->AllocateRaw(alignment, num_bytes, allocation_attr); + // If memory is exhausted AllocateRaw returns nullptr, and we should + // pass this through to the caller + if (nullptr == ptr) { + return ptr; + } + if (allocator_->TracksAllocationSizes()) { + size_t allocated_bytes = allocator_->AllocatedSize(ptr); + { + mutex_lock lock(mu_); + allocated_ += allocated_bytes; + high_watermark_ = std::max(high_watermark_, allocated_); + total_bytes_ += allocated_bytes; + allocations_.emplace_back(allocated_bytes, Env::Default()->NowMicros()); + ++ref_; + } + } else if (track_sizes_locally_) { + // Call the underlying allocator to try to get the allocated size + // whenever possible, even when it might be slow. If this fails, + // use the requested size as an approximation. + size_t allocated_bytes = allocator_->AllocatedSizeSlow(ptr); + allocated_bytes = std::max(num_bytes, allocated_bytes); + mutex_lock lock(mu_); + next_allocation_id_ += 1; + Chunk chunk = {num_bytes, allocated_bytes, next_allocation_id_}; + in_use_.emplace(std::make_pair(ptr, chunk)); + allocated_ += allocated_bytes; + high_watermark_ = std::max(high_watermark_, allocated_); + total_bytes_ += allocated_bytes; + allocations_.emplace_back(allocated_bytes, Env::Default()->NowMicros()); + ++ref_; + } else { + mutex_lock lock(mu_); + total_bytes_ += num_bytes; + allocations_.emplace_back(num_bytes, Env::Default()->NowMicros()); + ++ref_; + } + return ptr; +} + +void TrackingAllocator::DeallocateRaw(void* ptr) { + // freeing a null ptr is a no-op + if (nullptr == ptr) { + return; + } + bool should_delete; + // fetch the following outside the lock in case the call to + // AllocatedSize is slow + bool tracks_allocation_sizes = allocator_->TracksAllocationSizes(); + size_t allocated_bytes = 0; + if (tracks_allocation_sizes) { + allocated_bytes = allocator_->AllocatedSize(ptr); + } else if (track_sizes_locally_) { + mutex_lock lock(mu_); + auto itr = in_use_.find(ptr); + if (itr != in_use_.end()) { + tracks_allocation_sizes = true; + allocated_bytes = (*itr).second.allocated_size; + in_use_.erase(itr); + } + } + Allocator* allocator = allocator_; + { + mutex_lock lock(mu_); + if (tracks_allocation_sizes) { + CHECK_GE(allocated_, allocated_bytes); + allocated_ -= allocated_bytes; + allocations_.emplace_back(-allocated_bytes, Env::Default()->NowMicros()); + } + should_delete = UnRef(); + } + allocator->DeallocateRaw(ptr); + if (should_delete) { + delete this; + } +} + +bool TrackingAllocator::TracksAllocationSizes() { + return track_sizes_locally_ || allocator_->TracksAllocationSizes(); +} + +size_t TrackingAllocator::RequestedSize(void* ptr) { + if (track_sizes_locally_) { + mutex_lock lock(mu_); + auto it = in_use_.find(ptr); + if (it != in_use_.end()) { + return (*it).second.requested_size; + } + return 0; + } else { + return allocator_->RequestedSize(ptr); + } +} + +size_t TrackingAllocator::AllocatedSize(void* ptr) { + if (track_sizes_locally_) { + mutex_lock lock(mu_); + auto it = in_use_.find(ptr); + if (it != in_use_.end()) { + return (*it).second.allocated_size; + } + return 0; + } else { + return allocator_->AllocatedSize(ptr); + } +} + +int64 TrackingAllocator::AllocationId(void* ptr) { + if (track_sizes_locally_) { + mutex_lock lock(mu_); + auto it = in_use_.find(ptr); + if (it != in_use_.end()) { + return (*it).second.allocation_id; + } + return 0; + } else { + return allocator_->AllocationId(ptr); + } +} + +void TrackingAllocator::GetStats(AllocatorStats* stats) { + allocator_->GetStats(stats); +} + +std::tuple TrackingAllocator::GetSizes() { + size_t high_watermark; + size_t total_bytes; + size_t still_live_bytes; + { + mutex_lock lock(mu_); + high_watermark = high_watermark_; + total_bytes = total_bytes_; + still_live_bytes = allocated_; + } + return std::make_tuple(total_bytes, high_watermark, still_live_bytes); +} + +gtl::InlinedVector TrackingAllocator::GetRecordsAndUnRef() { + bool should_delete; + gtl::InlinedVector allocations; + { + mutex_lock lock(mu_); + allocations.swap(allocations_); + should_delete = UnRef(); + } + if (should_delete) { + delete this; + } + return allocations; +} + +gtl::InlinedVector TrackingAllocator::GetCurrentRecords() { + gtl::InlinedVector allocations; + { + mutex_lock lock(mu_); + for (const AllocRecord& alloc : allocations_) { + allocations.push_back(alloc); + } + } + return allocations; +} + +bool TrackingAllocator::UnRef() { + CHECK_GE(ref_, 1); + --ref_; + return (ref_ == 0); +} + +} // end namespace tensorflow diff --git a/tracking_allocator.h b/tracking_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..a6c26c89e51f1fec01886672b91f863ee36bedc8 --- /dev/null +++ b/tracking_allocator.h @@ -0,0 +1,133 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_ +#define TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_ + +#include +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// TrackingAllocator is a wrapper for an Allocator. It keeps a running +// count of the number of bytes allocated through the wrapper. It is +// used by the Executor to "charge" allocations to particular Op +// executions. Each Op gets a separate TrackingAllocator wrapper +// around the underlying allocator. +// +// The implementation assumes the invariant that all calls to +// AllocateRaw by an Op (or work items spawned by the Op) will occur +// before the Op's Compute method returns. Thus the high watermark is +// established once Compute returns. +// +// DeallocateRaw can be called long after the Op has finished, +// e.g. when an output tensor is deallocated, and the wrapper cannot +// be deleted until the last of these calls has occurred. The +// TrackingAllocator keeps track of outstanding calls using a +// reference count, and deletes itself once the last call has been +// received and the high watermark has been retrieved. +struct AllocRecord { + AllocRecord(int64 a_btyes, int64 a_micros) + : alloc_bytes(a_btyes), alloc_micros(a_micros) {} + AllocRecord() : AllocRecord(0, 0) {} + + int64 alloc_bytes; + int64 alloc_micros; +}; + +class TrackingAllocator : public Allocator { + public: + explicit TrackingAllocator(Allocator* allocator, bool track_ids); + string Name() override { return allocator_->Name(); } + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + return AllocateRaw(alignment, num_bytes, AllocationAttributes()); + } + void* AllocateRaw(size_t alignment, size_t num_bytes, + const AllocationAttributes& allocation_attr) override; + void DeallocateRaw(void* ptr) override; + bool TracksAllocationSizes() override; + size_t RequestedSize(void* ptr) override; + size_t AllocatedSize(void* ptr) override; + int64 AllocationId(void* ptr) override; + void GetStats(AllocatorStats* stats) override; + + // If the underlying allocator tracks allocation sizes, this returns + // a tuple where the first value is the total number of bytes + // allocated through this wrapper, the second value is the high + // watermark of bytes allocated through this wrapper and the third value is + // the allocated bytes through this wrapper that are still alive. If the + // underlying allocator does not track allocation sizes the first + // value is the total number of bytes requested through this wrapper + // and the second and the third are 0. + // + std::tuple GetSizes(); + // After GetRecordsAndUnRef is called, the only further calls allowed + // on this wrapper are calls to DeallocateRaw with pointers that + // were allocated by this wrapper and have not yet been + // deallocated. After this call completes and all allocated pointers + // have been deallocated the wrapper will delete itself. + gtl::InlinedVector GetRecordsAndUnRef(); + // Returns a copy of allocation records collected so far. + gtl::InlinedVector GetCurrentRecords(); + + protected: + ~TrackingAllocator() override {} + + private: + bool UnRef() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + Allocator* allocator_; // not owned. + mutex mu_; + // the number of calls to AllocateRaw that have not yet been matched + // by a corresponding call to DeAllocateRaw, plus 1 if the Executor + // has not yet read out the high watermark. + int ref_ GUARDED_BY(mu_); + // the current number of outstanding bytes that have been allocated + // by this wrapper, or 0 if the underlying allocator does not track + // allocation sizes. + size_t allocated_ GUARDED_BY(mu_); + // the maximum number of outstanding bytes that have been allocated + // by this wrapper, or 0 if the underlying allocator does not track + // allocation sizes. + size_t high_watermark_ GUARDED_BY(mu_); + // the total number of bytes that have been allocated by this + // wrapper if the underlying allocator tracks allocation sizes, + // otherwise the total number of bytes that have been requested by + // this allocator. + size_t total_bytes_ GUARDED_BY(mu_); + + gtl::InlinedVector allocations_ GUARDED_BY(mu_); + + // Track allocations locally if requested in the constructor and the + // underlying allocator doesn't already do it for us. + const bool track_sizes_locally_; + struct Chunk { + size_t requested_size; + size_t allocated_size; + int64 allocation_id; + }; + std::unordered_map in_use_ GUARDED_BY(mu_); + int64 next_allocation_id_ GUARDED_BY(mu_); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_ diff --git a/tracking_allocator_test.cc b/tracking_allocator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e32a907f20f34183abbbc57b93c38197710fa51 --- /dev/null +++ b/tracking_allocator_test.cc @@ -0,0 +1,179 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/tracking_allocator.h" + +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class TestableSizeTrackingAllocator : public Allocator { + public: + string Name() override { return "test"; } + void* AllocateRaw(size_t /*alignment*/, size_t num_bytes) override { + void* ptr = port::Malloc(num_bytes); + size_map_[ptr] = num_bytes; + return ptr; + } + void DeallocateRaw(void* ptr) override { + const auto& iter = size_map_.find(ptr); + EXPECT_NE(size_map_.end(), iter); + size_map_.erase(iter); + port::Free(ptr); + } + bool TracksAllocationSizes() override { return true; } + size_t RequestedSize(void* ptr) override { + const auto& iter = size_map_.find(ptr); + EXPECT_NE(size_map_.end(), iter); + return iter->second; + } + void GetStats(AllocatorStats* stats) override { stats->Clear(); } + + private: + std::unordered_map size_map_; +}; + +class NoMemoryAllocator : public Allocator { + public: + string Name() override { return "test"; } + void* AllocateRaw(size_t /*alignment*/, size_t num_bytes) override { + return nullptr; + } + void DeallocateRaw(void* ptr) override {} + bool TracksAllocationSizes() override { return true; } + void GetStats(AllocatorStats* stats) override { stats->Clear(); } +}; + +TEST(TrackingAllocatorTest, SimpleNoTracking) { + Allocator* a = cpu_allocator(); + + EXPECT_FALSE(a->TracksAllocationSizes()); + + // Don't enable the tracking inside the tracking allocator. Since + // the cpu_allocator doesn't track allocations itself the tracking + // will be partial + TrackingAllocator* ta = new TrackingAllocator(a, false); + + void* p1 = ta->AllocateRaw(4, 4); + ta->DeallocateRaw(p1); + void* p2 = ta->AllocateRaw(4, 12); + + std::tuple sizes = ta->GetSizes(); + + EXPECT_EQ(16, std::get<0>(sizes)); + EXPECT_EQ(0, std::get<1>(sizes)); + EXPECT_EQ(0, std::get<2>(sizes)); + + ta->DeallocateRaw(p2); + auto records = ta->GetRecordsAndUnRef(); + EXPECT_EQ(4, records[0].alloc_bytes); + EXPECT_EQ(12, records[1].alloc_bytes); + + // This time enable the tracking inside the tracking allocator + ta = new TrackingAllocator(a, true); + p1 = ta->AllocateRaw(4, 4); + EXPECT_EQ(4, ta->RequestedSize(p1)); + EXPECT_LE(4, ta->AllocatedSize(p1)); + EXPECT_EQ(1, ta->AllocationId(p1)); + + ta->DeallocateRaw(p1); + p2 = ta->AllocateRaw(4, 12); + EXPECT_EQ(12, ta->RequestedSize(p2)); + EXPECT_LE(12, ta->AllocatedSize(p2)); + EXPECT_EQ(2, ta->AllocationId(p2)); + + sizes = ta->GetSizes(); + + EXPECT_LE(16, std::get<0>(sizes)); + EXPECT_LE(12, std::get<1>(sizes)); + EXPECT_LE(12, std::get<2>(sizes)); + + ta->DeallocateRaw(p2); + records = ta->GetRecordsAndUnRef(); + EXPECT_LE(4, records[0].alloc_bytes); + EXPECT_GE(-4, records[1].alloc_bytes); + EXPECT_LE(12, records[2].alloc_bytes); + EXPECT_GE(-12, records[3].alloc_bytes); +} + +TEST(TrackingAllocatorTest, SimpleTracking) { + TestableSizeTrackingAllocator a = TestableSizeTrackingAllocator(); + + EXPECT_TRUE(a.TracksAllocationSizes()); + + TrackingAllocator* ta = new TrackingAllocator(&a, false); + + void* p1 = ta->AllocateRaw(4, 12); + ta->DeallocateRaw(p1); + void* p2 = ta->AllocateRaw(4, 4); + + std::tuple sizes = ta->GetSizes(); + + EXPECT_EQ(16, std::get<0>(sizes)); + EXPECT_EQ(12, std::get<1>(sizes)); + EXPECT_EQ(4, std::get<2>(sizes)); + + ta->DeallocateRaw(p2); + + auto records = ta->GetRecordsAndUnRef(); + EXPECT_EQ(12, records[0].alloc_bytes); + EXPECT_EQ(-12, records[1].alloc_bytes); + EXPECT_EQ(4, records[2].alloc_bytes); + EXPECT_EQ(-4, records[3].alloc_bytes); +} + +TEST(TrackingAllocatorTest, OutOfMemory) { + NoMemoryAllocator a; + + EXPECT_TRUE(a.TracksAllocationSizes()); + + TrackingAllocator* ta = new TrackingAllocator(&a, false); + + void* p1 = ta->AllocateRaw(4, 12); + EXPECT_EQ(nullptr, p1); + + std::tuple sizes = ta->GetSizes(); + + EXPECT_EQ(0, std::get<0>(sizes)); + EXPECT_EQ(0, std::get<1>(sizes)); + EXPECT_EQ(0, std::get<2>(sizes)); + + EXPECT_EQ(0, ta->GetRecordsAndUnRef().size()); +} + +TEST(TrackingAllocatorTest, FreeNullPtr) { + NoMemoryAllocator a; + + EXPECT_TRUE(a.TracksAllocationSizes()); + + TrackingAllocator* ta = new TrackingAllocator(&a, false); + + ta->DeallocateRaw(nullptr); + + std::tuple sizes = ta->GetSizes(); + + EXPECT_EQ(0, std::get<0>(sizes)); + EXPECT_EQ(0, std::get<1>(sizes)); + EXPECT_EQ(0, std::get<2>(sizes)); + + EXPECT_EQ(0, ta->GetRecordsAndUnRef().size()); +} + +} // namespace tensorflow diff --git a/type_index.h b/type_index.h new file mode 100644 index 0000000000000000000000000000000000000000..b978d90fa8001339a3a7ab27f3a428a350f65d46 --- /dev/null +++ b/type_index.h @@ -0,0 +1,87 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_TYPE_INDEX_H_ +#define TENSORFLOW_FRAMEWORK_TYPE_INDEX_H_ + +#include +#if defined(__GXX_RTTI) || defined(_CPPRTTI) +#include +#include +#endif // __GXX_RTTI + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// On some platforms, we would like to avoid using RTTI in order to have smaller +// binary sizes. The following #ifdef section provides a non-RTTI +// replacement for std::type_index (with a minimal set of functions needed by +// the TensorFlow framework, and more can be added if necessary). +#if !defined(__GXX_RTTI) && !defined(_CPPRTTI) + +// A thin TypeIndex class that mimics std::type_index but does not use RTTI. As +// a result, it does not provide the actual name of the type, and only returns a +// pre-baked string specifying that RTTI is disabled. +// The hash code provided in this class is unique for each class. However, it is +// generated at runtime so this hash code should not be serialized - the value +// for the same type can change from run to run. +class TypeIndex { + public: + TypeIndex(const TypeIndex& src) : hash_(src.hash_) {} + TypeIndex& operator=(const TypeIndex& src) { + hash_ = src.hash_; + return *this; + } + bool operator==(const TypeIndex& rhs) const { return (hash_ == rhs.hash_); } + bool operator!=(const TypeIndex& rhs) const { return (hash_ != rhs.hash_); } + ~TypeIndex() {} + + const char* name() const { return "[RTTI disabled for Android]"; } + uint64 hash_code() const { return hash_; } + + // Returns a TypeIndex object that corresponds to a typename. + template + static TypeIndex Make() { + static bool hash_bit[1]; + return TypeIndex(static_cast(reinterpret_cast(hash_bit))); + } + + private: + // We hide the constructor of the TypeIndex class. Use the templated + // Make() function to create a TypeIndex object. + TypeIndex(const uint64 hash) : hash_(hash) {} + uint64 hash_; +}; + +template +inline TypeIndex MakeTypeIndex() { + return TypeIndex::Make(); +} + +#else // __GXX_RTTI + +// In the presence of RTTI, we will simply delegate to std::type_index for +// runtime type inference. +typedef std::type_index TypeIndex; +template +inline TypeIndex MakeTypeIndex() { + return TypeIndex(typeid(T)); +} + +#endif // __GXX_RTTI +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TYPE_INDEX_H_ diff --git a/type_traits.h b/type_traits.h new file mode 100644 index 0000000000000000000000000000000000000000..e8351e494f91c3a428be9ff0fd1a2d3286b125a3 --- /dev/null +++ b/type_traits.h @@ -0,0 +1,109 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_ +#define TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_ + +#include +#include + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Functions to define quantization attribute of types. +struct true_type { + static const bool value = true; +}; +struct false_type { + static const bool value = false; +}; + +// Default is_quantized is false. +template +struct is_quantized : false_type {}; + +// Specialize the quantized types. +template <> +struct is_quantized : true_type {}; +template <> +struct is_quantized : true_type {}; +template <> +struct is_quantized : true_type {}; +template <> +struct is_quantized : true_type {}; +template <> +struct is_quantized : true_type {}; + +// Default is_complex is false. +template +struct is_complex : false_type {}; + +// Specialize std::complex and std::complex types. +template <> +struct is_complex> : true_type {}; +template <> +struct is_complex> : true_type {}; + +// is_simple_type::value if T[] can be safely constructed and destructed +// without running T() and ~T(). We do not use std::is_trivial +// directly because std::complex and std::complex are +// not trivial, but their arrays can be constructed and destructed +// without running their default ctors and dtors. +template +struct is_simple_type { + static constexpr bool value = + std::is_trivial::value || std::is_same::value || + std::is_same::value || std::is_same::value || + is_quantized::value || std::is_same::value; +}; + +} // namespace tensorflow + +// Define numeric limits for our quantized as subclasses of the +// standard types. +namespace std { +template <> +class numeric_limits + : public numeric_limits {}; +template <> +class numeric_limits + : public numeric_limits {}; +template <> +class numeric_limits + : public numeric_limits {}; +template <> +class numeric_limits + : public numeric_limits {}; +template <> +class numeric_limits + : public numeric_limits {}; + +// Specialize is_signed for quantized types. +template <> +struct is_signed : public is_signed {}; +template <> +struct is_signed : public is_signed {}; +template <> +struct is_signed : public is_signed {}; +template <> +struct is_signed : public is_signed {}; +template <> +struct is_signed : public is_signed {}; + +} // namespace std + +#endif // TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_ diff --git a/types.cc b/types.cc new file mode 100644 index 0000000000000000000000000000000000000000..58354d6f4edea1f29ba033f2579324d400a532ab --- /dev/null +++ b/types.cc @@ -0,0 +1,408 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/register_types.h" + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +bool DeviceType::operator<(const DeviceType& other) const { + return type_ < other.type_; +} + +bool DeviceType::operator==(const DeviceType& other) const { + return type_ == other.type_; +} + +std::ostream& operator<<(std::ostream& os, const DeviceType& d) { + os << d.type(); + return os; +} + +const char* const DEVICE_CPU = "CPU"; +const char* const DEVICE_GPU = "GPU"; +const char* const DEVICE_SYCL = "SYCL"; + +const std::string DeviceName::value = DEVICE_CPU; +#if GOOGLE_CUDA +const std::string DeviceName::value = DEVICE_GPU; +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +const std::string DeviceName::value = DEVICE_SYCL; +#endif // TENSORFLOW_USE_SYCL + +string DataTypeString(DataType dtype) { + if (IsRefType(dtype)) { + DataType non_ref = static_cast(dtype - kDataTypeRefOffset); + return strings::StrCat(DataTypeString(non_ref), "_ref"); + } + switch (dtype) { + case DT_INVALID: + return "INVALID"; + case DT_FLOAT: + return "float"; + case DT_DOUBLE: + return "double"; + case DT_INT32: + return "int32"; + case DT_UINT32: + return "uint32"; + case DT_UINT8: + return "uint8"; + case DT_UINT16: + return "uint16"; + case DT_INT16: + return "int16"; + case DT_INT8: + return "int8"; + case DT_STRING: + return "string"; + case DT_COMPLEX64: + return "complex64"; + case DT_COMPLEX128: + return "complex128"; + case DT_INT64: + return "int64"; + case DT_UINT64: + return "uint64"; + case DT_BOOL: + return "bool"; + case DT_QINT8: + return "qint8"; + case DT_QUINT8: + return "quint8"; + case DT_QUINT16: + return "quint16"; + case DT_QINT16: + return "qint16"; + case DT_QINT32: + return "qint32"; + case DT_BFLOAT16: + return "bfloat16"; + case DT_HALF: + return "half"; + case DT_RESOURCE: + return "resource"; + case DT_VARIANT: + return "variant"; + default: + LOG(ERROR) << "Unrecognized DataType enum value " << dtype; + return strings::StrCat("unknown dtype enum (", dtype, ")"); + } +} + +bool DataTypeFromString(StringPiece sp, DataType* dt) { + if (sp.ends_with("_ref")) { + sp.remove_suffix(4); + DataType non_ref; + if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) { + *dt = static_cast(non_ref + kDataTypeRefOffset); + return true; + } else { + return false; + } + } + + if (sp == "float" || sp == "float32") { + *dt = DT_FLOAT; + return true; + } else if (sp == "double" || sp == "float64") { + *dt = DT_DOUBLE; + return true; + } else if (sp == "int32") { + *dt = DT_INT32; + return true; + } else if (sp == "uint32") { + *dt = DT_UINT32; + return true; + } else if (sp == "uint8") { + *dt = DT_UINT8; + return true; + } else if (sp == "uint16") { + *dt = DT_UINT16; + return true; + } else if (sp == "int16") { + *dt = DT_INT16; + return true; + } else if (sp == "int8") { + *dt = DT_INT8; + return true; + } else if (sp == "string") { + *dt = DT_STRING; + return true; + } else if (sp == "complex64") { + *dt = DT_COMPLEX64; + return true; + } else if (sp == "complex128") { + *dt = DT_COMPLEX128; + return true; + } else if (sp == "int64") { + *dt = DT_INT64; + return true; + } else if (sp == "uint64") { + *dt = DT_UINT64; + return true; + } else if (sp == "bool") { + *dt = DT_BOOL; + return true; + } else if (sp == "qint8") { + *dt = DT_QINT8; + return true; + } else if (sp == "quint8") { + *dt = DT_QUINT8; + return true; + } else if (sp == "qint16") { + *dt = DT_QINT16; + return true; + } else if (sp == "quint16") { + *dt = DT_QUINT16; + return true; + } else if (sp == "qint32") { + *dt = DT_QINT32; + return true; + } else if (sp == "bfloat16") { + *dt = DT_BFLOAT16; + return true; + } else if (sp == "half" || sp == "float16") { + *dt = DT_HALF; + return true; + } else if (sp == "resource") { + *dt = DT_RESOURCE; + return true; + } else if (sp == "variant") { + *dt = DT_VARIANT; + return true; + } + return false; +} + +string DeviceTypeString(const DeviceType& device_type) { + return device_type.type(); +} + +string DataTypeSliceString(const DataTypeSlice types) { + string out; + for (auto it = types.begin(); it != types.end(); ++it) { + strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "), + DataTypeString(*it)); + } + return out; +} + +DataTypeVector AllTypes() { + return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, + DT_UINT16, DT_INT8, DT_STRING, DT_COMPLEX64, DT_COMPLEX128, + DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT16, + DT_QUINT16, DT_QINT32, DT_HALF, DT_RESOURCE, DT_VARIANT, + DT_UINT32, DT_UINT64, DT_BFLOAT16}; +} + +#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) + +DataTypeVector RealNumberTypes() { + return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, + DT_INT8, DT_UINT16, DT_HALF, DT_UINT32, DT_UINT64, DT_BFLOAT16}; +} + +DataTypeVector QuantizedTypes() { + return {DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32}; +} + +DataTypeVector RealAndQuantizedTypes() { + return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, + DT_UINT16, DT_UINT16, DT_INT8, DT_QINT8, DT_QUINT8, + DT_QINT16, DT_QUINT16, DT_QINT32, DT_HALF, DT_BFLOAT16}; +} + +DataTypeVector NumberTypes() { + return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, + DT_UINT16, DT_INT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, + DT_QINT8, DT_QUINT8, DT_QINT32, DT_HALF, DT_UINT32, + DT_UINT64, DT_BFLOAT16}; +} + +#elif defined(__ANDROID_TYPES_FULL__) + +DataTypeVector RealNumberTypes() { + return {DT_FLOAT, DT_INT32, DT_INT64, DT_HALF}; +} + +DataTypeVector NumberTypes() { + return {DT_FLOAT, DT_INT32, DT_INT64, DT_QINT8, + DT_QUINT8, DT_QINT32, DT_HALF}; +} + +DataTypeVector QuantizedTypes() { + return {DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32}; +} + +DataTypeVector RealAndQuantizedTypes() { + return {DT_FLOAT, DT_INT32, DT_INT64, DT_QINT8, DT_QUINT8, + DT_QINT16, DT_QUINT16, DT_QINT32, DT_HALF}; +} + +#else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) + +DataTypeVector RealNumberTypes() { return {DT_FLOAT, DT_INT32}; } + +DataTypeVector NumberTypes() { + return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, DT_QINT32}; +} + +DataTypeVector QuantizedTypes() { + return {DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32}; +} + +DataTypeVector RealAndQuantizedTypes() { + return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, + DT_QINT16, DT_QUINT16, DT_QINT32}; +} + +#endif // defined(IS_MOBILE_PLATFORM) + +// TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying +// is_simple in tensor.cc (and possible choose a more general name?) +bool DataTypeCanUseMemcpy(DataType dt) { + switch (dt) { + case DT_FLOAT: + case DT_DOUBLE: + case DT_INT32: + case DT_UINT32: + case DT_UINT8: + case DT_UINT16: + case DT_INT16: + case DT_INT8: + case DT_COMPLEX64: + case DT_COMPLEX128: + case DT_INT64: + case DT_UINT64: + case DT_BOOL: + case DT_QINT8: + case DT_QUINT8: + case DT_QINT16: + case DT_QUINT16: + case DT_QINT32: + case DT_BFLOAT16: + case DT_HALF: + return true; + default: + return false; + } +} + +bool DataTypeAlwaysOnHost(DataType dt) { + // Includes DT_STRING and DT_RESOURCE. + switch (dt) { + case DT_STRING: + case DT_STRING_REF: + case DT_RESOURCE: + return true; + default: + return false; + } +} + +bool DataTypeIsFloating(DataType dt) { + switch (dt) { + case DT_HALF: + case DT_BFLOAT16: + case DT_FLOAT: + case DT_DOUBLE: + return true; + default: + return false; + } +} + +bool DataTypeIsComplex(DataType dt) { + switch (dt) { + case DT_COMPLEX64: + case DT_COMPLEX128: + return true; + default: + return false; + } +} + +bool DataTypeIsQuantized(DataType dt) { + switch (dt) { + case DT_QINT8: + case DT_QUINT8: + case DT_QINT16: + case DT_QUINT16: + case DT_QINT32: + return true; + default: + return false; + } +} + +bool DataTypeIsInteger(DataType dt) { + switch (dt) { + case DT_INT8: + case DT_UINT8: + case DT_INT16: + case DT_UINT16: + case DT_INT32: + case DT_UINT32: + case DT_INT64: + case DT_UINT64: + return true; + default: + return false; + } +} + +bool DataTypeIsUnsigned(DataType dt) { + switch (dt) { + case DT_UINT8: + case DT_UINT16: + case DT_UINT32: + case DT_UINT64: + return true; + default: + return false; + } +} + +int DataTypeSize(DataType dt) { +#define CASE(T) \ + case DataTypeToEnum::value: \ + return sizeof(T); + switch (dt) { + TF_CALL_POD_TYPES(CASE); + TF_CALL_QUANTIZED_TYPES(CASE); + // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since + // they are not supported widely, but are explicitly listed here for + // bitcast. + TF_CALL_qint16(CASE); + TF_CALL_quint16(CASE); + + // uint32 and uint64 aren't included in TF_CALL_POD_TYPES because we + // don't want to define kernels for them at this stage to avoid binary + // bloat. + TF_CALL_uint32(CASE); + TF_CALL_uint64(CASE); + default: + return 0; + } +#undef CASE +} + +} // namespace tensorflow diff --git a/types.h b/types.h new file mode 100644 index 0000000000000000000000000000000000000000..27005c0e93267ff4f91d470a011be6d673fe8cc2 --- /dev/null +++ b/types.h @@ -0,0 +1,249 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_TYPES_H_ + +#include +#include +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +// Disable clang-format to prevent 'FixedPoint' header from being included +// before 'Tensor' header on which it depends. +// clang-format off +#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" +// clang-format on +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// MemoryType is used to describe whether input or output Tensors of +// an OpKernel should reside in "Host memory" (e.g., CPU memory) or +// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU +// devices). +enum MemoryType { + DEVICE_MEMORY = 0, + HOST_MEMORY = 1, +}; + +// A DeviceType is just a string, but we wrap it up in a class to give +// some type checking as we're passing these around +class DeviceType { + public: + DeviceType(const char* type) // NOLINT(runtime/explicit) + : type_(type) {} + + explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {} + + const char* type() const { return type_.c_str(); } + const string& type_string() const { return type_; } + + bool operator<(const DeviceType& other) const; + bool operator==(const DeviceType& other) const; + bool operator!=(const DeviceType& other) const { return !(*this == other); } + + private: + string type_; +}; +std::ostream& operator<<(std::ostream& os, const DeviceType& d); + +// Convenient constants that can be passed to a DeviceType constructor +TF_EXPORT extern const char* const DEVICE_CPU; // "CPU" +TF_EXPORT extern const char* const DEVICE_GPU; // "GPU" +TF_EXPORT extern const char* const DEVICE_SYCL; // "SYCL" + +template +struct DeviceName {}; + +template <> +struct DeviceName { + static const std::string value; +}; + +#if GOOGLE_CUDA +template <> +struct DeviceName { + static const std::string value; +}; +#endif // GOOGLE_CUDA + +#ifdef TENSORFLOW_USE_SYCL +template <> +struct DeviceName { + static const std::string value; +}; +#endif // TENSORFLOW_USE_SYCL + +typedef gtl::InlinedVector MemoryTypeVector; +typedef gtl::ArraySlice MemoryTypeSlice; + +typedef gtl::InlinedVector DataTypeVector; +typedef gtl::ArraySlice DataTypeSlice; + +typedef gtl::InlinedVector DeviceTypeVector; + +// Convert the enums to strings for errors: +string DataTypeString(DataType dtype); +string DeviceTypeString(const DeviceType& device_type); +string DataTypeSliceString(const DataTypeSlice dtypes); +inline string DataTypeVectorString(const DataTypeVector& dtypes) { + return DataTypeSliceString(dtypes); +} + +// If "sp" names a valid type, store it in "*dt" and return true. Otherwise, +// return false. +bool DataTypeFromString(StringPiece sp, DataType* dt); + +// DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc. +enum { kDataTypeRefOffset = 100 }; +inline bool IsRefType(DataType dtype) { + return dtype > static_cast(kDataTypeRefOffset); +} +inline DataType MakeRefType(DataType dtype) { + DCHECK(!IsRefType(dtype)); + return static_cast(dtype + kDataTypeRefOffset); +} +inline DataType RemoveRefType(DataType dtype) { + DCHECK(IsRefType(dtype)); + return static_cast(dtype - kDataTypeRefOffset); +} +inline DataType BaseType(DataType dtype) { + return IsRefType(dtype) ? RemoveRefType(dtype) : dtype; +} + +// Returns true if the actual type is the same as or ref of the expected type. +inline bool TypesCompatible(DataType expected, DataType actual) { + return expected == actual || expected == BaseType(actual); +} + +// Does not include _ref types. +DataTypeVector AllTypes(); + +// Return the list of all numeric types. +// NOTE: On Android, we only include the float and int32 types for now. +DataTypeVector RealNumberTypes(); // Types that support '<' and '>'. +DataTypeVector NumberTypes(); // Includes complex and quantized types. + +DataTypeVector QuantizedTypes(); +DataTypeVector RealAndQuantizedTypes(); // Types that support '<' and + // '>', including quantized + // types + +// Validates type T for whether it is a supported DataType. +template +struct IsValidDataType; + +// DataTypeToEnum::v() and DataTypeToEnum::value are the DataType +// constants for T, e.g. DataTypeToEnum::v() is DT_FLOAT. +template +struct DataTypeToEnum { + static_assert(IsValidDataType::value, "Specified Data Type not supported"); +}; // Specializations below + +// EnumToDataType::Type is the type for DataType constant VALUE, e.g. +// EnumToDataType::Type is float. +template +struct EnumToDataType {}; // Specializations below + +// Template specialization for both DataTypeToEnum and EnumToDataType. +#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \ + template <> \ + struct DataTypeToEnum { \ + static DataType v() { return ENUM; } \ + static DataType ref() { return MakeRefType(ENUM); } \ + static constexpr DataType value = ENUM; \ + }; \ + template <> \ + struct IsValidDataType { \ + static constexpr bool value = true; \ + }; \ + template <> \ + struct EnumToDataType { \ + typedef TYPE Type; \ + } + +MATCH_TYPE_AND_ENUM(float, DT_FLOAT); +MATCH_TYPE_AND_ENUM(double, DT_DOUBLE); +MATCH_TYPE_AND_ENUM(int32, DT_INT32); +MATCH_TYPE_AND_ENUM(uint32, DT_UINT32); +MATCH_TYPE_AND_ENUM(uint16, DT_UINT16); +MATCH_TYPE_AND_ENUM(uint8, DT_UINT8); +MATCH_TYPE_AND_ENUM(int16, DT_INT16); +MATCH_TYPE_AND_ENUM(int8, DT_INT8); +MATCH_TYPE_AND_ENUM(string, DT_STRING); +MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64); +MATCH_TYPE_AND_ENUM(complex128, DT_COMPLEX128); +MATCH_TYPE_AND_ENUM(int64, DT_INT64); +MATCH_TYPE_AND_ENUM(uint64, DT_UINT64); +MATCH_TYPE_AND_ENUM(bool, DT_BOOL); +MATCH_TYPE_AND_ENUM(qint8, DT_QINT8); +MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8); +MATCH_TYPE_AND_ENUM(qint16, DT_QINT16); +MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16); +MATCH_TYPE_AND_ENUM(qint32, DT_QINT32); +MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16); +MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF); +MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE); +MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT); + +#undef MATCH_TYPE_AND_ENUM + +// All types not specialized are marked invalid. +template +struct IsValidDataType { + static constexpr bool value = false; +}; + +// Extra validity checking; not part of public API. +static_assert(IsValidDataType::value, "Incorrect impl for int64"); +static_assert(IsValidDataType::value, "Incorrect impl for int32"); + +bool DataTypeCanUseMemcpy(DataType dt); + +// Returns true iff 'dt' is a real, non-quantized floating point type. +bool DataTypeIsFloating(DataType dt); + +// Returns true iff 'dt' is a complex type. +bool DataTypeIsComplex(DataType dt); + +bool DataTypeIsQuantized(DataType dt); + +// Is the dtype nonquantized integral? +bool DataTypeIsInteger(DataType dt); + +// Is the dtype an unsigned integral type? +bool DataTypeIsUnsigned(DataType dt); + +// Returns a 0 on failure +int DataTypeSize(DataType dt); + +// Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE. +// For DT_RESOURCE, the handle always sits on host (even if the underlying +// object has device-allocated resources). +bool DataTypeAlwaysOnHost(DataType dt); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TYPES_H_ diff --git a/types.proto b/types.proto new file mode 100644 index 0000000000000000000000000000000000000000..e003fd00106fbaeaf4322ed8d599049ba91e3e7d --- /dev/null +++ b/types.proto @@ -0,0 +1,74 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/c_api.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, +// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, +// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/types_test.cc b/types_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ddc9865633623561760bbcb06d1edf4eecec7a6 --- /dev/null +++ b/types_test.cc @@ -0,0 +1,151 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/types.h" + +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(TypesTest, DeviceTypeName) { + EXPECT_EQ("CPU", DeviceTypeString(DeviceType(DEVICE_CPU))); + EXPECT_EQ("GPU", DeviceTypeString(DeviceType(DEVICE_GPU))); + EXPECT_EQ("SYCL", DeviceTypeString(DeviceType(DEVICE_SYCL))); +} + +TEST(TypesTest, kDataTypeRefOffset) { + // Basic sanity check + EXPECT_EQ(DT_FLOAT + kDataTypeRefOffset, DT_FLOAT_REF); + + // Use the meta-data provided by proto2 to iterate through the basic + // types and validate that adding kDataTypeRefOffset gives the + // corresponding reference type. + const auto* enum_descriptor = DataType_descriptor(); + int e = DataType_MIN; + if (e == DT_INVALID) ++e; + int e_ref = e + kDataTypeRefOffset; + EXPECT_FALSE(DataType_IsValid(e_ref - 1)) + << "Reference enum " + << enum_descriptor->FindValueByNumber(e_ref - 1)->name() + << " without corresponding base enum with value " << e - 1; + for (; + DataType_IsValid(e) && DataType_IsValid(e_ref) && e_ref <= DataType_MAX; + ++e, ++e_ref) { + string enum_name = enum_descriptor->FindValueByNumber(e)->name(); + string enum_ref_name = enum_descriptor->FindValueByNumber(e_ref)->name(); + EXPECT_EQ(enum_name + "_REF", enum_ref_name) + << enum_name << "_REF should have value " << e_ref << " not " + << enum_ref_name; + // Validate DataTypeString() as well. + DataType dt_e = static_cast(e); + DataType dt_e_ref = static_cast(e_ref); + EXPECT_EQ(DataTypeString(dt_e) + "_ref", DataTypeString(dt_e_ref)); + + // Test DataTypeFromString reverse conversion + DataType dt_e2, dt_e2_ref; + EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e), &dt_e2)); + EXPECT_EQ(dt_e, dt_e2); + EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e_ref), &dt_e2_ref)); + EXPECT_EQ(dt_e_ref, dt_e2_ref); + } + ASSERT_FALSE(DataType_IsValid(e)) + << "Should define " << enum_descriptor->FindValueByNumber(e)->name() + << "_REF to be " << e_ref; + ASSERT_FALSE(DataType_IsValid(e_ref)) + << "Extra reference enum " + << enum_descriptor->FindValueByNumber(e_ref)->name() + << " without corresponding base enum with value " << e; + ASSERT_LT(DataType_MAX, e_ref) << "Gap in reference types, missing value for " + << e_ref; + + // Make sure there are no enums defined after the last regular type before + // the first reference type. + for (; e < DataType_MIN + kDataTypeRefOffset; ++e) { + EXPECT_FALSE(DataType_IsValid(e)) + << "Discontinuous enum value " + << enum_descriptor->FindValueByNumber(e)->name() << " = " << e; + } +} + +TEST(TypesTest, DataTypeFromString) { + DataType dt; + ASSERT_TRUE(DataTypeFromString("int32", &dt)); + EXPECT_EQ(DT_INT32, dt); + ASSERT_TRUE(DataTypeFromString("int32_ref", &dt)); + EXPECT_EQ(DT_INT32_REF, dt); + EXPECT_FALSE(DataTypeFromString("int32_ref_ref", &dt)); + EXPECT_FALSE(DataTypeFromString("foo", &dt)); + EXPECT_FALSE(DataTypeFromString("foo_ref", &dt)); + ASSERT_TRUE(DataTypeFromString("int64", &dt)); + EXPECT_EQ(DT_INT64, dt); + ASSERT_TRUE(DataTypeFromString("int64_ref", &dt)); + EXPECT_EQ(DT_INT64_REF, dt); + ASSERT_TRUE(DataTypeFromString("quint8_ref", &dt)); + EXPECT_EQ(DT_QUINT8_REF, dt); + ASSERT_TRUE(DataTypeFromString("bfloat16", &dt)); + EXPECT_EQ(DT_BFLOAT16, dt); +} + +template +static bool GetQuantized() { + return is_quantized::value; +} + +TEST(TypesTest, QuantizedTypes) { + // NOTE: GUnit cannot parse is::quantized::value() within the + // EXPECT_TRUE() clause, so we delegate through a template function. + EXPECT_TRUE(GetQuantized()); + EXPECT_TRUE(GetQuantized()); + EXPECT_TRUE(GetQuantized()); + + EXPECT_FALSE(GetQuantized()); + EXPECT_FALSE(GetQuantized()); + EXPECT_FALSE(GetQuantized()); + EXPECT_FALSE(GetQuantized()); + + EXPECT_TRUE(DataTypeIsQuantized(DT_QINT8)); + EXPECT_TRUE(DataTypeIsQuantized(DT_QUINT8)); + EXPECT_TRUE(DataTypeIsQuantized(DT_QINT32)); + + EXPECT_FALSE(DataTypeIsQuantized(DT_INT8)); + EXPECT_FALSE(DataTypeIsQuantized(DT_UINT8)); + EXPECT_FALSE(DataTypeIsQuantized(DT_UINT16)); + EXPECT_FALSE(DataTypeIsQuantized(DT_INT16)); + EXPECT_FALSE(DataTypeIsQuantized(DT_INT32)); + EXPECT_FALSE(DataTypeIsQuantized(DT_BFLOAT16)); +} + +TEST(TypesTest, ComplexTypes) { + EXPECT_TRUE(DataTypeIsComplex(DT_COMPLEX64)); + EXPECT_TRUE(DataTypeIsComplex(DT_COMPLEX128)); + EXPECT_FALSE(DataTypeIsComplex(DT_FLOAT)); + EXPECT_FALSE(DataTypeIsComplex(DT_DOUBLE)); +} + +TEST(TypesTest, IntegerTypes) { + for (auto dt : AllTypes()) { + const string name = DataTypeString(dt); + const StringPiece n = name; + EXPECT_EQ(DataTypeIsInteger(dt), + n.starts_with("int") || n.starts_with("uint")) + << "DataTypeInteger failed for " << name; + } +} + +} // namespace +} // namespace tensorflow diff --git a/unique_tensor_references.cc b/unique_tensor_references.cc new file mode 100644 index 0000000000000000000000000000000000000000..ab33d9ede6cde3431857b368a7a213d4323df313 --- /dev/null +++ b/unique_tensor_references.cc @@ -0,0 +1,91 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/unique_tensor_references.h" + +namespace tensorflow { + +UniqueTensorReferences::~UniqueTensorReferences() { + if (!frozen_) { + // The references were not retrieved so discard them to avoid + // leaking memory. + TensorReferenceVector refs; + FreezeAndReturnReferences(&refs); + for (auto& tensor : refs) { + tensor.Unref(); + } + } + delete referenced_tensors_set_; +} + +void UniqueTensorReferences::Add(const Tensor& tensor) { + DCHECK(!frozen_); + // Do nothing if the tensor has a null buffer. + if (tensor.IsInitialized() && tensor.NumElements() > 0) { + if (referenced_tensors_set_ != nullptr) { + // There are enough tensors that we are using a hash set to + // de-duplicate. + const TensorReference tensor_ref(tensor); + if (!referenced_tensors_set_->insert(tensor_ref).second) { + // The tensor was a duplicate, so discard the reference. + tensor_ref.Unref(); + } + } else { + for (size_t i = 0; i < referenced_tensors_vector_.size(); ++i) { + if (referenced_tensors_vector_[i].SharesBufferWith(tensor)) { + // tensor is a duplicate, so nothing to do. + return; + } + } + referenced_tensors_vector_.push_back(TensorReference(tensor)); + if (kInVector == referenced_tensors_vector_.size()) { + // There are too many tensors to keep using the N^2 algorithm + // so start de-duplicating using a set. + // Transfer the refs from the vector to the set. + DCHECK(referenced_tensors_set_ == nullptr); + referenced_tensors_set_ = new ReferencedTensorsSet; + referenced_tensors_set_->reserve(kInVector); + referenced_tensors_set_->insert(referenced_tensors_vector_.begin(), + referenced_tensors_vector_.end()); + DCHECK_EQ(kInVector, referenced_tensors_set_->size()); + referenced_tensors_vector_.clear(); + } + } + } +} + +void UniqueTensorReferences::FreezeAndReturnReferences( + TensorReferenceVector* out_vector) { + // Prevent any further additions. + frozen_ = true; + if (referenced_tensors_set_ != nullptr) { + DCHECK(referenced_tensors_vector_.empty()); + out_vector->reserve(referenced_tensors_set_->size()); + for (const auto& ref : *referenced_tensors_set_) { + out_vector->push_back(ref); + } + referenced_tensors_set_->clear(); + delete referenced_tensors_set_; + referenced_tensors_set_ = nullptr; + } else { + out_vector->reserve(referenced_tensors_vector_.size()); + for (const auto& ref : referenced_tensors_vector_) { + out_vector->push_back(ref); + } + referenced_tensors_vector_.clear(); + } +} + +} // namespace tensorflow diff --git a/unique_tensor_references.h b/unique_tensor_references.h new file mode 100644 index 0000000000000000000000000000000000000000..7520034301aed06ad516ab8715aadbbc9bc84003 --- /dev/null +++ b/unique_tensor_references.h @@ -0,0 +1,81 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_UNIQUE_TENSOR_REFERENCES_H_ +#define TENSORFLOW_FRAMEWORK_UNIQUE_TENSOR_REFERENCES_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_reference.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// Helper class to maintain a unique set of tensor references. In the +// common case there are not many references, so an inline vector is +// used for <= kInVector unique elements, defaulting to 4 since that +// is the inlined size of TensorReferenceVector. To avoid N^2 +// operations when adding N items, any larger number of unique tensor +// references switches to using an unordered set. +class UniqueTensorReferences { + public: + UniqueTensorReferences() : frozen_(false), referenced_tensors_set_(nullptr) {} + + ~UniqueTensorReferences(); + + // Adds a reference to tensor if its buffer is not already referenced. + void Add(const Tensor& tensor); + + // No more references may be added after this is called. The unique + // references are returning in out_vector. + void FreezeAndReturnReferences(TensorReferenceVector* out_vector); + + private: + // Up to kInVector elements are stored in reference_tensors_vector_ + // to avoid any allocations or hash computations in the common + // case. When more unique elements are added they move to + // referenced_tensors_set_ to avoid an N^2 algorithm on insert. + static const int kInVector = 4; // Must be >= 1. + + struct TensorReferenceEqualFn { + bool operator()(const TensorReference& t1, + const TensorReference& t2) const { + return t1.SharesBufferWith(t2); + } + }; + + struct TensorReferenceHashFn { + size_t operator()(const TensorReference& t) const { return t.BufferHash(); } + }; + + bool frozen_; + TensorReferenceVector referenced_tensors_vector_; + + typedef std::unordered_set + ReferencedTensorsSet; + // Lazily allocated hash set for when the number of tensors becomes too large. + // If this is non-NULL, then we use the hash set, otherwise, we use the + // referenced_tensors_vector_ (and do O(N^2) work per insertion). + ReferencedTensorsSet* referenced_tensors_set_; + + TF_DISALLOW_COPY_AND_ASSIGN(UniqueTensorReferences); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_UNIQUE_TENSOR_REFERENCES_H_ diff --git a/unique_tensor_references_test.cc b/unique_tensor_references_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..47d8f9145251a3c0d83df38bc48cda37ed0480ca --- /dev/null +++ b/unique_tensor_references_test.cc @@ -0,0 +1,139 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/unique_tensor_references.h" + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(UniquifyTensors, TestUniqueVector) { + UniqueTensorReferences refs; + Tensor a(DT_FLOAT, TensorShape({2, 2})); + Tensor b(DT_FLOAT, TensorShape({2, 2})); + + EXPECT_FALSE(a.SharesBufferWith(b)); + + refs.Add(a); + refs.Add(b); + TensorReferenceVector tensors; + refs.FreezeAndReturnReferences(&tensors); + EXPECT_EQ(2, tensors.size()); + if (tensors[0].SharesBufferWith(a)) { + EXPECT_TRUE(tensors[1].SharesBufferWith(b)); + } else { + EXPECT_TRUE(tensors[1].SharesBufferWith(a)); + EXPECT_TRUE(tensors[0].SharesBufferWith(b)); + } + for (auto& t : tensors) { + t.Unref(); + } +} + +TEST(UniquifyTensors, TestNonUniqueVector) { + UniqueTensorReferences refs; + Tensor a(DT_FLOAT, TensorShape({2, 2})); + Tensor b(a); + + EXPECT_TRUE(a.SharesBufferWith(b)); + + refs.Add(a); + refs.Add(b); + TensorReferenceVector tensors; + refs.FreezeAndReturnReferences(&tensors); + EXPECT_EQ(1, tensors.size()); + EXPECT_TRUE(tensors[0].SharesBufferWith(a)); + EXPECT_TRUE(tensors[0].SharesBufferWith(b)); + for (auto& t : tensors) { + t.Unref(); + } +} + +TEST(UniquifyTensors, TestNoLeakVector) { + UniqueTensorReferences refs; + Tensor a(DT_FLOAT, TensorShape({2, 2})); + Tensor b(DT_FLOAT, TensorShape({2, 2})); + + EXPECT_FALSE(a.SharesBufferWith(b)); + + refs.Add(a); + refs.Add(b); +} + +TEST(UniquifyTensors, TestUniqueSet) { + UniqueTensorReferences refs; + Tensor a(DT_FLOAT, TensorShape({2, 2})); + Tensor b(DT_FLOAT, TensorShape({2, 2})); + Tensor c(DT_FLOAT, TensorShape({2, 2})); + Tensor d(DT_FLOAT, TensorShape({2, 2})); + Tensor e(DT_FLOAT, TensorShape({2, 2})); + + EXPECT_FALSE(a.SharesBufferWith(b)); + + refs.Add(a); + refs.Add(b); + refs.Add(c); + refs.Add(d); + refs.Add(e); + TensorReferenceVector tensors; + refs.FreezeAndReturnReferences(&tensors); + EXPECT_EQ(5, tensors.size()); + for (auto& t : tensors) { + t.Unref(); + } +} + +TEST(UniquifyTensors, TestNonUniqueSet) { + UniqueTensorReferences refs; + Tensor a(DT_FLOAT, TensorShape({2, 2})); + Tensor b(DT_FLOAT, TensorShape({2, 2})); + Tensor c(DT_FLOAT, TensorShape({2, 2})); + Tensor d(DT_FLOAT, TensorShape({2, 2})); + Tensor e(DT_FLOAT, TensorShape({2, 2})); + Tensor f(c); + + EXPECT_TRUE(f.SharesBufferWith(c)); + + refs.Add(a); + refs.Add(b); + refs.Add(c); + refs.Add(d); + refs.Add(e); + refs.Add(f); + TensorReferenceVector tensors; + refs.FreezeAndReturnReferences(&tensors); + EXPECT_EQ(5, tensors.size()); + for (auto& t : tensors) { + t.Unref(); + } +} + +TEST(UniquifyTensors, TestNoLeakSet) { + UniqueTensorReferences refs; + Tensor a(DT_FLOAT, TensorShape({2, 2})); + Tensor b(DT_FLOAT, TensorShape({2, 2})); + Tensor c(DT_FLOAT, TensorShape({2, 2})); + Tensor d(DT_FLOAT, TensorShape({2, 2})); + Tensor e(DT_FLOAT, TensorShape({2, 2})); + + refs.Add(a); + refs.Add(b); + refs.Add(c); + refs.Add(d); + refs.Add(e); +} + +} // namespace tensorflow diff --git a/variable.proto b/variable.proto new file mode 100644 index 0000000000000000000000000000000000000000..e0df01cc9b77589708049b6ab0f87ac1f1480b4c --- /dev/null +++ b/variable.proto @@ -0,0 +1,39 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VariableProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a Variable. +message VariableDef { + // Name of the variable tensor. + string variable_name = 1; + + // Name of the tensor holding the variable's initial value. + string initial_value_name = 6; + + // Name of the initializer op. + string initializer_name = 2; + + // Name of the snapshot tensor. + string snapshot_name = 3; + + // Support for saving variables as slices of a larger variable. + SaveSliceInfoDef save_slice_info_def = 4; + + // Whether to represent this as a ResourceVariable. + bool is_resource = 5; +} + +message SaveSliceInfoDef { + // Name of the full variable of which this is a slice. + string full_name = 1; + // Shape of the full variable. + repeated int64 full_shape = 2; + // Offset of this variable into the full variable. + repeated int64 var_offset = 3; + // Shape of this variable. + repeated int64 var_shape = 4; +} diff --git a/variant.cc b/variant.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ad2fafee778de05da68799344618d94e5780176 --- /dev/null +++ b/variant.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { + +bool Variant::TryDecode(Variant* out) const { + const VariantTensorDataProto* p = get(); + if (p == nullptr) return false; + VariantTensorData data(*p); + return out->Decode(data); +} + +template <> +void* Variant::get() { + if (is_empty()) { + return nullptr; + } + return value_->RawPtr(); +} + +template <> +const void* Variant::get() const { + if (is_empty()) { + return nullptr; + } + return value_->RawPtr(); +} + +template <> +string TypeNameVariant(const VariantTensorDataProto& value) { + return value.type_name(); +} + +template <> +void EncodeVariant(const VariantTensorDataProto& value, + VariantTensorData* data) { + data->FromProto(value); +} + +template <> +bool DecodeVariant(const VariantTensorData& data, + VariantTensorDataProto* value) { + data.ToProto(value); + return true; +} + +template <> +void EncodeVariant(const VariantTensorDataProto& value, string* buf) { + value.SerializeToString(buf); +} + +template <> +bool DecodeVariant(const string& buf, VariantTensorDataProto* value) { + return value->ParseFromString(buf); +} + +} // end namespace tensorflow diff --git a/variant.h b/variant.h new file mode 100644 index 0000000000000000000000000000000000000000..c02391dae32f561d0a2430b91d861551fd85dc72 --- /dev/null +++ b/variant.h @@ -0,0 +1,354 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_VARIANT_H_ +#define TENSORFLOW_FRAMEWORK_VARIANT_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove +#include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +template +string TypeNameVariant(const T& value); + +template +string DebugStringVariant(const T& value); + +template +void EncodeVariant(const T& value, VariantTensorData* data); + +template +bool DecodeVariant(const VariantTensorData& data, T* value); + +template +void EncodeVariant(const T& value, string* buf); + +template +bool DecodeVariant(const string& buf, T* value); + +// This is an implementation of a type-erased container that can store an +// object of any type. The implementation is very similar to std::any, but has +// restrictions on the types of objects that can be stored, and eschews some of +// the fancier constructors available for std::any. An object of +// tensorflow::Variant is intended to be used as the value that will be stored +// in a tensorflow::Tensor object when its type is DT_VARIANT. +// +// tensorflow::Variant can store an object of a class that satisfies the +// following constraints: +// +// * The class is CopyConstructible. +// * The class has a default constructor. +// * It's either a protocol buffer, a tensorflow::Tensor, or defines the +// following functions: +// +// string TypeName() const; +// void Encode(VariantTensorData* data) const; +// void Decode(const VariantTensorData& data); +// +// Simple POD types can elide the Encode/Decode functions, they are provided by +// helper methods. +// Here are some typical usage patterns: +// +// Variant x = 10; +// EXPECT_EQ(*x.get(), 10); +// +// Tensor t(DT_FLOAT, TensorShape({})); +// t.flat()(0) = 42.0f; +// Variant x = t; +// EXPECT_EQ(x.get()->flat()(0), 42.0f); +// +// Accessing the stored object: +// +// The get function is the main mechanism to access the object +// stored in the container. It is type-safe, that is, calling +// get when the stored object's type is not T, returns a +// nullptr. A raw pointer to the stored object can be obtained by calling +// get(). +// +// Serializing/deserializing Variant object: +// +// The Variant class delegates serializing and deserializing operations to the +// contained object. Helper functions to do these operations are provided for +// POD data types, tensorflow::Tensor, and protocol buffer objects. However, +// other classes have to provide Encode/Decode functions to handle +// serialization. +// +// Objects stored in a Variant object often contain references to other +// tensorflow::Tensors of primitive types (Eg., a list of tensorflow::Tensors). +// To efficiently support those use cases, a structure is imposed on the +// serialization format. Namely, classes should serialize their contents into a +// VariantTensorData object: +// +// struct VariantTensorData { +// string type_name; +// string metadata; +// std::vector tensors; +// }; +// +// Objects with references to other Tensors can simply store those tensors in +// the `tensors` field, and serialize other metadata content in to the +// `metadata` field. +// +// Serialization example: +// +// Foo f = Foo {...}; +// Variant x = f; +// string serialized_f; +// x.Encode(&serialized_f); +// +// Variant y = Foo(); // default constructed Foo. +// y.Decode(&serialized_f); +// EXPECT_EQ(*x.get(), *y.get()); +// +// +// A Variant storing serialized Variant data (a value of type +// VariantTensorDataProto) has different behavior from a standard Variant. +// Namely, its TypeName matches the TypeName of the original Variant; +// and its non-const get method performs lazy deserialization. +// +// Decode and copy example: +// +// Foo f = Foo {...}; +// Variant x = f; +// +// VariantTensorData serialized_data_f; +// VariantTensorDataProto serialized_proto_f; +// x.Encode(&serialized_data_f); +// serialized_data_f.ToProto(&serialized_proto_f); +// +// Variant y_type_unknown = serialized_proto_f; // Store serialized Variant. +// +// EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo. +// EXPECT_EQ(MakeTypeIndex(), +// y_type_unknown.TypeId()); +// // Decode and get y_type_unknown; compare to value in x. +// Foo f_decoded; +// EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded)); +// EXPECT_EQ(f_decoded, f); +// +class Variant { + public: + constexpr Variant() noexcept = default; + + Variant(const Variant& other) + : value_(other.is_empty() ? std::unique_ptr() + : other.value_->Clone()) {} + + Variant(Variant&& other) noexcept = default; + + // Make sure that the type is CopyConstructible and not a tensorflow::Variant + // object itself. We want the copy constructor to be chosen for the + // tensorflow::Variant case. + template ::type, + typename std::enable_if::value && + std::is_copy_constructible::value, + void>::type* = nullptr> + Variant(T&& value) // NOLINT + : value_(new Value(in_place, std::forward(value))) {} + + Variant& operator=(const Variant& rhs) { + Variant(rhs).swap(*this); + return *this; + } + + Variant& operator=(Variant&& rhs) noexcept { + Variant(std::move(rhs)).swap(*this); + return *this; + } + + bool is_empty() const { return value_ == nullptr; } + + void clear() noexcept { value_.reset(); } + + void swap(Variant& other) noexcept { value_.swap(other.value_); } + + // Note, unlike TypeName(), TypeId() does not return the TypeIndex + // of the original type when a TensorValueDataProto is stored as the + // value. In this case, it returns the TypeIndex of TensorValueDataProto. + TypeIndex TypeId() const { + const TypeIndex VoidTypeIndex = MakeTypeIndex(); + if (is_empty()) { + return VoidTypeIndex; + } + return value_->TypeId(); + } + + string DebugString() const { + return strings::StrCat("VariantDebugString(), ">"); + } + + // Returns a pointer to the stored value if it is type T, or nullptr + // otherwise. + template + T* get() { + const TypeIndex TTypeIndex = MakeTypeIndex(); + if (is_empty() || (TTypeIndex != TypeId())) return nullptr; + return std::addressof(static_cast*>(value_.get())->value); + } + + // Returns a pointer to the stored value if it is type T, or nullptr + // otherwise. + template + const T* get() const { + const TypeIndex TTypeIndex = MakeTypeIndex(); + if (is_empty() || (TTypeIndex != TypeId())) return nullptr; + return std::addressof( + static_cast*>(value_.get())->value); + } + + // Returns TypeNameVariant(value). + // + // In the special case that a serialized Variant is stored (value + // is a VariantTensorDataProto), returns value.TypeName(), the + // TypeName field stored in the VariantTensorDataProto buffer. + string TypeName() const { + if (is_empty()) { + return ""; + } + return value_->TypeName(); + } + + // Serialize the contents of the stored object into `data`. + void Encode(VariantTensorData* data) const { + if (!is_empty()) { + value_->Encode(data); + } + } + + // Deserialize `data` and update the stored object. + bool Decode(const VariantTensorData& data) { + if (!is_empty()) { + return value_->Decode(data); + } + return true; + } + + // Helper methods to directly serialize/deserialize from strings. + void Encode(string* buf) const { + if (!is_empty()) { + value_->Encode(buf); + } + } + bool Decode(const string& buf) { + if (!is_empty()) { + return value_->Decode(buf); + } + return true; + } + + template + bool MaybeDecodeAndCopy(T* out) const { + const T* ret = get(); + if (ret != nullptr) { + *out = std::move(*ret); + return true; + }; + Variant decoded = T(); + if (!TryDecode(&decoded)) return false; + T* decoded_ret = decoded.get(); + CHECK_NOTNULL(decoded_ret); + *out = std::move(*decoded_ret); + return true; + } + + private: + bool TryDecode(Variant* out) const; + + private: + struct in_place_t {}; + static constexpr in_place_t in_place{}; + + struct ValueInterface { + virtual ~ValueInterface() = default; + virtual TypeIndex TypeId() const = 0; + virtual void* RawPtr() = 0; + virtual const void* RawPtr() const = 0; + virtual std::unique_ptr Clone() const = 0; + virtual string TypeName() const = 0; + virtual string DebugString() const = 0; + virtual void Encode(VariantTensorData* data) const = 0; + virtual bool Decode(const VariantTensorData& data) = 0; + virtual void Encode(string* buf) const = 0; + virtual bool Decode(const string& data) = 0; + }; + + template + struct Value : ValueInterface { + template + explicit Value(in_place_t /*tag*/, Args&&... args) + : value(std::forward(args)...) {} + + TypeIndex TypeId() const override { + const TypeIndex value_type_index = + MakeTypeIndex::type>(); + return value_type_index; + } + + void* RawPtr() override { return &value; } + + const void* RawPtr() const override { return &value; } + + std::unique_ptr Clone() const override { + return std::unique_ptr(new Value(in_place, value)); + } + + string TypeName() const override { return TypeNameVariant(value); } + + string DebugString() const override { return DebugStringVariant(value); } + + void Encode(VariantTensorData* data) const override { + EncodeVariant(value, data); + } + + bool Decode(const VariantTensorData& data) override { + return DecodeVariant(data, &value); + } + + void Encode(string* buf) const override { EncodeVariant(value, buf); } + + bool Decode(const string& buf) override { + return DecodeVariant(buf, &value); + } + + T value; + }; + + // value_ can point to any type T as wrapped by a ValueInterface. + // The only real requirement is that T is default-constructible. + std::unique_ptr value_; +}; + +template <> +void* Variant::get(); + +template <> +const void* Variant::get() const; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_VARIANT_H_ diff --git a/variant_encode_decode.h b/variant_encode_decode.h new file mode 100644 index 0000000000000000000000000000000000000000..5a84f9d94385a7048a0f4adfe78e1805b367f02d --- /dev/null +++ b/variant_encode_decode.h @@ -0,0 +1,264 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ +#define TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/abi.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// Type used for tag-dispatch of the Encode/Decode Variant implementations. This +// template can determine whether the first type parameter `T` is one of the +// following: +// +// * A POD type (TypeResolver) +// * A tensorflow::Tensor (TypeResolver) +// * A protocol buffer (TypeResolver) +// * None of the above (TypeResolver) +// +template ::type>::value, + bool = std::is_same::type, + ::tensorflow::Tensor>::value, + bool = std::is_base_of::type>::value> +struct TypeResolver {}; + +// Specialization for POD type +template +void EncodeVariantImpl(const T& value, TypeResolver, + VariantTensorData* data) { + data->set_metadata(value); +} + +// Specialization for tensorflow::Tensor +template +void EncodeVariantImpl(const T& value, + TypeResolver, + VariantTensorData* data) { + data->tensors_.clear(); + data->tensors_.push_back(value); +} + +// Specialization for protobuf +template +void EncodeVariantImpl(const T& value, + TypeResolver, + VariantTensorData* data) { + value.SerializeToString(&data->metadata_); +} + +// Specialization for other types +template +void EncodeVariantImpl(const T& value, + TypeResolver, + VariantTensorData* data) { + value.Encode(data); +} + +// Specialization for POD type +template +bool DecodeVariantImpl(const VariantTensorData& data, + TypeResolver, + T* value) { + return data.get_metadata(value); +} + +// Specialization for tensorflow::Tensor +template +bool DecodeVariantImpl(const VariantTensorData& data, + TypeResolver, + T* value) { + *value = data.tensors(0); + return true; +} + +// Specialization for protobuf +template +bool DecodeVariantImpl(const VariantTensorData& data, + TypeResolver, + T* value) { + string metadata; + data.get_metadata(&metadata); + return value->ParseFromString(std::move(metadata)); +} + +// Specialization for other types +template +bool DecodeVariantImpl(const VariantTensorData& data, + TypeResolver, + T* value) { + return value->Decode(data); +} + +template +struct has_type_name : std::false_type {}; + +template +struct has_type_name< + C, typename std::enable_if().TypeName()), string>::value>::type> + : std::true_type {}; + +template ::type>::value, + bool = std::is_same::type, + ::tensorflow::Tensor>::value, + bool = std::is_base_of::type>::value> +struct TypeNameResolver {}; + +template +string TypeNameVariantImpl(const T& value, + TypeNameResolver) { + return value.TypeName(); +} + +template +string TypeNameVariantImpl( + const T& value, + TypeNameResolver) { + return "tensorflow::Tensor"; +} + +template +string TypeNameVariantImpl( + const T& value, TypeNameResolver) { + return value.GetTypeName(); +} + +template +string TypeNameVariantImpl( + const T& value, + TypeNameResolver) { + return port::MaybeAbiDemangle(MakeTypeIndex().name()); +} + +template +string TypeNameVariant(const T& value) { + return TypeNameVariantImpl(value, TypeNameResolver()); +} + +template +struct has_debug_string : std::false_type {}; + +template +struct has_debug_string< + C, typename std::enable_if().DebugString()), string>::value>::type> + : std::true_type {}; + +template +struct can_strcat : std::false_type {}; + +template +struct can_strcat< + C, typename std::enable_if())), string>::value>::type> + : std::true_type {}; + +template ::type>::value, + bool = can_strcat::type>::value> +struct DebugStringResolver {}; + +// TODO(ebrevdo): Expand DebugStringResolver to return TypeString if +// there is no StrCat() constructor. +template +string DebugStringVariantImpl( + const T& value, DebugStringResolver) { + return value.DebugString(); +} + +template +string DebugStringVariantImpl( + const T& value, DebugStringResolver) { + return strings::StrCat(value); +} + +template +string DebugStringVariantImpl( + const T& value, DebugStringResolver) { + return "?"; +} + +template +string DebugStringVariant(const T& value) { + return DebugStringVariantImpl(value, DebugStringResolver()); +} + +template +void EncodeVariant(const T& value, VariantTensorData* data) { + EncodeVariantImpl(value, TypeResolver(), data); + data->set_type_name(TypeNameVariant(value)); +} + +template +bool DecodeVariant(const VariantTensorData& data, T* value) { + return DecodeVariantImpl(data, TypeResolver(), value); +} + +template +void EncodeVariant(const T& value, string* buf) { + VariantTensorData data; + EncodeVariantImpl(value, TypeResolver(), &data); + data.set_type_name(TypeNameVariant(value)); + DCHECK(buf != nullptr); + data.SerializeToString(buf); +} + +template +bool DecodeVariant(const string& buf, T* value) { + VariantTensorData data; + if (!data.ParseFromString(buf)) return false; + if (!DecodeVariantImpl(data, TypeResolver(), value)) return false; + return true; +} + +// Specializations for VariantTensorDataProto +template <> +string TypeNameVariant(const VariantTensorDataProto& value); +template <> +void EncodeVariant(const VariantTensorDataProto& value, + VariantTensorData* data); +template <> +bool DecodeVariant(const VariantTensorData& data, + VariantTensorDataProto* value); +template <> +void EncodeVariant(const VariantTensorDataProto& value, string* buf); +template <> +bool DecodeVariant(const string& buf, VariantTensorDataProto* value); + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ diff --git a/variant_op_copy_test.cc b/variant_op_copy_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..85e014f80434d2a2de2851d2cb361f4b0a0c9433 --- /dev/null +++ b/variant_op_copy_test.cc @@ -0,0 +1,378 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/port.h" + +namespace tensorflow { + +namespace { + +static int* GetCopyCPUToGPUCounter() { + static int* counter = new int(0); + return counter; +} + +static int* GetCopyGPUToCPUCounter() { + static int* counter = new int(0); + return counter; +} + +static int* GetCopyGPUToGPUCounter() { + static int* counter = new int(0); + return counter; +} + +struct StoredTensorValue { + Tensor stored; + string TypeName() const { return "StoredTensorValue"; } + void Encode(VariantTensorData* data) const { data->tensors_ = {stored}; } + bool Decode(const VariantTensorData& data) { + CHECK_EQ(1, data.tensors_.size()); + stored = data.tensors_[0]; + return true; + } + static Status CopyCPUToGPU( + const StoredTensorValue& from, StoredTensorValue* to, + const std::function& copy) { + ++*GetCopyCPUToGPUCounter(); + return copy(from.stored, &(to->stored)); + } + static Status CopyGPUToCPU( + const StoredTensorValue& from, StoredTensorValue* to, + const std::function& copy) { + ++*GetCopyGPUToCPUCounter(); + return copy(from.stored, &(to->stored)); + } + static Status CopyGPUToGPU( + const StoredTensorValue& from, StoredTensorValue* to, + const std::function& copy) { + ++*GetCopyGPUToGPUCounter(); + return copy(from.stored, &(to->stored)); + } +}; + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(StoredTensorValue, "StoredTensorValue"); + +INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( + StoredTensorValue, VariantDeviceCopyDirection::HOST_TO_DEVICE, + "StoredTensorValue", StoredTensorValue::CopyCPUToGPU); + +INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( + StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_HOST, + "StoredTensorValue", StoredTensorValue::CopyGPUToCPU); + +INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( + StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_DEVICE, + "StoredTensorValue", StoredTensorValue::CopyGPUToGPU); + +REGISTER_OP("CreateTestVariant") + .Input("input: T") + .Attr("T: type") + .Output("output: variant") + .SetShapeFn(shape_inference::UnknownShape); + +class CreateTestVariantOp : public OpKernel { + public: + explicit CreateTestVariantOp(OpKernelConstruction* c) : OpKernel(c) {} + void Compute(OpKernelContext* c) override { + // Take the scalar tensor fed as input, and emit a Tensor + // containing 10 Variants (StoredTensorValues), both containing + // the input tensor. + const Tensor& stored_t = c->input(0); + Tensor* out; + OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({10}), &out)); + StoredTensorValue store{stored_t}; + auto t = out->flat(); + for (int i = 0; i < 10; ++i) { + t(i) = store; + } + CHECK_EQ("StoredTensorValue", t(0).TypeName()); + } +}; + +REGISTER_KERNEL_BUILDER(Name("CreateTestVariant").Device(DEVICE_CPU), + CreateTestVariantOp); + +class CreateTestVariant { + public: + explicit CreateTestVariant(const ::tensorflow::Scope& scope, + const Input& value) { + if (!scope.ok()) return; + auto _value = ops::AsNodeOut(scope, value); + if (!scope.ok()) return; + ::tensorflow::Node* ret; + const auto unique_name = scope.GetUniqueNameForOp("CreateTestVariant"); + auto builder = ::tensorflow::NodeBuilder(unique_name, "CreateTestVariant") + .Input(_value); + scope.UpdateBuilder(&builder); + scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); + if (!scope.ok()) return; + scope.UpdateStatus(scope.DoShapeInference(ret)); + if (!scope.ok()) return; + this->output_ = Output(ret, 0); + } + + // Intentionally not marked as explicit. + // NOLINTNEXTLINE google-explicit-constructor + operator ::tensorflow::Output() const { return output_; } + // Intentionally not marked as explicit. + // NOLINTNEXTLINE google-explicit-constructor + operator ::tensorflow::Input() const { return output_; } + + ::tensorflow::Node* node() const { return output_.node(); } + + ::tensorflow::Output output_; +}; + +} // end namespace + +TEST(VariantOpCopyTest, CreateConstOnCPU) { + Scope root = Scope::NewRootScope().WithDevice("/cpu:0"); + + // Create the input StoredTensorValue and serialize it. + StoredTensorValue from; + from.stored = Tensor(DT_INT64, TensorShape({})); + from.stored.scalar()() = 0xdeadbeef; + VariantTensorData data; + data.set_type_name(from.TypeName()); + from.Encode(&data); + + TensorProto variant_proto; + variant_proto.set_dtype(DT_VARIANT); + TensorShape scalar_shape({}); + scalar_shape.AsProto(variant_proto.mutable_tensor_shape()); + data.ToProto(variant_proto.add_variant_val()); + + Output create_const = ops::ConstFromProto(root, variant_proto); + TF_ASSERT_OK(root.status()); + ClientSession session(root); + std::vector outputs; + TF_CHECK_OK(session.Run({create_const}, &outputs)); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(DT_VARIANT, outputs[0].dtype()); + EXPECT_EQ(0, outputs[0].dims()); + const Variant& variant = outputs[0].scalar()(); + EXPECT_EQ("StoredTensorValue", variant.TypeName()); + const StoredTensorValue* to = variant.get(); + EXPECT_EQ(to->stored.dtype(), DT_INT64); + EXPECT_EQ(0xdeadbeef, to->stored.scalar()()); +} + +TEST(VariantOpCopyTest, CreateConstOnGPU) { + if (!IsGoogleCudaEnabled()) return; + + Scope root = Scope::NewRootScope().WithDevice("/gpu:0"); + + // Create the input StoredTensorValue and serialize it. + StoredTensorValue from; + from.stored = Tensor(DT_INT64, TensorShape({})); + from.stored.scalar()() = 0xdeadbeef; + VariantTensorData data; + data.set_type_name(from.TypeName()); + from.Encode(&data); + + TensorProto variant_proto; + variant_proto.set_dtype(DT_VARIANT); + TensorShape scalar_shape({}); + scalar_shape.AsProto(variant_proto.mutable_tensor_shape()); + data.ToProto(variant_proto.add_variant_val()); + + Output create_const = ops::ConstFromProto(root, variant_proto); + TF_ASSERT_OK(root.status()); + ClientSession session(root); + std::vector outputs; + + int copy_to_gpu_before = *GetCopyCPUToGPUCounter(); + int copy_to_cpu_before = *GetCopyGPUToCPUCounter(); + TF_CHECK_OK(session.Run({create_const}, &outputs)); + int copy_to_cpu_after = *GetCopyGPUToCPUCounter(); + int copy_to_gpu_after = *GetCopyCPUToGPUCounter(); + + EXPECT_GT(copy_to_cpu_after - copy_to_cpu_before, 0); + EXPECT_GT(copy_to_gpu_after - copy_to_gpu_before, 0); + + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(DT_VARIANT, outputs[0].dtype()); + EXPECT_EQ(0, outputs[0].dims()); + const Variant& variant = outputs[0].scalar()(); + EXPECT_EQ("StoredTensorValue", variant.TypeName()); + const StoredTensorValue* to = variant.get(); + EXPECT_EQ(to->stored.dtype(), DT_INT64); + EXPECT_EQ(0xdeadbeef, to->stored.scalar()()); +} + +TEST(VariantOpCopyTest, CreateConstOnGPUFailsGracefully) { + if (!IsGoogleCudaEnabled()) return; + + Scope root = Scope::NewRootScope().WithDevice("/gpu:0"); + + // Create the input StoredTensorValue and serialize it. + StoredTensorValue from; + from.stored = Tensor(DT_STRING, TensorShape({})); + from.stored.scalar()() = "hi"; + VariantTensorData data; + data.set_type_name(from.TypeName()); + from.Encode(&data); + + TensorProto variant_proto; + variant_proto.set_dtype(DT_VARIANT); + TensorShape scalar_shape({}); + scalar_shape.AsProto(variant_proto.mutable_tensor_shape()); + data.ToProto(variant_proto.add_variant_val()); + + Output create_const = ops::ConstFromProto(root, variant_proto); + TF_ASSERT_OK(root.status()); + ClientSession session(root); + std::vector outputs; + Status s = session.Run({create_const}, &outputs); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("GPU copy from non-DMA string tensor")) + << s.ToString(); +} + +TEST(VariantOpCopyTest, CreateCopyCPUToCPU) { + Scope root = Scope::NewRootScope().WithDevice("/cpu:0"); + Tensor t_42(DT_INT32, TensorShape({})); + t_42.flat()(0) = 42; + Output create_op = CreateTestVariant(root, t_42); + Output identity = ops::Identity(root, create_op); + + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_CHECK_OK(session.Run({create_op, identity}, &outputs)); + EXPECT_EQ(2, outputs.size()); + EXPECT_EQ(10, outputs[1].dim_size(0)); + auto output = outputs[1].flat(); + for (int i = 0; i < 10; ++i) { + const Variant& r1 = output(i); + EXPECT_EQ("StoredTensorValue", r1.TypeName()); + const StoredTensorValue* v1 = r1.get(); + EXPECT_NE(v1, nullptr); + EXPECT_EQ(42, v1->stored.scalar()()); + } +} + +TEST(VariantOpCopyTest, CreateCopyCPUToCPUString) { + Scope root = Scope::NewRootScope().WithDevice("/cpu:0"); + Tensor t_str(DT_STRING, TensorShape({})); + t_str.scalar()() = "hi"; + Output create_op = CreateTestVariant(root, t_str); + Output identity = ops::Identity(root, create_op); + + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_CHECK_OK(session.Run({create_op, identity}, &outputs)); + EXPECT_EQ(2, outputs.size()); + EXPECT_EQ(10, outputs[1].dim_size(0)); + auto output = outputs[1].flat(); + for (int i = 0; i < 10; ++i) { + const Variant& r1 = output(i); + EXPECT_EQ("StoredTensorValue", r1.TypeName()); + const StoredTensorValue* v1 = r1.get(); + EXPECT_NE(v1, nullptr); + EXPECT_EQ("hi", v1->stored.scalar()()); + } +} + +TEST(VariantOpCopyTest, CreateCopyCPUToGPU) { + if (!IsGoogleCudaEnabled()) return; + + Scope root = Scope::NewRootScope().WithDevice("/cpu:0"); + Scope with_gpu = root.WithDevice("/gpu:0"); + Tensor t_42(DT_INT32, TensorShape({})); + t_42.scalar()() = 42; + Output create_op = CreateTestVariant(root, t_42); + Output identity = ops::Identity(with_gpu, create_op); + + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + int copy_to_gpu_before = *GetCopyCPUToGPUCounter(); + int copy_to_cpu_before = *GetCopyGPUToCPUCounter(); + // Force the identity to run on GPU, and then the data to be copied + // back to CPU for the final output. + TF_CHECK_OK(session.Run({create_op, identity}, &outputs)); + int copy_to_cpu_after = *GetCopyGPUToCPUCounter(); + int copy_to_gpu_after = *GetCopyCPUToGPUCounter(); + + EXPECT_GT(copy_to_cpu_after - copy_to_cpu_before, 0); + EXPECT_GT(copy_to_gpu_after - copy_to_gpu_before, 0); + + EXPECT_EQ(2, outputs.size()); + EXPECT_EQ(10, outputs[1].dim_size(0)); + auto output = outputs[1].flat(); + for (int i = 0; i < 10; ++i) { + const Variant& r1 = output(i); + EXPECT_EQ("StoredTensorValue", r1.TypeName()); + const StoredTensorValue* v1 = r1.get(); + EXPECT_NE(v1, nullptr); + EXPECT_EQ(42, v1->stored.scalar()()); + } +} + +TEST(VariantOpCopyTest, CreateCopyCPUToGPUStringFailsSafely) { + if (!IsGoogleCudaEnabled()) return; + + Scope root = Scope::NewRootScope().WithDevice("/cpu:0"); + Scope with_gpu = root.WithDevice("/gpu:0"); + Tensor t_str(DT_STRING, TensorShape({})); + t_str.scalar()() = "hi"; + Output create_op = CreateTestVariant(root, t_str); + Output identity = ops::Identity(with_gpu, create_op); + + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + Status err = session.Run({create_op, identity}, &outputs); + EXPECT_EQ(err.code(), errors::Code::INVALID_ARGUMENT); + EXPECT_TRUE(StringPiece(err.error_message()) + .contains("During Variant Host->Device Copy: non-DMA-copy " + "attempted of tensor type: string")) + << err.error_message(); +} + +// TODO(ebrevdo): Identify a way to create two virtual GPUs within a +// single session, so that we can test the Device <-> Device copy +// branch. + +} // end namespace tensorflow diff --git a/variant_op_registry.cc b/variant_op_registry.cc new file mode 100644 index 0000000000000000000000000000000000000000..395329da3bee01cf73c69d52b150b88f34d1b1ff --- /dev/null +++ b/variant_op_registry.cc @@ -0,0 +1,275 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +std::unordered_set* UnaryVariantOpRegistry::PersistentStringStorage() { + static std::unordered_set* string_storage = + new std::unordered_set(); + return string_storage; +} + +// static +UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() { + static UnaryVariantOpRegistry* global_unary_variant_op_registry = + new UnaryVariantOpRegistry; + return global_unary_variant_op_registry; +} + +UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn( + StringPiece type_name) { + auto found = shape_fns.find(type_name); + if (found == shape_fns.end()) return nullptr; + return &found->second; +} + +void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name, + const VariantShapeFn& shape_fn) { + CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape"; + VariantShapeFn* existing = GetShapeFn(type_name); + CHECK_EQ(existing, nullptr) + << "Unary VariantShapeFn for type_name: " << type_name + << " already registered"; + shape_fns.insert(std::pair( + GetPersistentStringPiece(type_name), shape_fn)); +} + +Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { + CHECK_EQ(variant_tensor.dtype(), DT_VARIANT); + CHECK_EQ(variant_tensor.dims(), 0); + const Variant& v = variant_tensor.scalar()(); + UnaryVariantOpRegistry::VariantShapeFn* shape_fn = + UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName()); + if (shape_fn == nullptr) { + return errors::Internal( + "No unary variant shape function found for Variant type_name: ", + v.TypeName()); + } + return (*shape_fn)(v, shape); +} + +// Add some basic registrations for use by others, e.g., for testing. +namespace { +template +Status ScalarShape(const T&, TensorShape* shape) { + *shape = TensorShape({}); + return Status::OK(); +} +} // namespace + +#define REGISTER_VARIANT_SHAPE_TYPE(T) \ + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape); + +// No encode/shape registered for std::complex<> and Eigen::half +// objects yet. +REGISTER_VARIANT_SHAPE_TYPE(int); +REGISTER_VARIANT_SHAPE_TYPE(float); +REGISTER_VARIANT_SHAPE_TYPE(bool); +REGISTER_VARIANT_SHAPE_TYPE(double); + +#undef REGISTER_VARIANT_SHAPE_TYPE + +UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn( + StringPiece type_name) { + auto found = decode_fns.find(type_name); + if (found == decode_fns.end()) return nullptr; + return &found->second; +} + +void UnaryVariantOpRegistry::RegisterDecodeFn( + const string& type_name, const VariantDecodeFn& decode_fn) { + CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDecode"; + VariantDecodeFn* existing = GetDecodeFn(type_name); + CHECK_EQ(existing, nullptr) + << "Unary VariantDecodeFn for type_name: " << type_name + << " already registered"; + decode_fns.insert(std::pair( + GetPersistentStringPiece(type_name), decode_fn)); +} + +bool DecodeUnaryVariant(Variant* variant) { + UnaryVariantOpRegistry::VariantDecodeFn* decode_fn = + UnaryVariantOpRegistry::Global()->GetDecodeFn(variant->TypeName()); + if (decode_fn == nullptr) { + return false; + } + const string type_name = variant->TypeName(); + bool decoded = (*decode_fn)(variant); + if (!decoded) return false; + if (variant->TypeName() != type_name) { + LOG(ERROR) << "DecodeUnaryVariant: Variant type_name before decoding was: " + << type_name + << " but after decoding was: " << variant->TypeName() + << ". Treating this as a failure."; + return false; + } + return true; +} + +// Add some basic registrations for use by others, e.g., for testing. + +#define REGISTER_VARIANT_DECODE_TYPE(T) \ + REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T)); + +// No encode/decode registered for std::complex<> and Eigen::half +// objects yet. +REGISTER_VARIANT_DECODE_TYPE(int); +REGISTER_VARIANT_DECODE_TYPE(float); +REGISTER_VARIANT_DECODE_TYPE(bool); +REGISTER_VARIANT_DECODE_TYPE(double); + +#undef REGISTER_VARIANT_DECODE_TYPE + +UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* +UnaryVariantOpRegistry::GetDeviceCopyFn( + const VariantDeviceCopyDirection direction, StringPiece type_name) { + auto found = device_copy_fns.find(std::make_pair(direction, type_name)); + if (found == device_copy_fns.end()) return nullptr; + return &found->second; +} + +void UnaryVariantOpRegistry::RegisterDeviceCopyFn( + const VariantDeviceCopyDirection direction, const string& type_name, + const AsyncVariantDeviceCopyFn& device_copy_fn) { + CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy"; + AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name); + CHECK_EQ(existing, nullptr) + << "UnaryVariantDeviceCopy for direction: " << direction + << " and type_name: " << type_name << " already registered"; + device_copy_fns.insert( + std::pair, + AsyncVariantDeviceCopyFn>( + std::make_pair(direction, GetPersistentStringPiece(type_name)), + device_copy_fn)); +} + +Status VariantDeviceCopy( + const VariantDeviceCopyDirection direction, const Variant& from, + Variant* to, + const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) { + UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn = + UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction, + from.TypeName()); + if (device_copy_fn == nullptr) { + return errors::Internal( + "No unary variant device copy function found for direction: ", + direction, " and Variant type_name: ", from.TypeName()); + } + return (*device_copy_fn)(from, to, copy_fn); +} + +// Special casing UnaryOpFn per op and per device. +UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn( + VariantUnaryOp op, StringPiece device, StringPiece type_name) { + auto found = unary_op_fns.find(std::make_tuple(op, device, type_name)); + if (found == unary_op_fns.end()) return nullptr; + return &found->second; +} + +void UnaryVariantOpRegistry::RegisterUnaryOpFn( + VariantUnaryOp op, const string& device, const string& type_name, + const VariantUnaryOpFn& unary_op_fn) { + CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp"; + VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name); + CHECK_EQ(existing, nullptr) + << "Unary VariantUnaryOpFn for type_name: " << type_name + << " already registered for device type: " << device; + unary_op_fns.insert( + std::pair, + VariantUnaryOpFn>( + std::make_tuple(op, GetPersistentStringPiece(device), + GetPersistentStringPiece(type_name)), + unary_op_fn)); +} + +namespace { +template +Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t, + T* t_out) { + *t_out = T(0); + return Status::OK(); +} +} // namespace + +#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \ + DEVICE_CPU, T, TF_STR(T), \ + ZerosLikeVariantPrimitiveType); + +// No zeros_like registered for std::complex<> or Eigen::half objects yet. +REGISTER_VARIANT_ZEROS_LIKE_TYPE(int); +REGISTER_VARIANT_ZEROS_LIKE_TYPE(float); +REGISTER_VARIANT_ZEROS_LIKE_TYPE(double); +REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool); + +#undef REGISTER_VARIANT_ZEROS_LIKE_TYPE + +// Special casing BinaryOpFn per op and per device. +UnaryVariantOpRegistry::VariantBinaryOpFn* +UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device, + StringPiece type_name) { + auto found = binary_op_fns.find(std::make_tuple(op, device, type_name)); + if (found == binary_op_fns.end()) return nullptr; + return &found->second; +} + +void UnaryVariantOpRegistry::RegisterBinaryOpFn( + VariantBinaryOp op, const string& device, const string& type_name, + const VariantBinaryOpFn& add_fn) { + CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp"; + VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name); + CHECK_EQ(existing, nullptr) + << "Unary VariantBinaryOpFn for type_name: " << type_name + << " already registered for device type: " << device; + binary_op_fns.insert( + std::pair, + VariantBinaryOpFn>( + std::make_tuple(op, GetPersistentStringPiece(device), + GetPersistentStringPiece(type_name)), + add_fn)); +} + +namespace { +template +Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b, + T* out) { + *out = a + b; + return Status::OK(); +} +} // namespace + +#define REGISTER_VARIANT_ADD_TYPE(T) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \ + T, TF_STR(T), \ + AddVariantPrimitiveType); + +// No add registered for std::complex<> or Eigen::half objects yet. +REGISTER_VARIANT_ADD_TYPE(int); +REGISTER_VARIANT_ADD_TYPE(float); +REGISTER_VARIANT_ADD_TYPE(double); +REGISTER_VARIANT_ADD_TYPE(bool); + +#undef REGISTER_VARIANT_ADD_TYPE + +} // namespace tensorflow diff --git a/variant_op_registry.h b/variant_op_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..13f6908cae1ed1b1964bf827dce0fcb2bee4e6d1 --- /dev/null +++ b/variant_op_registry.h @@ -0,0 +1,556 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_ +#define TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_ + +#include +#include +#include + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { + +class OpKernelContext; +// A global UnaryVariantOpRegistry is used to hold callback functions +// for different variant types. To be used by ShapeOp, RankOp, and +// SizeOp, decoding, etc. + +enum VariantUnaryOp { + INVALID_VARIANT_UNARY_OP = 0, + ZEROS_LIKE_VARIANT_UNARY_OP = 1, + CONJ_VARIANT_UNARY_OP = 2, +}; + +enum VariantBinaryOp { + INVALID_VARIANT_BINARY_OP = 0, + ADD_VARIANT_BINARY_OP = 1, +}; + +enum VariantDeviceCopyDirection { + INVALID_DEVICE_COPY_DIRECTION = 0, + HOST_TO_DEVICE = 1, + DEVICE_TO_HOST = 2, + DEVICE_TO_DEVICE = 3, +}; + +class UnaryVariantOpRegistry { + public: + typedef std::function VariantShapeFn; + typedef std::function VariantDecodeFn; + typedef std::function + VariantUnaryOpFn; + typedef std::function + VariantBinaryOpFn; + + // An AsyncTensorDeviceCopyFn is a function provided to + // the user-provided DeviceCopyFn callback as the third argument ("copier"). + // + // Expected inputs: + // from: A Tensor on the host (if performing cpu->gpu copy), or + // device (if performing gpu->cpu or gpu->gpu copy). + // to: An empty/uninitialized tensor. It will be updated upon + // successful return of the function with the correct dtype and shape. + // However, the copied data will not be available until the compute + // stream has been synchronized. + // + // Returns: + // The status upon memory allocation / initialization of the + // "to" tensor, and enqueue of the copy onto the compute stream. + // Any failure of the copy itself will update the underlying + // stream status and propagate through the runtime independent + // of the caller. + typedef std::function + AsyncTensorDeviceCopyFn; + + // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn' + // expected to be passed to the registration macro + // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION. + typedef std::function + AsyncVariantDeviceCopyFn; + + // Add a shape lookup function to the registry. + void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn); + + // Returns nullptr if no shape function was found for the given TypeName. + VariantShapeFn* GetShapeFn(StringPiece type_name); + + // Add a decode function to the registry. + void RegisterDecodeFn(const string& type_name, + const VariantDecodeFn& decode_fn); + + // Returns nullptr if no decode function was found for the given TypeName. + VariantDecodeFn* GetDecodeFn(StringPiece type_name); + + // Add a copy-to-GPU function to the registry. + void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction, + const string& type_name, + const AsyncVariantDeviceCopyFn& device_copy_fn); + + // Returns nullptr if no copy function was found for the given + // TypeName and direction. + AsyncVariantDeviceCopyFn* GetDeviceCopyFn( + const VariantDeviceCopyDirection direction, StringPiece type_name); + + // Add a unary op function to the registry. + void RegisterUnaryOpFn(VariantUnaryOp op, const string& device, + const string& type_name, + const VariantUnaryOpFn& unary_op_fn); + + // Returns nullptr if no unary op function was found for the given + // op, device, and TypeName. + VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device, + StringPiece type_name); + + // Add a binary op function to the registry. + void RegisterBinaryOpFn(VariantBinaryOp op, const string& device, + const string& type_name, + const VariantBinaryOpFn& add_fn); + + // Returns nullptr if no binary op function was found for the given + // op, device and TypeName. + VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device, + StringPiece type_name); + + // Get a pointer to a global UnaryVariantOpRegistry object + static UnaryVariantOpRegistry* Global(); + + // Get a pointer to a global persistent string storage object. + // ISO/IEC C++ working draft N4296 clarifies that insertion into an + // std::unordered_set does not invalidate memory locations of + // *values* inside the set (though it may invalidate existing + // iterators). In other words, one may safely point a StringPiece to + // a value in the set without that StringPiece being invalidated by + // future insertions. + static std::unordered_set* PersistentStringStorage(); + + private: + std::unordered_map shape_fns; + std::unordered_map + decode_fns; + + // Map std::pair to function. + struct PairHash { + template + std::size_t operator()(const std::pair& x) const { + // The hash of an enum is just its value as a std::size_t. + std::size_t ret = static_cast(std::get<0>(x)); + ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); + return ret; + } + StringPieceHasher sp_hasher_; + }; + + std::unordered_map, + AsyncVariantDeviceCopyFn, PairHash> + device_copy_fns; + + // Map std::tuple to function. + struct TupleHash { + template + std::size_t operator()( + const std::tuple& x) const { + // The hash of an enum is just its value as a std::size_t. + std::size_t ret = static_cast(std::get<0>(x)); + ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); + ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x))); + return ret; + } + StringPieceHasher sp_hasher_; + }; + std::unordered_map, + VariantUnaryOpFn, TupleHash> + unary_op_fns; + std::unordered_map, + VariantBinaryOpFn, TupleHash> + binary_op_fns; + + // Find or insert a string into a persistent string storage + // container; return the StringPiece pointing to the permanent + // string location. + static StringPiece GetPersistentStringPiece(const string& str) { + const auto string_storage = PersistentStringStorage(); + auto found = string_storage->find(str); + if (found == string_storage->end()) { + auto inserted = string_storage->insert(str); + return StringPiece(*inserted.first); + } else { + return StringPiece(*found); + } + } +}; + +// Gets a TensorShape from a Tensor containing a scalar Variant. +// Returns an Internal error if the Variant does not have a registered shape +// function, or if it's a serialized Variant that cannot be decoded. +// +// REQUIRES: +// variant_tensor.dtype() == DT_VARIANT +// variant_tensor.dims() == 0 +// +Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape); + +// Decodes the Variant whose data_type has a registered decode +// function. Returns an Internal error if the Variant does not have a +// registered decode function, or if the decoding function fails. +// +// REQUIRES: +// variant is not null. +// +bool DecodeUnaryVariant(Variant* variant); + +// Copies a variant between CPU<->GPU, or between GPU<->GPU. +// The variant 'from' must have a registered DeviceCopyFn for the +// given direction. The returned variant 'to' will have +// (some subset of its) tensors stored on destination according to the +// registered DeviceCopyFn function for the given direction. Returns +// an Internal error if the Variant does not have a registered +// DeviceCopyFn function for the given direction, or if initiating the +// copy fails. +// +// REQUIRES: +// 'to' is not null. +// +Status VariantDeviceCopy( + const VariantDeviceCopyDirection direction, const Variant& from, + Variant* to, + const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn); + +// Sets *v_out = unary_op(v). The variant v must have a registered +// UnaryOp function for the given Device. Returns an Internal error +// if v does not have a registered unary_op function for this device, or if +// UnaryOp fails. +// +// REQUIRES: +// v_out is not null. +// +template +Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, + Variant* v_out) { + const string& device = DeviceName::value; + UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn = + UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName()); + if (unary_op_fn == nullptr) { + return errors::Internal( + "No unary variant unary_op function found for unary variant op enum: ", + op, " Variant type_name: ", v.TypeName(), " for device type: ", device); + } + return (*unary_op_fn)(ctx, v, v_out); +} + +// Sets *out = binary_op(a, b). The variants a and b must be the same type +// and have a registered binary_op function for the given Device. Returns an +// Internal error if a and b are not the same type_name or if +// if a does not have a registered op function for this device, or if +// BinaryOp fails. +// +// REQUIRES: +// out is not null. +// +template +Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, + const Variant& a, const Variant& b, Variant* out) { + if (a.TypeName() != b.TypeName()) { + return errors::Internal( + "BianryOpVariants: Variants a and b have different " + "type names: '", + a.TypeName(), "' vs. '", b.TypeName(), "'"); + } + const string& device = DeviceName::value; + UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn = + UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName()); + if (binary_op_fn == nullptr) { + return errors::Internal( + "No unary variant binary_op function found for binary variant op " + "enum: ", + op, " Variant type_name: '", a.TypeName(), + "' for device type: ", device); + } + return (*binary_op_fn)(ctx, a, b, out); +} + +namespace variant_op_registry_fn_registration { + +template +class UnaryVariantShapeRegistration { + public: + typedef std::function LocalVariantShapeFn; + + UnaryVariantShapeRegistration(const string& type_name, + const LocalVariantShapeFn& shape_fn) { + UnaryVariantOpRegistry::Global()->RegisterShapeFn( + type_name, + [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status { + const T* t = v.get(); + if (t == nullptr) { + return errors::Internal( + "VariantShapeFn: Could not access object, type_name: ", + type_name); + } + return shape_fn(*t, s); + }); + } +}; + +template +class UnaryVariantDecodeRegistration { + public: + UnaryVariantDecodeRegistration(const string& type_name) { + // The Variant is passed by pointer because it should be + // mutable: get below may Decode the variant, which + // is a self-mutating behavior. The variant is not modified in + // any other way. + UnaryVariantOpRegistry::Global()->RegisterDecodeFn( + type_name, [type_name](Variant* v) -> bool { + DCHECK_NE(v, nullptr); + VariantTensorDataProto* t = v->get(); + if (t == nullptr) { + return false; + } + Variant decoded = T(); + VariantTensorData data(*t); + if (!decoded.Decode(data)) { + return false; + } + *v = std::move(decoded); + return true; + }); + } +}; + +template +class UnaryVariantDeviceCopyRegistration { + public: + typedef std::function + LocalVariantDeviceCopyFn; + UnaryVariantDeviceCopyRegistration( + const VariantDeviceCopyDirection direction, const string& type_name, + const LocalVariantDeviceCopyFn& device_copy_fn) { + UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn( + direction, type_name, + [type_name, device_copy_fn]( + const Variant& from, Variant* to, + UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn + device_copy_tensor_fn) -> Status { + DCHECK_NE(to, nullptr); + *to = T(); + if (from.get() == nullptr) { + return errors::Internal( + "VariantCopyToGPUFn: Could not access object, type_name: ", + type_name); + } + const T& t = *from.get(); + T* t_out = to->get(); + return device_copy_fn(t, t_out, device_copy_tensor_fn); + }); + } +}; + +template +class UnaryVariantUnaryOpRegistration { + typedef std::function + LocalVariantUnaryOpFn; + + public: + UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device, + const string& type_name, + const LocalVariantUnaryOpFn& unary_op_fn) { + UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn( + op, device, type_name, + [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v, + Variant* v_out) -> Status { + DCHECK_NE(v_out, nullptr); + *v_out = T(); + if (v.get() == nullptr) { + return errors::Internal( + "VariantUnaryOpFn: Could not access object, type_name: ", + type_name); + } + const T& t = *v.get(); + T* t_out = v_out->get(); + return unary_op_fn(ctx, t, t_out); + }); + } +}; + +template +class UnaryVariantBinaryOpRegistration { + typedef std::function + LocalVariantBinaryOpFn; + + public: + UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device, + const string& type_name, + const LocalVariantBinaryOpFn& binary_op_fn) { + UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn( + op, device, type_name, + [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a, + const Variant& b, Variant* out) -> Status { + DCHECK_NE(out, nullptr); + *out = T(); + if (a.get() == nullptr) { + return errors::Internal( + "VariantBinaryOpFn: Could not access object 'a', type_name: ", + type_name); + } + if (b.get() == nullptr) { + return errors::Internal( + "VariantBinaryOpFn: Could not access object 'b', type_name: ", + type_name); + } + const T& t_a = *a.get(); + const T& t_b = *b.get(); + T* t_out = out->get(); + return binary_op_fn(ctx, t_a, t_b, t_out); + }); + } +}; + +}; // namespace variant_op_registry_fn_registration + +// Register a unary shape variant function with the signature: +// Status ShapeFn(const T& t, TensorShape* s); +// to Variants having TypeName type_name. +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function) \ + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \ + shape_function) + +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \ + shape_function) \ + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function) + +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, \ + shape_function) \ + static variant_op_registry_fn_registration::UnaryVariantShapeRegistration \ + register_unary_variant_op_shape_registration_fn_##ctr(type_name, \ + shape_function) + +// Register a unary decode variant function for the given type. +#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \ + REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name) + +#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \ + REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) + +#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) \ + static variant_op_registry_fn_registration::UnaryVariantDecodeRegistration< \ + T> \ + register_unary_variant_op_decoder_fn_##ctr(type_name) + +// ****** NOTE ****** +// FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE. +// ****** NOTE ****** +// +// Register a device copy variant function for the given copy +// direction and type; where direction is the enum +// VariantDeviceCopyDirection, and the device_copy_fn has signature: +// +// Status device_copy_fn( +// const T& t, T* t_out, +// const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier); +// +// And device_copy_fn calls copier 0 or more times. For details on +// the behavior of the copier function, see the comments at the +// declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn. +// +// Note, the device_copy_fn may choose to keep some tensors +// on host, e.g. by assigning to->tensor = from.tensor (assuming +// from.tensor is already on host); or by setting +// to->tensor = Tensor(cpu_allocator(), ...) +// and manually updating its values. +// +// If this is the case, the CopyFns for HOST_TO_DEVICE, +// DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host +// copies in a consistent manner. For example, one must always +// manually copy any "always on host" tensors in all directions instead of e.g. +// - performing a host-to-host copy in one direction, +// - using the provided copier function in the reverse direction. +// Doing the latter will cause program failures. +// +// ****** NOTE ****** +// FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE. +// ****** NOTE ****** +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ + T, direction, type_name, device_copy_fn) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, T, direction, type_name, device_copy_fn) + +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ + ctr, T, direction, type_name, device_copy_fn) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ + ctr, T, direction, type_name, device_copy_fn) + +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ + ctr, T, direction, type_name, device_copy_fn) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantDeviceCopyRegistration \ + register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \ + device_copy_fn) + +// Register a unary unary_op variant function with the signature: +// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out); +// to Variants having TypeName type_name, for device string device, +// for UnaryVariantOp enum op. +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \ + unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, type_name, unary_op_function) + +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + ctr, op, device, T, type_name, unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \ + unary_op_function) + +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_name, unary_op_function) \ + static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \ + T> \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ + unary_op_function) + +// Register a binary_op variant function with the signature: +// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out); +// to Variants having TypeName type_name, for device string device, +// for BinaryVariantOp enum OP. +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \ + binary_op_function) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, type_name, binary_op_function) + +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ + ctr, op, device, T, type_name, binary_op_function) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_name, binary_op_function) + +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_name, binary_op_function) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantBinaryOpRegistration \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ + binary_op_function) + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_ diff --git a/variant_op_registry_test.cc b/variant_op_registry_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..06ca211c762748b1dacd4eb9623ffd2d72762cca --- /dev/null +++ b/variant_op_registry_test.cc @@ -0,0 +1,355 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif + +#include "tensorflow/core/framework/variant_op_registry.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace { + +struct VariantValue { + string TypeName() const { return "TEST VariantValue"; } + static Status ShapeFn(const VariantValue& v, TensorShape* s) { + if (v.early_exit) { + return errors::InvalidArgument("early exit!"); + } + *s = TensorShape({-0xdeadbeef}); + return Status::OK(); + } + static Status CPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v, + VariantValue* v_out) { + if (v.early_exit) { + return errors::InvalidArgument("early exit zeros_like!"); + } + v_out->value = 1; // CPU + return Status::OK(); + } + static Status GPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v, + VariantValue* v_out) { + if (v.early_exit) { + return errors::InvalidArgument("early exit zeros_like!"); + } + v_out->value = 2; // GPU + return Status::OK(); + } + static Status CPUAddFn(OpKernelContext* ctx, const VariantValue& a, + const VariantValue& b, VariantValue* out) { + if (a.early_exit) { + return errors::InvalidArgument("early exit add!"); + } + out->value = a.value + b.value; // CPU + return Status::OK(); + } + static Status GPUAddFn(OpKernelContext* ctx, const VariantValue& a, + const VariantValue& b, VariantValue* out) { + if (a.early_exit) { + return errors::InvalidArgument("early exit add!"); + } + out->value = -(a.value + b.value); // GPU + return Status::OK(); + } + static Status CPUToGPUCopyFn( + const VariantValue& from, VariantValue* to, + const std::function& copier) { + TF_RETURN_IF_ERROR(copier(Tensor(), nullptr)); + to->value = 0xdeadbeef; + return Status::OK(); + } + bool early_exit; + int value; +}; + +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue", + VariantValue::ShapeFn); + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue"); + +INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( + VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE, + "TEST VariantValue", VariantValue::CPUToGPUCopyFn); + +REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, + DEVICE_CPU, VariantValue, + "TEST VariantValue", + VariantValue::CPUZerosLikeFn); + +REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, + DEVICE_GPU, VariantValue, + "TEST VariantValue", + VariantValue::GPUZerosLikeFn); + +REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, + VariantValue, "TEST VariantValue", + VariantValue::CPUAddFn); + +REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU, + VariantValue, "TEST VariantValue", + VariantValue::GPUAddFn); + +} // namespace + +TEST(VariantOpShapeRegistryTest, TestBasic) { + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"), + nullptr); + + auto* shape_fn = + UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue"); + EXPECT_NE(shape_fn, nullptr); + TensorShape shape; + + VariantValue vv_early_exit{true /* early_exit */}; + Variant v = vv_early_exit; + Status s0 = (*shape_fn)(v, &shape); + EXPECT_FALSE(s0.ok()); + EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit!")); + + VariantValue vv_ok{false /* early_exit */}; + v = vv_ok; + TF_EXPECT_OK((*shape_fn)(v, &shape)); + EXPECT_EQ(shape, TensorShape({-0xdeadbeef})); +} + +TEST(VariantOpShapeRegistryTest, TestDuplicate) { + UnaryVariantOpRegistry registry; + UnaryVariantOpRegistry::VariantShapeFn f; + string kTypeName = "fjfjfj"; + registry.RegisterShapeFn(kTypeName, f); + EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f), + "fjfjfj already registered"); +} + +TEST(VariantOpDecodeRegistryTest, TestBasic) { + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDecodeFn("YOU SHALL NOT PASS"), + nullptr); + + auto* decode_fn = + UnaryVariantOpRegistry::Global()->GetDecodeFn("TEST VariantValue"); + EXPECT_NE(decode_fn, nullptr); + + VariantValue vv{true /* early_exit */}; + Variant v = vv; + VariantTensorData data; + v.Encode(&data); + VariantTensorDataProto proto; + data.ToProto(&proto); + Variant encoded = proto; + EXPECT_TRUE((*decode_fn)(&encoded)); + VariantValue* decoded = encoded.get(); + EXPECT_NE(decoded, nullptr); + EXPECT_EQ(decoded->early_exit, true); +} + +TEST(VariantOpDecodeRegistryTest, TestDuplicate) { + UnaryVariantOpRegistry registry; + UnaryVariantOpRegistry::VariantDecodeFn f; + string kTypeName = "fjfjfj"; + registry.RegisterDecodeFn(kTypeName, f); + EXPECT_DEATH(registry.RegisterDecodeFn(kTypeName, f), + "fjfjfj already registered"); +} + +TEST(VariantOpCopyToGPURegistryTest, TestBasic) { + // No registered copy fn for GPU<->GPU. + EXPECT_EQ( + UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( + VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"), + nullptr); + + auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( + VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue"); + EXPECT_NE(copy_to_gpu_fn, nullptr); + + VariantValue vv{true /* early_exit */}; + Variant v = vv; + Variant v_out; + bool dummy_executed = false; + auto dummy_copy_fn = [&dummy_executed](const Tensor& from, + Tensor* to) -> Status { + dummy_executed = true; + return Status::OK(); + }; + TF_EXPECT_OK((*copy_to_gpu_fn)(v, &v_out, dummy_copy_fn)); + EXPECT_TRUE(dummy_executed); + VariantValue* copied_value = v_out.get(); + EXPECT_NE(copied_value, nullptr); + EXPECT_EQ(copied_value->value, 0xdeadbeef); +} + +TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) { + UnaryVariantOpRegistry registry; + UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f; + string kTypeName = "fjfjfj"; + registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE, + kTypeName, f); + EXPECT_DEATH(registry.RegisterDeviceCopyFn( + VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f), + "fjfjfj already registered"); +} + +TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + nullptr); + + VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; + Variant v = vv_early_exit; + Variant v_out = VariantValue(); + + OpKernelContext* null_context_pointer = nullptr; + Status s0 = UnaryOpVariant(null_context_pointer, + ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out); + EXPECT_FALSE(s0.ok()); + EXPECT_TRUE( + StringPiece(s0.error_message()).contains("early exit zeros_like")); + + VariantValue vv_ok{false /* early_exit */, 0 /* value */}; + v = vv_ok; + TF_EXPECT_OK(UnaryOpVariant( + null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out)); + VariantValue* vv_out = CHECK_NOTNULL(v_out.get()); + EXPECT_EQ(vv_out->value, 1); // CPU +} + +#if GOOGLE_CUDA +TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + nullptr); + + VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; + Variant v = vv_early_exit; + Variant v_out = VariantValue(); + + OpKernelContext* null_context_pointer = nullptr; + Status s0 = UnaryOpVariant(null_context_pointer, + ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out); + EXPECT_FALSE(s0.ok()); + EXPECT_TRUE( + StringPiece(s0.error_message()).contains("early exit zeros_like")); + + VariantValue vv_ok{false /* early_exit */, 0 /* value */}; + v = vv_ok; + TF_EXPECT_OK(UnaryOpVariant( + null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out)); + VariantValue* vv_out = CHECK_NOTNULL(v_out.get()); + EXPECT_EQ(vv_out->value, 2); // GPU +} +#endif // GOOGLE_CUDA + +TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) { + UnaryVariantOpRegistry registry; + UnaryVariantOpRegistry::VariantUnaryOpFn f; + string kTypeName = "fjfjfj"; + + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName, + f); + EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, + DEVICE_CPU, kTypeName, f), + "fjfjfj already registered"); + + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName, + f); + EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, + DEVICE_GPU, kTypeName, f), + "fjfjfj already registered"); +} + +TEST(VariantOpAddRegistryTest, TestBasicCPU) { + return; + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( + ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + nullptr); + + VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; + VariantValue vv_other{true /* early_exit */, 4 /* value */}; + Variant v_a = vv_early_exit; + Variant v_b = vv_other; + Variant v_out = VariantValue(); + + OpKernelContext* null_context_pointer = nullptr; + Status s0 = BinaryOpVariants( + null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out); + EXPECT_FALSE(s0.ok()); + EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add")); + + VariantValue vv_ok{false /* early_exit */, 3 /* value */}; + v_a = vv_ok; + TF_EXPECT_OK(BinaryOpVariants( + null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out)); + VariantValue* vv_out = CHECK_NOTNULL(v_out.get()); + EXPECT_EQ(vv_out->value, 7); // CPU +} + +#if GOOGLE_CUDA +TEST(VariantOpAddRegistryTest, TestBasicGPU) { + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( + ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + nullptr); + + VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; + VariantValue vv_other{true /* early_exit */, 4 /* value */}; + Variant v_a = vv_early_exit; + Variant v_b = vv_other; + Variant v_out = VariantValue(); + + OpKernelContext* null_context_pointer = nullptr; + Status s0 = BinaryOpVariants( + null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out); + EXPECT_FALSE(s0.ok()); + EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add")); + + VariantValue vv_ok{false /* early_exit */, 3 /* value */}; + v_a = vv_ok; + TF_EXPECT_OK(BinaryOpVariants( + null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out)); + VariantValue* vv_out = CHECK_NOTNULL(v_out.get()); + EXPECT_EQ(vv_out->value, -7); // GPU +} +#endif // GOOGLE_CUDA + +TEST(VariantOpAddRegistryTest, TestDuplicate) { + UnaryVariantOpRegistry registry; + UnaryVariantOpRegistry::VariantBinaryOpFn f; + string kTypeName = "fjfjfj"; + + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f); + EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, + kTypeName, f), + "fjfjfj already registered"); + + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f); + EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, + kTypeName, f), + "fjfjfj already registered"); +} + +} // namespace tensorflow diff --git a/variant_tensor_data.cc b/variant_tensor_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..82479193d2a3464897b0fff6c8feaf6c487a23c4 --- /dev/null +++ b/variant_tensor_data.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +VariantTensorData::VariantTensorData() {} + +VariantTensorData::VariantTensorData(const VariantTensorDataProto& proto) { + FromProto(proto); +} + +VariantTensorData::~VariantTensorData() {} + +int VariantTensorData::tensors_size() const { return tensors_.size(); } + +const Tensor& VariantTensorData::tensors(int index) const { + return tensors_[index]; +} + +std::vector VariantTensorData::tensors() { return tensors_; } + +Tensor* VariantTensorData::add_tensors() { + tensors_.emplace_back(); + return &(tensors_[tensors_.size() - 1]); +} + +void VariantTensorData::ToProto(VariantTensorDataProto* proto) const { + proto->set_type_name(type_name()); + proto->set_metadata(metadata_); + proto->clear_tensors(); + for (const auto& tensor : tensors_) { + tensor.AsProtoField(proto->mutable_tensors()->Add()); + } +} + +bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) { + set_type_name(proto.type_name()); + set_metadata(proto.metadata()); + for (const auto& tensor : proto.tensors()) { + Tensor tmp; + if (!tmp.FromProto(tensor)) return false; + tensors_.push_back(tmp); + } + return true; +} + +string VariantTensorData::SerializeAsString() const { + VariantTensorDataProto proto; + ToProto(&proto); + return proto.SerializeAsString(); +} + +bool VariantTensorData::SerializeToString(string* buf) { + VariantTensorDataProto proto; + ToProto(&proto); + return proto.SerializeToString(buf); +} + +bool VariantTensorData::ParseFromString(const string& s) { + VariantTensorDataProto proto; + const bool status = proto.ParseFromString(s); + if (status) FromProto(proto); + return status; +} + +string VariantTensorData::DebugString() const { + string repeated_field = ""; + for (const auto& t : tensors_) { + repeated_field = + strings::StrCat(repeated_field, " tensors: ", t.DebugString()); + } + return strings::StrCat("type_name: ", type_name(), " metadata: ", metadata_, + repeated_field); +} + +string ProtoDebugString(const VariantTensorData& object) { + return object.DebugString(); +} + +} // namespace tensorflow diff --git a/variant_tensor_data.h b/variant_tensor_data.h new file mode 100644 index 0000000000000000000000000000000000000000..6e04879494af447e620f6737bc749f68d9e1394d --- /dev/null +++ b/variant_tensor_data.h @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_VARIANT_TENSOR_DATA_H +#define TENSORFLOW_FRAMEWORK_VARIANT_TENSOR_DATA_H + +#include +#include + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class VariantTensorDataProto; +class Tensor; + +// The serialization format for Variant objects. Objects with references to +// other Tensors can simply store those tensors in the `tensors` field, and +// serialize other metadata content in to the `metadata` field. Objects can +// optionally set the `type_name` for type-checking before deserializing an +// object. +// +// This is the native C++ class equivalent of VariantTensorDataProto. They are +// separate so that kernels do not need to depend on protos. +class VariantTensorData { + public: + VariantTensorData(); + VariantTensorData(const VariantTensorDataProto& proto); + ~VariantTensorData(); + + // Name of the type of objects being serialized. + const string& type_name() const { return type_name_; } + void set_type_name(const string& type_name) { type_name_ = type_name; } + + template ::type>::value> + struct PODResolver {}; + + // Portions of the object that are not Tensors. + // Directly supported types include string POD types. + template + void set_metadata(const T& value) { + SetMetadata(value, PODResolver()); + } + + template + bool get_metadata(T* value) const { + return GetMetadata(value, PODResolver()); + } + + // Tensors contained within objects being serialized. + int tensors_size() const; + const Tensor& tensors(int index) const; + std::vector tensors(); + Tensor* add_tensors(); + + // Conversion to and from VariantTensorDataProto + void ToProto(VariantTensorDataProto* proto) const; + bool FromProto(const VariantTensorDataProto& proto); + + // Serialization via VariantTensorDataProto + string SerializeAsString() const; + bool SerializeToString(string* buf); + bool ParseFromString(const string& s); + + string DebugString() const; + + public: + string type_name_; + string metadata_; + std::vector tensors_; + + private: + template + void SetMetadata(const string& value, PODResolver) { + metadata_ = value; + } + + template + bool GetMetadata(string* value, PODResolver) const { + *value = metadata_; + return true; + } + + template + void SetMetadata(const T& value, PODResolver) { + metadata_.assign(reinterpret_cast(&value), sizeof(T)); + } + + template + bool GetMetadata(T* value, PODResolver) const { + if (metadata_.size() != sizeof(T)) return false; + std::copy_n(metadata_.data(), sizeof(T), reinterpret_cast(value)); + return true; + } +}; + +// For backwards compatibility for when this was a proto +string ProtoDebugString(const VariantTensorData& object); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_VARIANT_TENSOR_DATA_H diff --git a/variant_test.cc b/variant_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..eef5c47d15b65725d80d2c1e19a33dbf50f8aa93 --- /dev/null +++ b/variant_test.cc @@ -0,0 +1,284 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_tensor_data.h" + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +template +struct Wrapper { + T value; + string TypeName() const { return "POD"; } +}; + +using Int = Wrapper; +using Float = Wrapper; + +} // end namespace + +TEST(VariantTest, Int) { + Variant x; + EXPECT_EQ(x.get(), nullptr); + x = 3; + EXPECT_NE(x.get(), nullptr); + EXPECT_EQ(*x.get(), 3); + EXPECT_EQ(x.TypeName(), "int"); +} + +TEST(VariantTest, Basic) { + Variant x; + EXPECT_EQ(x.get(), nullptr); + + x = Int{42}; + + EXPECT_NE(x.get(), nullptr); + EXPECT_NE(x.get(), nullptr); + EXPECT_EQ(x.get()->value, 42); + EXPECT_EQ(x.TypeName(), "POD"); +} + +TEST(VariantTest, ConstGet) { + Variant x; + EXPECT_EQ(x.get(), nullptr); + + x = Int{42}; + + const Variant y = x; + + EXPECT_NE(y.get(), nullptr); + EXPECT_NE(y.get(), nullptr); + EXPECT_EQ(y.get()->value, 42); +} + +TEST(VariantTest, Clear) { + Variant x; + EXPECT_EQ(x.get(), nullptr); + + x = Int{42}; + + EXPECT_NE(x.get(), nullptr); + EXPECT_NE(x.get(), nullptr); + EXPECT_EQ(x.get()->value, 42); + + x.clear(); + EXPECT_EQ(x.get(), nullptr); +} + +TEST(VariantTest, Tensor) { + Variant x; + Tensor t(DT_FLOAT, {}); + t.flat()(0) = 42.0f; + x = t; + + EXPECT_NE(x.get(), nullptr); + EXPECT_EQ(x.get()->flat()(0), 42.0f); + x.get()->flat()(0) += 1.0f; + EXPECT_EQ(x.get()->flat()(0), 43.0f); + EXPECT_EQ(x.TypeName(), "tensorflow::Tensor"); +} + +TEST(VariantTest, TensorProto) { + Variant x; + TensorProto t; + t.set_dtype(DT_FLOAT); + t.mutable_tensor_shape()->set_unknown_rank(true); + x = t; + + EXPECT_EQ(x.TypeName(), "tensorflow.TensorProto"); + EXPECT_NE(x.get(), nullptr); + EXPECT_EQ(x.get()->dtype(), DT_FLOAT); + EXPECT_EQ(x.get()->tensor_shape().unknown_rank(), true); +} + +TEST(VariantTest, CopyValue) { + Variant x, y; + x = Int{10}; + y = x; + + EXPECT_EQ(x.get()->value, 10); + EXPECT_EQ(x.get()->value, y.get()->value); +} + +TEST(VariantTest, MoveValue) { + Variant x; + x = []() -> Variant { + Variant y; + y = Int{10}; + return y; + }(); + EXPECT_EQ(x.get()->value, 10); +} + +TEST(VariantTest, TypeMismatch) { + Variant x; + x = Int{10}; + EXPECT_EQ(x.get(), nullptr); + EXPECT_EQ(x.get(), nullptr); + EXPECT_NE(x.get(), nullptr); +} + +struct TensorList { + void Encode(VariantTensorData* data) const { data->tensors_ = vec; } + + bool Decode(const VariantTensorData& data) { + vec = data.tensors_; + return true; + } + + string TypeName() const { return "TensorList"; } + + std::vector vec; +}; + +TEST(VariantTest, TensorListTest) { + Variant x; + + TensorList vec; + for (int i = 0; i < 4; ++i) { + Tensor elem(DT_INT32, {1}); + elem.flat()(0) = i; + vec.vec.push_back(elem); + } + + for (int i = 0; i < 4; ++i) { + Tensor elem(DT_FLOAT, {1}); + elem.flat()(0) = 2 * i; + vec.vec.push_back(elem); + } + + x = vec; + + EXPECT_EQ(x.TypeName(), "TensorList"); + EXPECT_EQ(x.DebugString(), "Variant"); + const TensorList& stored_vec = *x.get(); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(stored_vec.vec[i].flat()(0), i); + } + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(stored_vec.vec[i + 4].flat()(0), 2 * i); + } + + VariantTensorData serialized; + x.Encode(&serialized); + + Variant y = TensorList(); + y.Decode(serialized); + + const TensorList& decoded_vec = *y.get(); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(decoded_vec.vec[i].flat()(0), i); + } + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(decoded_vec.vec[i + 4].flat()(0), 2 * i); + } + + VariantTensorDataProto data; + serialized.ToProto(&data); + const Variant y_unknown = data; + EXPECT_EQ(y_unknown.TypeName(), "TensorList"); + EXPECT_EQ(y_unknown.TypeId(), MakeTypeIndex()); + EXPECT_EQ(y_unknown.DebugString(), + strings::StrCat( + "Variant")); + + TensorList unknown_decoded_vec; + EXPECT_TRUE(y_unknown.MaybeDecodeAndCopy(&unknown_decoded_vec)); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(unknown_decoded_vec.vec[i].flat()(0), i); + } + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(unknown_decoded_vec.vec[i + 4].flat()(0), 2 * i); + } +} + +TEST(VariantTest, VariantArray) { + Variant x[2]; + x[0] = Int{2}; + x[1] = Float{2.0f}; + + EXPECT_EQ(x[0].get()->value, 2); + EXPECT_EQ(x[1].get()->value, 2.0f); +} + +TEST(VariantTest, PodUpdate) { + struct Pod { + int x; + float y; + + string TypeName() const { return "POD"; } + }; + + Variant x = Pod{10, 20.f}; + EXPECT_NE(x.get(), nullptr); + EXPECT_EQ(x.TypeName(), "POD"); + EXPECT_EQ(x.DebugString(), "Variant"); + + x.get()->x += x.get()->y; + EXPECT_EQ(x.get()->x, 30); +} + +TEST(VariantTest, EncodeDecodePod) { + struct Pod { + int x; + float y; + + string TypeName() const { return "POD"; } + }; + + Variant x; + Pod p{10, 20.0f}; + x = p; + + VariantTensorData serialized; + x.Encode(&serialized); + + Variant y; + y = Pod(); + y.Decode(serialized); + + EXPECT_EQ(p.x, y.get()->x); + EXPECT_EQ(p.y, y.get()->y); +} + +TEST(VariantTest, EncodeDecodeTensor) { + Variant x; + Tensor t(DT_INT32, {}); + t.flat()(0) = 42; + x = t; + + VariantTensorData serialized; + x.Encode(&serialized); + + Variant y = Tensor(); + y.Decode(serialized); + EXPECT_EQ(y.DebugString(), + "Variant>"); + EXPECT_EQ(x.get()->flat()(0), y.get()->flat()(0)); +} + +} // end namespace tensorflow diff --git a/versions.cc b/versions.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ff0723ceec2576948f7e840ab0b45d2a741f215 --- /dev/null +++ b/versions.cc @@ -0,0 +1,56 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/versions.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +Status CheckVersions(const VersionDef& versions, int consumer, int min_producer, + const char* upper_name, const char* lower_name) { + // Guard against the caller misordering the arguments + if (consumer < min_producer) { + return errors::Internal(upper_name, " version check has consumer ", + consumer, " < min_producer ", min_producer, "."); + } + + // Check versions + if (versions.producer() < min_producer) { + return errors::InvalidArgument( + upper_name, " producer version ", versions.producer(), + " below min producer ", min_producer, " supported by TensorFlow ", + TF_VERSION_STRING, ". Please regenerate your ", lower_name, "."); + } + if (versions.min_consumer() > consumer) { + return errors::InvalidArgument( + upper_name, " min consumer version ", versions.min_consumer(), + " above current version ", consumer, " for TensorFlow ", + TF_VERSION_STRING, ". Please upgrade TensorFlow."); + } + for (const int bad_consumer : versions.bad_consumers()) { + if (bad_consumer == consumer) { + return errors::InvalidArgument( + upper_name, " disallows consumer version ", bad_consumer, + ". Please upgrade TensorFlow: this version is likely buggy."); + } + } + + // All good! + return Status::OK(); +} + +} // namespace tensorflow diff --git a/versions.h b/versions.h new file mode 100644 index 0000000000000000000000000000000000000000..9676667b8f6058b93535707eecac5cae223615fd --- /dev/null +++ b/versions.h @@ -0,0 +1,39 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_VERSIONS_H_ +#define TENSORFLOW_FRAMEWORK_VERSIONS_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class VersionDef; + +// Check whether data with the given versions is compatible with the given +// consumer and min producer. upper_name and lower_name are used to form +// error messages upon failure. Example usage: +// +// #include "tensorflow/core/public/version.h" +// +// TF_RETURN_IF_ERROR(CheckVersions(versions, TF_GRAPH_DEF_VERSION, +// TF_GRAPH_DEF_VERSION_MIN_PRODUCER, +// "GraphDef", "graph")); +Status CheckVersions(const VersionDef& versions, int consumer, int min_producer, + const char* upper_name, const char* lower_name); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_VERSIONS_H_ diff --git a/versions.proto b/versions.proto new file mode 100644 index 0000000000000000000000000000000000000000..7d5e58ae7d42307da310e1e878bbe00efc16b417 --- /dev/null +++ b/versions.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +};