pantat88 commited on
Commit
20bcbb8
1 Parent(s): 3cd677c

Create encrypt_image.py

Browse files
Files changed (1) hide show
  1. encrypt_image.py +206 -0
encrypt_image.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ from pathlib import Path
4
+ from modules import shared,script_callbacks,scripts as md_scripts,images
5
+ from modules.api import api
6
+ from modules.shared import opts
7
+ from scripts.core.core import get_sha256,dencrypt_image,dencrypt_image_v2,encrypt_image_v2
8
+ from PIL import PngImagePlugin,_util,ImagePalette
9
+ from PIL import Image as PILImage
10
+ from io import BytesIO
11
+ from typing import Optional
12
+ from fastapi import FastAPI
13
+ from gradio import Blocks
14
+ from fastapi import FastAPI, Request, Response
15
+ import sys
16
+ from urllib.parse import unquote
17
+ from colorama import Fore, Back, Style
18
+
19
+ repo_dir = md_scripts.basedir()
20
+ password = getattr(shared.cmd_opts, 'encrypt_pass', None)
21
+
22
+
23
+ def hook_http_request(app: FastAPI):
24
+ @app.middleware("http")
25
+ async def image_dencrypt(req: Request, call_next):
26
+ endpoint:str = req.scope.get('path', 'err')
27
+ endpoint='/'+endpoint.strip('/')
28
+ # 兼容无边浏览器
29
+ if endpoint.startswith('/infinite_image_browsing/image-thumbnail') or endpoint.startswith('/infinite_image_browsing/file'):
30
+ query_string:str = req.scope.get('query_string').decode('utf-8')
31
+ query_string = unquote(query_string)
32
+ if query_string and query_string.index('path=')>=0:
33
+ query = query_string.split('&')
34
+ path = ''
35
+ for sub in query:
36
+ if sub.startswith('path='):
37
+ path = sub[sub.index('=')+1:]
38
+ if path:
39
+ endpoint = '/file=' + path
40
+ # 模型预览图
41
+ if endpoint.startswith('/sd_extra_networks/thumb'):
42
+ query_string:str = req.scope.get('query_string').decode('utf-8')
43
+ query_string = unquote(query_string)
44
+ if query_string and query_string.index('filename=')>=0:
45
+ query = query_string.split('&')
46
+ path = ''
47
+ for sub in query:
48
+ if sub.startswith('filename='):
49
+ path = sub[sub.index('=')+1:]
50
+ if path:
51
+ endpoint = '/file=' + path
52
+ if endpoint.startswith('/file='):
53
+ file_path = endpoint[6:] or ''
54
+ if not file_path: return await call_next(req)
55
+ if file_path.rfind('.') == -1: return await call_next(req)
56
+ if not file_path[file_path.rfind('.'):]: return await call_next(req)
57
+ if file_path[file_path.rfind('.'):].lower() in ['.png','.jpg','.jpeg','.webp','.abcd']:
58
+ image = PILImage.open(file_path)
59
+ pnginfo = image.info or {}
60
+ if 'Encrypt' in pnginfo:
61
+ buffered = BytesIO()
62
+ info = PngImagePlugin.PngInfo()
63
+ for key in pnginfo.keys():
64
+ if pnginfo[key]:
65
+ info.add_text(key,pnginfo[key])
66
+ image.save(buffered, format=PngImagePlugin.PngImageFile.format, pnginfo=info)
67
+ decrypted_image_data = buffered.getvalue()
68
+ response: Response = Response(content=decrypted_image_data, media_type="image/png")
69
+ return response
70
+
71
+ return await call_next(req)
72
+
73
+ def set_shared_options():
74
+ # 传递插件状态到前端
75
+ section = ("encrypt_image_is_enable",'图片加密' if shared.opts.localization == 'zh_CN' else "encrypt image" )
76
+ option = shared.OptionInfo(
77
+ default="是",
78
+ label='是否启用了加密插件' if shared.opts.localization == 'zh_CN' else "Whether the encryption plug-in is enabled",
79
+ section=section,
80
+ )
81
+ option.do_not_save = True
82
+ shared.opts.add_option(
83
+ "encrypt_image_is_enable",
84
+ option,
85
+ )
86
+ shared.opts.data['encrypt_image_is_enable'] = "是"
87
+
88
+ def app_started_callback(_: Blocks, app: FastAPI):
89
+ set_shared_options()
90
+
91
+
92
+ if PILImage.Image.__name__ != 'EncryptedImage':
93
+ super_open = PILImage.open
94
+ super_encode_pil_to_base64 = api.encode_pil_to_base64
95
+ super_modules_images_save_image = images.save_image
96
+ super_api_middleware = api.api_middleware
97
+ class EncryptedImage(PILImage.Image):
98
+ __name__ = "EncryptedImage"
99
+
100
+ @staticmethod
101
+ def from_image(image:PILImage.Image):
102
+ image = image.copy()
103
+ img = EncryptedImage()
104
+ img.im = image.im
105
+ img._mode = image.mode
106
+ if image.im.mode:
107
+ try:
108
+ img.mode = image.im.mode
109
+ except Exception as e:
110
+ ''
111
+ img._size = image.size
112
+ img.format = image.format
113
+ if image.mode in ("P", "PA"):
114
+ if image.palette:
115
+ img.palette = image.palette.copy()
116
+ else:
117
+ img.palette = ImagePalette.ImagePalette()
118
+ img.info = image.info.copy()
119
+ return img
120
+
121
+ def save(self, fp, format=None, **params):
122
+ filename = ""
123
+ if isinstance(fp, Path):
124
+ filename = str(fp)
125
+ elif _util.is_path(fp):
126
+ filename = fp
127
+ elif fp == sys.stdout:
128
+ try:
129
+ fp = sys.stdout.buffer
130
+ except AttributeError:
131
+ pass
132
+ if not filename and hasattr(fp, "name") and _util.is_path(fp.name):
133
+ # only set the name for metadata purposes
134
+ filename = fp.name
135
+
136
+ if not filename or not password:
137
+ # 如果没有密码或不保存到硬盘,直接保存
138
+ super().save(fp, format = format, **params)
139
+ return
140
+
141
+ if 'Encrypt' in self.info and (self.info['Encrypt'] == 'pixel_shuffle' or self.info['Encrypt'] == 'pixel_shuffle_2'):
142
+ super().save(fp, format = format, **params)
143
+ return
144
+
145
+ encrypt_image_v2(self, get_sha256(password))
146
+ self.format = PngImagePlugin.PngImageFile.format
147
+ pnginfo = params.get('pnginfo', PngImagePlugin.PngInfo())
148
+ if not pnginfo:
149
+ pnginfo = PngImagePlugin.PngInfo()
150
+ pnginfo.add_text('Encrypt', 'pixel_shuffle_2')
151
+ pnginfo.add_text('EncryptPwdSha', get_sha256(f'{get_sha256(password)}Encrypt'))
152
+ for key in (self.info or {}).keys():
153
+ if self.info[key]:
154
+ pnginfo.add_text(key,str(self.info[key]))
155
+ params.update(pnginfo=pnginfo)
156
+ super().save(fp, format=self.format, **params)
157
+ # 保存到文件后解密内存内的图片,让直接在内存内使用时图片正常
158
+ dencrypt_image_v2(self, get_sha256(password))
159
+
160
+
161
+
162
+ def open(fp,*args, **kwargs):
163
+ image = super_open(fp,*args, **kwargs)
164
+ if password and image.format.lower() == PngImagePlugin.PngImageFile.format.lower():
165
+ pnginfo = image.info or {}
166
+ if 'Encrypt' in pnginfo and pnginfo["Encrypt"] == 'pixel_shuffle':
167
+ dencrypt_image(image, get_sha256(password))
168
+ pnginfo["Encrypt"] = None
169
+ image = EncryptedImage.from_image(image=image)
170
+ return image
171
+ if 'Encrypt' in pnginfo and pnginfo["Encrypt"] == 'pixel_shuffle_2':
172
+ dencrypt_image_v2(image, get_sha256(password))
173
+ pnginfo["Encrypt"] = None
174
+ image = EncryptedImage.from_image(image=image)
175
+ return image
176
+ return EncryptedImage.from_image(image=image)
177
+
178
+ def encode_pil_to_base64(image:PILImage.Image):
179
+ with io.BytesIO() as output_bytes:
180
+ image.save(output_bytes, format="PNG", quality=opts.jpeg_quality)
181
+ pnginfo = image.info or {}
182
+ if 'Encrypt' in pnginfo and pnginfo["Encrypt"] == 'pixel_shuffle':
183
+ dencrypt_image(image, get_sha256(password))
184
+ pnginfo["Encrypt"] = None
185
+ if 'Encrypt' in pnginfo and pnginfo["Encrypt"] == 'pixel_shuffle_2':
186
+ dencrypt_image_v2(image, get_sha256(password))
187
+ pnginfo["Encrypt"] = None
188
+ bytes_data = output_bytes.getvalue()
189
+ return base64.b64encode(bytes_data)
190
+
191
+ def api_middleware(app: FastAPI):
192
+ super_api_middleware(app)
193
+ hook_http_request(app)
194
+
195
+ if password:
196
+ PILImage.Image = EncryptedImage
197
+ PILImage.open = open
198
+ api.encode_pil_to_base64 = encode_pil_to_base64
199
+ api.api_middleware = api_middleware
200
+
201
+ if password:
202
+ script_callbacks.on_app_started(app_started_callback)
203
+ print(f'{Fore.GREEN}[-] Image Encryption started.{Style.RESET_ALL}')
204
+
205
+ else:
206
+ print(f'{Fore.RED}[-] Image Encryption DISABLED.{Style.RESET_ALL}')