Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""effect of eta and iterations on sgd.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1Lso8y1XapdHGJOHnY0pL4ZeSGaa-6lOz | |
""" | |
# Commented out IPython magic to ensure Python compatibility. | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# %matplotlib inline | |
plt.rcParams['figure.figsize'] = (10, 5) | |
def generate_data(): | |
X = 2 * np.random.rand(100, 1) | |
y = 4 + 3 * X + np.random.randn(100, 1) | |
return X, y | |
def get_norm_eqn(X, y): | |
X_b = np.c_[np.ones((100, 1)), X] | |
theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y) | |
y_norm = X_b.dot(theta_best) | |
return X_b, y_norm | |
def generate_sgd_plot(eta, n_iterations): | |
#initialize parameters | |
m = 100 | |
theta = np.random.randn(2,1) | |
X, y = generate_data() | |
X_b, y_norm = get_norm_eqn(X, y) | |
# plot how the parameters change wrt normal line | |
# as the algorithm learns | |
plt.scatter(X,y, c='#7678ed', label="data points") | |
plt.axis([0, 2.0, 0, 14]) | |
for iteration in range(n_iterations): | |
gradients = 2/m * X_b.T.dot(X_b.dot(theta) - y) | |
theta = theta - eta * gradients | |
y_new = X_b.dot(theta) | |
plt.plot(X, y_new, color='#f18701', linestyle='dashed', linewidth=0.2) | |
plt.plot(X, y_norm, '#3d348b', label="Normal Eqation line") | |
plt.xlabel('X') | |
plt.ylabel('Y') | |
plt.legend(loc='best') | |
return plt | |
import gradio as gr | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown( | |
""" | |
# How learning rate and number of iterations affect SGD | |
Move sliders to change the values of eta and number of iterations to see how it affects the convergance rate of algorithm. | |
""" | |
) | |
inputs = [gr.Slider(0.02, 0.5, label="learning rate, eta"), gr.Slider(500, 1000, 200, label="number of iterations")] | |
output = gr.Plot() | |
btn = gr.Button("Run") | |
btn.click(fn=generate_sgd_plot, inputs=inputs, outputs=output) | |
demo.launch() | |