File size: 441 Bytes
02757a2
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import io
import pickle

import torch

class cpu_unpickler(pickle.Unpickler):
    """
    Overrides the default behavior of the `Unpickler` class to load
    a `torch.storage` object from abyte string
    """
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        return super().find_class(module, name)