File size: 3,275 Bytes
09321b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os

import json
import requests
from ..agent_types import AgentType

from .base import LLM
from .utils import DEFAULT_MESSAGE


class CustomLLM(LLM):
    '''
        This method is for the service that provide llm serving through http.
        user could override the result parsing method if needed
        While put all the necessary information in the env variable, such as Token, Model, URL
    '''
    name = 'custom_llm'

    def __init__(self, cfg):
        super().__init__(cfg)
        self.token = os.getenv('HTTP_LLM_TOKEN', None)
        self.model = os.getenv('HTTP_LLM_MODEL', None)
        self.model_id = self.model
        self.url = os.getenv('HTTP_LLM_URL', None)

        if self.token is None:
            raise ValueError('HTTP_LLM_TOKEN is not set')
        self.agent_type = self.cfg.get('agent_type', AgentType.DEFAULT)

    def http_request(self, data):
        headers = {
            'Content-Type': 'application/json',
            'Authorization': f'Bearer {self.token}'
        }
        response = requests.post(self.url, json=data, headers=headers)
        return json.loads(response.content)

    def generate(self,
                 llm_artifacts,
                 functions=[],
                 function_call='none',
                 **kwargs):
        if self.agent_type != AgentType.Messages:
            messages = [{'role': 'user', 'content': llm_artifacts}]
        else:
            messages = llm_artifacts if len(
                llm_artifacts) > 0 else DEFAULT_MESSAGE

        data = {'model': self.model, 'messages': messages, 'n': 1}

        assert isinstance(functions, list)
        if len(functions) > 0:
            function_call = 'auto'
            data['functions'] = functions
            data['function_call'] = function_call

        retry_count = 0
        max_retries = 3
        message = {'content': ''}
        while retry_count <= max_retries:

            try:
                response = self.http_request(data)
            except Exception as e:
                retry_count += 1
                if retry_count > max_retries:
                    import traceback
                    traceback.print_exc()
                    print(f'input: {messages}, original error: {str(e)}')
                    raise e

            if response['code'] == 200:
                message = response['data']['response'][0]
                break
            else:
                retry_count += 1
                if retry_count > max_retries:
                    print('maximum retry reached, return default message')

        # truncate content
        content = message['content']

        if self.agent_type == AgentType.MS_AGENT:
            idx = content.find('<|endofthink|>')
            if idx != -1:
                content = content[:idx + len('<|endofthink|>')]
            return content
        elif self.agent_type == AgentType.Messages:
            new_message = {
                'content': content,
                'role': message.get('response_role', 'assistant')
            }
            if 'function_call' in message and message['function_call'] != {}:
                new_message['function_call'] = message.get('function_call')
            return new_message
        else:
            return content