|
|
|
|
|
import io |
|
import pickle |
|
import collections |
|
import sys |
|
import traceback |
|
|
|
import torch |
|
import numpy |
|
import _codecs |
|
import zipfile |
|
import re |
|
|
|
|
|
|
|
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage |
|
|
|
|
|
def encode(*args): |
|
out = _codecs.encode(*args) |
|
return out |
|
|
|
|
|
class RestrictedUnpickler(pickle.Unpickler): |
|
extra_handler = None |
|
|
|
def persistent_load(self, saved_id): |
|
assert saved_id[0] == 'storage' |
|
return TypedStorage() |
|
|
|
def find_class(self, module, name): |
|
if self.extra_handler is not None: |
|
res = self.extra_handler(module, name) |
|
if res is not None: |
|
return res |
|
|
|
if module == 'collections' and name == 'OrderedDict': |
|
return getattr(collections, name) |
|
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: |
|
return getattr(torch._utils, name) |
|
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']: |
|
return getattr(torch, name) |
|
if module == 'torch.nn.modules.container' and name in ['ParameterDict']: |
|
return getattr(torch.nn.modules.container, name) |
|
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']: |
|
return getattr(numpy.core.multiarray, name) |
|
if module == 'numpy' and name in ['dtype', 'ndarray']: |
|
return getattr(numpy, name) |
|
if module == '_codecs' and name == 'encode': |
|
return encode |
|
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': |
|
import pytorch_lightning.callbacks |
|
return pytorch_lightning.callbacks.model_checkpoint |
|
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': |
|
import pytorch_lightning.callbacks.model_checkpoint |
|
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint |
|
if module == "__builtin__" and name == 'set': |
|
return set |
|
|
|
|
|
raise Exception(f"global '{module}/{name}' is forbidden") |
|
|
|
|
|
|
|
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$") |
|
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") |
|
|
|
def check_zip_filenames(filename, names): |
|
for name in names: |
|
if allowed_zip_names_re.match(name): |
|
continue |
|
|
|
raise Exception(f"bad file inside {filename}: {name}") |
|
|
|
|
|
def check_pt(filename, extra_handler): |
|
try: |
|
|
|
|
|
with zipfile.ZipFile(filename) as z: |
|
check_zip_filenames(filename, z.namelist()) |
|
|
|
|
|
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] |
|
if len(data_pkl_filenames) == 0: |
|
raise Exception(f"data.pkl not found in {filename}") |
|
if len(data_pkl_filenames) > 1: |
|
raise Exception(f"Multiple data.pkl found in {filename}") |
|
with z.open(data_pkl_filenames[0]) as file: |
|
unpickler = RestrictedUnpickler(file) |
|
unpickler.extra_handler = extra_handler |
|
unpickler.load() |
|
|
|
except zipfile.BadZipfile: |
|
|
|
|
|
with open(filename, "rb") as file: |
|
unpickler = RestrictedUnpickler(file) |
|
unpickler.extra_handler = extra_handler |
|
for i in range(5): |
|
unpickler.load() |
|
|
|
|
|
def load(filename, *args, **kwargs): |
|
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs) |
|
|
|
|
|
def load_with_extra(filename, extra_handler=None, *args, **kwargs): |
|
""" |
|
this function is intended to be used by extensions that want to load models with |
|
some extra classes in them that the usual unpickler would find suspicious. |
|
|
|
Use the extra_handler argument to specify a function that takes module and field name as text, |
|
and returns that field's value: |
|
|
|
```python |
|
def extra(module, name): |
|
if module == 'collections' and name == 'OrderedDict': |
|
return collections.OrderedDict |
|
|
|
return None |
|
|
|
safe.load_with_extra('model.pt', extra_handler=extra) |
|
``` |
|
|
|
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is |
|
definitely unsafe. |
|
""" |
|
|
|
from modules import shared |
|
|
|
try: |
|
if not shared.cmd_opts.disable_safe_unpickle: |
|
check_pt(filename, extra_handler) |
|
|
|
except pickle.UnpicklingError: |
|
print(f"Error verifying pickled file from {filename}:", file=sys.stderr) |
|
print(traceback.format_exc(), file=sys.stderr) |
|
print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) |
|
print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) |
|
return None |
|
|
|
except Exception: |
|
print(f"Error verifying pickled file from {filename}:", file=sys.stderr) |
|
print(traceback.format_exc(), file=sys.stderr) |
|
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) |
|
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) |
|
return None |
|
|
|
return unsafe_torch_load(filename, *args, **kwargs) |
|
|
|
|
|
class Extra: |
|
""" |
|
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra |
|
(because it's not your code making the torch.load call). The intended use is like this: |
|
|
|
``` |
|
import torch |
|
from modules import safe |
|
|
|
def handler(module, name): |
|
if module == 'torch' and name in ['float64', 'float16']: |
|
return getattr(torch, name) |
|
|
|
return None |
|
|
|
with safe.Extra(handler): |
|
x = torch.load('model.pt') |
|
``` |
|
""" |
|
|
|
def __init__(self, handler): |
|
self.handler = handler |
|
|
|
def __enter__(self): |
|
global global_extra_handler |
|
|
|
assert global_extra_handler is None, 'already inside an Extra() block' |
|
global_extra_handler = self.handler |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
global global_extra_handler |
|
|
|
global_extra_handler = None |
|
|
|
|
|
unsafe_torch_load = torch.load |
|
torch.load = load |
|
global_extra_handler = None |
|
|
|
|