import logging import os import json import os.path as op import numpy as np from typing import List, Union from collections import OrderedDict def generate_lineidx(filein, idxout): idxout_tmp = idxout + '.tmp' with open(filein, 'r') as tsvin, open(idxout_tmp,'w') as tsvout: fsize = os.fstat(tsvin.fileno()).st_size fpos = 0 while fpos!=fsize: tsvout.write(str(fpos)+"\n") tsvin.readline() fpos = tsvin.tell() os.rename(idxout_tmp, idxout) def read_to_character(fp, c): result = [] while True: s = fp.read(32) assert s != '' if c in s: result.append(s[: s.index(c)]) break else: result.append(s) return ''.join(result) class TSVFile(object): def __init__(self, tsv_file, generate_lineidx=False): self.tsv_file = tsv_file self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' self._fp = None self._lineidx = None # the process always keeps the process which opens the file. # If the pid is not equal to the currrent pid, we will re-open the file. self.pid = None # generate lineidx if not exist if not op.isfile(self.lineidx) and generate_lineidx: generate_lineidx(self.tsv_file, self.lineidx) def __del__(self): if self._fp: self._fp.close() def __str__(self): return "TSVFile(tsv_file='{}')".format(self.tsv_file) def __repr__(self): return str(self) def num_rows(self): self._ensure_lineidx_loaded() return len(self._lineidx) def seek(self, idx): self._ensure_tsv_opened() self._ensure_lineidx_loaded() try: pos = self._lineidx[idx] except: logging.info('{}-{}'.format(self.tsv_file, idx)) raise self._fp.seek(pos) return [s.strip() for s in self._fp.readline().split('\t')] def seek_first_column(self, idx): self._ensure_tsv_opened() self._ensure_lineidx_loaded() pos = self._lineidx[idx] self._fp.seek(pos) return read_to_character(self._fp, '\t') def get_key(self, idx): return self.seek_first_column(idx) def __getitem__(self, index): return self.seek(index) def __len__(self): return self.num_rows() def _ensure_lineidx_loaded(self): if self._lineidx is None: # print('loading lineidx: {}'.format(self.lineidx)) with open(self.lineidx, 'r') as fp: self._lineidx = [int(i.strip()) for i in fp.readlines()] def _ensure_tsv_opened(self): if self._fp is None: self._fp = open(self.tsv_file, 'r') self.pid = os.getpid() if self.pid != os.getpid(): # print('re-open {} because the process id changed'.format(self.tsv_file)) self._fp = open(self.tsv_file, 'r') self.pid = os.getpid() class TSVFileNew(object): def __init__(self, tsv_file: str, if_generate_lineidx: bool = False, lineidx: str = None, class_selector: List[str] = None): self.tsv_file = tsv_file self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' \ if not lineidx else lineidx self.linelist = op.splitext(tsv_file)[0] + '.linelist' self.chunks = op.splitext(tsv_file)[0] + '.chunks' self._fp = None self._lineidx = None self._sample_indices = None self._class_boundaries = None self._class_selector = class_selector # the process always keeps the process which opens the file. # If the pid is not equal to the currrent pid, we will re-open the file. self.pid = None # generate lineidx if not exist if not op.isfile(self.lineidx) and if_generate_lineidx: generate_lineidx(self.tsv_file, self.lineidx) def __del__(self): if self._fp: self._fp.close() def __str__(self): return "TSVFile(tsv_file='{}')".format(self.tsv_file) def __repr__(self): return str(self) def get_class_boundaries(self): return self._class_boundaries def num_rows(self): self._ensure_lineidx_loaded() return len(self._sample_indices) def seek(self, idx: int): self._ensure_tsv_opened() self._ensure_lineidx_loaded() try: pos = self._lineidx[self._sample_indices[idx]] except: logging.info('=> {}-{}'.format(self.tsv_file, idx)) raise self._fp.seek(pos) return [s.strip() for s in self._fp.readline().split('\t')] def seek_first_column(self, idx: int): self._ensure_tsv_opened() self._ensure_lineidx_loaded() pos = self._lineidx[idx] self._fp.seek(pos) return read_to_character(self._fp, '\t') def get_key(self, idx: int): return self.seek_first_column(idx) def __getitem__(self, index: int): return self.seek(index) def __len__(self): return self.num_rows() def _ensure_lineidx_loaded(self): if self._lineidx is None: # print('=> loading lineidx: {}'.format(self.lineidx)) with open(self.lineidx, 'r') as fp: lines = fp.readlines() lines = [line.strip() for line in lines] self._lineidx = [int(line) for line in lines] # except: # print("error in loading lineidx file {}, regenerate it".format(self.lineidx)) # generate_lineidx(self.tsv_file, self.lineidx) # with open(self.lineidx, 'r') as fp: # lines = fp.readlines() # lines = [line.strip() for line in lines] # self._lineidx = [int(line) for line in lines] # read the line list if exists linelist = None if op.isfile(self.linelist): with open(self.linelist, 'r') as fp: linelist = sorted( [ int(line.strip()) for line in fp.readlines() ] ) if op.isfile(self.chunks) and self._class_selector: self._sample_indices = [] self._class_boundaries = [] class_boundaries = json.load(open(self.chunks, 'r')) for class_name, boundary in class_boundaries.items(): start = len(self._sample_indices) if class_name in self._class_selector: for idx in range(boundary[0], boundary[1] + 1): # NOTE: potentially slow when linelist is long, try to speed it up if linelist and idx not in linelist: continue self._sample_indices.append(idx) end = len(self._sample_indices) self._class_boundaries.append((start, end)) else: if linelist: self._sample_indices = linelist else: self._sample_indices = list(range(len(self._lineidx))) def _ensure_tsv_opened(self): if self._fp is None: self._fp = open(self.tsv_file, 'r') self.pid = os.getpid() if self.pid != os.getpid(): logging.debug('=> re-open {} because the process id changed'.format(self.tsv_file)) self._fp = open(self.tsv_file, 'r') self.pid = os.getpid() class LRU(OrderedDict): """Limit size, evicting the least recently looked-up key when full. https://docs.python.org/3/library/collections.html#collections.OrderedDict """ def __init__(self, maxsize=4, *args, **kwds): self.maxsize = maxsize super().__init__(*args, **kwds) def __getitem__(self, key): value = super().__getitem__(key) self.move_to_end(key) return value def __setitem__(self, key, value): if key in self: self.move_to_end(key) super().__setitem__(key, value) if len(self) > self.maxsize: oldest = next(iter(self)) del self[oldest] class CompositeTSVFile: def __init__(self, file_list: Union[str, list], root: str = '.', class_selector: List[str] = None): if isinstance(file_list, str): self.file_list = load_list_file(file_list) else: assert isinstance(file_list, list) self.file_list = file_list self.root = root self.cache = LRU() self.tsvs = None self.chunk_sizes = None self.accum_chunk_sizes = None self._class_selector = class_selector self._class_boundaries = None self.initialized = False self.initialize() def get_key(self, index: int): idx_source, idx_row = self._calc_chunk_idx_row(index) k = self.tsvs[idx_source].get_key(idx_row) return '_'.join([self.file_list[idx_source], k]) def get_class_boundaries(self): return self._class_boundaries def get_chunk_size(self): return self.chunk_sizes def num_rows(self): return sum(self.chunk_sizes) def _calc_chunk_idx_row(self, index: int): idx_chunk = 0 idx_row = index while index >= self.accum_chunk_sizes[idx_chunk]: idx_chunk += 1 idx_row = index - self.accum_chunk_sizes[idx_chunk-1] return idx_chunk, idx_row def __getitem__(self, index: int): idx_source, idx_row = self._calc_chunk_idx_row(index) if idx_source not in self.cache: self.cache[idx_source] = TSVFileNew( op.join(self.root, self.file_list[idx_source]), class_selector=self._class_selector ) return self.cache[idx_source].seek(idx_row) def __len__(self): return sum(self.chunk_sizes) def initialize(self): """ this function has to be called in init function if cache_policy is enabled. Thus, let's always call it in init funciton to make it simple. """ if self.initialized: return tsvs = [ TSVFileNew( op.join(self.root, f), class_selector=self._class_selector ) for f in self.file_list ] logging.info("Calculating chunk sizes ...") self.chunk_sizes = [len(tsv) for tsv in tsvs] self.accum_chunk_sizes = [0] for size in self.chunk_sizes: self.accum_chunk_sizes += [self.accum_chunk_sizes[-1] + size] self.accum_chunk_sizes = self.accum_chunk_sizes[1:] if ( self._class_selector and all([tsv.get_class_boundaries() for tsv in tsvs]) ): """ Note: When using CompositeTSVFile, make sure that the classes contained in each tsv file do not overlap. Otherwise, the class boundaries won't be correct. """ self._class_boundaries = [] offset = 0 for tsv in tsvs: boundaries = tsv.get_class_boundaries() for bound in boundaries: self._class_boundaries.append((bound[0] + offset, bound[1] + offset)) offset += len(tsv) # NOTE: in current setting, get_key is not used during training, so we remove tsvs for saving memory cost del tsvs self.initialized = True def load_list_file(fname: str) -> List[str]: with open(fname, 'r') as fp: lines = fp.readlines() result = [line.strip() for line in lines] if len(result) > 0 and result[-1] == '': result = result[:-1] return result