jianuo's picture
first
09321b6
raw
history blame
No virus
6.34 kB
import os
import time
import json
import pandas as pd
import requests
from ..tools.tool import Tool, ToolSchema
from pydantic import ValidationError
from requests.exceptions import RequestException, Timeout
MAX_RETRY_TIMES = 3
class WordArtTexture(Tool):
description = '生成艺术字纹理图片'
name = 'wordart_texture_generation'
parameters: list = [{
'name': 'input.text.text_content',
'description': 'text that the user wants to convert to WordArt',
'required': True
}, {
'name': 'input.prompt',
'description':
'Users’ style requirements for word art may be requirements in terms of shape, color, entity, etc.',
'required': True
}, {
'name': 'input.texture_style',
'description':
'Type of texture style;Default is "material";If not provided by the user, \
defaults to "material".Another value is scene.',
'required': True
}, {
'name': 'input.text.output_image_ratio',
'description':
'The aspect ratio of the text input image; the default is "1:1", \
the available ratios are: "1:1", "16:9", "9:16";',
'required': True
}]
def __init__(self, cfg={}):
self.cfg = cfg.get(self.name, {})
# remote call
self.url = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/wordart/texture'
self.token = self.cfg.get('token',
os.environ.get('DASHSCOPE_API_KEY', ''))
assert self.token != '', 'dashscope api token must be acquired with wordart'
try:
all_param = {
'name': self.name,
'description': self.description,
'parameters': self.parameters
}
self.tool_schema = ToolSchema(**all_param)
except ValidationError:
raise ValueError(f'Error when parsing parameters of {self.name}')
self._str = self.tool_schema.model_dump_json()
self._function = self.parse_pydantic_model_to_openai_function(
all_param)
def __call__(self, *args, **kwargs):
remote_parsed_input = json.dumps(
self._remote_parse_input(*args, **kwargs))
origin_result = None
retry_times = MAX_RETRY_TIMES
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.token}',
'X-DashScope-Async': 'enable'
}
while retry_times:
retry_times -= 1
try:
response = requests.request(
'POST',
url=self.url,
headers=headers,
data=remote_parsed_input)
if response.status_code != requests.codes.ok:
response.raise_for_status()
origin_result = json.loads(response.content.decode('utf-8'))
self.final_result = self._parse_output(
origin_result, remote=True)
return self.get_wordart_result()
except Timeout:
continue
except RequestException as e:
raise ValueError(
f'Remote call failed with error code: {e.response.status_code},\
error message: {e.response.content.decode("utf-8")}')
raise ValueError(
'Remote call max retry times exceeded! Please try to use local call.'
)
def _remote_parse_input(self, *args, **kwargs):
restored_dict = {}
for key, value in kwargs.items():
if '.' in key:
# Split keys by "." and create nested dictionary structures
keys = key.split('.')
temp_dict = restored_dict
for k in keys[:-1]:
temp_dict = temp_dict.setdefault(k, {})
temp_dict[keys[-1]] = value
else:
# f the key does not contain ".", directly store the key-value pair into restored_dict
restored_dict[key] = value
kwargs = restored_dict
kwargs['model'] = 'wordart-texture'
print('传给tool的参数:', kwargs)
return kwargs
def get_result(self):
result_data = json.loads(json.dumps(self.final_result['result']))
if 'task_id' in result_data['output']:
task_id = result_data['output']['task_id']
get_url = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
get_header = {'Authorization': f'Bearer {self.token}'}
origin_result = None
retry_times = MAX_RETRY_TIMES
while retry_times:
retry_times -= 1
try:
response = requests.request(
'GET', url=get_url, headers=get_header)
if response.status_code != requests.codes.ok:
response.raise_for_status()
origin_result = json.loads(response.content.decode('utf-8'))
get_result = self._parse_output(origin_result, remote=True)
return get_result
except Timeout:
continue
except RequestException as e:
raise ValueError(
f'Remote call failed with error code: {e.response.status_code},\
error message: {e.response.content.decode("utf-8")}')
raise ValueError(
'Remote call max retry times exceeded! Please try to use local call.'
)
def get_wordart_result(self):
try:
result = self.get_result()
print(result)
while True:
result_data = result.get('result', {})
output = result_data.get('output', {})
task_status = output.get('task_status', '')
if task_status == 'SUCCEEDED':
print('任务已完成')
return result
elif task_status == 'FAILED':
raise ('任务失败')
else:
# 继续轮询,等待一段时间后再次调用
time.sleep(1) # 等待 1 秒钟
result = self.get_result()
except Exception as e:
print('get Remote Error:', str(e))