jianuo's picture
first
09321b6
import os
import time
import json
import pandas as pd
import requests
from ..tools.tool import Tool, ToolSchema
from pydantic import ValidationError
from requests.exceptions import RequestException, Timeout
MAX_RETRY_TIMES = 3
class WordArtTexture(Tool):
description = '生成艺术字纹理图片'
name = 'wordart_texture_generation'
parameters: list = [{
'name': 'input.text.text_content',
'description': 'text that the user wants to convert to WordArt',
'required': True
}, {
'name': 'input.prompt',
'description':
'Users’ style requirements for word art may be requirements in terms of shape, color, entity, etc.',
'required': True
}, {
'name': 'input.texture_style',
'description':
'Type of texture style;Default is "material";If not provided by the user, \
defaults to "material".Another value is scene.',
'required': True
}, {
'name': 'input.text.output_image_ratio',
'description':
'The aspect ratio of the text input image; the default is "1:1", \
the available ratios are: "1:1", "16:9", "9:16";',
'required': True
}]
def __init__(self, cfg={}):
self.cfg = cfg.get(self.name, {})
# remote call
self.url = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/wordart/texture'
self.token = self.cfg.get('token',
os.environ.get('DASHSCOPE_API_KEY', ''))
assert self.token != '', 'dashscope api token must be acquired with wordart'
try:
all_param = {
'name': self.name,
'description': self.description,
'parameters': self.parameters
}
self.tool_schema = ToolSchema(**all_param)
except ValidationError:
raise ValueError(f'Error when parsing parameters of {self.name}')
self._str = self.tool_schema.model_dump_json()
self._function = self.parse_pydantic_model_to_openai_function(
all_param)
def __call__(self, *args, **kwargs):
remote_parsed_input = json.dumps(
self._remote_parse_input(*args, **kwargs))
origin_result = None
retry_times = MAX_RETRY_TIMES
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.token}',
'X-DashScope-Async': 'enable'
}
while retry_times:
retry_times -= 1
try:
response = requests.request(
'POST',
url=self.url,
headers=headers,
data=remote_parsed_input)
if response.status_code != requests.codes.ok:
response.raise_for_status()
origin_result = json.loads(response.content.decode('utf-8'))
self.final_result = self._parse_output(
origin_result, remote=True)
return self.get_wordart_result()
except Timeout:
continue
except RequestException as e:
raise ValueError(
f'Remote call failed with error code: {e.response.status_code},\
error message: {e.response.content.decode("utf-8")}')
raise ValueError(
'Remote call max retry times exceeded! Please try to use local call.'
)
def _remote_parse_input(self, *args, **kwargs):
restored_dict = {}
for key, value in kwargs.items():
if '.' in key:
# Split keys by "." and create nested dictionary structures
keys = key.split('.')
temp_dict = restored_dict
for k in keys[:-1]:
temp_dict = temp_dict.setdefault(k, {})
temp_dict[keys[-1]] = value
else:
# f the key does not contain ".", directly store the key-value pair into restored_dict
restored_dict[key] = value
kwargs = restored_dict
kwargs['model'] = 'wordart-texture'
print('传给tool的参数:', kwargs)
return kwargs
def get_result(self):
result_data = json.loads(json.dumps(self.final_result['result']))
if 'task_id' in result_data['output']:
task_id = result_data['output']['task_id']
get_url = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
get_header = {'Authorization': f'Bearer {self.token}'}
origin_result = None
retry_times = MAX_RETRY_TIMES
while retry_times:
retry_times -= 1
try:
response = requests.request(
'GET', url=get_url, headers=get_header)
if response.status_code != requests.codes.ok:
response.raise_for_status()
origin_result = json.loads(response.content.decode('utf-8'))
get_result = self._parse_output(origin_result, remote=True)
return get_result
except Timeout:
continue
except RequestException as e:
raise ValueError(
f'Remote call failed with error code: {e.response.status_code},\
error message: {e.response.content.decode("utf-8")}')
raise ValueError(
'Remote call max retry times exceeded! Please try to use local call.'
)
def get_wordart_result(self):
try:
result = self.get_result()
print(result)
while True:
result_data = result.get('result', {})
output = result_data.get('output', {})
task_status = output.get('task_status', '')
if task_status == 'SUCCEEDED':
print('任务已完成')
return result
elif task_status == 'FAILED':
raise ('任务失败')
else:
# 继续轮询,等待一段时间后再次调用
time.sleep(1) # 等待 1 秒钟
result = self.get_result()
except Exception as e:
print('get Remote Error:', str(e))