#!/usr/bin/env python # Copyright 2023 The Chromium Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. """Code generator for proto descriptors used for on-device model execution. This script generates a C++ source file containing the proto descriptors. """ from __future__ import annotations import dataclasses import functools from io import StringIO import optparse import os import collections import re import sys _HERE_PATH = os.path.dirname(__file__) _SRC_PATH = os.path.normpath(os.path.join(_HERE_PATH, '..', '..', '..')) sys.path.insert(0, os.path.join(_SRC_PATH, 'third_party', 'protobuf', 'python')) from google.protobuf import descriptor_pb2 class Error(Exception): pass class Type: """Aliases for FieldDescriptorProto::Type(s).""" DOUBLE = 1 FLOAT = 2 INT64 = 3 UINT64 = 4 INT32 = 5 FIXED64 = 6 FIXED32 = 7 BOOL = 8 STRING = 9 GROUP = 10 MESSAGE = 11 BYTES = 12 UINT32 = 13 ENUM = 14 SFIXED32 = 15 SFIXED64 = 16 SINT32 = 17 SINT64 = 18 @dataclasses.dataclass(frozen=True) class BaseValueType: cpptype: str getIfFn: str class VType: """Base::Value types.""" DOUBLE = BaseValueType("std::optional", "Double") BOOL = BaseValueType("std::optional", "Bool") INT = BaseValueType("std::optional", "Int") STRING = BaseValueType("std::string*", "String") BLOB = BaseValueType("BlobStorage*", "Blob") DICT = BaseValueType("Dict*", "Dict") LIST = BaseValueType("List*", "List") BASE_VALUE_TYPES = { Type.DOUBLE: VType.DOUBLE, Type.FLOAT: VType.DOUBLE, Type.INT64: VType.INT, Type.UINT64: VType.INT, Type.INT32: VType.INT, Type.FIXED64: VType.INT, Type.FIXED32: VType.INT, Type.BOOL: VType.BOOL, Type.STRING: VType.STRING, Type.GROUP: VType.STRING, # Not handled Type.MESSAGE: VType.DICT, # Not handled Type.BYTES: VType.STRING, # Not handled Type.UINT32: VType.INT, Type.ENUM: VType.INT, # Not handled Type.SFIXED32: VType.INT, Type.SFIXED64: VType.INT, Type.SINT32: VType.INT, Type.SINT64: VType.INT, } @dataclasses.dataclass(frozen=True) class Message: desc: descriptor_pb2.DescriptorProto package: str parent_names: tuple[str, ...] = () @functools.cached_property def type_name(self) -> str: """Returns the value returned for MessageLite::GetTypeName().""" return '.'.join((self.package, *self.parent_names, self.desc.name)) @functools.cached_property def cpp_name(self) -> str: """Returns the fully qualified c++ type name.""" namespace = self.package.replace('.', '::') classname = '_'.join((*self.parent_names, self.desc.name)) return f'{namespace}::{classname}' @functools.cached_property def iname(self) -> str: """Returns the identifier piece for generated function names.""" return '_' + self.type_name.replace('.', '_') @functools.cached_property def fields(self): return tuple(Field(fdesc) for fdesc in self.desc.field) @dataclasses.dataclass(frozen=True) class Field: desc: descriptor_pb2.FieldDescriptorProto @property def tag_number(self): return self.desc.number @property def name(self): return self.desc.name @property def type(self): return self.desc.type @property def is_repeated(self): return self.desc.label == 3 @property def typename(self): return self.desc.type_name.replace('.', '_') @dataclasses.dataclass() class KnownMessages: _known: dict[str, Message] = dataclasses.field(default_factory=dict) def _AddMessage(self, msg: Message) -> None: self._known['.' + msg.type_name] = msg for nested_type in msg.desc.nested_type: self._AddMessage( Message(desc=nested_type, package=msg.package, parent_names=(*msg.parent_names, msg.desc.name))) def AddFileDescriptorSet(self, fds: descriptor_pb2.FileDescriptorSet) -> None: for f in fds.file: for m in f.message_type: self._AddMessage(Message(desc=m, package=f.package)) def GetMessages(self, message_types: set[str]) -> list[Message]: return [self._known[t] for t in sorted(message_types)] def GetAllTransitiveDeps(self, message_types: set[str]) -> list[Message]: seen = message_types stack = list(message_types) while stack: msg = self._known[stack.pop()] field_types = { field.desc.type_name for field in msg.fields if field.type == Type.MESSAGE } stack.extend(field_types - seen) seen.update(field_types) return self.GetMessages(seen) def GenerateProtoDescriptors(out, includes: set[str], messages: KnownMessages, requests: set[str], responses: set[str]): """Generate the on_device_model_execution_proto_descriptors.cc content.""" readable_messages = messages.GetAllTransitiveDeps(requests | responses) writable_messages = messages.GetAllTransitiveDeps(responses) out.write( '// DO NOT MODIFY. GENERATED BY gen_on_device_proto_descriptors.py\n') out.write('\n') out.write( '#include "components/optimization_guide/core/model_execution/on_device_model_execution_proto_descriptors.h"\n' # pylint: disable=line-too-long '#include "components/optimization_guide/core/optimization_guide_util.h"\n' # pylint: disable=line-too-long ) out.write('\n') includes.add('"base/values.h"') for include in sorted(includes): out.write(f'#include {include}\n') out.write('\n') out.write('namespace optimization_guide {\n') out.write('\n') out.write('namespace {\n') _GetProtoValue.GenPrivate(out, readable_messages) _GetProtoRepeated.GenPrivate(out, readable_messages) _SetProtoValue.GenPrivate(out, writable_messages) _ConvertValue.GenPrivate(out, writable_messages) out.write('} // namespace\n\n') _GetProtoValue.GenPublic(out) _GetProtoRepeated.GenPublic(out) _GetProtoFromAny.GenPublic(out, readable_messages) _SetProtoValue.GenPublic(out) _NestedMessageIteratorGet.GenPublic(out, readable_messages) _ConvertValue.GenPublic(out, writable_messages) out.write("""\ NestedMessageIterator::NestedMessageIterator( const google::protobuf::MessageLite* parent, int32_t tag_number, int32_t field_size, int32_t offset) : parent_(parent), tag_number_(tag_number), field_size_(field_size), offset_(offset) {} """) out.write('} // namespace optimization_guide\n') out.write('\n') class _GetProtoValue: """Namespace class for GetProtoValue method builders.""" @classmethod def GenPublic(cls, out): out.write(""" std::optional GetProtoValue( const google::protobuf::MessageLite& msg, const proto::ProtoField& proto_field) { return GetProtoValue(msg, proto_field, /*index=*/0); } """) @classmethod def GenPrivate(cls, out, messages: list[Message]): out.write(""" std::optional GetProtoValue( const google::protobuf::MessageLite& msg, const proto::ProtoField& proto_field, int32_t index) { if (index >= proto_field.proto_descriptors_size()) { return std::nullopt; } int32_t tag_number = proto_field.proto_descriptors(index).tag_number(); """) for msg in messages: cls._IfMsg(out, msg) out.write('return std::nullopt;\n') out.write('}\n\n') # End function @classmethod def _IfMsg(cls, out, msg: Message): if all(field.is_repeated for field in msg.fields): # Omit the empty case to avoid unused variable warnings. return out.write(f'if (msg.GetTypeName() == "{msg.type_name}") {{\n') out.write(f'const {msg.cpp_name}& casted_msg = ') out.write(f' static_cast(msg);\n') out.write('switch (tag_number) {\n') for field in msg.fields: if field.is_repeated: continue cls._FieldCase(out, field) out.write('}\n') # End switch out.write('}\n\n') # End if statement @classmethod def _FieldCase(cls, out, field: Field): out.write(f'case {field.tag_number}: {{\n') name = f'casted_msg.{field.name}()' if field.type == Type.MESSAGE: out.write(f'return GetProtoValue({name}, proto_field, index+1);\n') else: out.write('proto::Value value;\n') if field.type in {Type.DOUBLE, Type.FLOAT}: out.write( f'value.set_float_value(static_cast({name}));\n') elif field.type in {Type.INT64, Type.UINT64}: out.write( f'value.set_int64_value(static_cast({name}));\n') elif field.type in {Type.INT32, Type.UINT32, Type.ENUM}: out.write( f'value.set_int32_value(static_cast({name}));\n') elif field.type in {Type.BOOL}: out.write(f'value.set_boolean_value({name});\n') elif field.type in {Type.STRING}: out.write(f'value.set_string_value({name});\n') else: raise Error() out.write('return value;\n') out.write('}\n') # End case class _GetProtoFromAny: """Namespace class for GetProtoFromAny method builders.""" @classmethod def GenPublic(cls, out, messages: list[Message]): out.write(""" std::unique_ptr GetProtoFromAny( const proto::Any& msg) { """) for msg in messages: cls._IfMsg(out, msg) out.write('return nullptr;\n') out.write('}\n\n') # End function @classmethod def _IfMsg(cls, out, msg: Message): out.write(f"""if (msg.type_url() == "type.googleapis.com/{msg.type_name}") {{ """) out.write( f'auto casted_msg = ParsedAnyMetadata<{msg.cpp_name}>(msg);\n') out.write(""" std::unique_ptr copy( casted_msg->New());\n """) out.write('copy->CheckTypeAndMergeFrom(*casted_msg);\n') out.write('return copy;\n') out.write('}\n\n') # End if statement class _NestedMessageIteratorGet: """Namespace class for NestedMessageIterator::Get method builders.""" @classmethod def GenPublic(cls, out, messages: list[Message]): out.write('const google::protobuf::MessageLite* ' 'NestedMessageIterator::Get() const {\n') for msg in messages: cls._IfMsg(out, msg) out.write(' NOTREACHED_IN_MIGRATION();\n') out.write(' return nullptr;\n') out.write('}\n') @classmethod def _IfMsg(cls, out, msg: Message): out.write(f'if (parent_->GetTypeName() == "{msg.type_name}") {{\n') out.write('switch (tag_number_) {\n') for field in msg.fields: if field.type == Type.MESSAGE and field.is_repeated: cls._FieldCase(out, msg, field) out.write('}\n') # End switch out.write('}\n\n') # End if statement @classmethod def _FieldCase(cls, out, msg: Message, field: Field): cast_msg = f'static_cast(parent_)' out.write(f'case {field.tag_number}: {{\n') out.write(f'return &{cast_msg}->{field.name}(offset_);\n') out.write('}\n') # End case class _GetProtoRepeated: """Namespace class for GetProtoRepeated method builders.""" @classmethod def GenPublic(cls, out): out.write(""" std::optional GetProtoRepeated( const google::protobuf::MessageLite* msg, const proto::ProtoField& proto_field) { return GetProtoRepeated(msg, proto_field, /*index=*/0); } """) @classmethod def GenPrivate(cls, out, messages: list[Message]): out.write("""\ std::optional GetProtoRepeated( const google::protobuf::MessageLite* msg, const proto::ProtoField& proto_field, int32_t index) { if (index >= proto_field.proto_descriptors_size()) { return std::nullopt; } int32_t tag_number = proto_field.proto_descriptors(index).tag_number(); """) for msg in messages: cls._IfMsg(out, msg) out.write('return std::nullopt;\n') out.write('}\n\n') # End function @classmethod def _IfMsg(cls, out, msg: Message): out.write(f'if (msg->GetTypeName() == "{msg.type_name}") {{\n') out.write('switch (tag_number) {\n') for field in msg.fields: if field.type == Type.MESSAGE: cls._FieldCase(out, msg, field) out.write('}\n') # End switch out.write('}\n\n') # End if statement @classmethod def _FieldCase(cls, out, msg: Message, field: Field): field_expr = f'static_cast(msg)->{field.name}()' out.write(f'case {field.tag_number}: {{\n') if field.is_repeated: out.write(f'return NestedMessageIterator(' f'msg, tag_number, {field_expr}.size(), 0);\n') else: out.write(f'return GetProtoRepeated(' f'&{field_expr}, proto_field, index+1);\n') out.write('}\n') # End case class _SetProtoValue: """Namespace class for SetProtoValue method builders.""" @classmethod def GenPublic(cls, out): out.write(""" std::optional SetProtoValue( const std::string& proto_name, const proto::ProtoField& proto_field, const std::string& value) { return SetProtoValue(proto_name, proto_field, value, /*index=*/0); } """) @classmethod def GenPrivate(cls, out, messages: list[Message]): out.write(""" std::optional SetProtoValue( const std::string& proto_name, const proto::ProtoField& proto_field, const std::string& value, int32_t index) { if (index >= proto_field.proto_descriptors_size()) { return std::nullopt; } """) for msg in messages: cls._IfMsg(out, msg) out.write(""" return std::nullopt; } """) @classmethod def _IfMsg(cls, out, msg: Message): out.write(f'if (proto_name == "{msg.type_name}") {{\n') out.write( 'switch(proto_field.proto_descriptors(index).tag_number()) {\n') for field in msg.fields: cls._FieldCase(out, msg, field) out.write(""" default: return std::nullopt;\n """) out.write('}') out.write('}\n') # End if statement @classmethod def _FieldCase(cls, out, msg: Message, field: Field): if field.type == Type.STRING and not field.is_repeated: out.write(f'case {field.tag_number}: {{\n') out.write('proto::Any any;\n') out.write( f'any.set_type_url("type.googleapis.com/{msg.type_name}");\n') out.write(f'{msg.cpp_name} response_value;\n') out.write(f'response_value.set_{field.name}(value);') out.write('response_value.SerializeToString(any.mutable_value());') out.write('return any;') out.write('}\n') class _ConvertValue: """Namespace class for base::Value->Message method builders.""" @classmethod def GenPublic(cls, out, messages: list[Message]): out.write(f""" std::optional ConvertToAnyWrappedProto( const base::Value& object, const std::string& type_name) {{ proto::Any any; any.set_type_url("type.googleapis.com/" + type_name); """) for msg in messages: out.write(f""" if (type_name == "{msg.type_name}") {{ {msg.cpp_name} msg; if (Convert{msg.iname}(object, msg)) {{ msg.SerializeToString(any.mutable_value()); return any; }} }} """) out.write(f""" return std::nullopt; }} """) @classmethod def GenPrivate(cls, out, messages: list[Message]): for msg in messages: out.write(f""" bool Convert{msg.iname}( const base::Value& object, {msg.cpp_name}& proto); """) for msg in messages: cls._DefineConvert(out, msg) @classmethod def _DefineConvert(cls, out, msg: Message): out.write(f""" bool Convert{msg.iname}( const base::Value& object, {msg.cpp_name}& proto) {{ const base::Value::Dict* asdict = object.GetIfDict(); if (!asdict) {{ return false; }} """) for field in msg.fields: if field.type == Type.GROUP: continue if field.type == Type.ENUM: continue out.write('if (const base::Value* field_value =\n') out.write(f' asdict->Find("{field.desc.json_name}")) {{') cls._FieldCase(out, msg, field) out.write(f'}}') out.write(f""" return true; }} """) @classmethod def _FieldCase(cls, out, msg: Message, field: Field): if field.is_repeated: out.write(f""" const auto* lst = field_value->GetIfList(); if (!lst) {{ return false; }} for (const base::Value& entry_value : *lst) {{ """) if field.type == Type.MESSAGE: out.write(f""" if (!Convert{field.typename}( entry_value, *proto.add_{field.name}())) {{ return false; }} """) else: vtype = BASE_VALUE_TYPES[field.type] out.write(f""" const {vtype.cpptype} v = entry_value.GetIf{vtype.getIfFn}(); if (!v) {{ return false; }} proto.add_{field.name}(*v); """) out.write("}") # end for loop else: if field.type == Type.MESSAGE: out.write(f""" if (!Convert{field.typename}( *field_value, *proto.mutable_{field.name}())) {{ return false; }} """) return else: vtype = BASE_VALUE_TYPES[field.type] out.write(f""" const {vtype.cpptype} v = field_value->GetIf{vtype.getIfFn}(); if (!v) {{ return false; }} proto.set_{field.name}(*v); """) def main(argv): parser = optparse.OptionParser() parser.add_option('--input_file', action='append', default=[]) parser.add_option('--output_cc') parser.add_option('--include', action='append', default=[]) parser.add_option('--request', action='append', default=[]) parser.add_option('--response', action='append', default=[]) options, _ = parser.parse_args(argv) input_files = list(options.input_file) includes = set(options.include) requests = set(options.request) responses = set(options.response) # Write to standard output or file specified by --output_cc. out_cc = getattr(sys.stdout, 'buffer', sys.stdout) if options.output_cc: out_cc = open(options.output_cc, 'wb') messages = KnownMessages() for input_file in input_files: fds = descriptor_pb2.FileDescriptorSet() with open(input_file, 'rb') as fp: fds.ParseFromString(fp.read()) messages.AddFileDescriptorSet(fds) out_cc_str = StringIO() GenerateProtoDescriptors(out_cc_str, includes, messages, requests, responses) out_cc.write(out_cc_str.getvalue().encode('utf-8')) if options.output_cc: out_cc.close() return 0 if __name__ == '__main__': sys.exit(main(sys.argv[1:]))