Spaces:
Running
Running
File size: 5,105 Bytes
436c4c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# import numpy as np
# import plotly.graph_objects as go
# from scipy.interpolate import griddata
# def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
# detectability = np.array(detectability_val)
# distortion = np.array(distortion_val)
# euclidean = np.array(euclidean_val)
# # Find the closest point to the origin
# distances_to_origin = np.linalg.norm(np.array([distortion, detectability, euclidean]).T, axis=1)
# closest_point_index = np.argmin(distances_to_origin)
# # Determine the closest points to each axis
# closest_to_x_axis = np.argmin(distortion)
# closest_to_y_axis = np.argmin(detectability)
# closest_to_z_axis = np.argmin(euclidean)
# # Use the detected closest point as the "sweet spot"
# sweet_spot_detectability = detectability[closest_point_index]
# sweet_spot_distortion = distortion[closest_point_index]
# sweet_spot_euclidean = euclidean[closest_point_index]
# # Create a meshgrid from the data
# x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
# np.linspace(min(distortion), max(distortion), 30))
# # Interpolate z values (Euclidean distances) to fit the grid
# z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear')
# if z_grid is None:
# raise ValueError("griddata could not generate a valid interpolation. Check your input data.")
# # Create the 3D contour plot with the Plasma color scale
# fig = go.Figure(data=go.Surface(
# z=z_grid,
# x=x_grid,
# y=y_grid,
# contours={
# "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
# },
# colorscale='Plasma'
# ))
# # Add a marker for the sweet spot
# fig.add_trace(go.Scatter3d(
# x=[sweet_spot_detectability],
# y=[sweet_spot_distortion],
# z=[sweet_spot_euclidean],
# mode='markers+text',
# marker=dict(size=10, color='red', symbol='circle'),
# text=["Sweet Spot"],
# textposition="top center"
# ))
# # Set axis labels
# fig.update_layout(
# scene=dict(
# xaxis_title='Detectability Score',
# yaxis_title='Distortion Score',
# zaxis_title='Euclidean Distance'
# ),
# margin=dict(l=0, r=0, b=0, t=0)
# )
# return fig
import numpy as np
import plotly.graph_objects as go
from scipy.interpolate import griddata
def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
detectability = np.array(detectability_val)
distortion = np.array(distortion_val)
euclidean = np.array(euclidean_val)
# Normalize the values to range [0, 1]
norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability))
norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion))
norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean))
# Composite score: maximize detectability, minimize distortion and Euclidean distance
# We subtract distortion and euclidean as we want them minimized.
composite_score = norm_detectability - (norm_distortion + norm_euclidean)
# Find the index of the maximum score (sweet spot)
sweet_spot_index = np.argmax(composite_score)
# Sweet spot values
sweet_spot_detectability = detectability[sweet_spot_index]
sweet_spot_distortion = distortion[sweet_spot_index]
sweet_spot_euclidean = euclidean[sweet_spot_index]
# Create a meshgrid from the data
x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
np.linspace(min(distortion), max(distortion), 30))
# Interpolate z values (Euclidean distances) to fit the grid
z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear')
if z_grid is None:
raise ValueError("griddata could not generate a valid interpolation. Check your input data.")
# Create the 3D contour plot with the Plasma color scale
fig = go.Figure(data=go.Surface(
z=z_grid,
x=x_grid,
y=y_grid,
contours={
"z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
},
colorscale='Plasma'
))
# Add a marker for the sweet spot
fig.add_trace(go.Scatter3d(
x=[sweet_spot_detectability],
y=[sweet_spot_distortion],
z=[sweet_spot_euclidean],
mode='markers+text',
marker=dict(size=10, color='red', symbol='circle'),
text=["Sweet Spot"],
textposition="top center"
))
# Set axis labels
fig.update_layout(
scene=dict(
xaxis_title='Detectability Score',
yaxis_title='Distortion Score',
zaxis_title='Euclidean Distance'
),
margin=dict(l=0, r=0, b=0, t=0)
)
return fig
|