Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import base64 | |
import io | |
import json | |
import logging | |
import os | |
import pathlib | |
import tempfile | |
import time | |
from datetime import datetime | |
import requests | |
import tiktoken | |
from PIL import Image | |
from modules.config import retrieve_proxy | |
from modules.models.XMChat import XMChat | |
mj_proxy_api_base = os.getenv("MIDJOURNEY_PROXY_API_BASE") | |
mj_discord_proxy_url = os.getenv("MIDJOURNEY_DISCORD_PROXY_URL") | |
mj_temp_folder = os.getenv("MIDJOURNEY_TEMP_FOLDER") | |
class Midjourney_Client(XMChat): | |
class FetchDataPack: | |
""" | |
A class to store data for current fetching data from Midjourney API | |
""" | |
action: str # current action, e.g. "IMAGINE", "UPSCALE", "VARIATION" | |
prefix_content: str # prefix content, task description and process hint | |
task_id: str # task id | |
start_time: float # task start timestamp | |
timeout: int # task timeout in seconds | |
finished: bool # whether the task is finished | |
prompt: str # prompt for the task | |
def __init__(self, action, prefix_content, task_id, timeout=900): | |
self.action = action | |
self.prefix_content = prefix_content | |
self.task_id = task_id | |
self.start_time = time.time() | |
self.timeout = timeout | |
self.finished = False | |
def __init__(self, model_name, api_key, user_name=""): | |
super().__init__(api_key, user_name) | |
self.model_name = model_name | |
self.history = [] | |
self.api_key = api_key | |
self.headers = { | |
"Content-Type": "application/json", | |
"mj-api-secret": f"{api_key}" | |
} | |
self.proxy_url = mj_proxy_api_base | |
self.command_splitter = "::" | |
if mj_temp_folder: | |
temp = "./tmp" | |
if user_name: | |
temp = os.path.join(temp, user_name) | |
if not os.path.exists(temp): | |
os.makedirs(temp) | |
self.temp_path = tempfile.mkdtemp(dir=temp) | |
logging.info("mj temp folder: " + self.temp_path) | |
else: | |
self.temp_path = None | |
def use_mj_self_proxy_url(self, img_url): | |
""" | |
replace discord cdn url with mj self proxy url | |
""" | |
return img_url.replace( | |
"https://cdn.discordapp.com/", | |
mj_discord_proxy_url and mj_discord_proxy_url or "https://cdn.discordapp.com/" | |
) | |
def split_image(self, image_url): | |
""" | |
when enabling temp dir, split image into 4 parts | |
""" | |
with retrieve_proxy(): | |
image_bytes = requests.get(image_url).content | |
img = Image.open(io.BytesIO(image_bytes)) | |
width, height = img.size | |
# calculate half width and height | |
half_width = width // 2 | |
half_height = height // 2 | |
# create coordinates (top-left x, top-left y, bottom-right x, bottom-right y) | |
coordinates = [(0, 0, half_width, half_height), | |
(half_width, 0, width, half_height), | |
(0, half_height, half_width, height), | |
(half_width, half_height, width, height)] | |
images = [img.crop(c) for c in coordinates] | |
return images | |
def auth_mj(self): | |
""" | |
auth midjourney api | |
""" | |
# TODO: check if secret is valid | |
return {'status': 'ok'} | |
def request_mj(self, path: str, action: str, data: str, retries=3): | |
""" | |
request midjourney api | |
""" | |
mj_proxy_url = self.proxy_url | |
if mj_proxy_url is None or not (mj_proxy_url.startswith("http://") or mj_proxy_url.startswith("https://")): | |
raise Exception('please set MIDJOURNEY_PROXY_API_BASE in ENV or in config.json') | |
auth_ = self.auth_mj() | |
if auth_.get('error'): | |
raise Exception('auth not set') | |
fetch_url = f"{mj_proxy_url}/{path}" | |
# logging.info(f"[MJ Proxy] {action} {fetch_url} params: {data}") | |
for _ in range(retries): | |
try: | |
with retrieve_proxy(): | |
res = requests.request(method=action, url=fetch_url, headers=self.headers, data=data) | |
break | |
except Exception as e: | |
print(e) | |
if res.status_code != 200: | |
raise Exception(f'{res.status_code} - {res.content}') | |
return res | |
def fetch_status(self, fetch_data: FetchDataPack): | |
""" | |
fetch status of current task | |
""" | |
if fetch_data.start_time + fetch_data.timeout < time.time(): | |
fetch_data.finished = True | |
return "任务超时,请检查 dc 输出。描述:" + fetch_data.prompt | |
time.sleep(3) | |
status_res = self.request_mj(f"task/{fetch_data.task_id}/fetch", "GET", '') | |
status_res_json = status_res.json() | |
if not (200 <= status_res.status_code < 300): | |
raise Exception("任务状态获取失败:" + status_res_json.get( | |
'error') or status_res_json.get('description') or '未知错误') | |
else: | |
fetch_data.finished = False | |
if status_res_json['status'] == "SUCCESS": | |
content = status_res_json['imageUrl'] | |
fetch_data.finished = True | |
elif status_res_json['status'] == "FAILED": | |
content = status_res_json['failReason'] or '未知原因' | |
fetch_data.finished = True | |
elif status_res_json['status'] == "NOT_START": | |
content = f'任务未开始,已等待 {time.time() - fetch_data.start_time:.2f} 秒' | |
elif status_res_json['status'] == "IN_PROGRESS": | |
content = '任务正在运行' | |
if status_res_json.get('progress'): | |
content += f",进度:{status_res_json['progress']}" | |
elif status_res_json['status'] == "SUBMITTED": | |
content = '任务已提交处理' | |
elif status_res_json['status'] == "FAILURE": | |
fetch_data.finished = True | |
return "任务处理失败,原因:" + status_res_json['failReason'] or '未知原因' | |
else: | |
content = status_res_json['status'] | |
if fetch_data.finished: | |
img_url = self.use_mj_self_proxy_url(status_res_json['imageUrl']) | |
if fetch_data.action == "DESCRIBE": | |
return f"\n{status_res_json['prompt']}" | |
time_cost_str = f"\n\n{fetch_data.action} 花费时间:{time.time() - fetch_data.start_time:.2f} 秒" | |
upscale_str = "" | |
variation_str = "" | |
if fetch_data.action in ["IMAGINE", "UPSCALE", "VARIATION"]: | |
upscale = [f'/mj UPSCALE{self.command_splitter}{i+1}{self.command_splitter}{fetch_data.task_id}' | |
for i in range(4)] | |
upscale_str = '\n放大图片:\n\n' + '\n\n'.join(upscale) | |
variation = [f'/mj VARIATION{self.command_splitter}{i+1}{self.command_splitter}{fetch_data.task_id}' | |
for i in range(4)] | |
variation_str = '\n图片变体:\n\n' + '\n\n'.join(variation) | |
if self.temp_path and fetch_data.action in ["IMAGINE", "VARIATION"]: | |
try: | |
images = self.split_image(img_url) | |
# save images to temp path | |
for i in range(4): | |
images[i].save(pathlib.Path(self.temp_path) / f"{fetch_data.task_id}_{i}.png") | |
img_str = '\n'.join( | |
[f"![{fetch_data.task_id}](/file={self.temp_path}/{fetch_data.task_id}_{i}.png)" | |
for i in range(4)]) | |
return fetch_data.prefix_content + f"{time_cost_str}\n\n{img_str}{upscale_str}{variation_str}" | |
except Exception as e: | |
logging.error(e) | |
return fetch_data.prefix_content + \ | |
f"{time_cost_str}[![{fetch_data.task_id}]({img_url})]({img_url}){upscale_str}{variation_str}" | |
else: | |
content = f"**任务状态:** [{(datetime.now()).strftime('%Y-%m-%d %H:%M:%S')}] - {content}" | |
content += f"\n\n花费时间:{time.time() - fetch_data.start_time:.2f} 秒" | |
if status_res_json['status'] == 'IN_PROGRESS' and status_res_json.get('imageUrl'): | |
img_url = status_res_json.get('imageUrl') | |
return f"{content}\n[![{fetch_data.task_id}]({img_url})]({img_url})" | |
return content | |
return None | |
def handle_file_upload(self, files, chatbot, language): | |
""" | |
handle file upload | |
""" | |
if files: | |
for file in files: | |
if file.name: | |
logging.info(f"尝试读取图像: {file.name}") | |
self.try_read_image(file.name) | |
if self.image_path is not None: | |
chatbot = chatbot + [((self.image_path,), None)] | |
if self.image_bytes is not None: | |
logging.info("使用图片作为输入") | |
return None, chatbot, None | |
def reset(self, remain_system_prompt=False): | |
self.image_bytes = None | |
self.image_path = None | |
return super().reset() | |
def get_answer_at_once(self): | |
content = self.history[-1]['content'] | |
answer = self.get_help() | |
if not content.lower().startswith("/mj"): | |
return answer, len(content) | |
prompt = content[3:].strip() | |
action = "IMAGINE" | |
first_split_index = prompt.find(self.command_splitter) | |
if first_split_index > 0: | |
action = prompt[:first_split_index] | |
if action not in ["IMAGINE", "DESCRIBE", "UPSCALE", | |
# "VARIATION", "BLEND", "REROLL" | |
]: | |
raise Exception("任务提交失败:未知的任务类型") | |
else: | |
action_index = None | |
action_use_task_id = None | |
if action in ["VARIATION", "UPSCALE", "REROLL"]: | |
action_index = int(prompt[first_split_index + 2:first_split_index + 3]) | |
action_use_task_id = prompt[first_split_index + 5:] | |
try: | |
res = None | |
if action == "IMAGINE": | |
data = { | |
"prompt": prompt | |
} | |
if self.image_bytes is not None: | |
data["base64"] = 'data:image/png;base64,' + self.image_bytes | |
res = self.request_mj("submit/imagine", "POST", | |
json.dumps(data)) | |
elif action == "DESCRIBE": | |
res = self.request_mj("submit/describe", "POST", | |
json.dumps({"base64": 'data:image/png;base64,' + self.image_bytes})) | |
elif action == "BLEND": | |
res = self.request_mj("submit/blend", "POST", json.dumps( | |
{"base64Array": [self.image_bytes, self.image_bytes]})) | |
elif action in ["UPSCALE", "VARIATION", "REROLL"]: | |
res = self.request_mj( | |
"submit/change", "POST", | |
json.dumps({"action": action, "index": action_index, "taskId": action_use_task_id})) | |
res_json = res.json() | |
if not (200 <= res.status_code < 300) or (res_json['code'] not in [1, 22]): | |
answer = "任务提交失败:" + res_json.get('error', res_json.get('description', '未知错误')) | |
else: | |
task_id = res_json['result'] | |
prefix_content = f"**画面描述:** {prompt}\n**任务ID:** {task_id}\n" | |
fetch_data = Midjourney_Client.FetchDataPack( | |
action=action, | |
prefix_content=prefix_content, | |
task_id=task_id, | |
) | |
fetch_data.prompt = prompt | |
while not fetch_data.finished: | |
answer = self.fetch_status(fetch_data) | |
except Exception as e: | |
logging.error("submit failed", e) | |
answer = "任务提交错误:" + str(e.args[0]) if e.args else '未知错误' | |
return answer, tiktoken.get_encoding("cl100k_base").encode(content) | |
def get_answer_stream_iter(self): | |
content = self.history[-1]['content'] | |
answer = self.get_help() | |
if not content.lower().startswith("/mj"): | |
yield answer | |
return | |
prompt = content[3:].strip() | |
action = "IMAGINE" | |
first_split_index = prompt.find(self.command_splitter) | |
if first_split_index > 0: | |
action = prompt[:first_split_index] | |
if action not in ["IMAGINE", "DESCRIBE", "UPSCALE", | |
"VARIATION", "BLEND", "REROLL" | |
]: | |
yield "任务提交失败:未知的任务类型" | |
return | |
action_index = None | |
action_use_task_id = None | |
if action in ["VARIATION", "UPSCALE", "REROLL"]: | |
action_index = int(prompt[first_split_index + 2:first_split_index + 3]) | |
action_use_task_id = prompt[first_split_index + 5:] | |
try: | |
res = None | |
if action == "IMAGINE": | |
data = { | |
"prompt": prompt | |
} | |
if self.image_bytes is not None: | |
data["base64"] = 'data:image/png;base64,' + self.image_bytes | |
res = self.request_mj("submit/imagine", "POST", | |
json.dumps(data)) | |
elif action == "DESCRIBE": | |
res = self.request_mj("submit/describe", "POST", json.dumps( | |
{"base64": 'data:image/png;base64,' + self.image_bytes})) | |
elif action == "BLEND": | |
res = self.request_mj("submit/blend", "POST", json.dumps( | |
{"base64Array": [self.image_bytes, self.image_bytes]})) | |
elif action in ["UPSCALE", "VARIATION", "REROLL"]: | |
res = self.request_mj( | |
"submit/change", "POST", | |
json.dumps({"action": action, "index": action_index, "taskId": action_use_task_id})) | |
res_json = res.json() | |
if not (200 <= res.status_code < 300) or (res_json['code'] not in [1, 22]): | |
yield "任务提交失败:" + res_json.get('error', res_json.get('description', '未知错误')) | |
else: | |
task_id = res_json['result'] | |
prefix_content = f"**画面描述:** {prompt}\n**任务ID:** {task_id}\n" | |
content = f"[{(datetime.now()).strftime('%Y-%m-%d %H:%M:%S')}] - 任务提交成功:" + \ | |
res_json.get('description') or '请稍等片刻' | |
yield content | |
fetch_data = Midjourney_Client.FetchDataPack( | |
action=action, | |
prefix_content=prefix_content, | |
task_id=task_id, | |
) | |
while not fetch_data.finished: | |
yield self.fetch_status(fetch_data) | |
except Exception as e: | |
logging.error('submit failed', e) | |
yield "任务提交错误:" + str(e.args[0]) if e.args else '未知错误' | |
def get_help(self): | |
return """``` | |
【绘图帮助】 | |
所有命令都需要以 /mj 开头,如:/mj a dog | |
IMAGINE - 绘图,可以省略该命令,后面跟上绘图内容 | |
/mj a dog | |
/mj IMAGINE::a cat | |
DESCRIBE - 描述图片,需要在右下角上传需要描述的图片内容 | |
/mj DESCRIBE:: | |
UPSCALE - 确认后放大图片,第一个数值为需要放大的图片(1~4),第二参数为任务ID | |
/mj UPSCALE::1::123456789 | |
请使用SD进行UPSCALE | |
VARIATION - 图片变体,第一个数值为需要放大的图片(1~4),第二参数为任务ID | |
/mj VARIATION::1::123456789 | |
【绘图参数】 | |
所有命令默认会带上参数--v 5.2 | |
其他参数参照 https://docs.midjourney.com/docs/parameter-list | |
长宽比 --aspect/--ar | |
--ar 1:2 | |
--ar 16:9 | |
负面tag --no | |
--no plants | |
--no hands | |
随机种子 --seed | |
--seed 1 | |
生成动漫风格(NijiJourney) --niji | |
--niji | |
``` | |
""" | |