ecommurz-talent-search-engine / cpu_unpickler.py
Elvan Selvano
Upload cpu_unpickler.py
02757a2
raw history blame
No virus
441 Bytes
import io
import pickle
import torch
class cpu_unpickler(pickle.Unpickler):
"""
Overrides the default behavior of the `Unpickler` class to load
a `torch.storage` object from abyte string
"""
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
return super().find_class(module, name)