Spaces:
Running
Running
| import json | |
| from typing import Any, Dict, List, Union, get_args, get_origin | |
| from pydantic import BaseModel, Field | |
| from pydantic_core import PydanticUndefined | |
| from lagent.prompts.parsers.str_parser import StrParser | |
| def get_field_type_name(field_type): | |
| # 获取字段类型的起源类型(对于复合类型,如 List、Dict 等) | |
| origin = get_origin(field_type) | |
| if origin: | |
| # 获取复合类型的所有参数 | |
| args = get_args(field_type) | |
| # 重新构建类型名称,例如 List[str] 或 Optional[int] | |
| args_str = ', '.join([get_field_type_name(arg) for arg in args]) | |
| return f'{origin.__name__}[{args_str}]' | |
| # 如果不是复合类型,直接返回类型的名称 | |
| elif hasattr(field_type, '__name__'): | |
| return field_type.__name__ | |
| else: | |
| return str(field_type) # 处理一些特殊情况,如来自未知库的类型 | |
| # class JSONParser(BaseParser): | |
| class JSONParser(StrParser): | |
| def _extract_fields_with_metadata( | |
| self, model: BaseModel) -> Dict[str, Dict[str, Any]]: | |
| fields_metadata = {} | |
| for field_name, field in model.model_fields.items(): | |
| fields_metadata[field_name] = { | |
| 'annotation': field.annotation, | |
| 'default': field.default | |
| if field.default is not PydanticUndefined else '<required>', | |
| 'comment': field.description if field.description else '' | |
| } | |
| # 类型检查,以支持 BaseModel 的子类 | |
| origin = get_origin(field.annotation) | |
| args = get_args(field.annotation) | |
| if origin is None: | |
| # 不是复合类型,直接检查是否为 BaseModel 的子类 | |
| if isinstance(field.annotation, type) and issubclass( | |
| field.annotation, BaseModel): | |
| fields_metadata[field_name][ | |
| 'fields'] = self._extract_fields_with_metadata( | |
| field.annotation) | |
| else: | |
| # 是复合类型,检查其中是否有 BaseModel 的子类 | |
| for arg in args: | |
| if isinstance(arg, type) and issubclass(arg, BaseModel): | |
| fields_metadata[field_name][ | |
| 'fields'] = self._extract_fields_with_metadata(arg) | |
| break | |
| return fields_metadata | |
| def _format_field(self, | |
| field_name: str, | |
| metadata: Dict[str, Any], | |
| indent: int = 1) -> str: | |
| comment = metadata.get('comment', '') | |
| field_type = get_field_type_name( | |
| metadata['annotation'] | |
| ) if metadata['annotation'] is not None else 'Any' | |
| default_value = metadata['default'] | |
| indent_str = ' ' * indent | |
| formatted_lines = [] | |
| if comment: | |
| formatted_lines.append(f'{indent_str}// {comment}') | |
| if 'fields' in metadata: | |
| formatted_lines.append(f'{indent_str}"{field_name}": {{') | |
| for sub_field_name, sub_metadata in metadata['fields'].items(): | |
| formatted_lines.append( | |
| self._format_field(sub_field_name, sub_metadata, | |
| indent + 1)) | |
| formatted_lines.append(f'{indent_str}}},') | |
| else: | |
| if default_value == '<required>': | |
| formatted_lines.append( | |
| f'{indent_str}"{field_name}": "{field_type}", // required' | |
| ) | |
| else: | |
| formatted_lines.append( | |
| f'{indent_str}"{field_name}": "{field_type}", // default: {default_value}' | |
| ) | |
| return '\n'.join(formatted_lines) | |
| def format_to_string(self, format_model) -> str: | |
| fields = self._extract_fields_with_metadata(format_model) | |
| formatted_lines = [] | |
| for field_name, metadata in fields.items(): | |
| formatted_lines.append(self._format_field(field_name, metadata)) | |
| # Remove the trailing comma from the last line | |
| if formatted_lines and formatted_lines[-1].endswith(','): | |
| formatted_lines[-1] = formatted_lines[-1].rstrip(',') | |
| return '{\n' + '\n'.join(formatted_lines) + '\n}' | |
| def parse_response(self, data: str) -> Union[dict, BaseModel]: | |
| # Remove comments | |
| data_no_comments = '\n'.join( | |
| line for line in data.split('\n') | |
| if not line.strip().startswith('//')) | |
| try: | |
| data_dict = json.loads(data_no_comments) | |
| parsed_data = {} | |
| for field_name, value in self.format_field.items(): | |
| if self._is_valid_format(data_dict, value): | |
| model = value | |
| break | |
| self.fields = self._extract_fields_with_metadata(model) | |
| for field_name, value in data_dict.items(): | |
| if field_name in self.fields: | |
| metadata = self.fields[field_name] | |
| if value in [ | |
| 'str', 'int', 'float', 'bool', 'list', 'dict' | |
| ]: | |
| if metadata['default'] == '<required>': | |
| raise ValueError( | |
| f"Field '{field_name}' is required but not provided" | |
| ) | |
| parsed_data[field_name] = metadata['default'] | |
| else: | |
| parsed_data[field_name] = value | |
| return model.model_validate(parsed_data).dict() | |
| except json.JSONDecodeError: | |
| raise ValueError('Input string is not a valid JSON.') | |
| def _is_valid_format(self, data: dict, format_model: BaseModel) -> bool: | |
| try: | |
| format_model.model_validate(data) | |
| return True | |
| except Exception: | |
| return False | |
| if __name__ == '__main__': | |
| # Example usage | |
| class DefaultFormat(BaseModel): | |
| name: List[str] = Field(description='Name of the person') | |
| age: int = Field(description='Age of the person') | |
| class UnknownFormat(BaseModel): | |
| title: str | |
| year: int | |
| TEMPLATE = """如果了解该问题请按照一下格式回复 | |
| ```json | |
| {format} | |
| ``` | |
| 否则请回复 | |
| ```json | |
| {unknown_format} | |
| ``` | |
| """ | |
| parser = JSONParser( | |
| template=TEMPLATE, | |
| default_format=DefaultFormat, | |
| unknown_format=UnknownFormat, | |
| ) | |
| # Example data | |
| data = ''' | |
| { | |
| "name": ["John Doe"], | |
| "age": 30 | |
| } | |
| ''' | |
| print(parser.format()) | |
| result = parser.parse_response(data) | |
| print(result) | |