File size: 3,479 Bytes
7ae68fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import zipfile
import os.path as osp
# import lmdb
import logging
from PIL import Image
import pickle
import io
import glob
import os
from pathlib import Path
import time
from threading import Thread
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

home = str(Path.home())
abs_blob_path=os.path.realpath("/mnt/blob/")
CACHE_FOLDER=os.path.join(home,"caching")
USE_CACHE=True

def norm(path):
    assert "*" not in path
    return os.path.realpath(os.path.abspath(path))

def in_blob(file):
    if abs_blob_path in file:
        return True
    else:
        return False

def map_name(file):
    path=norm(file)
    path=path.lstrip(abs_blob_path+"/")
    path=path.replace("/","_")
    assert len(path)<250
    return path


def preload(db,sync=False):
    if sync:
        db.initialize()
    else:
        p = Thread(target=db.initialize)
        p.start()

def get_keys_from_lmdb(db):
    with db.begin(write=False) as txn:
        return list(txn.cursor().iternext(values=False))

def decode_img(byteflow):
    try:
        img=Image.open(io.BytesIO(byteflow)).convert("RGB")
        img.load()
    except:
        img = Image.open("white.jpeg").convert("RGB")
        img.load()
    return img

def decode_text(byteflow):
    return pickle.loads(byteflow)
    
decode_funcs={
    "image": decode_img,
    "text": decode_text
}


class ZipManager:
    def __init__(self, zip_path,data_type,prefix=None) -> None:
        self.decode_func=decode_funcs[data_type]
        self.zip_path=zip_path
        self._init=False
        preload(self)
        
    def deinitialze(self):
        self.zip_fd.close()
        del self.zip_fd
        self._init = False

    def initialize(self,close=True):
        self.zip_fd = zipfile.ZipFile(self.zip_path, mode="r")
        if not hasattr(self,"_keys"):
            self._keys = self.zip_fd.namelist()
        self._init = True
        if close:
            self.deinitialze()
        
    @property
    def keys(self):
        while not hasattr(self,"_keys"):
            time.sleep(0.1)
        return self._keys

    def get(self, name):
        if not self._init:
            self.initialize(close=False)  
        byteflow = self.zip_fd.read(name)
        return self.decode_func(byteflow)


class MultipleZipManager:
    def __init__(self, files: list, data_type, sync=True):
        self.files = files
        self._is_init = False
        self.data_type=data_type
        if sync:
            print("sync",files)
            self.initialize()
        else:
            print("async",files)
            preload(self)
        print("initialize over")
        

    def initialize(self):
        self.mapping={}
        self.managers={}
        for file in self.files:
            manager = ZipManager(file, self.data_type)
            self.managers[file]=manager

        for file,manager in self.managers.items():
            print(file)
            # print("loading")
            logging.info(f"{file} loading")
            keys=manager.keys
            for key in keys:
                self.mapping[key]=file
            logging.info(f"{file} loaded, size = {len(keys)}")
            print("loaded")

        self._keys=list(self.mapping.keys())
        self._is_init=True

    @property
    def keys(self):
        while not self._is_init:
            time.sleep(0.1)
        return self._keys
        
    def get(self, name):
        data = self.managers[self.mapping[name]].get(name)
        return data