Jasmine45 commited on
Commit
6de069e
·
verified ·
1 Parent(s): 621eb03

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import re
5
+ import math
6
+
7
+ # Initialize session state for storing history and iteration count
8
+ if "history" not in st.session_state:
9
+ st.session_state.history = []
10
+ if "iteration" not in st.session_state:
11
+ st.session_state.iteration = 0
12
+
13
+ # Function for gradient calculation
14
+ def gradient(func, x):
15
+ return (func(x + 1e-5) - func(x)) / 1e-5 # Numerical derivative approximation
16
+
17
+ # Sidebar for user input
18
+ st.sidebar.title("Gradient Descent Visualizer")
19
+
20
+ # Function input
21
+ function_str = st.sidebar.text_input("Enter the function (in terms of x):", "x**2 + 2*x + 1")
22
+ start_point = st.sidebar.number_input("Enter the starting point:", value=0.0)
23
+ learning_rate = st.sidebar.number_input("Enter the learning rate:", value=0.1, step=0.01)
24
+
25
+ # Replace common symbols with Python-friendly syntax
26
+ function_str = re.sub(r'\^', '**', function_str)
27
+ function_str = re.sub(r'([0-9])([a-zA-Z])', r'\1*\2', function_str) # 2x -> 2*x
28
+ function_str = re.sub(r'([a-zA-Z])([0-9])', r'\1*\2', function_str) # x2 -> x*2
29
+
30
+ # Setup button to initialize history
31
+ if st.sidebar.button("Set Up"):
32
+ st.session_state.history = [start_point]
33
+ st.session_state.iteration = 0
34
+
35
+ # Button to perform the next iteration
36
+ if st.sidebar.button("Next Iteration"):
37
+ if st.session_state.history:
38
+ try:
39
+ # Convert string function to Python function, supporting more functions
40
+ def func(x):
41
+ return eval(function_str, {"x": x, "sin": np.sin, "cos": np.cos, "tan": np.tan, "log": np.log, "exp": np.exp, "sqrt": np.sqrt, "pi": np.pi, "e": np.e})
42
+
43
+ # Calculate the gradient and the next point
44
+ current_x = st.session_state.history[-1]
45
+ grad = gradient(func, current_x)
46
+ next_x = current_x - learning_rate * grad
47
+
48
+ # Append next point to history
49
+ st.session_state.history.append(next_x)
50
+ st.session_state.iteration += 1
51
+
52
+ # Plot the function and the gradient descent path
53
+ x_vals = np.linspace(min(st.session_state.history) - 2, max(st.session_state.history) + 2, 100)
54
+ y_vals = func(x_vals)
55
+
56
+ plt.figure(figsize=(10, 6))
57
+ plt.plot(x_vals, y_vals, label=f'{function_str}')
58
+ plt.scatter(st.session_state.history, [func(x) for x in st.session_state.history], color='red', s=50)
59
+ plt.plot(st.session_state.history, [func(x) for x in st.session_state.history], linestyle='--', color='gray')
60
+
61
+ # Plot the slope (tangent line) at the current point
62
+ tangent_x_vals = np.linspace(current_x - 2, current_x + 2, 10)
63
+ tangent_y_vals = func(current_x) + grad * (tangent_x_vals - current_x)
64
+ plt.plot(tangent_x_vals, tangent_y_vals, color='orange', label=f"Slope at x={current_x:.2f}")
65
+
66
+ plt.xlabel("x")
67
+ plt.ylabel("f(x)")
68
+ plt.title(f"Iteration {st.session_state.iteration}")
69
+ plt.legend()
70
+ st.pyplot(plt)
71
+
72
+ except Exception as e:
73
+ st.sidebar.error(f"Error in function: {e}")
74
+ else:
75
+ st.sidebar.error("Please click 'Set Up' first to initialize the starting point.")
76
+
77
+ # Display current point and iteration
78
+ if st.session_state.history:
79
+ st.sidebar.write(f"**Current Point:** {st.session_state.history[-1]:.4f}")
80
+ st.sidebar.write(f"**Iteration:** {st.session_state.iteration}")