ImageSynthesizer / src /generation.py
nightfury's picture
assets-src-app
028bd43
raw
history blame
No virus
3.88 kB
import json
import os
import time
import gradio as gr
import requests
from src.log import logger
from src.util import download_images
def call_generation(base_image_url="", layout_image_url="", color_image_url="", style_image_url="",
strict_edge=0, layout_scale=8, edge_consistency=8, color_scale=8,
style_scale=8, prompt="", negative_prompt="", output_aspect_ratio=1.0):
API_KEY = os.getenv("API_KEY_BG_GENERATION")
if output_aspect_ratio >= 1:
BATCH_SIZE = 2
REPEAT = 2
else:
BATCH_SIZE = 4
REPEAT = 1
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {API_KEY}",
"X-DashScope-Async": "enable",
}
data = {
"model": "wanx-poster-imitation-v1",
"input": {
"base_image_url": base_image_url,
"layout_image_url": layout_image_url,
"color_image_url": color_image_url,
"style_image_url": style_image_url,
"prompt": prompt,
"negative_prompt": negative_prompt,
},
"parameters": {
"strict_layout": strict_edge,
"layout_scale": layout_scale,
"layout_spatial_consistency": edge_consistency,
"color_scale": color_scale,
"style_scale": style_scale,
"n": BATCH_SIZE
}
}
url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/poster-imitation/generation'
all_res_ = []
for _ in range(REPEAT):
res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
all_res_.append(res_)
all_image_data = []
for res_ in all_res_:
respose_code = res_.status_code
if 200 == respose_code:
res = json.loads(res_.content.decode())
request_id = res['request_id']
task_id = res['output']['task_id']
logger.info(f"task_id: {task_id}: Create Poster Imitation request success. Params: {data}")
# 异步查询
is_running = True
while is_running:
url_query = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
res_ = requests.post(url_query, headers=headers)
respose_code = res_.status_code
if 200 == respose_code:
res = json.loads(res_.content.decode())
if "SUCCEEDED" == res['output']['task_status']:
logger.info(f"task_id: {task_id}: Generation task query success.")
results = res['output']['results']
img_urls = [x['url'] for x in results]
logger.info(f"task_id: {task_id}: {res}")
break
elif "FAILED" != res['output']['task_status']:
logger.debug(f"task_id: {task_id}: query result...")
time.sleep(1)
else:
raise gr.Error('Fail to get results from Generation task.')
else:
logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}')
raise gr.Error("Fail to query task result.")
logger.info(f"task_id: {task_id}: download generated images.")
img_data = download_images(img_urls, BATCH_SIZE)
logger.info(f"task_id: {task_id}: Generate done.")
all_image_data += img_data
else:
logger.error(f'Fail to create Generation task: {res_.content}')
raise gr.Error("Fail to create Generation task.")
if len(all_image_data) != REPEAT * BATCH_SIZE:
raise gr.Error("Fail to Generation.")
return all_image_data
if __name__ == "__main__":
call_generation()