File size: 878 Bytes
58627fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import numpy as np

from bitarray import bitarray


class IndexManager():
    def __init__(self, dim):
        self.dim = dim

    def save(self, tensor, path_prefix):
        torch.save(tensor, path_prefix)

    def save_bitarray(self, bitarray, path_prefix):
        with open(path_prefix, "wb") as f:
            bitarray.tofile(f)


def load_index_part(filename, verbose=True):
    part = torch.load(filename)

    if type(part) == list:  # for backward compatibility
        part = torch.cat(part)

    return part


def load_compressed_index_part(filename, dim, bits):
    a = bitarray()

    with open(filename, "rb") as f:
        a.fromfile(f)

    n = len(a) // dim // bits
    part = torch.tensor(np.frombuffer(a.tobytes(), dtype=np.uint8))  # TODO: isn't from_numpy(.) faster?
    part = part.reshape((n, int(np.ceil(dim * bits / 8))))

    return part