nyanko7 commited on
Commit
86562e4
1 Parent(s): 4b30d84

Create modules/safe.py

Browse files
Files changed (1) hide show
  1. modules/safe.py +188 -0
modules/safe.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this code is adapted from the script contributed by anon from /h/
2
+ # modified, from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/safe.py
3
+
4
+ import io
5
+ import pickle
6
+ import collections
7
+ import sys
8
+ import traceback
9
+
10
+ import torch
11
+ import numpy
12
+ import _codecs
13
+ import zipfile
14
+ import re
15
+
16
+
17
+ # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
18
+ TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
19
+
20
+
21
+ def encode(*args):
22
+ out = _codecs.encode(*args)
23
+ return out
24
+
25
+
26
+ class RestrictedUnpickler(pickle.Unpickler):
27
+ extra_handler = None
28
+
29
+ def persistent_load(self, saved_id):
30
+ assert saved_id[0] == 'storage'
31
+ return TypedStorage()
32
+
33
+ def find_class(self, module, name):
34
+ if self.extra_handler is not None:
35
+ res = self.extra_handler(module, name)
36
+ if res is not None:
37
+ return res
38
+
39
+ if module == 'collections' and name == 'OrderedDict':
40
+ return getattr(collections, name)
41
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
42
+ return getattr(torch._utils, name)
43
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
44
+ return getattr(torch, name)
45
+ if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
46
+ return getattr(torch.nn.modules.container, name)
47
+ if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
48
+ return getattr(numpy.core.multiarray, name)
49
+ if module == 'numpy' and name in ['dtype', 'ndarray']:
50
+ return getattr(numpy, name)
51
+ if module == '_codecs' and name == 'encode':
52
+ return encode
53
+ if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
54
+ import pytorch_lightning.callbacks
55
+ return pytorch_lightning.callbacks.model_checkpoint
56
+ if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
57
+ import pytorch_lightning.callbacks.model_checkpoint
58
+ return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
59
+ if module == "__builtin__" and name == 'set':
60
+ return set
61
+
62
+ # Forbid everything else.
63
+ raise Exception(f"global '{module}/{name}' is forbidden")
64
+
65
+
66
+ # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
67
+ allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
68
+ data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
69
+
70
+ def check_zip_filenames(filename, names):
71
+ for name in names:
72
+ if allowed_zip_names_re.match(name):
73
+ continue
74
+
75
+ raise Exception(f"bad file inside {filename}: {name}")
76
+
77
+
78
+ def check_pt(filename, extra_handler):
79
+ try:
80
+
81
+ # new pytorch format is a zip file
82
+ with zipfile.ZipFile(filename) as z:
83
+ check_zip_filenames(filename, z.namelist())
84
+
85
+ # find filename of data.pkl in zip file: '<directory name>/data.pkl'
86
+ data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
87
+ if len(data_pkl_filenames) == 0:
88
+ raise Exception(f"data.pkl not found in {filename}")
89
+ if len(data_pkl_filenames) > 1:
90
+ raise Exception(f"Multiple data.pkl found in {filename}")
91
+ with z.open(data_pkl_filenames[0]) as file:
92
+ unpickler = RestrictedUnpickler(file)
93
+ unpickler.extra_handler = extra_handler
94
+ unpickler.load()
95
+
96
+ except zipfile.BadZipfile:
97
+
98
+ # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
99
+ with open(filename, "rb") as file:
100
+ unpickler = RestrictedUnpickler(file)
101
+ unpickler.extra_handler = extra_handler
102
+ for i in range(5):
103
+ unpickler.load()
104
+
105
+
106
+ def load(filename, *args, **kwargs):
107
+ return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
108
+
109
+
110
+ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
111
+ """
112
+ this function is intended to be used by extensions that want to load models with
113
+ some extra classes in them that the usual unpickler would find suspicious.
114
+
115
+ Use the extra_handler argument to specify a function that takes module and field name as text,
116
+ and returns that field's value:
117
+
118
+ ```python
119
+ def extra(module, name):
120
+ if module == 'collections' and name == 'OrderedDict':
121
+ return collections.OrderedDict
122
+
123
+ return None
124
+
125
+ safe.load_with_extra('model.pt', extra_handler=extra)
126
+ ```
127
+
128
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
129
+ definitely unsafe.
130
+ """
131
+
132
+ try:
133
+ check_pt(filename, extra_handler)
134
+
135
+ except pickle.UnpicklingError:
136
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
137
+ print(traceback.format_exc(), file=sys.stderr)
138
+ print("The file is most likely corrupted.", file=sys.stderr)
139
+ return None
140
+
141
+ except Exception:
142
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
143
+ print(traceback.format_exc(), file=sys.stderr)
144
+ print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
145
+ print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
146
+ return None
147
+
148
+ return unsafe_torch_load(filename, *args, **kwargs)
149
+
150
+
151
+ class Extra:
152
+ """
153
+ A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
154
+ (because it's not your code making the torch.load call). The intended use is like this:
155
+
156
+ ```
157
+ import torch
158
+ from modules import safe
159
+
160
+ def handler(module, name):
161
+ if module == 'torch' and name in ['float64', 'float16']:
162
+ return getattr(torch, name)
163
+
164
+ return None
165
+
166
+ with safe.Extra(handler):
167
+ x = torch.load('model.pt')
168
+ ```
169
+ """
170
+
171
+ def __init__(self, handler):
172
+ self.handler = handler
173
+
174
+ def __enter__(self):
175
+ global global_extra_handler
176
+
177
+ assert global_extra_handler is None, 'already inside an Extra() block'
178
+ global_extra_handler = self.handler
179
+
180
+ def __exit__(self, exc_type, exc_val, exc_tb):
181
+ global global_extra_handler
182
+
183
+ global_extra_handler = None
184
+
185
+
186
+ unsafe_torch_load = torch.load
187
+ torch.load = load
188
+ global_extra_handler = None