Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import re
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
# Initialize session state for storing history and iteration count
|
| 8 |
+
if "history" not in st.session_state:
|
| 9 |
+
st.session_state.history = []
|
| 10 |
+
if "iteration" not in st.session_state:
|
| 11 |
+
st.session_state.iteration = 0
|
| 12 |
+
|
| 13 |
+
# Function for gradient calculation
|
| 14 |
+
def gradient(func, x):
|
| 15 |
+
return (func(x + 1e-5) - func(x)) / 1e-5 # Numerical derivative approximation
|
| 16 |
+
|
| 17 |
+
# Sidebar for user input
|
| 18 |
+
st.sidebar.title("Gradient Descent Visualizer")
|
| 19 |
+
|
| 20 |
+
# Function input
|
| 21 |
+
function_str = st.sidebar.text_input("Enter the function (in terms of x):", "x**2 + 2*x + 1")
|
| 22 |
+
start_point = st.sidebar.number_input("Enter the starting point:", value=0.0)
|
| 23 |
+
learning_rate = st.sidebar.number_input("Enter the learning rate:", value=0.1, step=0.01)
|
| 24 |
+
|
| 25 |
+
# Replace common symbols with Python-friendly syntax
|
| 26 |
+
function_str = re.sub(r'\^', '**', function_str)
|
| 27 |
+
function_str = re.sub(r'([0-9])([a-zA-Z])', r'\1*\2', function_str) # 2x -> 2*x
|
| 28 |
+
function_str = re.sub(r'([a-zA-Z])([0-9])', r'\1*\2', function_str) # x2 -> x*2
|
| 29 |
+
|
| 30 |
+
# Setup button to initialize history
|
| 31 |
+
if st.sidebar.button("Set Up"):
|
| 32 |
+
st.session_state.history = [start_point]
|
| 33 |
+
st.session_state.iteration = 0
|
| 34 |
+
|
| 35 |
+
# Button to perform the next iteration
|
| 36 |
+
if st.sidebar.button("Next Iteration"):
|
| 37 |
+
if st.session_state.history:
|
| 38 |
+
try:
|
| 39 |
+
# Convert string function to Python function, supporting more functions
|
| 40 |
+
def func(x):
|
| 41 |
+
return eval(function_str, {"x": x, "sin": np.sin, "cos": np.cos, "tan": np.tan, "log": np.log, "exp": np.exp, "sqrt": np.sqrt, "pi": np.pi, "e": np.e})
|
| 42 |
+
|
| 43 |
+
# Calculate the gradient and the next point
|
| 44 |
+
current_x = st.session_state.history[-1]
|
| 45 |
+
grad = gradient(func, current_x)
|
| 46 |
+
next_x = current_x - learning_rate * grad
|
| 47 |
+
|
| 48 |
+
# Append next point to history
|
| 49 |
+
st.session_state.history.append(next_x)
|
| 50 |
+
st.session_state.iteration += 1
|
| 51 |
+
|
| 52 |
+
# Plot the function and the gradient descent path
|
| 53 |
+
x_vals = np.linspace(min(st.session_state.history) - 2, max(st.session_state.history) + 2, 100)
|
| 54 |
+
y_vals = func(x_vals)
|
| 55 |
+
|
| 56 |
+
plt.figure(figsize=(10, 6))
|
| 57 |
+
plt.plot(x_vals, y_vals, label=f'{function_str}')
|
| 58 |
+
plt.scatter(st.session_state.history, [func(x) for x in st.session_state.history], color='red', s=50)
|
| 59 |
+
plt.plot(st.session_state.history, [func(x) for x in st.session_state.history], linestyle='--', color='gray')
|
| 60 |
+
|
| 61 |
+
# Plot the slope (tangent line) at the current point
|
| 62 |
+
tangent_x_vals = np.linspace(current_x - 2, current_x + 2, 10)
|
| 63 |
+
tangent_y_vals = func(current_x) + grad * (tangent_x_vals - current_x)
|
| 64 |
+
plt.plot(tangent_x_vals, tangent_y_vals, color='orange', label=f"Slope at x={current_x:.2f}")
|
| 65 |
+
|
| 66 |
+
plt.xlabel("x")
|
| 67 |
+
plt.ylabel("f(x)")
|
| 68 |
+
plt.title(f"Iteration {st.session_state.iteration}")
|
| 69 |
+
plt.legend()
|
| 70 |
+
st.pyplot(plt)
|
| 71 |
+
|
| 72 |
+
except Exception as e:
|
| 73 |
+
st.sidebar.error(f"Error in function: {e}")
|
| 74 |
+
else:
|
| 75 |
+
st.sidebar.error("Please click 'Set Up' first to initialize the starting point.")
|
| 76 |
+
|
| 77 |
+
# Display current point and iteration
|
| 78 |
+
if st.session_state.history:
|
| 79 |
+
st.sidebar.write(f"**Current Point:** {st.session_state.history[-1]:.4f}")
|
| 80 |
+
st.sidebar.write(f"**Iteration:** {st.session_state.iteration}")
|