File size: 4,191 Bytes
6f02b1f
 
 
 
 
 
 
26f0716
6f02b1f
 
 
 
 
26f0716
6f02b1f
 
 
 
 
 
26f0716
6f02b1f
 
 
 
 
 
 
 
 
 
26f0716
6f02b1f
 
 
 
 
26f0716
6f02b1f
 
 
 
 
 
26f0716
6f02b1f
 
 
26f0716
6f02b1f
 
 
26f0716
6f02b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26f0716
6f02b1f
 
 
26f0716
6f02b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26f0716
6f02b1f
 
 
26f0716
6f02b1f
 
 
26f0716
0262264
 
26f0716
6f02b1f
 
0262264
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
import numpy as np
import plotly.graph_objs as go
import sympy as sp

st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide")

# inputs
st.sidebar.header("Gradient Descent Settings")
func_input = st.sidebar.text_input("Enter a function (use 'x'):", "x**2")
learning_rate = st.sidebar.number_input("Learning Rate", min_value=0.001, max_value=1.0, value=0.1, step=0.01)
initial_x = st.sidebar.number_input("Initial X", min_value=-10.0, max_value=10.0, value=5.0, step=0.1)

# Reset when functin changes
if "previous_func" not in st.session_state or st.session_state.previous_func != func_input:
    st.session_state.current_x = initial_x
    st.session_state.iteration = 0
    st.session_state.path = [(initial_x, 0)]
    st.session_state.previous_func = func_input

# Symbolic computation
x = sp.symbols('x')
try:
    func = sp.sympify(func_input)
    derivative = sp.diff(func, x)
    func_np = sp.lambdify(x, func, 'numpy')
    derivative_np = sp.lambdify(x, derivative, 'numpy')
except Exception as e:
    st.error(f"Invalid function: {e}")
    st.stop()

# Gradient Descent Function
def step_gradient_descent(current_x, lr):
    grad = derivative_np(current_x)
    next_x = current_x - lr * grad
    return next_x, grad

# for next itertions
if st.sidebar.button("Next Iteration"):
    next_x, _ = step_gradient_descent(st.session_state.current_x, learning_rate)
    st.session_state.path.append((st.session_state.current_x, func_np(st.session_state.current_x)))
    st.session_state.current_x = next_x
    st.session_state.iteration += 1

# Calculate actual minima
critical_points = sp.solve(derivative, x)
actual_minima = [p.evalf() for p in critical_points if derivative_np(p) == 0 and sp.diff(derivative, x).evalf(subs={x: p}) > 0]

# Generate graph
x_vals = np.linspace(-15, 15, 1000)
y_vals = func_np(x_vals)


fig = go.Figure()

# Function Plot
fig.add_trace(go.Scatter(
    x=x_vals, y=y_vals, mode='lines',
    line=dict(color='blue', width=2),
    hoverinfo='none'
))

# Gradient Descent Path
path = st.session_state.path
x_path, y_path = zip(*[(pt[0], func_np(pt[0])) for pt in path])
fig.add_trace(go.Scatter(
    x=x_path, y=y_path, mode='markers+lines',
    marker=dict(color='red', size=8),
    line=dict(color='red', width=2),
    hoverinfo='none'
))

# Highlight the current point on graph
fig.add_trace(go.Scatter(
    x=[st.session_state.current_x], y=[func_np(st.session_state.current_x)],
    mode='markers', marker=dict(color='orange', size=12),
    name="Current Point", hoverinfo='none'))

# Highlight Actual Minima
if actual_minima:
    minima_x = [float(p) for p in actual_minima]
    minima_y = [func_np(p) for p in minima_x]
    fig.add_trace(go.Scatter(
        x=minima_x, y=minima_y,
        mode='markers', marker=dict(color='green', size=14, symbol='star'),
        name="Actual Minima", hoverinfo='text',
        text=[f"x = {x_val:.4f}, f(x) = {y_val:.4f}" for x_val, y_val in zip(minima_x, minima_y)]
    ))

# Add Cross-Axes (X and Y lines)
fig.add_trace(go.Scatter(
    x=[-15, 15], y=[0, 0], mode='lines',
    line=dict(color='black', width=1, dash='dash'),
    hoverinfo='none'
))
fig.add_trace(go.Scatter(
    x=[0, 0], y=[-15, 15], mode='lines',
    line=dict(color='black', width=1, dash='dash'),
    hoverinfo='none'
))

# Layout Configuration
fig.update_layout(
    title="Gradient Descent Visualization",
    xaxis=dict(
        title="X",
        zeroline=True, zerolinewidth=1, zerolinecolor='black',
        tickvals=np.arange(-15, 16, 5),
        range=[-15, 15]
    ),
    yaxis=dict(
        title="f(X)",
        zeroline=True, zerolinewidth=1, zerolinecolor='black',
        tickvals=np.arange(-15, 16, 5),
        range=[-15, 15]
    ),
    showlegend=False,
    hovermode="closest",
    dragmode="pan",  #removed extra space
    autosize=True,
)


st.markdown("### Gradient Descent Visualization")
st.plotly_chart(fig, use_container_width=True)

#current point
st.write(f"**Current Point (x):** {st.session_state.current_x:.4f}")

#iteration history
st.write("### Iteration History:")
for i, (x_val, _) in enumerate(st.session_state.path):
    st.write(f"Iteration {i+1}: x = {x_val:.4f}")