Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import os.path as osp | |
import tempfile | |
import pytest | |
import torch | |
from mmocr.models import build_detector | |
def _create_dummy_vocab_file(vocab_file): | |
with open(vocab_file, 'w') as fw: | |
for char in list(map(chr, range(ord('a'), ord('z') + 1))): | |
fw.write(char + '\n') | |
def _get_config_module(fname): | |
"""Load a configuration as a python module.""" | |
from mmcv import Config | |
config_mod = Config.fromfile(fname) | |
return config_mod | |
def _get_detector_cfg(fname): | |
"""Grab configs necessary to create a detector. | |
These are deep copied to allow for safe modification of parameters without | |
influencing other tests. | |
""" | |
config = _get_config_module(fname) | |
model = copy.deepcopy(config.model) | |
return model | |
def test_bert_softmax(cfg_file): | |
# prepare data | |
texts = ['中'] * 47 | |
img = [31] * 47 | |
labels = [31] * 128 | |
input_ids = [0] * 128 | |
attention_mask = [0] * 128 | |
token_type_ids = [0] * 128 | |
img_metas = { | |
'texts': texts, | |
'labels': torch.tensor(labels).unsqueeze(0), | |
'img': img, | |
'input_ids': torch.tensor(input_ids).unsqueeze(0), | |
'attention_masks': torch.tensor(attention_mask).unsqueeze(0), | |
'token_type_ids': torch.tensor(token_type_ids).unsqueeze(0) | |
} | |
# create dummy data | |
tmp_dir = tempfile.TemporaryDirectory() | |
vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt') | |
_create_dummy_vocab_file(vocab_file) | |
model = _get_detector_cfg(cfg_file) | |
model['label_convertor']['vocab_file'] = vocab_file | |
detector = build_detector(model) | |
losses = detector.forward(img, img_metas) | |
assert isinstance(losses, dict) | |
model['loss']['type'] = 'MaskedFocalLoss' | |
detector = build_detector(model) | |
losses = detector.forward(img, img_metas) | |
assert isinstance(losses, dict) | |
tmp_dir.cleanup() | |
# Test forward test | |
with torch.no_grad(): | |
batch_results = [] | |
result = detector.forward(None, img_metas, return_loss=False) | |
batch_results.append(result) | |