File size: 16,448 Bytes
0bae6cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
import base64
import io
import json
import logging
import os
import pathlib
import tempfile
import time
from datetime import datetime

import requests
import tiktoken
from PIL import Image

from modules.config import retrieve_proxy
from modules.models.XMChat import XMChat

mj_proxy_api_base = os.getenv("MIDJOURNEY_PROXY_API_BASE")
mj_discord_proxy_url = os.getenv("MIDJOURNEY_DISCORD_PROXY_URL")
mj_temp_folder = os.getenv("MIDJOURNEY_TEMP_FOLDER")


class Midjourney_Client(XMChat):

    class FetchDataPack:
        """
        A class to store data for current fetching data from Midjourney API
        """

        action: str  # current action, e.g. "IMAGINE", "UPSCALE", "VARIATION"
        prefix_content: str  # prefix content, task description and process hint
        task_id: str  # task id
        start_time: float  # task start timestamp
        timeout: int  # task timeout in seconds
        finished: bool  # whether the task is finished
        prompt: str  # prompt for the task

        def __init__(self, action, prefix_content, task_id, timeout=900):
            self.action = action
            self.prefix_content = prefix_content
            self.task_id = task_id
            self.start_time = time.time()
            self.timeout = timeout
            self.finished = False

    def __init__(self, model_name, api_key, user_name=""):
        super().__init__(api_key, user_name)
        self.model_name = model_name
        self.history = []
        self.api_key = api_key
        self.headers = {
            "Content-Type": "application/json",
            "mj-api-secret": f"{api_key}"
        }
        self.proxy_url = mj_proxy_api_base
        self.command_splitter = "::"

        if mj_temp_folder:
            temp = "./tmp"
            if user_name:
                temp = os.path.join(temp, user_name)
            if not os.path.exists(temp):
                os.makedirs(temp)
            self.temp_path = tempfile.mkdtemp(dir=temp)
            logging.info("mj temp folder: " + self.temp_path)
        else:
            self.temp_path = None

    def use_mj_self_proxy_url(self, img_url):
        """
        replace discord cdn url with mj self proxy url
        """
        return img_url.replace(
            "https://cdn.discordapp.com/",
            mj_discord_proxy_url and mj_discord_proxy_url or "https://cdn.discordapp.com/"
        )

    def split_image(self, image_url):
        """
        when enabling temp dir, split image into 4 parts
        """
        with retrieve_proxy():
            image_bytes = requests.get(image_url).content
        img = Image.open(io.BytesIO(image_bytes))
        width, height = img.size
        # calculate half width and height
        half_width = width // 2
        half_height = height // 2
        # create coordinates (top-left x, top-left y, bottom-right x, bottom-right y)
        coordinates = [(0, 0, half_width, half_height),
                       (half_width, 0, width, half_height),
                       (0, half_height, half_width, height),
                       (half_width, half_height, width, height)]

        images = [img.crop(c) for c in coordinates]
        return images

    def auth_mj(self):
        """
        auth midjourney api
        """
        # TODO: check if secret is valid
        return {'status': 'ok'}

    def request_mj(self, path: str, action: str, data: str, retries=3):
        """
        request midjourney api
        """
        mj_proxy_url = self.proxy_url
        if mj_proxy_url is None or not (mj_proxy_url.startswith("http://") or mj_proxy_url.startswith("https://")):
            raise Exception('please set MIDJOURNEY_PROXY_API_BASE in ENV or in config.json')

        auth_ = self.auth_mj()
        if auth_.get('error'):
            raise Exception('auth not set')

        fetch_url = f"{mj_proxy_url}/{path}"
        # logging.info(f"[MJ Proxy] {action} {fetch_url} params: {data}")

        for _ in range(retries):
            try:
                with retrieve_proxy():
                    res = requests.request(method=action, url=fetch_url, headers=self.headers, data=data)
                break
            except Exception as e:
                print(e)

        if res.status_code != 200:
            raise Exception(f'{res.status_code} - {res.content}')

        return res

    def fetch_status(self, fetch_data: FetchDataPack):
        """
        fetch status of current task
        """
        if fetch_data.start_time + fetch_data.timeout < time.time():
            fetch_data.finished = True
            return "任务超时,请检查 dc 输出。描述:" + fetch_data.prompt

        time.sleep(3)
        status_res = self.request_mj(f"task/{fetch_data.task_id}/fetch", "GET", '')
        status_res_json = status_res.json()
        if not (200 <= status_res.status_code < 300):
            raise Exception("任务状态获取失败:" + status_res_json.get(
                'error') or status_res_json.get('description') or '未知错误')
        else:
            fetch_data.finished = False
            if status_res_json['status'] == "SUCCESS":
                content = status_res_json['imageUrl']
                fetch_data.finished = True
            elif status_res_json['status'] == "FAILED":
                content = status_res_json['failReason'] or '未知原因'
                fetch_data.finished = True
            elif status_res_json['status'] == "NOT_START":
                content = f'任务未开始,已等待 {time.time() - fetch_data.start_time:.2f} 秒'
            elif status_res_json['status'] == "IN_PROGRESS":
                content = '任务正在运行'
                if status_res_json.get('progress'):
                    content += f",进度:{status_res_json['progress']}"
            elif status_res_json['status'] == "SUBMITTED":
                content = '任务已提交处理'
            elif status_res_json['status'] == "FAILURE":
                fetch_data.finished = True
                return "任务处理失败,原因:" + status_res_json['failReason'] or '未知原因'
            else:
                content = status_res_json['status']
            if fetch_data.finished:
                img_url = self.use_mj_self_proxy_url(status_res_json['imageUrl'])
                if fetch_data.action == "DESCRIBE":
                    return f"\n{status_res_json['prompt']}"
                time_cost_str = f"\n\n{fetch_data.action} 花费时间:{time.time() - fetch_data.start_time:.2f} 秒"
                upscale_str = ""
                variation_str = ""
                if fetch_data.action in ["IMAGINE", "UPSCALE", "VARIATION"]:
                    upscale = [f'/mj UPSCALE{self.command_splitter}{i+1}{self.command_splitter}{fetch_data.task_id}'
                               for i in range(4)]
                    upscale_str = '\n放大图片:\n\n' + '\n\n'.join(upscale)
                    variation = [f'/mj VARIATION{self.command_splitter}{i+1}{self.command_splitter}{fetch_data.task_id}'
                                 for i in range(4)]
                    variation_str = '\n图片变体:\n\n' + '\n\n'.join(variation)
                if self.temp_path and fetch_data.action in ["IMAGINE", "VARIATION"]:
                    try:
                        images = self.split_image(img_url)
                        # save images to temp path
                        for i in range(4):
                            images[i].save(pathlib.Path(self.temp_path) / f"{fetch_data.task_id}_{i}.png")
                        img_str = '\n'.join(
                            [f"![{fetch_data.task_id}](/file={self.temp_path}/{fetch_data.task_id}_{i}.png)"
                             for i in range(4)])
                        return fetch_data.prefix_content + f"{time_cost_str}\n\n{img_str}{upscale_str}{variation_str}"
                    except Exception as e:
                        logging.error(e)
                return fetch_data.prefix_content + \
                    f"{time_cost_str}[![{fetch_data.task_id}]({img_url})]({img_url}){upscale_str}{variation_str}"
            else:
                content = f"**任务状态:** [{(datetime.now()).strftime('%Y-%m-%d %H:%M:%S')}] - {content}"
                content += f"\n\n花费时间:{time.time() - fetch_data.start_time:.2f} 秒"
                if status_res_json['status'] == 'IN_PROGRESS' and status_res_json.get('imageUrl'):
                    img_url = status_res_json.get('imageUrl')
                    return f"{content}\n[![{fetch_data.task_id}]({img_url})]({img_url})"
                return content
        return None

    def handle_file_upload(self, files, chatbot, language):
        """
        handle file upload
        """
        if files:
            for file in files:
                if file.name:
                    logging.info(f"尝试读取图像: {file.name}")
                    self.try_read_image(file.name)
            if self.image_path is not None:
                chatbot = chatbot + [((self.image_path,), None)]
            if self.image_bytes is not None:
                logging.info("使用图片作为输入")
        return None, chatbot, None

    def reset(self, remain_system_prompt=False):
        self.image_bytes = None
        self.image_path = None
        return super().reset()

    def get_answer_at_once(self):
        content = self.history[-1]['content']
        answer = self.get_help()

        if not content.lower().startswith("/mj"):
            return answer, len(content)

        prompt = content[3:].strip()
        action = "IMAGINE"
        first_split_index = prompt.find(self.command_splitter)
        if first_split_index > 0:
            action = prompt[:first_split_index]
        if action not in ["IMAGINE", "DESCRIBE", "UPSCALE",
                          # "VARIATION", "BLEND", "REROLL"
                          ]:
            raise Exception("任务提交失败:未知的任务类型")
        else:
            action_index = None
            action_use_task_id = None
            if action in ["VARIATION", "UPSCALE", "REROLL"]:
                action_index = int(prompt[first_split_index + 2:first_split_index + 3])
                action_use_task_id = prompt[first_split_index + 5:]

            try:
                res = None
                if action == "IMAGINE":
                    data = {
                        "prompt": prompt
                    }
                    if self.image_bytes is not None:
                        data["base64"] = 'data:image/png;base64,' + self.image_bytes
                    res = self.request_mj("submit/imagine", "POST",
                                          json.dumps(data))
                elif action == "DESCRIBE":
                    res = self.request_mj("submit/describe", "POST",
                                          json.dumps({"base64": 'data:image/png;base64,' + self.image_bytes}))
                elif action == "BLEND":
                    res = self.request_mj("submit/blend", "POST", json.dumps(
                        {"base64Array": [self.image_bytes, self.image_bytes]}))
                elif action in ["UPSCALE", "VARIATION", "REROLL"]:
                    res = self.request_mj(
                        "submit/change", "POST",
                        json.dumps({"action": action, "index": action_index, "taskId": action_use_task_id}))
                res_json = res.json()
                if not (200 <= res.status_code < 300) or (res_json['code'] not in [1, 22]):
                    answer = "任务提交失败:" + res_json.get('error', res_json.get('description', '未知错误'))
                else:
                    task_id = res_json['result']
                    prefix_content = f"**画面描述:** {prompt}\n**任务ID:** {task_id}\n"

                    fetch_data = Midjourney_Client.FetchDataPack(
                        action=action,
                        prefix_content=prefix_content,
                        task_id=task_id,
                    )
                    fetch_data.prompt = prompt
                    while not fetch_data.finished:
                        answer = self.fetch_status(fetch_data)
            except Exception as e:
                logging.error("submit failed", e)
                answer = "任务提交错误:" + str(e.args[0]) if e.args else '未知错误'

        return answer, tiktoken.get_encoding("cl100k_base").encode(content)

    def get_answer_stream_iter(self):
        content = self.history[-1]['content']
        answer = self.get_help()

        if not content.lower().startswith("/mj"):
            yield answer
            return

        prompt = content[3:].strip()
        action = "IMAGINE"
        first_split_index = prompt.find(self.command_splitter)
        if first_split_index > 0:
            action = prompt[:first_split_index]
        if action not in ["IMAGINE", "DESCRIBE", "UPSCALE",
                          "VARIATION", "BLEND", "REROLL"
                          ]:
            yield "任务提交失败:未知的任务类型"
            return

        action_index = None
        action_use_task_id = None
        if action in ["VARIATION", "UPSCALE", "REROLL"]:
            action_index = int(prompt[first_split_index + 2:first_split_index + 3])
            action_use_task_id = prompt[first_split_index + 5:]

        try:
            res = None
            if action == "IMAGINE":
                data = {
                    "prompt": prompt
                }
                if self.image_bytes is not None:
                    data["base64"] = 'data:image/png;base64,' + self.image_bytes
                res = self.request_mj("submit/imagine", "POST",
                                      json.dumps(data))
            elif action == "DESCRIBE":
                res = self.request_mj("submit/describe", "POST", json.dumps(
                    {"base64": 'data:image/png;base64,' + self.image_bytes}))
            elif action == "BLEND":
                res = self.request_mj("submit/blend", "POST", json.dumps(
                    {"base64Array": [self.image_bytes, self.image_bytes]}))
            elif action in ["UPSCALE", "VARIATION", "REROLL"]:
                res = self.request_mj(
                    "submit/change", "POST",
                    json.dumps({"action": action, "index": action_index, "taskId": action_use_task_id}))
            res_json = res.json()
            if not (200 <= res.status_code < 300) or (res_json['code'] not in [1, 22]):
                yield "任务提交失败:" + res_json.get('error', res_json.get('description', '未知错误'))
            else:
                task_id = res_json['result']
                prefix_content = f"**画面描述:** {prompt}\n**任务ID:** {task_id}\n"
                content = f"[{(datetime.now()).strftime('%Y-%m-%d %H:%M:%S')}] - 任务提交成功:" + \
                    res_json.get('description') or '请稍等片刻'
                yield content

                fetch_data = Midjourney_Client.FetchDataPack(
                    action=action,
                    prefix_content=prefix_content,
                    task_id=task_id,
                )
                while not fetch_data.finished:
                    yield self.fetch_status(fetch_data)
        except Exception as e:
            logging.error('submit failed', e)
            yield "任务提交错误:" + str(e.args[0]) if e.args else '未知错误'

    def get_help(self):
        return """```
【绘图帮助】
所有命令都需要以 /mj 开头,如:/mj a dog
IMAGINE - 绘图,可以省略该命令,后面跟上绘图内容
    /mj a dog
    /mj IMAGINE::a cat
DESCRIBE - 描述图片,需要在右下角上传需要描述的图片内容
    /mj DESCRIBE::
UPSCALE - 确认后放大图片,第一个数值为需要放大的图片(1~4),第二参数为任务ID
    /mj UPSCALE::1::123456789
    请使用SD进行UPSCALE
VARIATION - 图片变体,第一个数值为需要放大的图片(1~4),第二参数为任务ID
    /mj VARIATION::1::123456789

【绘图参数】
所有命令默认会带上参数--v 5.2
其他参数参照 https://docs.midjourney.com/docs/parameter-list
长宽比 --aspect/--ar
    --ar 1:2
    --ar 16:9
负面tag --no
    --no plants
    --no hands
随机种子 --seed
    --seed 1
生成动漫风格(NijiJourney) --niji
    --niji
```
"""