Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os | |
import os.path as osp | |
import shutil | |
import warnings | |
import mmcv | |
from mmocr import digit_version | |
from mmocr.utils import list_from_file | |
class LmdbAnnFileBackend: | |
"""Lmdb storage backend for annotation file. | |
Args: | |
lmdb_path (str): Lmdb file path. | |
""" | |
def __init__(self, lmdb_path, encoding='utf8'): | |
self.lmdb_path = lmdb_path | |
self.encoding = encoding | |
env = self._get_env() | |
with env.begin(write=False) as txn: | |
self.total_number = int( | |
txn.get('total_number'.encode('utf-8')).decode(self.encoding)) | |
def __getitem__(self, index): | |
"""Retrieve one line from lmdb file by index.""" | |
# only attach env to self when __getitem__ is called | |
# because env object cannot be pickle | |
if not hasattr(self, 'env'): | |
self.env = self._get_env() | |
with self.env.begin(write=False) as txn: | |
line = txn.get(str(index).encode('utf-8')).decode(self.encoding) | |
return line | |
def __len__(self): | |
return self.total_number | |
def _get_env(self): | |
try: | |
import lmdb | |
except ImportError: | |
raise ImportError( | |
'Please install lmdb to enable LmdbAnnFileBackend.') | |
return lmdb.open( | |
self.lmdb_path, | |
max_readers=1, | |
readonly=True, | |
lock=False, | |
readahead=False, | |
meminit=False, | |
) | |
def close(self): | |
self.env.close() | |
class HardDiskAnnFileBackend: | |
"""Load annotation file with raw hard disks storage backend.""" | |
def __init__(self, file_format='txt'): | |
assert file_format in ['txt', 'lmdb'] | |
self.file_format = file_format | |
def __call__(self, ann_file): | |
if self.file_format == 'lmdb': | |
return LmdbAnnFileBackend(ann_file) | |
return list_from_file(ann_file) | |
class PetrelAnnFileBackend: | |
"""Load annotation file with petrel storage backend.""" | |
def __init__(self, file_format='txt', save_dir='tmp_dir'): | |
assert file_format in ['txt', 'lmdb'] | |
self.file_format = file_format | |
self.save_dir = save_dir | |
def __call__(self, ann_file): | |
file_client = mmcv.FileClient(backend='petrel') | |
if self.file_format == 'lmdb': | |
mmcv_version = digit_version(mmcv.__version__) | |
if mmcv_version < digit_version('1.3.16'): | |
raise Exception('Please update mmcv to 1.3.16 or higher ' | |
'to enable "get_local_path" of "FileClient".') | |
assert file_client.isdir(ann_file) | |
files = file_client.list_dir_or_file(ann_file) | |
ann_file_rel_path = ann_file.split('s3://')[-1] | |
ann_file_dir = osp.dirname(ann_file_rel_path) | |
ann_file_name = osp.basename(ann_file_rel_path) | |
local_dir = osp.join(self.save_dir, ann_file_dir, ann_file_name) | |
if osp.exists(local_dir): | |
warnings.warn( | |
f'local_ann_file: {local_dir} is already existed and ' | |
'will be used. If it is not the correct ann_file ' | |
'corresponding to {ann_file}, please remove it or ' | |
'change "save_dir" first then try again.') | |
else: | |
os.makedirs(local_dir, exist_ok=True) | |
print(f'Fetching {ann_file} to {local_dir}...') | |
for each_file in files: | |
tmp_file_path = file_client.join_path(ann_file, each_file) | |
with file_client.get_local_path( | |
tmp_file_path) as local_path: | |
shutil.copy(local_path, osp.join(local_dir, each_file)) | |
return LmdbAnnFileBackend(local_dir) | |
lines = str(file_client.get(ann_file), encoding='utf-8').split('\n') | |
return [x for x in lines if x.strip() != ''] | |
class HTTPAnnFileBackend: | |
"""Load annotation file with http storage backend.""" | |
def __init__(self, file_format='txt'): | |
assert file_format in ['txt', 'lmdb'] | |
self.file_format = file_format | |
def __call__(self, ann_file): | |
file_client = mmcv.FileClient(backend='http') | |
if self.file_format == 'lmdb': | |
raise NotImplementedError( | |
'Loading lmdb file on http is not supported yet.') | |
lines = str(file_client.get(ann_file), encoding='utf-8').split('\n') | |
return [x for x in lines if x.strip() != ''] | |