Spaces:
Running
Running
File size: 3,095 Bytes
e679d69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import re
from typing import Any, Dict, Union
from pydantic import BaseModel, ValidationError
from lagent.prompts.parsers.str_parser import StrParser
class CustomFormatParser(StrParser):
def _extract_fields_with_metadata(
self, model: BaseModel) -> Dict[str, Dict[str, Any]]:
fields_metadata = {}
for field_name, field in model.model_fields.items():
fields_metadata[field_name] = {
'annotation': field.annotation,
'default': field.default
if field.default is not None else '<required>',
'comment': field.description if field.description else ''
}
return fields_metadata
def format_to_string(self, format_model: BaseModel) -> str:
fields = self._extract_fields_with_metadata(format_model)
formatted_str = ''
for field_name, metadata in fields.items():
comment = metadata.get('comment', '')
field_annotation = metadata['annotation'].__name__ if metadata[
'annotation'] is not None else 'Any'
if comment:
formatted_str += f'<!-- {comment} -->\n'
formatted_str += f'<{field_name} type="{field_annotation}">{metadata["default"] if metadata["default"] != "<required>" else ""}</{field_name}>\n'
return formatted_str
def parse_response(self, data: str) -> Union[dict, BaseModel]:
pattern = re.compile(r'(<!--\s*(.*?)\s*-->)?\s*<(\w+)[^>]*>(.*?)</\3>',
re.DOTALL)
matches = pattern.findall(data)
data_dict = {}
for _, comment_text, key, value in matches:
if comment_text:
self.fields[key]['comment'] = comment_text.strip()
data_dict[key] = value
model = self.default_format
if self.unknown_format and not self._is_valid_format(
data_dict, self.default_format):
model = self.unknown_format
return model.model_validate(data_dict)
def _is_valid_format(self, data: Dict, format_model: BaseModel) -> bool:
try:
format_model.model_validate(data)
return True
except ValidationError:
return False
if __name__ == '__main__':
# Example usage
class DefaultFormat(BaseModel):
name: str
age: int
class UnknownFormat(BaseModel):
title: str
year: int
template = """如果了解该问题请按照一下格式回复
```html
{format}
```
否则请回复
```html
{unknown_format}
```
"""
parser = CustomFormatParser(
template, default_format=DefaultFormat, unknown_format=UnknownFormat)
# Example data
response = '''
<!-- User's full name -->
<name type="str">John Doe</name>
<!-- User's age -->
<age type="int">30</age>
'''
result = parser.parse_response(response)
print(result)
|