kadirnar commited on
Commit
ac79ccb
1 Parent(s): b66ebdd

added line fitting module

Browse files
Files changed (2) hide show
  1. app.py +71 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Example showing how to fit a 2d line with kornia / pytorch
2
+ import matplotlib.pyplot as plt
3
+ import torch
4
+ import matplotlib
5
+ matplotlib.use('Agg')
6
+ import matplotlib.pyplot as plt
7
+ import gradio as gr
8
+ from kornia.geometry.line import ParametrizedLine, fit_line
9
+
10
+
11
+ def inference(point1, point2, point3, point4):
12
+ std = 1.2 # standard deviation for the points
13
+ num_points = 50 # total number of points
14
+
15
+ # create a baseline
16
+ p0 = torch.tensor([point1, point2], dtype=torch.float32)
17
+ p1 = torch.tensor([point3, point4], dtype=torch.float32)
18
+
19
+ l1 = ParametrizedLine.through(p0, p1)
20
+
21
+ # sample some points and weights
22
+ pts, w = [], []
23
+ for t in torch.linspace(-10, 10, num_points):
24
+ p2 = l1.point_at(t)
25
+ p2_noise = torch.rand_like(p2) * std
26
+ p2 += p2_noise
27
+ pts.append(p2)
28
+ w.append(1 - p2_noise.mean())
29
+ pts = torch.stack(pts)
30
+ w = torch.stack(w)
31
+
32
+ l2 = fit_line(pts, w)
33
+
34
+ # project some points along the estimated line
35
+ p3 = l2.point_at(-10)
36
+ p4 = l2.point_at(10)
37
+
38
+ X = torch.stack((p3, p4)).detach().numpy()
39
+ X_pts = pts.detach().numpy()
40
+
41
+ fig = plt.figure()
42
+ plt.plot(X_pts[:, 0], X_pts[:, 1], 'ro')
43
+ plt.plot(X[:, 0], X[:, 1])
44
+ return fig
45
+
46
+ inputs = [
47
+ gr.inputs.Slider(0.0, 10.0, default=0.0, label="Point 1"),
48
+ gr.inputs.Slider(0.0, 10.0, default=0.0, label="Point 2"),
49
+ gr.inputs.Slider(0.0, 10.0, default=0.0, label="Point 3"),
50
+ gr.inputs.Slider(0.0, 10.0, default=0.0, label="Point 4"),
51
+ ]
52
+ outputs = gr.Plot()
53
+
54
+ examples = [
55
+ [[0.0, 0.0, 1.0, 1.0]],
56
+ [[0.0, 0.0, 1.0, 2.0]],
57
+ ]
58
+
59
+ title = 'Line Fitting'
60
+
61
+ demo = gr.Interface(
62
+ fn=inference,
63
+ inputs=inputs,
64
+ outputs=outputs,
65
+ title=title,
66
+ cache_examples=True,
67
+ theme='huggingface',
68
+ live=True,
69
+
70
+ )
71
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ kornia
2
+ kornia_rs
3
+ matplotlib
4
+ torch