Spaces:
Build error
Build error
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
==============================================================================*/ | |
namespace tensorflow { | |
namespace { | |
class FakeInputImpl { | |
public: | |
FakeInputImpl(const OpDef* op_def, int in_index, const NodeDef* node_def, | |
NodeDefBuilder* builder); | |
void SetN(int n); | |
void SetDataType(DataType dt); | |
void SetTypeList(DataTypeSlice dts); | |
Status AddInputToBuilder(); | |
private: | |
static string FakeNodeName(int in_index); | |
Status GetN(int* n) const; | |
Status GetDataType(DataType* dt) const; | |
void NSources(int n, DataType dt) const; | |
void SourceList(DataTypeSlice dts) const; | |
const OpDef* const op_def_; | |
const OpDef::ArgDef* const arg_; | |
const string in_node_; | |
const NodeDef* const node_def_; | |
NodeDefBuilder* const builder_; | |
bool n_specified_; | |
int n_; | |
bool dt_specified_; | |
DataType dt_; | |
bool dts_specified_; | |
DataTypeSlice dts_; | |
}; | |
FakeInputImpl::FakeInputImpl(const OpDef* op_def, int in_index, | |
const NodeDef* node_def, NodeDefBuilder* builder) | |
: op_def_(op_def), | |
arg_(&op_def->input_arg(in_index)), | |
in_node_(FakeNodeName(in_index)), | |
node_def_(node_def), | |
builder_(builder), | |
n_specified_(false), | |
dt_specified_(false), | |
dts_specified_(false) {} | |
void FakeInputImpl::SetN(int n) { | |
n_specified_ = true; | |
n_ = n; | |
} | |
void FakeInputImpl::SetDataType(DataType dt) { | |
dt_specified_ = true; | |
dt_ = dt; | |
} | |
void FakeInputImpl::SetTypeList(DataTypeSlice dts) { | |
dts_specified_ = true; | |
dts_ = dts; | |
} | |
Status FakeInputImpl::AddInputToBuilder() { | |
if (dts_specified_) { | |
SourceList(dts_); | |
} else if (n_specified_ || !arg_->number_attr().empty()) { | |
int n; | |
TF_RETURN_IF_ERROR(GetN(&n)); | |
DataType dt; | |
if (n > 0) { | |
TF_RETURN_IF_ERROR(GetDataType(&dt)); | |
} else { | |
dt = DT_FLOAT; | |
} | |
NSources(n, dt); | |
} else { | |
if (!dt_specified_ && !arg_->type_list_attr().empty()) { | |
DataTypeVector dts; | |
Status status = GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts); | |
if (!status.ok()) { | |
return errors::InvalidArgument( | |
"Could not infer list of types for input '", arg_->name(), "': ", | |
status.error_message()); | |
} | |
SourceList(dts); | |
return Status::OK(); | |
} | |
DataType dt; | |
TF_RETURN_IF_ERROR(GetDataType(&dt)); | |
builder_->Input(in_node_, 0, dt); | |
} | |
return Status::OK(); | |
} | |
// static | |
string FakeInputImpl::FakeNodeName(int in_index) { | |
char c = 'a' + (in_index % 26); | |
return string(&c, 1); | |
} | |
Status FakeInputImpl::GetN(int* n) const { | |
if (n_specified_) { | |
*n = n_; | |
} else { | |
Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n); | |
if (!status.ok()) { | |
return errors::InvalidArgument("Could not infer length of input '", | |
arg_->name(), "': ", | |
status.error_message()); | |
} | |
} | |
return Status::OK(); | |
} | |
Status FakeInputImpl::GetDataType(DataType* dt) const { | |
if (dt_specified_) { | |
*dt = dt_; | |
return Status::OK(); // Ignore is_ref field of arg_. | |
} else if (arg_->type() != DT_INVALID) { | |
*dt = arg_->type(); | |
} else if (!arg_->type_attr().empty()) { | |
Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt); | |
if (!status.ok()) { | |
// Check if the type attr has a default | |
const OpDef::AttrDef* attr = FindAttr(arg_->type_attr(), *op_def_); | |
if (attr && attr->has_default_value()) { | |
*dt = attr->default_value().type(); | |
} else { | |
return errors::InvalidArgument("Could not infer type for input '", | |
arg_->name(), "': ", | |
status.error_message()); | |
} | |
} | |
} else { | |
return errors::InvalidArgument("No type or type_attr field in arg '", | |
arg_->name(), "'"); | |
} | |
if (arg_->is_ref()) { | |
*dt = MakeRefType(*dt); | |
} | |
return Status::OK(); | |
} | |
void FakeInputImpl::NSources(int n, DataType dt) const { | |
std::vector<NodeDefBuilder::NodeOut> srcs; | |
srcs.reserve(n); | |
for (int i = 0; i < n; ++i) { | |
srcs.emplace_back(in_node_, i, dt); | |
} | |
builder_->Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs)); | |
} | |
void FakeInputImpl::SourceList(DataTypeSlice dts) const { | |
std::vector<NodeDefBuilder::NodeOut> srcs; | |
srcs.reserve(dts.size()); | |
for (size_t i = 0; i < dts.size(); ++i) { | |
srcs.emplace_back(in_node_, i, dts[i]); | |
} | |
builder_->Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs)); | |
} | |
} // namespace | |
// Public interface ------------------------------------------------------------ | |
FakeInputFunctor FakeInput() { | |
return [](const OpDef& op_def, int in_index, const NodeDef& node_def, | |
NodeDefBuilder* builder) { | |
FakeInputImpl impl(&op_def, in_index, &node_def, builder); | |
return impl.AddInputToBuilder(); | |
}; | |
} | |
FakeInputFunctor FakeInput(DataType dt) { | |
return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def, | |
NodeDefBuilder* builder) { | |
FakeInputImpl impl(&op_def, in_index, &node_def, builder); | |
impl.SetDataType(dt); | |
return impl.AddInputToBuilder(); | |
}; | |
} | |
FakeInputFunctor FakeInput(int n) { | |
return [n](const OpDef& op_def, int in_index, const NodeDef& node_def, | |
NodeDefBuilder* builder) { | |
FakeInputImpl impl(&op_def, in_index, &node_def, builder); | |
impl.SetN(n); | |
return impl.AddInputToBuilder(); | |
}; | |
} | |
FakeInputFunctor FakeInput(int n, DataType dt) { | |
return [n, dt](const OpDef& op_def, int in_index, const NodeDef& node_def, | |
NodeDefBuilder* builder) { | |
FakeInputImpl impl(&op_def, in_index, &node_def, builder); | |
impl.SetN(n); | |
impl.SetDataType(dt); | |
return impl.AddInputToBuilder(); | |
}; | |
} | |
FakeInputFunctor FakeInput(DataTypeSlice dts) { | |
// Make a copy to ensure the data will still be around when the lambda is | |
// called. | |
DataTypeVector dtv(dts.begin(), dts.end()); | |
return [dtv](const OpDef& op_def, int in_index, const NodeDef& node_def, | |
NodeDefBuilder* builder) { | |
FakeInputImpl impl(&op_def, in_index, &node_def, builder); | |
impl.SetTypeList(dtv); | |
return impl.AddInputToBuilder(); | |
}; | |
} | |
} // namespace tensorflow | |