MemeMaster / src /generation.py
承弱
add tracked files
2d9bfd2
raw history blame
No virus
3.37 kB
import json
import os
import time
import gradio as gr
import requests
from src.log import logger
from src.util import download_images
def call_generation(prompt, mask_image_url,lora_path_ratio="0 1.0", image_width=512, image_height=512, BATCH_SIZE=1):
API_KEY = os.getenv("API_KEY_GENERATION")
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {API_KEY}",
"X-DashScope-Async": "enable",
}
data = {
"model": "jinshu-emoji",
"input": {
"prompt": prompt,
"mask_image_url": mask_image_url,
"lora_path_ratio": lora_path_ratio,
"base_model_path": 0,
},
"parameters": {
"n": BATCH_SIZE,
"image_width": image_width,
"image_height": image_height,
"text_position_revise": True,
}
}
url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/anytext/generation'
all_res_ = []
REPEAT = 1
for _ in range(REPEAT):
res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
all_res_.append(res_)
all_image_data = []
for res_ in all_res_:
respose_code = res_.status_code
if 200 == respose_code:
res = json.loads(res_.content.decode())
request_id = res['request_id']
task_id = res['output']['task_id']
logger.info(f"task_id: {task_id}: Create Poster Imitation request success. Params: {data}")
# 异步查询
is_running = True
while is_running:
url_query = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
res_ = requests.post(url_query, headers=headers)
respose_code = res_.status_code
if 200 == respose_code:
res = json.loads(res_.content.decode())
if "SUCCEEDED" == res['output']['task_status']:
logger.info(f"task_id: {task_id}: Generation task query success.")
results = res['output']
img_urls = results['result_url']
logger.info(f"task_id: {task_id}: {res}")
break
elif "FAILED" != res['output']['task_status']:
logger.debug(f"task_id: {task_id}: query result...")
time.sleep(1)
else:
raise gr.Error('Fail to get results from Generation task.')
else:
logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}')
raise gr.Error("Fail to query task result.")
logger.info(f"task_id: {task_id}: download generated images.")
img_data = download_images(img_urls, BATCH_SIZE)
logger.info(f"task_id: {task_id}: Generate done.")
all_image_data += img_data
else:
logger.error(f'Fail to create Generation task: {res_.content}')
raise gr.Error("Fail to create Generation task.")
if len(all_image_data) != REPEAT * BATCH_SIZE:
raise gr.Error("Fail to Generation.")
return all_image_data
if __name__ == "__main__":
call_generation()