File size: 4,522 Bytes
4121bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) Facebook, Inc. and its affiliates.
import io
import numpy as np
import torch

from detectron2 import model_zoo
from detectron2.data import DatasetCatalog
from detectron2.data.detection_utils import read_image
from detectron2.modeling import build_model
from detectron2.structures import Boxes, Instances, ROIMasks
from detectron2.utils.file_io import PathManager


"""
Internal utilities for tests. Don't use except for writing tests.
"""


def get_model_no_weights(config_path):
    """
    Like model_zoo.get, but do not load any weights (even pretrained)
    """
    cfg = model_zoo.get_config(config_path)
    if not torch.cuda.is_available():
        cfg.MODEL.DEVICE = "cpu"
    return build_model(cfg)


def random_boxes(num_boxes, max_coord=100, device="cpu"):
    """
    Create a random Nx4 boxes tensor, with coordinates < max_coord.
    """
    boxes = torch.rand(num_boxes, 4, device=device) * (max_coord * 0.5)
    boxes.clamp_(min=1.0)  # tiny boxes cause numerical instability in box regression
    # Note: the implementation of this function in torchvision is:
    # boxes[:, 2:] += torch.rand(N, 2) * 100
    # but it does not guarantee non-negative widths/heights constraints:
    # boxes[:, 2] >= boxes[:, 0] and boxes[:, 3] >= boxes[:, 1]:
    boxes[:, 2:] += boxes[:, :2]
    return boxes


def get_sample_coco_image(tensor=True):
    """
    Args:
        tensor (bool): if True, returns 3xHxW tensor.
            else, returns a HxWx3 numpy array.

    Returns:
        an image, in BGR color.
    """
    try:
        file_name = DatasetCatalog.get("coco_2017_val_100")[0]["file_name"]
        if not PathManager.exists(file_name):
            raise FileNotFoundError()
    except IOError:
        # for public CI to run
        file_name = "http://images.cocodataset.org/train2017/000000000009.jpg"
    ret = read_image(file_name, format="BGR")
    if tensor:
        ret = torch.from_numpy(np.ascontiguousarray(ret.transpose(2, 0, 1)))
    return ret


def convert_scripted_instances(instances):
    """
    Convert a scripted Instances object to a regular :class:`Instances` object
    """
    ret = Instances(instances.image_size)
    for name in instances._field_names:
        val = getattr(instances, "_" + name, None)
        if val is not None:
            ret.set(name, val)
    return ret


def assert_instances_allclose(input, other, *, rtol=1e-5, msg="", size_as_tensor=False):
    """
    Args:
        input, other (Instances):
        size_as_tensor: compare image_size of the Instances as tensors (instead of tuples).
             Useful for comparing outputs of tracing.
    """
    if not isinstance(input, Instances):
        input = convert_scripted_instances(input)
    if not isinstance(other, Instances):
        other = convert_scripted_instances(other)

    if not msg:
        msg = "Two Instances are different! "
    else:
        msg = msg.rstrip() + " "

    size_error_msg = msg + f"image_size is {input.image_size} vs. {other.image_size}!"
    if size_as_tensor:
        assert torch.equal(
            torch.tensor(input.image_size), torch.tensor(other.image_size)
        ), size_error_msg
    else:
        assert input.image_size == other.image_size, size_error_msg
    fields = sorted(input.get_fields().keys())
    fields_other = sorted(other.get_fields().keys())
    assert fields == fields_other, msg + f"Fields are {fields} vs {fields_other}!"

    for f in fields:
        val1, val2 = input.get(f), other.get(f)
        if isinstance(val1, (Boxes, ROIMasks)):
            # boxes in the range of O(100) and can have a larger tolerance
            assert torch.allclose(val1.tensor, val2.tensor, atol=100 * rtol), (
                msg + f"Field {f} differs too much!"
            )
        elif isinstance(val1, torch.Tensor):
            if val1.dtype.is_floating_point:
                mag = torch.abs(val1).max().cpu().item()
                assert torch.allclose(val1, val2, atol=mag * rtol), (
                    msg + f"Field {f} differs too much!"
                )
            else:
                assert torch.equal(val1, val2), msg + f"Field {f} is different!"
        else:
            raise ValueError(f"Don't know how to compare type {type(val1)}")


def reload_script_model(module):
    """
    Save a jit module and load it back.
    Similar to the `getExportImportCopy` function in torch/testing/
    """
    buffer = io.BytesIO()
    torch.jit.save(module, buffer)
    buffer.seek(0)
    return torch.jit.load(buffer)