File size: 4,822 Bytes
09321b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import random
import traceback
from http import HTTPStatus
from typing import Union

import dashscope
import json
from dashscope import Generation
from ..agent_types import AgentType

from .base import LLM
from .utils import DEFAULT_MESSAGE, CustomOutputWrapper

dashscope.api_key = os.getenv('DASHSCOPE_API_KEY')


class DashScopeLLM(LLM):
    name = 'dashscope_llm'

    def __init__(self, cfg):
        super().__init__(cfg)
        self.model = self.cfg.get('model', 'modelscope-agent-llm-v1')
        self.model_id = self.model
        self.generate_cfg = self.cfg.get('generate_cfg', {})
        self.agent_type = self.cfg.get('agent_type', AgentType.DEFAULT)

    def generate(self,
                 llm_artifacts: Union[str, dict],
                 functions=[],
                 **kwargs):

        # TODO retry and handle message
        try:
            if self.agent_type == AgentType.Messages:
                messages = llm_artifacts if len(
                    llm_artifacts) > 0 else DEFAULT_MESSAGE
                self.generate_cfg['use_raw_prompt'] = False
                response = dashscope.Generation.call(
                    model=self.model,
                    messages=messages,
                    # set the random seed, optional, default to 1234 if not set
                    seed=random.randint(1, 10000),
                    result_format=
                    'message',  # set the result to be "message" format.
                    stream=False,
                    **self.generate_cfg)
                llm_result = CustomOutputWrapper.handle_message_chat_completion(
                    response)
            else:
                response = Generation.call(
                    model=self.model,
                    prompt=llm_artifacts,
                    stream=False,
                    **self.generate_cfg)
                llm_result = CustomOutputWrapper.handle_message_text_completion(
                    response)
            return llm_result
        except Exception as e:
            error = traceback.format_exc()
            error_msg = f'LLM error with input {llm_artifacts} \n dashscope error: {str(e)} with traceback {error}'
            print(error_msg)
            raise RuntimeError(error)

        if self.agent_type == AgentType.MS_AGENT:
            # in the form of text
            idx = llm_result.find('<|endofthink|>')
            if idx != -1:
                llm_result = llm_result[:idx + len('<|endofthink|>')]
            return llm_result
        elif self.agent_type == AgentType.Messages:
            # in the form of message
            return llm_result
        else:
            # in the form of text
            return llm_result

    def stream_generate(self,
                        llm_artifacts: Union[str, dict],
                        functions=[],
                        **kwargs):
        total_response = ''
        try:
            if self.agent_type == AgentType.Messages:
                self.generate_cfg['use_raw_prompt'] = False
                responses = Generation.call(
                    model=self.model,
                    messages=llm_artifacts,
                    stream=True,
                    result_format='message',
                    **self.generate_cfg)
            else:
                responses = Generation.call(
                    model=self.model,
                    prompt=llm_artifacts,
                    stream=True,
                    **self.generate_cfg)
        except Exception as e:
            error = traceback.format_exc()
            error_msg = f'LLM error with input {llm_artifacts} \n dashscope error: {str(e)} with traceback {error}'
            print(error_msg)
            raise RuntimeError(error)

        for response in responses:
            if response.status_code == HTTPStatus.OK:
                if self.agent_type == AgentType.Messages:
                    llm_result = CustomOutputWrapper.handle_message_chat_completion(
                        response)
                    frame_text = llm_result['content'][len(total_response):]
                else:
                    llm_result = CustomOutputWrapper.handle_message_text_completion(
                        response)
                    frame_text = llm_result[len(total_response):]
                yield frame_text

                if self.agent_type == AgentType.Messages:
                    total_response = llm_result['content']
                else:
                    total_response = llm_result
            else:
                err_msg = 'Error Request id: %s, Code: %d, status: %s, message: %s' % (
                    response.request_id, response.status_code, response.code,
                    response.message)
                print(err_msg)
                raise RuntimeError(err_msg)