metisllm-dashboard / domain /domain_protocol.py
Gateston Johns
first real commit
9041389
raw
history blame
No virus
3.19 kB
from __future__ import annotations
from typing import Optional, Protocol, Tuple, Type, TypeVar, get_args
from google.protobuf import json_format, message
MessageType = TypeVar("MessageType", bound=message.Message)
DomainProtocolType = TypeVar("DomainProtocolType", bound='DomainProtocol')
class ProtoDeserializationError(Exception):
...
class DomainProtocol(Protocol[MessageType]):
@property
def id(self) -> str:
...
@classmethod
def _from_proto(cls: Type[DomainProtocolType], proto: MessageType) -> DomainProtocolType:
...
def to_proto(self) -> MessageType:
...
@classmethod
def message_cls(cls: Type[DomainProtocolType]) -> Type[MessageType]:
orig_bases: Optional[Tuple[Type[MessageType], ...]] = getattr(cls, "__orig_bases__", None)
if not orig_bases:
raise ValueError(f"Class {cls} does not have __orig_bases__")
if len(orig_bases) != 1:
raise ValueError(f"Class {cls} has unexpected number of bases: {orig_bases}")
return get_args(orig_bases[0])[0]
@classmethod
def from_proto(cls: Type[DomainProtocolType],
proto: MessageType,
allow_empty: bool = False) -> DomainProtocolType:
try:
if not allow_empty:
cls.validate_proto_not_empty(proto)
return cls._from_proto(proto)
except Exception as e:
error_str = f"Failed to convert {cls} - {e}"
raise ProtoDeserializationError(error_str) from e
@classmethod
def from_json(cls: Type[DomainProtocolType], json_str: str) -> DomainProtocolType:
try:
proto_cls = cls.message_cls()
proto = proto_cls()
json_format.Parse(json_str, proto)
return cls.from_proto(proto)
except json_format.ParseError as e:
error_str = f"{cls} failed to parse json string: {json_str} - {e}"
raise ProtoDeserializationError(error_str) from e
def to_json(self) -> str:
return json_format.MessageToJson(self.to_proto()).replace("\n", " ")
@classmethod
def validate_proto_not_empty(cls, proto: message.Message):
if cls.is_empty(proto):
raise ValueError("Proto is empty")
@classmethod
def is_empty(cls, proto: message.Message) -> bool:
descriptor = getattr(proto, 'DESCRIPTOR', None)
fields = list(descriptor.fields) if descriptor else []
while fields:
field = fields.pop()
if field.label == field.LABEL_REPEATED:
eval_func = lambda x: x == field.default_value
if field.type == field.TYPE_MESSAGE:
eval_func = cls.is_empty
if not all([eval_func(item) for item in getattr(proto, field.name)]):
return False
elif field.type == field.TYPE_MESSAGE:
if not cls.is_empty(getattr(proto, field.name)):
return False
else:
field_value = getattr(proto, field.name)
if field_value != field.default_value:
return False
return True