line_models_sampling / backup1.py
jmartinezot's picture
Adding circular noise
164569e
raw
history blame contribute delete
No virus
4.46 kB
import numpy as np
import pandas as pd
def compute_line_parameters(slope, intercept, mean, std, num_points, x_min, x_max, num_iterations):
line_parameters = np.zeros((num_iterations, 2))
for i in range(num_iterations):
# Generate points on the line with noise
x_values, y_values_with_noise = predict_line_with_noise(slope, intercept, mean, std, num_points, x_min, x_max)
# Randomly sample two points from the line
idx = np.random.choice(num_points, 2, replace=False)
x1, y1 = x_values[idx[0]], y_values_with_noise[idx[0]]
x2, y2 = x_values[idx[1]], y_values_with_noise[idx[1]]
# Compute slope and intercept of line passing through sampled points
line_slope = (y2 - y1) / (x2 - x1)
line_intercept = y1 - line_slope * x1
line_parameters[i, 0] = line_slope
line_parameters[i, 1] = line_intercept
return line_parameters
from sklearn.cluster import KMeans
# import matplotlib.pyplot as plt
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
def cluster_line_parameters(line_parameters, num_clusters):
# Cluster line parameters using KMeans
kmeans = KMeans(n_clusters=num_clusters)
kmeans.fit(line_parameters)
labels = kmeans.labels_
centroids = kmeans.cluster_centers_
# how many points are in each cluster
counts = np.bincount(labels)
# Create a scatter plot of the data points
fig, ax = plt.subplots()
ax.scatter(line_parameters[:, 0], line_parameters[:, 1], c=labels, cmap='viridis')
ax.set_xlabel('Slope')
ax.set_ylabel('Intercept')
# Generate some lines using the centroids
'''
x = np.linspace(x_min, x_max, num_points)
for i in range(num_clusters):
slope, intercept = centroids[i]
y = slope * x + intercept
ax.plot(x, y, linewidth=2)
'''
# Convert the figure to a PNG image
fig.canvas.draw()
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
# Close the figure to free up memory
plt.close(fig)
# Return the labels, centroids, and image
return labels, centroids, counts, img
import gradio as gr
def predict_line_with_noise(slope, intercept, mean, std, num_points, x_min, x_max):
x_values = np.linspace(x_min, x_max, num_points)
noise = np.random.normal(mean, std, num_points)
y_values = slope * x_values + intercept + noise
return x_values, y_values
def cluster_line_params(slope, intercept, mean, std, num_points, x_min, x_max, num_iterations, num_clusters):
num_points = int(num_points)
num_iterations = int(num_iterations)
num_clusters = int(num_clusters)
line_parameters = compute_line_parameters(slope, intercept, mean, std, num_points, x_min, x_max, num_iterations)
labels, centroids, counts, img = cluster_line_parameters(line_parameters, num_clusters)
# return labels, centroids, img
# put counts as the third column of centroids
centroids = np.c_[centroids, counts]
df = pd.DataFrame(centroids, columns=["Slope", "Intercept", "Count"])
# return img, centroids, df
return img, df
# Define input and output components
inputs = [
gr.inputs.Slider(minimum=-5.0, maximum=5.0, default=1.0, label="Slope"),
gr.inputs.Slider(minimum=-10.0, maximum=10.0, default=0.0, label="Intercept"),
gr.inputs.Slider(minimum=0.0, maximum=5.0, default=1.0, label="Mean"),
gr.inputs.Slider(minimum=0.0, maximum=5.0, default=1.0, label="Standard Deviation"),
gr.inputs.Number(default=100, label="Number of Points"),
gr.inputs.Number(default=-5, label="Minimum Value of x"),
gr.inputs.Number(default=5, label="Maximum Value of x"),
gr.inputs.Number(default=100, label="Number of Iterations"),
gr.inputs.Number(default=3, label="Number of Clusters")
]
outputs = [
#image and numpy
gr.outputs.Image(label="Image", type="pil"),
# show the centroids and counts, which is a numpy array
gr.outputs.Dataframe(label="Centroids and Counts", type="pandas")
]
# Create the Gradio interface
interface = gr.Interface(
fn=cluster_line_params,
inputs=inputs,
# outputs=["numpy", "numpy", "image"],
outputs=outputs,
title="Line Parameter Clustering",
description="Cluster line parameters with Gaussian noise",
allow_flagging=False
)
# Launch the interface
interface.launch()