Spaces:
Runtime error
Runtime error
File size: 4,541 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import shutil
import warnings
import mmcv
from mmocr import digit_version
from mmocr.utils import list_from_file
class LmdbAnnFileBackend:
"""Lmdb storage backend for annotation file.
Args:
lmdb_path (str): Lmdb file path.
"""
def __init__(self, lmdb_path, encoding='utf8'):
self.lmdb_path = lmdb_path
self.encoding = encoding
env = self._get_env()
with env.begin(write=False) as txn:
self.total_number = int(
txn.get('total_number'.encode('utf-8')).decode(self.encoding))
def __getitem__(self, index):
"""Retrieve one line from lmdb file by index."""
# only attach env to self when __getitem__ is called
# because env object cannot be pickle
if not hasattr(self, 'env'):
self.env = self._get_env()
with self.env.begin(write=False) as txn:
line = txn.get(str(index).encode('utf-8')).decode(self.encoding)
return line
def __len__(self):
return self.total_number
def _get_env(self):
try:
import lmdb
except ImportError:
raise ImportError(
'Please install lmdb to enable LmdbAnnFileBackend.')
return lmdb.open(
self.lmdb_path,
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
def close(self):
self.env.close()
class HardDiskAnnFileBackend:
"""Load annotation file with raw hard disks storage backend."""
def __init__(self, file_format='txt'):
assert file_format in ['txt', 'lmdb']
self.file_format = file_format
def __call__(self, ann_file):
if self.file_format == 'lmdb':
return LmdbAnnFileBackend(ann_file)
return list_from_file(ann_file)
class PetrelAnnFileBackend:
"""Load annotation file with petrel storage backend."""
def __init__(self, file_format='txt', save_dir='tmp_dir'):
assert file_format in ['txt', 'lmdb']
self.file_format = file_format
self.save_dir = save_dir
def __call__(self, ann_file):
file_client = mmcv.FileClient(backend='petrel')
if self.file_format == 'lmdb':
mmcv_version = digit_version(mmcv.__version__)
if mmcv_version < digit_version('1.3.16'):
raise Exception('Please update mmcv to 1.3.16 or higher '
'to enable "get_local_path" of "FileClient".')
assert file_client.isdir(ann_file)
files = file_client.list_dir_or_file(ann_file)
ann_file_rel_path = ann_file.split('s3://')[-1]
ann_file_dir = osp.dirname(ann_file_rel_path)
ann_file_name = osp.basename(ann_file_rel_path)
local_dir = osp.join(self.save_dir, ann_file_dir, ann_file_name)
if osp.exists(local_dir):
warnings.warn(
f'local_ann_file: {local_dir} is already existed and '
'will be used. If it is not the correct ann_file '
'corresponding to {ann_file}, please remove it or '
'change "save_dir" first then try again.')
else:
os.makedirs(local_dir, exist_ok=True)
print(f'Fetching {ann_file} to {local_dir}...')
for each_file in files:
tmp_file_path = file_client.join_path(ann_file, each_file)
with file_client.get_local_path(
tmp_file_path) as local_path:
shutil.copy(local_path, osp.join(local_dir, each_file))
return LmdbAnnFileBackend(local_dir)
lines = str(file_client.get(ann_file), encoding='utf-8').split('\n')
return [x for x in lines if x.strip() != '']
class HTTPAnnFileBackend:
"""Load annotation file with http storage backend."""
def __init__(self, file_format='txt'):
assert file_format in ['txt', 'lmdb']
self.file_format = file_format
def __call__(self, ann_file):
file_client = mmcv.FileClient(backend='http')
if self.file_format == 'lmdb':
raise NotImplementedError(
'Loading lmdb file on http is not supported yet.')
lines = str(file_client.get(ann_file), encoding='utf-8').split('\n')
return [x for x in lines if x.strip() != '']
|