Spaces:
Running
Running
File size: 4,822 Bytes
09321b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import random
import traceback
from http import HTTPStatus
from typing import Union
import dashscope
import json
from dashscope import Generation
from ..agent_types import AgentType
from .base import LLM
from .utils import DEFAULT_MESSAGE, CustomOutputWrapper
dashscope.api_key = os.getenv('DASHSCOPE_API_KEY')
class DashScopeLLM(LLM):
name = 'dashscope_llm'
def __init__(self, cfg):
super().__init__(cfg)
self.model = self.cfg.get('model', 'modelscope-agent-llm-v1')
self.model_id = self.model
self.generate_cfg = self.cfg.get('generate_cfg', {})
self.agent_type = self.cfg.get('agent_type', AgentType.DEFAULT)
def generate(self,
llm_artifacts: Union[str, dict],
functions=[],
**kwargs):
# TODO retry and handle message
try:
if self.agent_type == AgentType.Messages:
messages = llm_artifacts if len(
llm_artifacts) > 0 else DEFAULT_MESSAGE
self.generate_cfg['use_raw_prompt'] = False
response = dashscope.Generation.call(
model=self.model,
messages=messages,
# set the random seed, optional, default to 1234 if not set
seed=random.randint(1, 10000),
result_format=
'message', # set the result to be "message" format.
stream=False,
**self.generate_cfg)
llm_result = CustomOutputWrapper.handle_message_chat_completion(
response)
else:
response = Generation.call(
model=self.model,
prompt=llm_artifacts,
stream=False,
**self.generate_cfg)
llm_result = CustomOutputWrapper.handle_message_text_completion(
response)
return llm_result
except Exception as e:
error = traceback.format_exc()
error_msg = f'LLM error with input {llm_artifacts} \n dashscope error: {str(e)} with traceback {error}'
print(error_msg)
raise RuntimeError(error)
if self.agent_type == AgentType.MS_AGENT:
# in the form of text
idx = llm_result.find('<|endofthink|>')
if idx != -1:
llm_result = llm_result[:idx + len('<|endofthink|>')]
return llm_result
elif self.agent_type == AgentType.Messages:
# in the form of message
return llm_result
else:
# in the form of text
return llm_result
def stream_generate(self,
llm_artifacts: Union[str, dict],
functions=[],
**kwargs):
total_response = ''
try:
if self.agent_type == AgentType.Messages:
self.generate_cfg['use_raw_prompt'] = False
responses = Generation.call(
model=self.model,
messages=llm_artifacts,
stream=True,
result_format='message',
**self.generate_cfg)
else:
responses = Generation.call(
model=self.model,
prompt=llm_artifacts,
stream=True,
**self.generate_cfg)
except Exception as e:
error = traceback.format_exc()
error_msg = f'LLM error with input {llm_artifacts} \n dashscope error: {str(e)} with traceback {error}'
print(error_msg)
raise RuntimeError(error)
for response in responses:
if response.status_code == HTTPStatus.OK:
if self.agent_type == AgentType.Messages:
llm_result = CustomOutputWrapper.handle_message_chat_completion(
response)
frame_text = llm_result['content'][len(total_response):]
else:
llm_result = CustomOutputWrapper.handle_message_text_completion(
response)
frame_text = llm_result[len(total_response):]
yield frame_text
if self.agent_type == AgentType.Messages:
total_response = llm_result['content']
else:
total_response = llm_result
else:
err_msg = 'Error Request id: %s, Code: %d, status: %s, message: %s' % (
response.request_id, response.status_code, response.code,
response.message)
print(err_msg)
raise RuntimeError(err_msg)
|