jmartinezot commited on
Commit
3cc2cf7
1 Parent(s): d5f73fd

Upload line_models_sampling.py

Browse files
Files changed (1) hide show
  1. line_models_sampling.py +123 -0
line_models_sampling.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ def compute_line_parameters(slope, intercept, mean, std, num_points, x_min, x_max, num_iterations):
5
+ line_parameters = np.zeros((num_iterations, 2))
6
+ for i in range(num_iterations):
7
+ # Generate points on the line with noise
8
+ x_values, y_values_with_noise = predict_line_with_noise(slope, intercept, mean, std, num_points, x_min, x_max)
9
+
10
+ # Randomly sample two points from the line
11
+ idx = np.random.choice(num_points, 2, replace=False)
12
+ x1, y1 = x_values[idx[0]], y_values_with_noise[idx[0]]
13
+ x2, y2 = x_values[idx[1]], y_values_with_noise[idx[1]]
14
+
15
+ # Compute slope and intercept of line passing through sampled points
16
+ line_slope = (y2 - y1) / (x2 - x1)
17
+ line_intercept = y1 - line_slope * x1
18
+
19
+ line_parameters[i, 0] = line_slope
20
+ line_parameters[i, 1] = line_intercept
21
+
22
+ return line_parameters
23
+
24
+ from sklearn.cluster import KMeans
25
+ # import matplotlib.pyplot as plt
26
+
27
+ import gradio as gr
28
+ import matplotlib.pyplot as plt
29
+ import numpy as np
30
+ from sklearn.cluster import KMeans
31
+
32
+ def cluster_line_parameters(line_parameters, num_clusters):
33
+ # Cluster line parameters using KMeans
34
+ kmeans = KMeans(n_clusters=num_clusters)
35
+ kmeans.fit(line_parameters)
36
+ labels = kmeans.labels_
37
+ centroids = kmeans.cluster_centers_
38
+ # how many points are in each cluster
39
+ counts = np.bincount(labels)
40
+
41
+ # Create a scatter plot of the data points
42
+ fig, ax = plt.subplots()
43
+ ax.scatter(line_parameters[:, 0], line_parameters[:, 1], c=labels, cmap='viridis')
44
+ ax.set_xlabel('Slope')
45
+ ax.set_ylabel('Intercept')
46
+
47
+ # Generate some lines using the centroids
48
+ '''
49
+ x = np.linspace(x_min, x_max, num_points)
50
+ for i in range(num_clusters):
51
+ slope, intercept = centroids[i]
52
+ y = slope * x + intercept
53
+ ax.plot(x, y, linewidth=2)
54
+ '''
55
+
56
+ # Convert the figure to a PNG image
57
+ fig.canvas.draw()
58
+ img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
59
+ img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
60
+
61
+ # Close the figure to free up memory
62
+ plt.close(fig)
63
+
64
+ # Return the labels, centroids, and image
65
+ return labels, centroids, counts, img
66
+
67
+ import gradio as gr
68
+
69
+ def predict_line_with_noise(slope, intercept, mean, std, num_points, x_min, x_max):
70
+ x_values = np.linspace(x_min, x_max, num_points)
71
+ noise = np.random.normal(mean, std, num_points)
72
+ y_values = slope * x_values + intercept + noise
73
+ return x_values, y_values
74
+
75
+ def cluster_line_params(slope, intercept, mean, std, num_points, x_min, x_max, num_iterations, num_clusters):
76
+ num_points = int(num_points)
77
+ num_iterations = int(num_iterations)
78
+ num_clusters = int(num_clusters)
79
+ line_parameters = compute_line_parameters(slope, intercept, mean, std, num_points, x_min, x_max, num_iterations)
80
+ labels, centroids, counts, img = cluster_line_parameters(line_parameters, num_clusters)
81
+ # return labels, centroids, img
82
+ # put counts as the third column of centroids
83
+ centroids = np.c_[centroids, counts]
84
+ df = pd.DataFrame(centroids, columns=["Slope", "Intercept", "Count"])
85
+ # return img, centroids, df
86
+ return img, df
87
+
88
+ # Define input and output components
89
+ inputs = [
90
+ gr.inputs.Slider(minimum=-5.0, maximum=5.0, default=1.0, label="Slope"),
91
+ gr.inputs.Slider(minimum=-10.0, maximum=10.0, default=0.0, label="Intercept"),
92
+ gr.inputs.Slider(minimum=0.0, maximum=5.0, default=1.0, label="Mean"),
93
+ gr.inputs.Slider(minimum=0.0, maximum=5.0, default=1.0, label="Standard Deviation"),
94
+ gr.inputs.Number(default=100, label="Number of Points"),
95
+ gr.inputs.Number(default=-5, label="Minimum Value of x"),
96
+ gr.inputs.Number(default=5, label="Maximum Value of x"),
97
+ gr.inputs.Number(default=100, label="Number of Iterations"),
98
+ gr.inputs.Number(default=3, label="Number of Clusters")
99
+ ]
100
+
101
+ outputs = [
102
+ #image and numpy
103
+ gr.outputs.Image(label="Image", type="pil"),
104
+ # show the centroids and counts, which is a numpy array
105
+ gr.outputs.Dataframe(label="Centroids and Counts", type="pandas")
106
+ ]
107
+
108
+ # Create the Gradio interface
109
+ interface = gr.Interface(
110
+ fn=cluster_line_params,
111
+ inputs=inputs,
112
+ # outputs=["numpy", "numpy", "image"],
113
+ outputs=outputs,
114
+ title="Line Parameter Clustering",
115
+ description="Cluster line parameters with Gaussian noise",
116
+ allow_flagging=False
117
+ )
118
+
119
+ # Launch the interface
120
+ interface.launch(share=True)
121
+
122
+
123
+