File size: 16,448 Bytes
0bae6cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 |
import base64
import io
import json
import logging
import os
import pathlib
import tempfile
import time
from datetime import datetime
import requests
import tiktoken
from PIL import Image
from modules.config import retrieve_proxy
from modules.models.XMChat import XMChat
mj_proxy_api_base = os.getenv("MIDJOURNEY_PROXY_API_BASE")
mj_discord_proxy_url = os.getenv("MIDJOURNEY_DISCORD_PROXY_URL")
mj_temp_folder = os.getenv("MIDJOURNEY_TEMP_FOLDER")
class Midjourney_Client(XMChat):
class FetchDataPack:
"""
A class to store data for current fetching data from Midjourney API
"""
action: str # current action, e.g. "IMAGINE", "UPSCALE", "VARIATION"
prefix_content: str # prefix content, task description and process hint
task_id: str # task id
start_time: float # task start timestamp
timeout: int # task timeout in seconds
finished: bool # whether the task is finished
prompt: str # prompt for the task
def __init__(self, action, prefix_content, task_id, timeout=900):
self.action = action
self.prefix_content = prefix_content
self.task_id = task_id
self.start_time = time.time()
self.timeout = timeout
self.finished = False
def __init__(self, model_name, api_key, user_name=""):
super().__init__(api_key, user_name)
self.model_name = model_name
self.history = []
self.api_key = api_key
self.headers = {
"Content-Type": "application/json",
"mj-api-secret": f"{api_key}"
}
self.proxy_url = mj_proxy_api_base
self.command_splitter = "::"
if mj_temp_folder:
temp = "./tmp"
if user_name:
temp = os.path.join(temp, user_name)
if not os.path.exists(temp):
os.makedirs(temp)
self.temp_path = tempfile.mkdtemp(dir=temp)
logging.info("mj temp folder: " + self.temp_path)
else:
self.temp_path = None
def use_mj_self_proxy_url(self, img_url):
"""
replace discord cdn url with mj self proxy url
"""
return img_url.replace(
"https://cdn.discordapp.com/",
mj_discord_proxy_url and mj_discord_proxy_url or "https://cdn.discordapp.com/"
)
def split_image(self, image_url):
"""
when enabling temp dir, split image into 4 parts
"""
with retrieve_proxy():
image_bytes = requests.get(image_url).content
img = Image.open(io.BytesIO(image_bytes))
width, height = img.size
# calculate half width and height
half_width = width // 2
half_height = height // 2
# create coordinates (top-left x, top-left y, bottom-right x, bottom-right y)
coordinates = [(0, 0, half_width, half_height),
(half_width, 0, width, half_height),
(0, half_height, half_width, height),
(half_width, half_height, width, height)]
images = [img.crop(c) for c in coordinates]
return images
def auth_mj(self):
"""
auth midjourney api
"""
# TODO: check if secret is valid
return {'status': 'ok'}
def request_mj(self, path: str, action: str, data: str, retries=3):
"""
request midjourney api
"""
mj_proxy_url = self.proxy_url
if mj_proxy_url is None or not (mj_proxy_url.startswith("http://") or mj_proxy_url.startswith("https://")):
raise Exception('please set MIDJOURNEY_PROXY_API_BASE in ENV or in config.json')
auth_ = self.auth_mj()
if auth_.get('error'):
raise Exception('auth not set')
fetch_url = f"{mj_proxy_url}/{path}"
# logging.info(f"[MJ Proxy] {action} {fetch_url} params: {data}")
for _ in range(retries):
try:
with retrieve_proxy():
res = requests.request(method=action, url=fetch_url, headers=self.headers, data=data)
break
except Exception as e:
print(e)
if res.status_code != 200:
raise Exception(f'{res.status_code} - {res.content}')
return res
def fetch_status(self, fetch_data: FetchDataPack):
"""
fetch status of current task
"""
if fetch_data.start_time + fetch_data.timeout < time.time():
fetch_data.finished = True
return "任务超时,请检查 dc 输出。描述:" + fetch_data.prompt
time.sleep(3)
status_res = self.request_mj(f"task/{fetch_data.task_id}/fetch", "GET", '')
status_res_json = status_res.json()
if not (200 <= status_res.status_code < 300):
raise Exception("任务状态获取失败:" + status_res_json.get(
'error') or status_res_json.get('description') or '未知错误')
else:
fetch_data.finished = False
if status_res_json['status'] == "SUCCESS":
content = status_res_json['imageUrl']
fetch_data.finished = True
elif status_res_json['status'] == "FAILED":
content = status_res_json['failReason'] or '未知原因'
fetch_data.finished = True
elif status_res_json['status'] == "NOT_START":
content = f'任务未开始,已等待 {time.time() - fetch_data.start_time:.2f} 秒'
elif status_res_json['status'] == "IN_PROGRESS":
content = '任务正在运行'
if status_res_json.get('progress'):
content += f",进度:{status_res_json['progress']}"
elif status_res_json['status'] == "SUBMITTED":
content = '任务已提交处理'
elif status_res_json['status'] == "FAILURE":
fetch_data.finished = True
return "任务处理失败,原因:" + status_res_json['failReason'] or '未知原因'
else:
content = status_res_json['status']
if fetch_data.finished:
img_url = self.use_mj_self_proxy_url(status_res_json['imageUrl'])
if fetch_data.action == "DESCRIBE":
return f"\n{status_res_json['prompt']}"
time_cost_str = f"\n\n{fetch_data.action} 花费时间:{time.time() - fetch_data.start_time:.2f} 秒"
upscale_str = ""
variation_str = ""
if fetch_data.action in ["IMAGINE", "UPSCALE", "VARIATION"]:
upscale = [f'/mj UPSCALE{self.command_splitter}{i+1}{self.command_splitter}{fetch_data.task_id}'
for i in range(4)]
upscale_str = '\n放大图片:\n\n' + '\n\n'.join(upscale)
variation = [f'/mj VARIATION{self.command_splitter}{i+1}{self.command_splitter}{fetch_data.task_id}'
for i in range(4)]
variation_str = '\n图片变体:\n\n' + '\n\n'.join(variation)
if self.temp_path and fetch_data.action in ["IMAGINE", "VARIATION"]:
try:
images = self.split_image(img_url)
# save images to temp path
for i in range(4):
images[i].save(pathlib.Path(self.temp_path) / f"{fetch_data.task_id}_{i}.png")
img_str = '\n'.join(
[f"![{fetch_data.task_id}](/file={self.temp_path}/{fetch_data.task_id}_{i}.png)"
for i in range(4)])
return fetch_data.prefix_content + f"{time_cost_str}\n\n{img_str}{upscale_str}{variation_str}"
except Exception as e:
logging.error(e)
return fetch_data.prefix_content + \
f"{time_cost_str}[![{fetch_data.task_id}]({img_url})]({img_url}){upscale_str}{variation_str}"
else:
content = f"**任务状态:** [{(datetime.now()).strftime('%Y-%m-%d %H:%M:%S')}] - {content}"
content += f"\n\n花费时间:{time.time() - fetch_data.start_time:.2f} 秒"
if status_res_json['status'] == 'IN_PROGRESS' and status_res_json.get('imageUrl'):
img_url = status_res_json.get('imageUrl')
return f"{content}\n[![{fetch_data.task_id}]({img_url})]({img_url})"
return content
return None
def handle_file_upload(self, files, chatbot, language):
"""
handle file upload
"""
if files:
for file in files:
if file.name:
logging.info(f"尝试读取图像: {file.name}")
self.try_read_image(file.name)
if self.image_path is not None:
chatbot = chatbot + [((self.image_path,), None)]
if self.image_bytes is not None:
logging.info("使用图片作为输入")
return None, chatbot, None
def reset(self, remain_system_prompt=False):
self.image_bytes = None
self.image_path = None
return super().reset()
def get_answer_at_once(self):
content = self.history[-1]['content']
answer = self.get_help()
if not content.lower().startswith("/mj"):
return answer, len(content)
prompt = content[3:].strip()
action = "IMAGINE"
first_split_index = prompt.find(self.command_splitter)
if first_split_index > 0:
action = prompt[:first_split_index]
if action not in ["IMAGINE", "DESCRIBE", "UPSCALE",
# "VARIATION", "BLEND", "REROLL"
]:
raise Exception("任务提交失败:未知的任务类型")
else:
action_index = None
action_use_task_id = None
if action in ["VARIATION", "UPSCALE", "REROLL"]:
action_index = int(prompt[first_split_index + 2:first_split_index + 3])
action_use_task_id = prompt[first_split_index + 5:]
try:
res = None
if action == "IMAGINE":
data = {
"prompt": prompt
}
if self.image_bytes is not None:
data["base64"] = 'data:image/png;base64,' + self.image_bytes
res = self.request_mj("submit/imagine", "POST",
json.dumps(data))
elif action == "DESCRIBE":
res = self.request_mj("submit/describe", "POST",
json.dumps({"base64": 'data:image/png;base64,' + self.image_bytes}))
elif action == "BLEND":
res = self.request_mj("submit/blend", "POST", json.dumps(
{"base64Array": [self.image_bytes, self.image_bytes]}))
elif action in ["UPSCALE", "VARIATION", "REROLL"]:
res = self.request_mj(
"submit/change", "POST",
json.dumps({"action": action, "index": action_index, "taskId": action_use_task_id}))
res_json = res.json()
if not (200 <= res.status_code < 300) or (res_json['code'] not in [1, 22]):
answer = "任务提交失败:" + res_json.get('error', res_json.get('description', '未知错误'))
else:
task_id = res_json['result']
prefix_content = f"**画面描述:** {prompt}\n**任务ID:** {task_id}\n"
fetch_data = Midjourney_Client.FetchDataPack(
action=action,
prefix_content=prefix_content,
task_id=task_id,
)
fetch_data.prompt = prompt
while not fetch_data.finished:
answer = self.fetch_status(fetch_data)
except Exception as e:
logging.error("submit failed", e)
answer = "任务提交错误:" + str(e.args[0]) if e.args else '未知错误'
return answer, tiktoken.get_encoding("cl100k_base").encode(content)
def get_answer_stream_iter(self):
content = self.history[-1]['content']
answer = self.get_help()
if not content.lower().startswith("/mj"):
yield answer
return
prompt = content[3:].strip()
action = "IMAGINE"
first_split_index = prompt.find(self.command_splitter)
if first_split_index > 0:
action = prompt[:first_split_index]
if action not in ["IMAGINE", "DESCRIBE", "UPSCALE",
"VARIATION", "BLEND", "REROLL"
]:
yield "任务提交失败:未知的任务类型"
return
action_index = None
action_use_task_id = None
if action in ["VARIATION", "UPSCALE", "REROLL"]:
action_index = int(prompt[first_split_index + 2:first_split_index + 3])
action_use_task_id = prompt[first_split_index + 5:]
try:
res = None
if action == "IMAGINE":
data = {
"prompt": prompt
}
if self.image_bytes is not None:
data["base64"] = 'data:image/png;base64,' + self.image_bytes
res = self.request_mj("submit/imagine", "POST",
json.dumps(data))
elif action == "DESCRIBE":
res = self.request_mj("submit/describe", "POST", json.dumps(
{"base64": 'data:image/png;base64,' + self.image_bytes}))
elif action == "BLEND":
res = self.request_mj("submit/blend", "POST", json.dumps(
{"base64Array": [self.image_bytes, self.image_bytes]}))
elif action in ["UPSCALE", "VARIATION", "REROLL"]:
res = self.request_mj(
"submit/change", "POST",
json.dumps({"action": action, "index": action_index, "taskId": action_use_task_id}))
res_json = res.json()
if not (200 <= res.status_code < 300) or (res_json['code'] not in [1, 22]):
yield "任务提交失败:" + res_json.get('error', res_json.get('description', '未知错误'))
else:
task_id = res_json['result']
prefix_content = f"**画面描述:** {prompt}\n**任务ID:** {task_id}\n"
content = f"[{(datetime.now()).strftime('%Y-%m-%d %H:%M:%S')}] - 任务提交成功:" + \
res_json.get('description') or '请稍等片刻'
yield content
fetch_data = Midjourney_Client.FetchDataPack(
action=action,
prefix_content=prefix_content,
task_id=task_id,
)
while not fetch_data.finished:
yield self.fetch_status(fetch_data)
except Exception as e:
logging.error('submit failed', e)
yield "任务提交错误:" + str(e.args[0]) if e.args else '未知错误'
def get_help(self):
return """```
【绘图帮助】
所有命令都需要以 /mj 开头,如:/mj a dog
IMAGINE - 绘图,可以省略该命令,后面跟上绘图内容
/mj a dog
/mj IMAGINE::a cat
DESCRIBE - 描述图片,需要在右下角上传需要描述的图片内容
/mj DESCRIBE::
UPSCALE - 确认后放大图片,第一个数值为需要放大的图片(1~4),第二参数为任务ID
/mj UPSCALE::1::123456789
请使用SD进行UPSCALE
VARIATION - 图片变体,第一个数值为需要放大的图片(1~4),第二参数为任务ID
/mj VARIATION::1::123456789
【绘图参数】
所有命令默认会带上参数--v 5.2
其他参数参照 https://docs.midjourney.com/docs/parameter-list
长宽比 --aspect/--ar
--ar 1:2
--ar 16:9
负面tag --no
--no plants
--no hands
随机种子 --seed
--seed 1
生成动漫风格(NijiJourney) --niji
--niji
```
"""
|