Spaces:
Running
Running
import json | |
from typing import Any, Dict, List, Union, get_args, get_origin | |
from pydantic import BaseModel, Field | |
from pydantic_core import PydanticUndefined | |
from lagent.prompts.parsers.str_parser import StrParser | |
def get_field_type_name(field_type): | |
# 获取字段类型的起源类型(对于复合类型,如 List、Dict 等) | |
origin = get_origin(field_type) | |
if origin: | |
# 获取复合类型的所有参数 | |
args = get_args(field_type) | |
# 重新构建类型名称,例如 List[str] 或 Optional[int] | |
args_str = ', '.join([get_field_type_name(arg) for arg in args]) | |
return f'{origin.__name__}[{args_str}]' | |
# 如果不是复合类型,直接返回类型的名称 | |
elif hasattr(field_type, '__name__'): | |
return field_type.__name__ | |
else: | |
return str(field_type) # 处理一些特殊情况,如来自未知库的类型 | |
# class JSONParser(BaseParser): | |
class JSONParser(StrParser): | |
def _extract_fields_with_metadata( | |
self, model: BaseModel) -> Dict[str, Dict[str, Any]]: | |
fields_metadata = {} | |
for field_name, field in model.model_fields.items(): | |
fields_metadata[field_name] = { | |
'annotation': field.annotation, | |
'default': field.default | |
if field.default is not PydanticUndefined else '<required>', | |
'comment': field.description if field.description else '' | |
} | |
# 类型检查,以支持 BaseModel 的子类 | |
origin = get_origin(field.annotation) | |
args = get_args(field.annotation) | |
if origin is None: | |
# 不是复合类型,直接检查是否为 BaseModel 的子类 | |
if isinstance(field.annotation, type) and issubclass( | |
field.annotation, BaseModel): | |
fields_metadata[field_name][ | |
'fields'] = self._extract_fields_with_metadata( | |
field.annotation) | |
else: | |
# 是复合类型,检查其中是否有 BaseModel 的子类 | |
for arg in args: | |
if isinstance(arg, type) and issubclass(arg, BaseModel): | |
fields_metadata[field_name][ | |
'fields'] = self._extract_fields_with_metadata(arg) | |
break | |
return fields_metadata | |
def _format_field(self, | |
field_name: str, | |
metadata: Dict[str, Any], | |
indent: int = 1) -> str: | |
comment = metadata.get('comment', '') | |
field_type = get_field_type_name( | |
metadata['annotation'] | |
) if metadata['annotation'] is not None else 'Any' | |
default_value = metadata['default'] | |
indent_str = ' ' * indent | |
formatted_lines = [] | |
if comment: | |
formatted_lines.append(f'{indent_str}// {comment}') | |
if 'fields' in metadata: | |
formatted_lines.append(f'{indent_str}"{field_name}": {{') | |
for sub_field_name, sub_metadata in metadata['fields'].items(): | |
formatted_lines.append( | |
self._format_field(sub_field_name, sub_metadata, | |
indent + 1)) | |
formatted_lines.append(f'{indent_str}}},') | |
else: | |
if default_value == '<required>': | |
formatted_lines.append( | |
f'{indent_str}"{field_name}": "{field_type}", // required' | |
) | |
else: | |
formatted_lines.append( | |
f'{indent_str}"{field_name}": "{field_type}", // default: {default_value}' | |
) | |
return '\n'.join(formatted_lines) | |
def format_to_string(self, format_model) -> str: | |
fields = self._extract_fields_with_metadata(format_model) | |
formatted_lines = [] | |
for field_name, metadata in fields.items(): | |
formatted_lines.append(self._format_field(field_name, metadata)) | |
# Remove the trailing comma from the last line | |
if formatted_lines and formatted_lines[-1].endswith(','): | |
formatted_lines[-1] = formatted_lines[-1].rstrip(',') | |
return '{\n' + '\n'.join(formatted_lines) + '\n}' | |
def parse_response(self, data: str) -> Union[dict, BaseModel]: | |
# Remove comments | |
data_no_comments = '\n'.join( | |
line for line in data.split('\n') | |
if not line.strip().startswith('//')) | |
try: | |
data_dict = json.loads(data_no_comments) | |
parsed_data = {} | |
for field_name, value in self.format_field.items(): | |
if self._is_valid_format(data_dict, value): | |
model = value | |
break | |
self.fields = self._extract_fields_with_metadata(model) | |
for field_name, value in data_dict.items(): | |
if field_name in self.fields: | |
metadata = self.fields[field_name] | |
if value in [ | |
'str', 'int', 'float', 'bool', 'list', 'dict' | |
]: | |
if metadata['default'] == '<required>': | |
raise ValueError( | |
f"Field '{field_name}' is required but not provided" | |
) | |
parsed_data[field_name] = metadata['default'] | |
else: | |
parsed_data[field_name] = value | |
return model.model_validate(parsed_data).dict() | |
except json.JSONDecodeError: | |
raise ValueError('Input string is not a valid JSON.') | |
def _is_valid_format(self, data: dict, format_model: BaseModel) -> bool: | |
try: | |
format_model.model_validate(data) | |
return True | |
except Exception: | |
return False | |
if __name__ == '__main__': | |
# Example usage | |
class DefaultFormat(BaseModel): | |
name: List[str] = Field(description='Name of the person') | |
age: int = Field(description='Age of the person') | |
class UnknownFormat(BaseModel): | |
title: str | |
year: int | |
TEMPLATE = """如果了解该问题请按照一下格式回复 | |
```json | |
{format} | |
``` | |
否则请回复 | |
```json | |
{unknown_format} | |
``` | |
""" | |
parser = JSONParser( | |
template=TEMPLATE, | |
default_format=DefaultFormat, | |
unknown_format=UnknownFormat, | |
) | |
# Example data | |
data = ''' | |
{ | |
"name": ["John Doe"], | |
"age": 30 | |
} | |
''' | |
print(parser.format()) | |
result = parser.parse_response(data) | |
print(result) | |