Spaces:
Sleeping
Sleeping
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
==============================================================================*/ | |
namespace tensorflow { | |
namespace { | |
string SummarizeString(const string& str) { | |
return strings::StrCat("\"", str_util::CEscape(str), "\""); | |
} | |
string SummarizeTensor(const TensorProto& tensor_proto) { | |
Tensor t; | |
if (!t.FromProto(tensor_proto)) { | |
return strings::StrCat( | |
"<Invalid TensorProto: ", ProtoShortDebugString(tensor_proto), ">"); | |
} | |
return t.DebugString(); | |
} | |
string SummarizeFunc(const NameAttrList& func) { | |
std::vector<string> entries; | |
for (auto p : func.attr()) { | |
entries.push_back( | |
strings::StrCat(p.first, "=", SummarizeAttrValue(p.second))); | |
} | |
std::sort(entries.begin(), entries.end()); | |
return strings::StrCat(func.name(), "[", str_util::Join(entries, ", "), "]"); | |
} | |
} // namespace | |
string SummarizeAttrValue(const AttrValue& attr_value) { | |
switch (attr_value.value_case()) { | |
case AttrValue::kS: | |
return SummarizeString(attr_value.s()); | |
case AttrValue::kI: | |
return strings::StrCat(attr_value.i()); | |
case AttrValue::kF: | |
return strings::StrCat(attr_value.f()); | |
case AttrValue::kB: | |
return attr_value.b() ? "true" : "false"; | |
case AttrValue::kType: | |
return EnumName_DataType(attr_value.type()); | |
case AttrValue::kShape: | |
return PartialTensorShape::DebugString(attr_value.shape()); | |
case AttrValue::kTensor: | |
return SummarizeTensor(attr_value.tensor()); | |
case AttrValue::kList: { | |
string ret = "["; | |
if (attr_value.list().s_size() > 0) { | |
for (int i = 0; i < attr_value.list().s_size(); ++i) { | |
if (i > 0) strings::StrAppend(&ret, ", "); | |
strings::StrAppend(&ret, SummarizeString(attr_value.list().s(i))); | |
} | |
} else if (attr_value.list().i_size() > 0) { | |
for (int i = 0; i < attr_value.list().i_size(); ++i) { | |
if (i > 0) strings::StrAppend(&ret, ", "); | |
strings::StrAppend(&ret, attr_value.list().i(i)); | |
} | |
} else if (attr_value.list().f_size() > 0) { | |
for (int i = 0; i < attr_value.list().f_size(); ++i) { | |
if (i > 0) strings::StrAppend(&ret, ", "); | |
strings::StrAppend(&ret, attr_value.list().f(i)); | |
} | |
} else if (attr_value.list().b_size() > 0) { | |
for (int i = 0; i < attr_value.list().b_size(); ++i) { | |
if (i > 0) strings::StrAppend(&ret, ", "); | |
strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false"); | |
} | |
} else if (attr_value.list().type_size() > 0) { | |
for (int i = 0; i < attr_value.list().type_size(); ++i) { | |
if (i > 0) strings::StrAppend(&ret, ", "); | |
strings::StrAppend(&ret, | |
EnumName_DataType(attr_value.list().type(i))); | |
} | |
} else if (attr_value.list().shape_size() > 0) { | |
for (int i = 0; i < attr_value.list().shape_size(); ++i) { | |
if (i > 0) strings::StrAppend(&ret, ", "); | |
strings::StrAppend( | |
&ret, TensorShape::DebugString(attr_value.list().shape(i))); | |
} | |
} else if (attr_value.list().tensor_size() > 0) { | |
for (int i = 0; i < attr_value.list().tensor_size(); ++i) { | |
if (i > 0) strings::StrAppend(&ret, ", "); | |
strings::StrAppend(&ret, | |
SummarizeTensor(attr_value.list().tensor(i))); | |
} | |
} else if (attr_value.list().func_size() > 0) { | |
for (int i = 0; i < attr_value.list().func_size(); ++i) { | |
if (i > 0) strings::StrAppend(&ret, ", "); | |
strings::StrAppend(&ret, SummarizeFunc(attr_value.list().func(i))); | |
} | |
} | |
strings::StrAppend(&ret, "]"); | |
return ret; | |
} | |
case AttrValue::kFunc: { | |
return SummarizeFunc(attr_value.func()); | |
} | |
case AttrValue::kPlaceholder: | |
return strings::StrCat("$", attr_value.placeholder()); | |
case AttrValue::VALUE_NOT_SET: | |
return "<Unknown AttrValue type>"; | |
} | |
return "<Unknown AttrValue type>"; // Prevent missing return warning | |
} | |
Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { | |
int num_set = 0; | |
VALIDATE_FIELD(s, "string", kS); | |
VALIDATE_FIELD(i, "int", kI); | |
VALIDATE_FIELD(f, "float", kF); | |
VALIDATE_FIELD(b, "bool", kB); | |
VALIDATE_FIELD(type, "type", kType); | |
VALIDATE_FIELD(shape, "shape", kShape); | |
VALIDATE_FIELD(tensor, "tensor", kTensor); | |
VALIDATE_FIELD(func, "func", kFunc); | |
if (attr_value.value_case() == AttrValue::kPlaceholder) { | |
return errors::InvalidArgument( | |
"AttrValue had value with unexpected type 'placeholder'"); | |
} | |
// If the attr type is 'list', we expect attr_value.has_list() to be | |
// true. However, proto3's attr_value.has_list() can be false when | |
// set to an empty list for GraphDef versions <= 4. So we simply | |
// check if has_list is false and some other field in attr_value is | |
// set to flag the error. This test can be made more strict once | |
// support for GraphDef versions <= 4 is dropped. | |
if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) { | |
if (num_set) { | |
return errors::InvalidArgument( | |
"AttrValue missing value with expected type '", type, "'"); | |
} else { | |
// Indicate that we have a list, but an empty one. | |
++num_set; | |
} | |
} | |
// Okay to have an empty list, but not to be missing a non-list value. | |
if (num_set == 0 && !StringPiece(type).starts_with("list(")) { | |
return errors::InvalidArgument( | |
"AttrValue missing value with expected type '", type, "'"); | |
} | |
// Ref types and DT_INVALID are illegal, and DataTypes must | |
// be a valid enum type. | |
if (type == "type") { | |
if (!DataType_IsValid(attr_value.type())) { | |
return errors::InvalidArgument("AttrValue has invalid DataType enum: ", | |
attr_value.type()); | |
} | |
if (IsRefType(attr_value.type())) { | |
return errors::InvalidArgument( | |
"AttrValue must not have reference type value of ", | |
DataTypeString(attr_value.type())); | |
} | |
if (attr_value.type() == DT_INVALID) { | |
return errors::InvalidArgument("AttrValue has invalid DataType"); | |
} | |
} else if (type == "list(type)") { | |
for (auto as_int : attr_value.list().type()) { | |
const DataType dtype = static_cast<DataType>(as_int); | |
if (!DataType_IsValid(dtype)) { | |
return errors::InvalidArgument("AttrValue has invalid DataType enum: ", | |
as_int); | |
} | |
if (IsRefType(dtype)) { | |
return errors::InvalidArgument( | |
"AttrValue must not have reference type value of ", | |
DataTypeString(dtype)); | |
} | |
if (dtype == DT_INVALID) { | |
return errors::InvalidArgument("AttrValue contains invalid DataType"); | |
} | |
} | |
} | |
return Status::OK(); | |
} | |
bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { | |
// Parse type. | |
string field_name; | |
bool is_list = type.Consume("list("); | |
if (type.Consume("string")) { | |
field_name = "s"; | |
} else if (type.Consume("int")) { | |
field_name = "i"; | |
} else if (type.Consume("float")) { | |
field_name = "f"; | |
} else if (type.Consume("bool")) { | |
field_name = "b"; | |
} else if (type.Consume("type")) { | |
field_name = "type"; | |
} else if (type.Consume("shape")) { | |
field_name = "shape"; | |
} else if (type.Consume("tensor")) { | |
field_name = "tensor"; | |
} else if (type.Consume("func")) { | |
field_name = "func"; | |
} else if (type.Consume("placeholder")) { | |
field_name = "placeholder"; | |
} else { | |
return false; | |
} | |
if (is_list && !type.Consume(")")) { | |
return false; | |
} | |
// Construct a valid text proto message to parse. | |
string to_parse; | |
if (is_list) { | |
// TextFormat parser considers "i: 7" to be the same as "i: [7]", | |
// but we only want to allow list values with []. | |
StringPiece cleaned = text; | |
str_util::RemoveLeadingWhitespace(&cleaned); | |
str_util::RemoveTrailingWhitespace(&cleaned); | |
if (cleaned.size() < 2 || cleaned[0] != '[' || | |
cleaned[cleaned.size() - 1] != ']') { | |
return false; | |
} | |
cleaned.remove_prefix(1); | |
str_util::RemoveLeadingWhitespace(&cleaned); | |
if (cleaned.size() == 1) { | |
// User wrote "[]", so return empty list without invoking the TextFormat | |
// parse which returns an error for "i: []". | |
out->Clear(); | |
out->mutable_list(); | |
return true; | |
} | |
to_parse = strings::StrCat("list { ", field_name, ": ", text, " }"); | |
} else { | |
to_parse = strings::StrCat(field_name, ": ", text); | |
} | |
return ProtoParseFromString(to_parse, out); | |
} | |
void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; } | |
DEFINE_SET_ATTR_VALUE_ONE(const string&, s) | |
DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s) | |
DEFINE_SET_ATTR_VALUE_BOTH(const char*, s) | |
DEFINE_SET_ATTR_VALUE_BOTH(int64, i) | |
DEFINE_SET_ATTR_VALUE_BOTH(int32, i) | |
DEFINE_SET_ATTR_VALUE_BOTH(float, f) | |
DEFINE_SET_ATTR_VALUE_BOTH(double, f) | |
DEFINE_SET_ATTR_VALUE_BOTH(bool, b) | |
DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b) | |
DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b) | |
DEFINE_SET_ATTR_VALUE_BOTH(DataType, type) | |
void SetAttrValue(StringPiece value, AttrValue* out) { | |
out->set_s(value.data(), value.size()); | |
} | |
void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) { | |
out->mutable_list()->Clear(); // Create list() even if value empty. | |
for (const auto& v : value) { | |
out->mutable_list()->add_s(v.data(), v.size()); | |
} | |
} | |
void SetAttrValue(const TensorShape& value, AttrValue* out) { | |
value.AsProto(out->mutable_shape()); | |
} | |
void SetAttrValue(const TensorShapeProto& value, AttrValue* out) { | |
*out->mutable_shape() = value; | |
} | |
void SetAttrValue(const PartialTensorShape& value, AttrValue* out) { | |
value.AsProto(out->mutable_shape()); | |
} | |
void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) { | |
out->mutable_list()->Clear(); // Create list() even if value empty. | |
for (const auto& v : value) { | |
v.AsProto(out->mutable_list()->add_shape()); | |
} | |
} | |
void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) { | |
out->mutable_list()->Clear(); // Create list() even if value empty. | |
for (const auto& v : value) { | |
*out->mutable_list()->add_shape() = v; | |
} | |
} | |
void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value, | |
AttrValue* out) { | |
out->mutable_list()->Clear(); // Create list() even if value empty. | |
for (const auto& v : value) { | |
v.AsProto(out->mutable_list()->add_shape()); | |
} | |
} | |
void SetAttrValue(const Tensor& value, AttrValue* out) { | |
if (value.NumElements() > 1) { | |
value.AsProtoTensorContent(out->mutable_tensor()); | |
} else { | |
value.AsProtoField(out->mutable_tensor()); | |
} | |
} | |
void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) { | |
out->mutable_list()->Clear(); // Create list() even if value empty. | |
for (const auto& v : value) { | |
if (v.NumElements() > 1) { | |
v.AsProtoTensorContent(out->mutable_list()->add_tensor()); | |
} else { | |
v.AsProtoField(out->mutable_list()->add_tensor()); | |
} | |
} | |
} | |
void SetAttrValue(const TensorProto& value, AttrValue* out) { | |
*out->mutable_tensor() = value; | |
} | |
void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) { | |
out->mutable_list()->Clear(); // Create list() even if value empty. | |
for (const auto& v : value) { | |
*out->mutable_list()->add_tensor() = v; | |
} | |
} | |
void SetAttrValue(const NameAttrList& value, AttrValue* out) { | |
*out->mutable_func() = value; | |
} | |
void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) { | |
out->mutable_list()->Clear(); // Create list() even if value empty. | |
for (const auto& v : value) { | |
*out->mutable_list()->add_func() = v; | |
} | |
} | |
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) { | |
// There are multiple equivalent representations of attr values containing | |
// TensorProtos. Compare them by constructing Tensors and serializing them | |
// back. Comparing Tensor objects is pretty tricky. | |
if (a.has_tensor() != b.has_tensor()) { | |
return false; | |
} else if (a.has_tensor() && b.has_tensor()) { | |
Tensor at(a.tensor().dtype()); | |
bool success = at.FromProto(a.tensor()); | |
DCHECK(success); | |
Tensor bt(b.tensor().dtype()); | |
success = bt.FromProto(b.tensor()); | |
DCHECK(success); | |
TensorProto ap; | |
at.AsProtoTensorContent(&ap); | |
TensorProto bp; | |
bt.AsProtoTensorContent(&bp); | |
string a_str, b_str; | |
SerializeToStringDeterministic(ap, &a_str); | |
SerializeToStringDeterministic(bp, &b_str); | |
return a_str == b_str; | |
} | |
// `func` field contains a nested AttrValue. Compare such AttrValues | |
// recursively. | |
if (a.has_func() != b.has_func()) { | |
return false; | |
} else if (a.has_func() && b.has_func()) { | |
const NameAttrList& af = a.func(); | |
const NameAttrList& bf = b.func(); | |
if (af.name() != bf.name()) return false; | |
std::unordered_map<string, AttrValue> am(af.attr().begin(), | |
af.attr().end()); | |
for (const auto& bm_pair : bf.attr()) { | |
const auto& iter = am.find(bm_pair.first); | |
if (iter == am.end()) return false; | |
if (!AreAttrValuesEqual(iter->second, bm_pair.second)) return false; | |
am.erase(iter); | |
} | |
if (!am.empty()) return false; | |
return true; | |
} | |
// All other fields in AttrValue have deterministic representations. | |
// It is safe to compare their serialized strings. | |
string a_str, b_str; | |
SerializeToStringDeterministic(a, &a_str); | |
SerializeToStringDeterministic(b, &b_str); | |
return a_str == b_str; | |
} | |
uint64 AttrValueHash(const AttrValue& a) { | |
if (a.has_tensor()) { | |
// Deal with multiple representations by parsing TensorProto to | |
// Tensor and serializing it back. This is slow, but current use case | |
// don't need high efficiency. | |
Tensor tensor(a.tensor().dtype()); | |
bool success = tensor.FromProto(a.tensor()); | |
DCHECK(success); | |
TensorProto p; | |
tensor.AsProtoTensorContent(&p); | |
string s; | |
SerializeToStringDeterministic(p, &s); | |
return Hash64(s); | |
} | |
if (a.has_func()) { | |
const NameAttrList& func = a.func(); | |
uint64 h = Hash64(func.name()); | |
std::map<string, AttrValue> map(func.attr().begin(), func.attr().end()); | |
for (const auto& pair : map) { | |
h = Hash64(pair.first.data(), pair.first.size(), h); | |
h = Hash64Combine(AttrValueHash(pair.second), h); | |
} | |
return h; | |
} | |
// If `a` is not a tensor or func, get a hash of serialized string. | |
string s; | |
SerializeToStringDeterministic(a, &s); | |
return Hash64(s); | |
} | |
bool HasPlaceHolder(const AttrValue& val) { | |
switch (val.value_case()) { | |
case AttrValue::kList: { | |
for (const NameAttrList& func : val.list().func()) { | |
for (const auto& p : func.attr()) { | |
if (HasPlaceHolder(p.second)) { | |
return true; | |
} | |
} | |
} | |
break; | |
} | |
case AttrValue::kFunc: | |
for (const auto& p : val.func().attr()) { | |
if (HasPlaceHolder(p.second)) { | |
return true; | |
} | |
} | |
break; | |
case AttrValue::kPlaceholder: | |
return true; | |
default: | |
break; | |
} | |
return false; | |
} | |
bool SubstitutePlaceholders(const SubstituteFunc& substitute, | |
AttrValue* value) { | |
switch (value->value_case()) { | |
case AttrValue::kList: { | |
for (NameAttrList& func : *value->mutable_list()->mutable_func()) { | |
for (auto& p : *func.mutable_attr()) { | |
if (!SubstitutePlaceholders(substitute, &p.second)) { | |
return false; | |
} | |
} | |
} | |
break; | |
} | |
case AttrValue::kFunc: | |
for (auto& p : *(value->mutable_func()->mutable_attr())) { | |
if (!SubstitutePlaceholders(substitute, &p.second)) { | |
return false; | |
} | |
} | |
break; | |
case AttrValue::kPlaceholder: | |
return substitute(value->placeholder(), value); | |
case AttrValue::VALUE_NOT_SET: | |
return false; | |
default: | |
break; | |
} | |
return true; | |
} | |
} // namespace tensorflow | |