Spaces:
Runtime error
Runtime error
File size: 1,997 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import numpy as np
import pytest
from mmocr.datasets.base_dataset import BaseDataset
def _create_dummy_ann_file(ann_file):
ann_info1 = 'sample1.jpg hello'
ann_info2 = 'sample2.jpg world'
with open(ann_file, 'w') as fw:
for ann_info in [ann_info1, ann_info2]:
fw.write(ann_info + '\n')
def _create_dummy_loader():
loader = dict(
type='HardDiskLoader',
repeat=1,
parser=dict(type='LineStrParser', keys=['file_name', 'text']))
return loader
def test_custom_dataset():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
_create_dummy_ann_file(ann_file)
loader = _create_dummy_loader()
for mode in [True, False]:
dataset = BaseDataset(ann_file, loader, pipeline=[], test_mode=mode)
# test len
assert len(dataset) == len(dataset.data_infos)
# test set group flag
assert np.allclose(dataset.flag, [0, 0])
# test prepare_train_img
expect_results = {
'img_info': {
'file_name': 'sample1.jpg',
'text': 'hello'
},
'img_prefix': ''
}
assert dataset.prepare_train_img(0) == expect_results
# test prepare_test_img
assert dataset.prepare_test_img(0) == expect_results
# test __getitem__
assert dataset[0] == expect_results
# test get_next_index
assert dataset._get_next_index(0) == 1
# test format_resuls
expect_results_copy = {
key: value
for key, value in expect_results.items()
}
dataset.format_results(expect_results)
assert expect_results_copy == expect_results
# test evaluate
with pytest.raises(NotImplementedError):
dataset.evaluate(expect_results)
tmp_dir.cleanup()
|