jianuo's picture
first
09321b6
raw
history blame
No virus
4.82 kB
import os
import random
import traceback
from http import HTTPStatus
from typing import Union
import dashscope
import json
from dashscope import Generation
from ..agent_types import AgentType
from .base import LLM
from .utils import DEFAULT_MESSAGE, CustomOutputWrapper
dashscope.api_key = os.getenv('DASHSCOPE_API_KEY')
class DashScopeLLM(LLM):
name = 'dashscope_llm'
def __init__(self, cfg):
super().__init__(cfg)
self.model = self.cfg.get('model', 'modelscope-agent-llm-v1')
self.model_id = self.model
self.generate_cfg = self.cfg.get('generate_cfg', {})
self.agent_type = self.cfg.get('agent_type', AgentType.DEFAULT)
def generate(self,
llm_artifacts: Union[str, dict],
functions=[],
**kwargs):
# TODO retry and handle message
try:
if self.agent_type == AgentType.Messages:
messages = llm_artifacts if len(
llm_artifacts) > 0 else DEFAULT_MESSAGE
self.generate_cfg['use_raw_prompt'] = False
response = dashscope.Generation.call(
model=self.model,
messages=messages,
# set the random seed, optional, default to 1234 if not set
seed=random.randint(1, 10000),
result_format=
'message', # set the result to be "message" format.
stream=False,
**self.generate_cfg)
llm_result = CustomOutputWrapper.handle_message_chat_completion(
response)
else:
response = Generation.call(
model=self.model,
prompt=llm_artifacts,
stream=False,
**self.generate_cfg)
llm_result = CustomOutputWrapper.handle_message_text_completion(
response)
return llm_result
except Exception as e:
error = traceback.format_exc()
error_msg = f'LLM error with input {llm_artifacts} \n dashscope error: {str(e)} with traceback {error}'
print(error_msg)
raise RuntimeError(error)
if self.agent_type == AgentType.MS_AGENT:
# in the form of text
idx = llm_result.find('<|endofthink|>')
if idx != -1:
llm_result = llm_result[:idx + len('<|endofthink|>')]
return llm_result
elif self.agent_type == AgentType.Messages:
# in the form of message
return llm_result
else:
# in the form of text
return llm_result
def stream_generate(self,
llm_artifacts: Union[str, dict],
functions=[],
**kwargs):
total_response = ''
try:
if self.agent_type == AgentType.Messages:
self.generate_cfg['use_raw_prompt'] = False
responses = Generation.call(
model=self.model,
messages=llm_artifacts,
stream=True,
result_format='message',
**self.generate_cfg)
else:
responses = Generation.call(
model=self.model,
prompt=llm_artifacts,
stream=True,
**self.generate_cfg)
except Exception as e:
error = traceback.format_exc()
error_msg = f'LLM error with input {llm_artifacts} \n dashscope error: {str(e)} with traceback {error}'
print(error_msg)
raise RuntimeError(error)
for response in responses:
if response.status_code == HTTPStatus.OK:
if self.agent_type == AgentType.Messages:
llm_result = CustomOutputWrapper.handle_message_chat_completion(
response)
frame_text = llm_result['content'][len(total_response):]
else:
llm_result = CustomOutputWrapper.handle_message_text_completion(
response)
frame_text = llm_result[len(total_response):]
yield frame_text
if self.agent_type == AgentType.Messages:
total_response = llm_result['content']
else:
total_response = llm_result
else:
err_msg = 'Error Request id: %s, Code: %d, status: %s, message: %s' % (
response.request_id, response.status_code, response.code,
response.message)
print(err_msg)
raise RuntimeError(err_msg)