|
|
import streamlit as st |
|
|
import numpy as np |
|
|
import plotly.graph_objects as go |
|
|
|
|
|
|
|
|
def safe_eval(func_str, x_val): |
|
|
""" Safely evaluates the function at a given x value. """ |
|
|
allowed_names = {"x": x_val, "np": np} |
|
|
try: |
|
|
return eval(func_str, {"__builtins__": None}, allowed_names) |
|
|
except Exception as e: |
|
|
raise ValueError(f"Error evaluating the function: {e}") |
|
|
|
|
|
|
|
|
def derivative(func_str, x_val, h=1e-5): |
|
|
""" Numerically compute the derivative of the function at x using finite differences. """ |
|
|
return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h) |
|
|
|
|
|
|
|
|
def tangent_line(func_str, x_val, x_range): |
|
|
""" Compute the tangent line at a given x value. """ |
|
|
y_val = safe_eval(func_str, x_val) |
|
|
slope = derivative(func_str, x_val) |
|
|
return slope * (x_range - x_val) + y_val |
|
|
|
|
|
|
|
|
def reset_state(): |
|
|
st.session_state.x = st.session_state.starting_point |
|
|
st.session_state.iteration = 0 |
|
|
st.session_state.x_vals = [st.session_state.starting_point] |
|
|
st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)] |
|
|
|
|
|
|
|
|
if "func_input" not in st.session_state: |
|
|
st.session_state.func_input = "x**2 + x" |
|
|
if "x" not in st.session_state: |
|
|
st.session_state.x = 4.0 |
|
|
st.session_state.iteration = 0 |
|
|
st.session_state.x_vals = [4.0] |
|
|
st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)] |
|
|
|
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
* { |
|
|
font-family: Cambria, Arial, sans-serif !important; |
|
|
} |
|
|
h1, h2, h3, h4, h5 { |
|
|
text-align: center; |
|
|
margin-top: 0; |
|
|
} |
|
|
input, .stButton button, .stDownloadButton button { |
|
|
border: 2px solid #ea445a; |
|
|
border-radius: 5px; |
|
|
padding: 10px; |
|
|
} |
|
|
.stInfo, .stSuccess { |
|
|
border: 2px solid #ea445a; |
|
|
border-radius: 5px; |
|
|
padding: 10px; |
|
|
} |
|
|
.stButton { |
|
|
margin-top: 10px; |
|
|
} |
|
|
/* Reduced Padding at the top */ |
|
|
.css-1d391kg { |
|
|
padding-top: 0.5rem; |
|
|
} |
|
|
/* Centering the legend in the plot */ |
|
|
.stPlotlyChart { |
|
|
display: block; |
|
|
margin: 0 auto; |
|
|
} |
|
|
/* Adjusting for full width without scrolling */ |
|
|
.css-1lcbvhc { |
|
|
padding-left: 0; |
|
|
padding-right: 0; |
|
|
} |
|
|
/* Custom borders for input fields */ |
|
|
.stTextInput input, .stNumberInput input { |
|
|
border: 2px solid #001A6E; |
|
|
border-radius: 5px; |
|
|
padding: 10px; |
|
|
} |
|
|
/* Tooltip styling */ |
|
|
.tooltip { |
|
|
position: relative; |
|
|
display: inline-block; |
|
|
cursor: pointer; |
|
|
} |
|
|
.tooltip .tooltiptext { |
|
|
visibility: hidden; |
|
|
opacity: 0; |
|
|
width: 300px; |
|
|
background-color: #001A6E; |
|
|
color: #fff; |
|
|
text-align: center; |
|
|
border-radius: 5px; |
|
|
padding: 5px; |
|
|
position: absolute; |
|
|
z-index: 1; |
|
|
bottom: 125%; /* Position the tooltip above */ |
|
|
left: 50%; |
|
|
margin-left: -150px; |
|
|
transition: opacity 0.3s; |
|
|
} |
|
|
.tooltip:hover .tooltiptext { |
|
|
visibility: visible; |
|
|
opacity: 1; |
|
|
} |
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
|
|
|
st.title("π Gradient Descent Visualization Tool π") |
|
|
|
|
|
col1, col2 = st.columns([1, 2]) |
|
|
|
|
|
|
|
|
with col1: |
|
|
st.subheader("π§ Define Your Function") |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<div class="tooltip"> |
|
|
<label for="func_input">Enter a function of 'x':</label> |
|
|
<span class="tooltiptext"> |
|
|
**How to input your function:** |
|
|
- Please give the inputs as mentioned below |
|
|
- x^n as x**n, |
|
|
- sin(x) as np.sin(x) |
|
|
- log(x) as np.log(x), |
|
|
- e^x or exp(x) as np.exp(x). |
|
|
</span> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
|
|
|
func_input = st.text_input( |
|
|
"π", |
|
|
key="func_input", |
|
|
on_change=reset_state |
|
|
) |
|
|
|
|
|
st.subheader("βοΈ Gradient Descent Parameters") |
|
|
starting_point = st.number_input( |
|
|
"Starting Point (Xβ)", |
|
|
value=4.0, |
|
|
step=0.1, |
|
|
format="%.2f", |
|
|
key="starting_point", |
|
|
on_change=reset_state |
|
|
) |
|
|
learning_rate = st.number_input( |
|
|
"Learning Rate (Ε)", |
|
|
value=0.25, |
|
|
step=0.01, |
|
|
format="%.2f", |
|
|
key="learning_rate", |
|
|
on_change=reset_state |
|
|
) |
|
|
|
|
|
col3, col4 = st.columns(2) |
|
|
with col3: |
|
|
if st.button("π Set Up Function"): |
|
|
reset_state() |
|
|
with col4: |
|
|
if st.button("βΆοΈ Next Iteration"): |
|
|
try: |
|
|
grad = derivative(st.session_state.func_input, st.session_state.x) |
|
|
st.session_state.x = st.session_state.x - learning_rate * grad |
|
|
st.session_state.iteration += 1 |
|
|
st.session_state.x_vals.append(st.session_state.x) |
|
|
st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x)) |
|
|
except Exception as e: |
|
|
st.error(f"β οΈ Error: {str(e)}") |
|
|
|
|
|
|
|
|
with col2: |
|
|
st.subheader("π Gradient Descent Visualization") |
|
|
try: |
|
|
|
|
|
x_plot = np.linspace(-10, 10, 400) |
|
|
y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot] |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=x_plot, |
|
|
y=y_plot, |
|
|
mode="lines+markers", |
|
|
line=dict(color="blue", width=2), |
|
|
marker=dict(size=4, color="blue", symbol="circle"), |
|
|
name="Function" |
|
|
)) |
|
|
|
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=st.session_state.x_vals, |
|
|
y=st.session_state.y_vals, |
|
|
mode="markers", |
|
|
marker=dict(color="red", size=10), |
|
|
name="Gradient Descent Points" |
|
|
)) |
|
|
|
|
|
|
|
|
current_x = st.session_state.x |
|
|
tangent_x = np.linspace(-10, 10, 200) |
|
|
tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x) |
|
|
fig.add_trace(go.Scatter( |
|
|
x=tangent_x, |
|
|
y=tangent_y, |
|
|
mode="lines", |
|
|
line=dict(color="orange", width=3), |
|
|
name="Tangent Line" |
|
|
)) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
xaxis=dict( |
|
|
title="x-axis", |
|
|
range=[-10, 10], |
|
|
showline=True, |
|
|
linecolor="white", |
|
|
tickcolor="white", |
|
|
tickfont=dict(color="white"), |
|
|
ticks="outside", |
|
|
), |
|
|
yaxis=dict( |
|
|
title="y-axis", |
|
|
range=[min(y_plot) - 5, min(max(y_plot) + 5, 1000)], |
|
|
showline=True, |
|
|
linecolor="white", |
|
|
tickcolor="white", |
|
|
tickfont=dict(color="white"), |
|
|
ticks="outside", |
|
|
), |
|
|
plot_bgcolor="black", |
|
|
paper_bgcolor="black", |
|
|
title="", |
|
|
margin=dict(l=10, r=10, t=10, b=10), |
|
|
width=800, |
|
|
height=400, |
|
|
showlegend=True, |
|
|
legend=dict( |
|
|
x=1.1, |
|
|
y=0.5, |
|
|
xanchor="left", |
|
|
yanchor="middle", |
|
|
orientation="v", |
|
|
font=dict(size=12, color="white"), |
|
|
bgcolor="black", |
|
|
bordercolor="white", |
|
|
borderwidth=2, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2)) |
|
|
fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2)) |
|
|
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"β οΈ Error in visualization: {str(e)}") |
|
|
|
|
|
|
|
|
col5, col6 = st.columns(2) |
|
|
col5.info(f"π§βπ» Iteration: {st.session_state.iteration}") |
|
|
col6.success(f"β
Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}") |