wayandadang commited on
Commit
01f0a3d
·
1 Parent(s): 10de2bb

first commit

Browse files
Files changed (4) hide show
  1. .gitignore +148 -0
  2. app.py +80 -0
  3. kan_linear.py +91 -0
  4. requirements.txt +78 -0
.gitignore ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Local folder
10
+ local_folder
11
+ project_demo
12
+ project_demo/
13
+ logs/
14
+ local_folder/
15
+ /demo.py
16
+ demo.py
17
+ runs/
18
+
19
+ # Large folders
20
+ weights/
21
+ videos/
22
+ images/
23
+
24
+
25
+ # Distribution / packaging
26
+ .Python
27
+ build/
28
+ develop-eggs/
29
+ dist/
30
+ downloads/
31
+ eggs/
32
+ .eggs/
33
+ lib/
34
+ lib64/
35
+ parts/
36
+ sdist/
37
+ var/
38
+ wheels/
39
+ pip-wheel-metadata/
40
+ share/python-wheels/
41
+ *.egg-info/
42
+ .installed.cfg
43
+ *.egg
44
+ MANIFEST
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ .vscode
53
+
54
+ # Installer logs
55
+ pip-log.txt
56
+ pip-delete-this-directory.txt
57
+
58
+ # Unit test / coverage reports
59
+ htmlcov/
60
+ .tox/
61
+ .nox/
62
+ .coverage
63
+ .coverage.*
64
+ .cache
65
+ nosetests.xml
66
+ coverage.xml
67
+ *.cover
68
+ *.py,cover
69
+ .hypothesis/
70
+ .pytest_cache/
71
+
72
+ # Translations
73
+ *.mo
74
+ *.pot
75
+
76
+ # Django stuff:
77
+ *.log
78
+ local_settings.py
79
+ db.sqlite3
80
+ db.sqlite3-journal
81
+
82
+ # Flask stuff:
83
+ instance/
84
+ .webassets-cache
85
+
86
+ # Scrapy stuff:
87
+ .scrapy
88
+
89
+ # Sphinx documentation
90
+ docs/_build/
91
+
92
+ # PyBuilder
93
+ target/
94
+
95
+ # Jupyter Notebook
96
+ .ipynb_checkpoints
97
+
98
+ # IPython
99
+ profile_default/
100
+ ipython_config.py
101
+
102
+ # pyenv
103
+ .python-version
104
+
105
+ # pipenv
106
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
107
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
108
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
109
+ # install all needed dependencies.
110
+ #Pipfile.lock
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ venv_/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import streamlit as st
7
+ import numpy as np
8
+ import requests
9
+ from io import BytesIO
10
+ from kan_linear import KANLinear
11
+
12
+ class CNNKAN(nn.Module):
13
+ def __init__(self):
14
+ super(CNNKAN, self).__init__()
15
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
16
+ self.pool1 = nn.MaxPool2d(2)
17
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
18
+ self.pool2 = nn.MaxPool2d(2)
19
+ self.kan1 = KANLinear(64 * 50 * 50, 256)
20
+ self.kan2 = KANLinear(256, 1)
21
+
22
+ def forward(self, x):
23
+ x = F.selu(self.conv1(x))
24
+ x = self.pool1(x)
25
+ x = F.selu(self.conv2(x))
26
+ x = self.pool2(x)
27
+ x = x.view(x.size(0), -1)
28
+ x = self.kan1(x)
29
+ x = self.kan2(x)
30
+ return x
31
+
32
+ # Assuming the model weights are saved in 'model.pth'
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ model = CNNKAN().to(device)
35
+ model.load_state_dict(torch.load('weights/model_weights_KAN1.pth', map_location=device))
36
+ model.eval()
37
+
38
+ # Define image transformations
39
+ transform = transforms.Compose([
40
+ transforms.Resize((200, 200)),
41
+ transforms.ToTensor()
42
+ ])
43
+
44
+ # Streamlit app
45
+ st.title("Image Classification with CNN-KAN")
46
+
47
+ st.sidebar.title("Upload Images")
48
+ uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
49
+ image_url = st.sidebar.text_input("Or enter image URL...")
50
+
51
+ def load_image_from_url(url):
52
+ response = requests.get(url)
53
+ img = Image.open(BytesIO(response.content)).convert('RGB')
54
+ return img
55
+
56
+ img = None
57
+
58
+ if uploaded_file is not None:
59
+ img = Image.open(uploaded_file).convert('RGB')
60
+ elif image_url:
61
+ try:
62
+ img = load_image_from_url(image_url)
63
+ except Exception as e:
64
+ st.sidebar.error(f"Error loading image from URL: {e}")
65
+
66
+ if img is not None:
67
+ st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
68
+ if st.button('Predict'):
69
+ img_tensor = transform(img).unsqueeze(0).to(device)
70
+
71
+ with torch.no_grad():
72
+ output = model(img_tensor)
73
+ prob = torch.sigmoid(output).item()
74
+
75
+ st.write(f"Prediction: {prob:.4f}")
76
+
77
+ if prob < 0.5:
78
+ st.write("This image is classified as a dandelion flower.")
79
+ else:
80
+ st.write("This image is classified as grass.")
kan_linear.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class KANLinear(nn.Module):
7
+ def __init__(self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=nn.SiLU, grid_eps=0.02, grid_range=[-1, 1]):
8
+ super(KANLinear, self).__init__()
9
+ self.in_features = in_features
10
+ self.out_features = out_features
11
+ self.grid_size = grid_size
12
+ self.spline_order = spline_order
13
+
14
+ h = (grid_range[1] - grid_range[0]) / grid_size
15
+ grid = ((torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0]).expand(in_features, -1).contiguous())
16
+ self.register_buffer("grid", grid)
17
+
18
+ self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
19
+ self.spline_weight = nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))
20
+ if enable_standalone_scale_spline:
21
+ self.spline_scaler = nn.Parameter(torch.Tensor(out_features, in_features))
22
+
23
+ self.scale_noise = scale_noise
24
+ self.scale_base = scale_base
25
+ self.scale_spline = scale_spline
26
+ self.enable_standalone_scale_spline = enable_standalone_scale_spline
27
+ self.base_activation = base_activation()
28
+ self.grid_eps = grid_eps
29
+
30
+ self.reset_parameters()
31
+
32
+ def reset_parameters(self):
33
+ nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
34
+ with torch.no_grad():
35
+ noise = ((torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2) * self.scale_noise / self.grid_size)
36
+ self.spline_weight.data.copy_((self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff(self.grid.T[self.spline_order : -self.spline_order], noise))
37
+ if self.enable_standalone_scale_spline:
38
+ nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
39
+
40
+ def b_splines(self, x: torch.Tensor):
41
+ assert x.dim() == 2 and x.size(1) == self.in_features
42
+ grid = self.grid
43
+ x = x.unsqueeze(-1)
44
+ bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
45
+ for k in range(1, self.spline_order + 1):
46
+ bases = ((x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1]) + ((grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:])
47
+ assert bases.size() == (x.size(0), self.in_features, self.grid_size + self.spline_order)
48
+ return bases.contiguous()
49
+
50
+ def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
51
+ assert x.dim() == 2 and x.size(1) == self.in_features
52
+ assert y.size() == (x.size(0), self.in_features, self.out_features)
53
+ A = self.b_splines(x).transpose(0, 1)
54
+ B = y.transpose(0, 1)
55
+ solution = torch.linalg.lstsq(A, B).solution
56
+ result = solution.permute(2, 0, 1)
57
+ assert result.size() == (self.out_features, self.in_features, self.grid_size + self.spline_order)
58
+ return result.contiguous()
59
+
60
+ @property
61
+ def scaled_spline_weight(self):
62
+ return self.spline_weight * (self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0)
63
+
64
+ def forward(self, x: torch.Tensor):
65
+ assert x.dim() == 2 and x.size(1) == self.in_features
66
+ base_output = F.linear(self.base_activation(x), self.base_weight)
67
+ spline_output = F.linear(self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1))
68
+ return base_output + spline_output
69
+
70
+ @torch.no_grad()
71
+ def update_grid(self, x: torch.Tensor, margin=0.01):
72
+ assert x.dim() == 2 and x.size(1) == self.in_features
73
+ batch = x.size(0)
74
+ splines = self.b_splines(x).permute(1, 0, 2)
75
+ orig_coeff = self.scaled_spline_weight.permute(1, 2, 0)
76
+ unreduced_spline_output = torch.bmm(splines, orig_coeff).permute(1, 0, 2)
77
+ x_sorted = torch.sort(x, dim=0)[0]
78
+ grid_adaptive = x_sorted[torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)]
79
+ uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
80
+ grid_uniform = (torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device).unsqueeze(1) * uniform_step + x_sorted[0] - margin)
81
+ grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
82
+ grid = torch.cat([grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1)], dim=0)
83
+ self.grid.copy_(grid.T)
84
+ self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
85
+
86
+ def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
87
+ l1_fake = self.spline_weight.abs().mean(-1)
88
+ regularization_loss_activation = l1_fake.sum()
89
+ p = l1_fake / regularization_loss_activation
90
+ regularization_loss_entropy = -torch.sum(p * p.log())
91
+ return regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy
requirements.txt ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.2.0
2
+ attrs==23.2.0
3
+ blinker==1.7.0
4
+ cachetools==5.3.3
5
+ certifi==2024.2.2
6
+ charset-normalizer==3.3.2
7
+ click==8.1.7
8
+ contourpy==1.2.0
9
+ cycler==0.12.1
10
+ filelock==3.13.1
11
+ fonttools==4.50.0
12
+ fsspec==2024.3.1
13
+ gitdb==4.0.11
14
+ GitPython==3.1.42
15
+ idna==3.6
16
+ Jinja2==3.1.3
17
+ jsonschema==4.21.1
18
+ jsonschema-specifications==2023.12.1
19
+ kiwisolver==1.4.5
20
+ markdown-it-py==3.0.0
21
+ MarkupSafe==2.1.5
22
+ matplotlib==3.8.3
23
+ mdurl==0.1.2
24
+ mpmath==1.3.0
25
+ networkx==3.2.1
26
+ numpy==1.26.4
27
+ nvidia-cublas-cu12==12.1.3.1
28
+ nvidia-cuda-cupti-cu12==12.1.105
29
+ nvidia-cuda-nvrtc-cu12==12.1.105
30
+ nvidia-cuda-runtime-cu12==12.1.105
31
+ nvidia-cudnn-cu12==8.9.2.26
32
+ nvidia-cufft-cu12==11.0.2.54
33
+ nvidia-curand-cu12==10.3.2.106
34
+ nvidia-cusolver-cu12==11.4.5.107
35
+ nvidia-cusparse-cu12==12.1.0.106
36
+ nvidia-nccl-cu12==2.19.3
37
+ nvidia-nvjitlink-cu12==12.4.99
38
+ nvidia-nvtx-cu12==12.1.105
39
+ opencv-python==4.9.0.80
40
+ packaging==23.2
41
+ pandas==2.2.1
42
+ pillow==10.2.0
43
+ protobuf==4.25.3
44
+ psutil==5.9.8
45
+ py-cpuinfo==9.0.0
46
+ pyarrow==15.0.2
47
+ pydeck==0.8.1b0
48
+ Pygments==2.17.2
49
+ pyparsing==3.1.2
50
+ python-dateutil==2.9.0.post0
51
+ pytz==2024.1
52
+ PyYAML==6.0.1
53
+ referencing==0.34.0
54
+ requests==2.31.0
55
+ rich==13.7.1
56
+ rpds-py==0.18.0
57
+ scipy==1.12.0
58
+ seaborn==0.13.2
59
+ six==1.16.0
60
+ smmap==5.0.1
61
+ streamlit==1.32.2
62
+ sympy==1.12
63
+ tenacity==8.2.3
64
+ thop==0.1.1.post2209072238
65
+ toml==0.10.2
66
+ toolz==0.12.1
67
+ torch==2.2.1
68
+ torchvision==0.17.1
69
+ tornado==6.4
70
+ tqdm==4.66.2
71
+ triton==2.2.0
72
+ typing_extensions==4.10.0
73
+ tzdata==2024.1
74
+ ultralytics==8.1.30
75
+ urllib3==2.2.1
76
+ watchdog==4.0.0
77
+ pafy
78
+ youtube-dl