File size: 6,461 Bytes
e679d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import warnings
from typing import Dict, List, Optional, Tuple, Union

from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM


class APITemplateParser:
    """Intermidate prompt template parser, specifically for API models.

    Args:
        meta_template (Dict): The meta template for the model.
    """

    def __init__(self, meta_template: Optional[Dict] = None):
        self.meta_template = meta_template
        # Check meta template
        if meta_template:
            assert isinstance(meta_template, list)
            self.roles: Dict[str, dict] = dict()  # maps role name to config
            for item in meta_template:
                assert isinstance(item, dict)
                assert item['role'] not in self.roles, \
                    'role in meta prompt must be unique!'
                self.roles[item['role']] = item.copy()

    def __call__(self, dialog: List[Union[str, List]]):
        """Parse the intermidate prompt template, and wrap it with meta
        template if applicable. When the meta template is set and the input is
        a list, the return value will be a list containing the full
        conversation history. Each item looks like:

        .. code-block:: python

            {'role': 'user', 'content': '...'}).

        Args:
            dialog (List[str or list]): An intermidate prompt
                template (potentially before being wrapped by meta template).

        Returns:
            List[str or list]: The finalized prompt or a conversation.
        """
        assert isinstance(dialog, (str, list))
        if isinstance(dialog, str):
            return dialog
        if self.meta_template:

            prompt = list()
            # Whether to keep generating the prompt
            generate = True
            for i, item in enumerate(dialog):
                if not generate:
                    break
                if isinstance(item, str):
                    if item.strip():
                        # TODO: logger
                        warnings.warn('Non-empty string in prompt template '
                                      'will be ignored in API models.')
                else:
                    api_prompts = self._prompt2api(item)
                    prompt.append(api_prompts)

            # merge the consecutive prompts assigned to the same role
            new_prompt = list([prompt[0]])
            last_role = prompt[0]['role']
            for item in prompt[1:]:
                if item['role'] == last_role:
                    new_prompt[-1]['content'] += '\n' + item['content']
                else:
                    last_role = item['role']
                    new_prompt.append(item)
            prompt = new_prompt

        else:
            # in case the model does not have any meta template
            prompt = ''
            last_sep = ''
            for item in dialog:
                if isinstance(item, str):
                    if item:
                        prompt += last_sep + item
                elif item.get('content', ''):
                    prompt += last_sep + item.get('content', '')
                last_sep = '\n'
        return prompt

    def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]:
        """Convert the prompts to a API-style prompts, given an updated
        role_dict.

        Args:
            prompts (Union[List, str]): The prompts to be converted.
            role_dict (Dict[str, Dict]): The updated role dict.
            for_gen (bool): If True, the prompts will be converted for
                generation tasks. The conversion stops before the first
                role whose "generate" is set to True.

        Returns:
            Tuple[str, bool]: The converted string, and whether the follow-up
            conversion should be proceeded.
        """
        if isinstance(prompts, str):
            return prompts
        elif isinstance(prompts, dict):
            api_role = self._role2api_role(prompts)
            return api_role

        res = []
        for prompt in prompts:
            if isinstance(prompt, str):
                raise TypeError('Mixing str without explicit role is not '
                                'allowed in API models!')
            else:
                api_role = self._role2api_role(prompt)
                res.append(api_role)
        return res

    def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
        merged_prompt = self.roles[role_prompt['role']]
        if merged_prompt.get('fallback_role'):
            merged_prompt = self.roles[self.roles[
                merged_prompt['fallback_role']]]
        res = role_prompt.copy()
        res['role'] = merged_prompt['api_role']
        res['content'] = merged_prompt.get('begin', '')
        res['content'] += role_prompt.get('content', '')
        res['content'] += merged_prompt.get('end', '')
        return res


class BaseAPILLM(BaseLLM):
    """Base class for API model wrapper.

    Args:
        model_type (str): The type of model.
        retry (int): Number of retires if the API call fails. Defaults to 2.
        meta_template (Dict, optional): The model's meta prompt
            template if needed, in case the requirement of injecting or
            wrapping of any meta instructions.
    """

    is_api: bool = True

    def __init__(self,
                 model_type: str,
                 retry: int = 2,
                 template_parser: 'APITemplateParser' = APITemplateParser,
                 meta_template: Optional[Dict] = None,
                 *,
                 max_new_tokens: int = 512,
                 top_p: float = 0.8,
                 top_k: int = 40,
                 temperature: float = 0.8,
                 repetition_penalty: float = 0.0,
                 stop_words: Union[List[str], str] = None):
        self.model_type = model_type
        self.meta_template = meta_template
        self.retry = retry
        if template_parser:
            self.template_parser = template_parser(meta_template)

        if isinstance(stop_words, str):
            stop_words = [stop_words]
        self.gen_params = dict(
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            stop_words=stop_words,
            skip_special_tokens=False)


class AsyncBaseAPILLM(AsyncLLMMixin, BaseAPILLM):
    pass