File size: 3,095 Bytes
e679d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import re
from typing import Any, Dict, Union

from pydantic import BaseModel, ValidationError

from lagent.prompts.parsers.str_parser import StrParser


class CustomFormatParser(StrParser):

    def _extract_fields_with_metadata(
            self, model: BaseModel) -> Dict[str, Dict[str, Any]]:
        fields_metadata = {}
        for field_name, field in model.model_fields.items():
            fields_metadata[field_name] = {
                'annotation': field.annotation,
                'default': field.default
                if field.default is not None else '<required>',
                'comment': field.description if field.description else ''
            }
        return fields_metadata

    def format_to_string(self, format_model: BaseModel) -> str:
        fields = self._extract_fields_with_metadata(format_model)
        formatted_str = ''
        for field_name, metadata in fields.items():
            comment = metadata.get('comment', '')
            field_annotation = metadata['annotation'].__name__ if metadata[
                'annotation'] is not None else 'Any'
            if comment:
                formatted_str += f'<!-- {comment} -->\n'
            formatted_str += f'<{field_name} type="{field_annotation}">{metadata["default"] if metadata["default"] != "<required>" else ""}</{field_name}>\n'
        return formatted_str

    def parse_response(self, data: str) -> Union[dict, BaseModel]:
        pattern = re.compile(r'(<!--\s*(.*?)\s*-->)?\s*<(\w+)[^>]*>(.*?)</\3>',
                             re.DOTALL)
        matches = pattern.findall(data)

        data_dict = {}
        for _, comment_text, key, value in matches:
            if comment_text:
                self.fields[key]['comment'] = comment_text.strip()
            data_dict[key] = value

        model = self.default_format
        if self.unknown_format and not self._is_valid_format(
                data_dict, self.default_format):
            model = self.unknown_format

        return model.model_validate(data_dict)

    def _is_valid_format(self, data: Dict, format_model: BaseModel) -> bool:
        try:
            format_model.model_validate(data)
            return True
        except ValidationError:
            return False


if __name__ == '__main__':
    # Example usage
    class DefaultFormat(BaseModel):
        name: str
        age: int

    class UnknownFormat(BaseModel):
        title: str
        year: int

    template = """如果了解该问题请按照一下格式回复
                    ```html
                    {format}
                    ```
                    否则请回复
                    ```html
                        {unknown_format}
                        ```
                        """
    parser = CustomFormatParser(
        template, default_format=DefaultFormat, unknown_format=UnknownFormat)

    # Example data
    response = '''
    <!-- User's full name -->
    <name type="str">John Doe</name>
    <!-- User's age -->
    <age type="int">30</age>
    '''

    result = parser.parse_response(response)
    print(result)