File size: 3,708 Bytes
205a7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import logging
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch

logger = logging.getLogger(__name__)

# flake8: noqa
# mypy: ignore-errors


def download_and_extract_benchmark(name: str, url: Path, output: Path) -> None:
    benchmark_dir = output / name
    if not output.exists():
        output.mkdir(parents=True)

    if benchmark_dir.exists():
        logger.info(f"Benchmark {name} already exists at {benchmark_dir}, skipping download.")
        return

    if name == "stanford2d3d":
        # prompt user to sign data sharing and usage terms
        txt = "\n" + "#" * 108 + "\n\n"
        txt += "To download the Stanford2D3D dataset, you must agree to the terms of use:\n\n"
        txt += (
            "https://docs.google.com/forms/d/e/"
            + "1FAIpQLScFR0U8WEUtb7tgjOhhnl31OrkEs73-Y8bQwPeXgebqVKNMpQ/viewform?c=0&w=1\n\n"
        )
        txt += "#" * 108 + "\n\n"
        txt += "Did you fill out the data sharing and usage terms? [y/n] "
        choice = input(txt)
        if choice.lower() != "y":
            raise ValueError(
                "You must agree to the terms of use to download the Stanford2D3D dataset."
            )

    zip_file = output / f"{name}.zip"

    if not zip_file.exists():
        logger.info(f"Downloading benchmark {name} to {zip_file} from {url}.")
        torch.hub.download_url_to_file(url, zip_file)

    logger.info(f"Extracting benchmark {name} in {output}.")
    shutil.unpack_archive(zip_file, output, format="zip")
    zip_file.unlink()


def check_keys_recursive(d, pattern):
    if isinstance(pattern, dict):
        {check_keys_recursive(d[k], v) for k, v in pattern.items()}
    else:
        for k in pattern:
            assert k in d.keys()


def plot_scatter_grid(
    results, x_keys, y_keys, name=None, diag=False, ax=None, line_idx=0, show_means=True
):  # sourcery skip: low-code-quality
    if ax is None:
        N, M = len(y_keys), len(x_keys)
        fig, ax = plt.subplots(N, M, figsize=(M * 6, N * 5))

        if N == 1:
            ax = np.array(ax)
            ax = ax.reshape(1, -1)

        if M == 1:
            ax = np.array(ax)
            ax = ax.reshape(-1, 1)
    else:
        fig = None

    for j, kx in enumerate(x_keys):
        for i, ky in enumerate(y_keys):
            ax[i, j].scatter(
                results[kx],
                results[ky],
                s=1,
                alpha=0.5,
                label=name or None,
            )

            ax[i, j].set_xlabel(f"{' '.join(kx.split('_')).title()}")
            ax[i, j].set_ylabel(f"{' '.join(ky.split('_')).title()}")

            low = min(ax[i, j].get_xlim()[0], ax[i, j].get_ylim()[0])
            high = max(ax[i, j].get_xlim()[1], ax[i, j].get_ylim()[1])
            if diag == "all" or (i == j and diag):
                ax[i, j].plot([low, high], [low, high], ls="--", c="red", label="y=x")

            if name or diag == "all" or (i == j and diag):
                ax[i, j].legend()

    if not show_means:
        return fig, ax

    means = {"y": {}, "x": {}}
    for kx in x_keys:
        for ky in y_keys:
            means["x"][kx] = np.mean(results[kx])
            means["y"][ky] = np.mean(results[ky])

    for j, kx in enumerate(x_keys):
        for i, ky in enumerate(y_keys):
            xlim = np.min(results[kx]), np.max(results[kx])
            ylim = np.min(results[ky]), np.max(results[ky])
            means_x = [means["x"][kx]]
            means_y = [means["y"][ky]]
            color = plt.cm.tab10(line_idx)
            ax[i, j].vlines(means_x, *ylim, colors=[color])
            ax[i, j].hlines(means_y, *xlim, colors=[color])

    return fig, ax