File size: 5,105 Bytes
436c4c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# import numpy as np
# import plotly.graph_objects as go
# from scipy.interpolate import griddata

# def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
#     detectability = np.array(detectability_val)
#     distortion = np.array(distortion_val)
#     euclidean = np.array(euclidean_val)

#     # Find the closest point to the origin
#     distances_to_origin = np.linalg.norm(np.array([distortion, detectability, euclidean]).T, axis=1)
#     closest_point_index = np.argmin(distances_to_origin)

#     # Determine the closest points to each axis
#     closest_to_x_axis = np.argmin(distortion)
#     closest_to_y_axis = np.argmin(detectability)
#     closest_to_z_axis = np.argmin(euclidean)

#     # Use the detected closest point as the "sweet spot"
#     sweet_spot_detectability = detectability[closest_point_index]
#     sweet_spot_distortion = distortion[closest_point_index]
#     sweet_spot_euclidean = euclidean[closest_point_index]

#     # Create a meshgrid from the data
#     x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
#                                  np.linspace(min(distortion), max(distortion), 30))

#     # Interpolate z values (Euclidean distances) to fit the grid
#     z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear')

#     if z_grid is None:
#         raise ValueError("griddata could not generate a valid interpolation. Check your input data.")

#     # Create the 3D contour plot with the Plasma color scale
#     fig = go.Figure(data=go.Surface(
#         z=z_grid, 
#         x=x_grid, 
#         y=y_grid, 
#         contours={
#             "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
#         },
#         colorscale='Plasma'
#     ))

#     # Add a marker for the sweet spot
#     fig.add_trace(go.Scatter3d(
#         x=[sweet_spot_detectability],
#         y=[sweet_spot_distortion],
#         z=[sweet_spot_euclidean],
#         mode='markers+text',
#         marker=dict(size=10, color='red', symbol='circle'),
#         text=["Sweet Spot"],
#         textposition="top center"
#     ))

#     # Set axis labels
#     fig.update_layout(
#         scene=dict(
#             xaxis_title='Detectability Score',
#             yaxis_title='Distortion Score',
#             zaxis_title='Euclidean Distance'
#         ),
#         margin=dict(l=0, r=0, b=0, t=0)
#     )

#     return fig


import numpy as np
import plotly.graph_objects as go
from scipy.interpolate import griddata

def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
    detectability = np.array(detectability_val)
    distortion = np.array(distortion_val)
    euclidean = np.array(euclidean_val)

    # Normalize the values to range [0, 1]
    norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability))
    norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion))
    norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean))

    # Composite score: maximize detectability, minimize distortion and Euclidean distance
    # We subtract distortion and euclidean as we want them minimized.
    composite_score = norm_detectability - (norm_distortion + norm_euclidean)

    # Find the index of the maximum score (sweet spot)
    sweet_spot_index = np.argmax(composite_score)

    # Sweet spot values
    sweet_spot_detectability = detectability[sweet_spot_index]
    sweet_spot_distortion = distortion[sweet_spot_index]
    sweet_spot_euclidean = euclidean[sweet_spot_index]

    # Create a meshgrid from the data
    x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
                                 np.linspace(min(distortion), max(distortion), 30))

    # Interpolate z values (Euclidean distances) to fit the grid
    z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear')

    if z_grid is None:
        raise ValueError("griddata could not generate a valid interpolation. Check your input data.")

    # Create the 3D contour plot with the Plasma color scale
    fig = go.Figure(data=go.Surface(
        z=z_grid, 
        x=x_grid, 
        y=y_grid, 
        contours={
            "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
        },
        colorscale='Plasma'
    ))

    # Add a marker for the sweet spot
    fig.add_trace(go.Scatter3d(
        x=[sweet_spot_detectability],
        y=[sweet_spot_distortion],
        z=[sweet_spot_euclidean],
        mode='markers+text',
        marker=dict(size=10, color='red', symbol='circle'),
        text=["Sweet Spot"],
        textposition="top center"
    ))

    # Set axis labels
    fig.update_layout(
        scene=dict(
            xaxis_title='Detectability Score',
            yaxis_title='Distortion Score',
            zaxis_title='Euclidean Distance'
        ),
        margin=dict(l=0, r=0, b=0, t=0)
    )

    return fig