Spaces:
Running
Running
File size: 6,689 Bytes
e679d69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import json
from typing import Any, Dict, List, Union, get_args, get_origin
from pydantic import BaseModel, Field
from pydantic_core import PydanticUndefined
from lagent.prompts.parsers.str_parser import StrParser
def get_field_type_name(field_type):
# 获取字段类型的起源类型(对于复合类型,如 List、Dict 等)
origin = get_origin(field_type)
if origin:
# 获取复合类型的所有参数
args = get_args(field_type)
# 重新构建类型名称,例如 List[str] 或 Optional[int]
args_str = ', '.join([get_field_type_name(arg) for arg in args])
return f'{origin.__name__}[{args_str}]'
# 如果不是复合类型,直接返回类型的名称
elif hasattr(field_type, '__name__'):
return field_type.__name__
else:
return str(field_type) # 处理一些特殊情况,如来自未知库的类型
# class JSONParser(BaseParser):
class JSONParser(StrParser):
def _extract_fields_with_metadata(
self, model: BaseModel) -> Dict[str, Dict[str, Any]]:
fields_metadata = {}
for field_name, field in model.model_fields.items():
fields_metadata[field_name] = {
'annotation': field.annotation,
'default': field.default
if field.default is not PydanticUndefined else '<required>',
'comment': field.description if field.description else ''
}
# 类型检查,以支持 BaseModel 的子类
origin = get_origin(field.annotation)
args = get_args(field.annotation)
if origin is None:
# 不是复合类型,直接检查是否为 BaseModel 的子类
if isinstance(field.annotation, type) and issubclass(
field.annotation, BaseModel):
fields_metadata[field_name][
'fields'] = self._extract_fields_with_metadata(
field.annotation)
else:
# 是复合类型,检查其中是否有 BaseModel 的子类
for arg in args:
if isinstance(arg, type) and issubclass(arg, BaseModel):
fields_metadata[field_name][
'fields'] = self._extract_fields_with_metadata(arg)
break
return fields_metadata
def _format_field(self,
field_name: str,
metadata: Dict[str, Any],
indent: int = 1) -> str:
comment = metadata.get('comment', '')
field_type = get_field_type_name(
metadata['annotation']
) if metadata['annotation'] is not None else 'Any'
default_value = metadata['default']
indent_str = ' ' * indent
formatted_lines = []
if comment:
formatted_lines.append(f'{indent_str}// {comment}')
if 'fields' in metadata:
formatted_lines.append(f'{indent_str}"{field_name}": {{')
for sub_field_name, sub_metadata in metadata['fields'].items():
formatted_lines.append(
self._format_field(sub_field_name, sub_metadata,
indent + 1))
formatted_lines.append(f'{indent_str}}},')
else:
if default_value == '<required>':
formatted_lines.append(
f'{indent_str}"{field_name}": "{field_type}", // required'
)
else:
formatted_lines.append(
f'{indent_str}"{field_name}": "{field_type}", // default: {default_value}'
)
return '\n'.join(formatted_lines)
def format_to_string(self, format_model) -> str:
fields = self._extract_fields_with_metadata(format_model)
formatted_lines = []
for field_name, metadata in fields.items():
formatted_lines.append(self._format_field(field_name, metadata))
# Remove the trailing comma from the last line
if formatted_lines and formatted_lines[-1].endswith(','):
formatted_lines[-1] = formatted_lines[-1].rstrip(',')
return '{\n' + '\n'.join(formatted_lines) + '\n}'
def parse_response(self, data: str) -> Union[dict, BaseModel]:
# Remove comments
data_no_comments = '\n'.join(
line for line in data.split('\n')
if not line.strip().startswith('//'))
try:
data_dict = json.loads(data_no_comments)
parsed_data = {}
for field_name, value in self.format_field.items():
if self._is_valid_format(data_dict, value):
model = value
break
self.fields = self._extract_fields_with_metadata(model)
for field_name, value in data_dict.items():
if field_name in self.fields:
metadata = self.fields[field_name]
if value in [
'str', 'int', 'float', 'bool', 'list', 'dict'
]:
if metadata['default'] == '<required>':
raise ValueError(
f"Field '{field_name}' is required but not provided"
)
parsed_data[field_name] = metadata['default']
else:
parsed_data[field_name] = value
return model.model_validate(parsed_data).dict()
except json.JSONDecodeError:
raise ValueError('Input string is not a valid JSON.')
def _is_valid_format(self, data: dict, format_model: BaseModel) -> bool:
try:
format_model.model_validate(data)
return True
except Exception:
return False
if __name__ == '__main__':
# Example usage
class DefaultFormat(BaseModel):
name: List[str] = Field(description='Name of the person')
age: int = Field(description='Age of the person')
class UnknownFormat(BaseModel):
title: str
year: int
TEMPLATE = """如果了解该问题请按照一下格式回复
```json
{format}
```
否则请回复
```json
{unknown_format}
```
"""
parser = JSONParser(
template=TEMPLATE,
default_format=DefaultFormat,
unknown_format=UnknownFormat,
)
# Example data
data = '''
{
"name": ["John Doe"],
"age": 30
}
'''
print(parser.format())
result = parser.parse_response(data)
print(result)
|