wayandadang commited on
Commit
fc24292
1 Parent(s): c9035d3

first commit

Browse files
Files changed (4) hide show
  1. .gitignore +148 -0
  2. app.py +193 -0
  3. kan_linear.py +91 -0
  4. requirements.txt +79 -0
.gitignore ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Local folder
10
+ local_folder
11
+ project_demo
12
+ project_demo/
13
+ logs/
14
+ local_folder/
15
+ /demo.py
16
+ demo.py
17
+ runs/
18
+
19
+ # Large folders
20
+ weights/
21
+ videos/
22
+ images/
23
+
24
+
25
+ # Distribution / packaging
26
+ .Python
27
+ build/
28
+ develop-eggs/
29
+ dist/
30
+ downloads/
31
+ eggs/
32
+ .eggs/
33
+ lib/
34
+ lib64/
35
+ parts/
36
+ sdist/
37
+ var/
38
+ wheels/
39
+ pip-wheel-metadata/
40
+ share/python-wheels/
41
+ *.egg-info/
42
+ .installed.cfg
43
+ *.egg
44
+ MANIFEST
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ .vscode
53
+
54
+ # Installer logs
55
+ pip-log.txt
56
+ pip-delete-this-directory.txt
57
+
58
+ # Unit test / coverage reports
59
+ htmlcov/
60
+ .tox/
61
+ .nox/
62
+ .coverage
63
+ .coverage.*
64
+ .cache
65
+ nosetests.xml
66
+ coverage.xml
67
+ *.cover
68
+ *.py,cover
69
+ .hypothesis/
70
+ .pytest_cache/
71
+
72
+ # Translations
73
+ *.mo
74
+ *.pot
75
+
76
+ # Django stuff:
77
+ *.log
78
+ local_settings.py
79
+ db.sqlite3
80
+ db.sqlite3-journal
81
+
82
+ # Flask stuff:
83
+ instance/
84
+ .webassets-cache
85
+
86
+ # Scrapy stuff:
87
+ .scrapy
88
+
89
+ # Sphinx documentation
90
+ docs/_build/
91
+
92
+ # PyBuilder
93
+ target/
94
+
95
+ # Jupyter Notebook
96
+ .ipynb_checkpoints
97
+
98
+ # IPython
99
+ profile_default/
100
+ ipython_config.py
101
+
102
+ # pyenv
103
+ .python-version
104
+
105
+ # pipenv
106
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
107
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
108
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
109
+ # install all needed dependencies.
110
+ #Pipfile.lock
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ venv_/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms, models
5
+ from PIL import Image, UnidentifiedImageError
6
+ import streamlit as st
7
+ import numpy as np
8
+ import requests
9
+ from io import BytesIO
10
+ from kan_linear import KANLinear
11
+ import logging
12
+ import os
13
+
14
+ # Setup logging
15
+ logging.basicConfig(level=logging.INFO)
16
+
17
+ # Define the model
18
+ class KANVGG16(nn.Module):
19
+ def __init__(self, num_classes=1): # For binary classification (cats and dogs)
20
+ super(KANVGG16, self).__init__()
21
+ self.features = nn.Sequential(
22
+ nn.Conv2d(3, 64, kernel_size=3, padding=1),
23
+ nn.ReLU(inplace=True),
24
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
25
+ nn.ReLU(inplace=True),
26
+ nn.MaxPool2d(kernel_size=2, stride=2),
27
+ nn.BatchNorm2d(64), # Added Batch Normalization
28
+
29
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
30
+ nn.ReLU(inplace=True),
31
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
32
+ nn.ReLU(inplace=True),
33
+ nn.MaxPool2d(kernel_size=2, stride=2),
34
+ nn.BatchNorm2d(128), # Added Batch Normalization
35
+
36
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
39
+ nn.ReLU(inplace=True),
40
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
41
+ nn.ReLU(inplace=True),
42
+ nn.MaxPool2d(kernel_size=2, stride=2),
43
+ nn.BatchNorm2d(256), # Added Batch Normalization
44
+
45
+ nn.Conv2d(256, 512, kernel_size=3, padding=1),
46
+ nn.ReLU(inplace=True),
47
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
48
+ nn.ReLU(inplace=True),
49
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
50
+ nn.ReLU(inplace=True),
51
+ nn.MaxPool2d(kernel_size=2, stride=2),
52
+ nn.BatchNorm2d(512), # Added Batch Normalization
53
+
54
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
55
+ nn.ReLU(inplace=True),
56
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
57
+ nn.ReLU(inplace=True),
58
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
59
+ nn.ReLU(inplace=True),
60
+ nn.MaxPool2d(kernel_size=2, stride=2),
61
+ nn.BatchNorm2d(512), # Added Batch Normalization
62
+ )
63
+ self.classifier = nn.Sequential(
64
+ KANLinear(512 * 7 * 7, 2048), # Adjusted for input size 224x224
65
+ nn.ReLU(inplace=True),
66
+ nn.Dropout(0.5), # Increased Dropout
67
+ KANLinear(2048, 2048),
68
+ nn.ReLU(inplace=True),
69
+ nn.Dropout(0.5), # Increased Dropout
70
+ KANLinear(2048, num_classes)
71
+ )
72
+
73
+ def forward(self, x):
74
+ x = self.features(x)
75
+ x = torch.flatten(x, 1)
76
+ x = self.classifier(x)
77
+ return x
78
+
79
+ def load_model(weights_path, device):
80
+ model = KANVGG16().to(device)
81
+ state_dict = torch.load(weights_path, map_location=device)
82
+
83
+ # Remove 'module.' prefix from keys
84
+ new_state_dict = {}
85
+ for k, v in state_dict.items():
86
+ if k.startswith('module.'):
87
+ new_state_dict[k[len('module.'):]] = v
88
+ else:
89
+ new_state_dict[k] = v
90
+
91
+ model.load_state_dict(new_state_dict)
92
+ model.eval()
93
+ return model
94
+
95
+ class CustomImageLoadingError(Exception):
96
+ """Custom exception for image loading errors"""
97
+ pass
98
+
99
+ def load_image_from_url(url):
100
+ try:
101
+ logging.info(f"Loading image from URL: {url}")
102
+
103
+ # Check the file extension
104
+ valid_extensions = ['jpg', 'jpeg', 'png', 'webp']
105
+ file_extension = os.path.splitext(url)[1][1:].lower()
106
+ if file_extension not in valid_extensions:
107
+ raise CustomImageLoadingError(f"URL does not point to an image with a valid extension: {file_extension}")
108
+
109
+ response = requests.get(url)
110
+ response.raise_for_status() # Check if the request was successful
111
+
112
+ content_type = response.headers['Content-Type']
113
+ logging.info(f"Content-Type: {content_type}")
114
+
115
+ # Check if the content type is an image
116
+ if 'image' not in content_type:
117
+ raise CustomImageLoadingError(f"URL does not point to an image: {content_type}")
118
+
119
+ img = Image.open(BytesIO(response.content)).convert('RGB')
120
+ logging.info("Image successfully loaded and converted to RGB")
121
+ return img
122
+ except requests.HTTPError as e:
123
+ logging.error(f"HTTPError while loading image: {e}")
124
+ raise CustomImageLoadingError(f"Error loading image from URL: {e}")
125
+ except UnidentifiedImageError as e:
126
+ logging.error(f"UnidentifiedImageError while loading image: {e}")
127
+ raise CustomImageLoadingError(f"Cannot identify image file: {e}")
128
+ except requests.RequestException as e:
129
+ logging.error(f"RequestException while loading image: {e}")
130
+ raise CustomImageLoadingError(f"Error loading image from URL: {e}")
131
+ except Exception as e:
132
+ logging.error(f"Unexpected error while loading image: {e}")
133
+ raise CustomImageLoadingError(f"Error loading image from URL: {e}")
134
+
135
+ def preprocess_image(image):
136
+ transform = transforms.Compose([
137
+ transforms.Resize((224, 224)),
138
+ transforms.ToTensor()
139
+ ])
140
+ return transform(image).unsqueeze(0)
141
+
142
+ # Streamlit app
143
+ st.title("Cat and Dog Classification with VGG16-KAN")
144
+
145
+ st.sidebar.title("Upload Images")
146
+ uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
147
+ image_url = st.sidebar.text_input("Or enter image URL...")
148
+
149
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
150
+ model = load_model('weights/best_model_vgg16_KAN.pth', device)
151
+
152
+ img = None
153
+
154
+ if uploaded_file is not None:
155
+ logging.info("Image uploaded via file uploader")
156
+ img = Image.open(uploaded_file).convert('RGB')
157
+ elif image_url:
158
+ try:
159
+ img = load_image_from_url(image_url)
160
+ except CustomImageLoadingError as e:
161
+ st.sidebar.error(str(e))
162
+ except Exception as e:
163
+ st.sidebar.error(f"Unexpected error: {e}")
164
+
165
+ st.sidebar.write("-----")
166
+
167
+ # Define your information for the footer
168
+ name = "Wayan Dadang"
169
+
170
+ st.sidebar.write("Follow me on:")
171
+ # Create a footer section with links and copyright information
172
+ st.sidebar.markdown(f"""
173
+ [LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/)
174
+ [GitHub](https://github.com/Wayan123)
175
+ [Resume](https://wayan123.github.io/)
176
+ © {name} - {2024}
177
+ """, unsafe_allow_html=True)
178
+
179
+ if img is not None:
180
+ st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
181
+ if st.button('Predict'):
182
+ img_tensor = preprocess_image(img).to(device)
183
+
184
+ with torch.no_grad():
185
+ output = model(img_tensor)
186
+ prob = torch.sigmoid(output).item()
187
+
188
+ st.write(f"Prediction: {prob:.4f}")
189
+
190
+ if prob < 0.5:
191
+ st.write("This image is classified as a Cat.")
192
+ else:
193
+ st.write("This image is classified as a Dog.")
kan_linear.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class KANLinear(nn.Module):
7
+ def __init__(self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=nn.SiLU, grid_eps=0.02, grid_range=[-1, 1]):
8
+ super(KANLinear, self).__init__()
9
+ self.in_features = in_features
10
+ self.out_features = out_features
11
+ self.grid_size = grid_size
12
+ self.spline_order = spline_order
13
+
14
+ h = (grid_range[1] - grid_range[0]) / grid_size
15
+ grid = ((torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0]).expand(in_features, -1).contiguous())
16
+ self.register_buffer("grid", grid)
17
+
18
+ self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
19
+ self.spline_weight = nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))
20
+ if enable_standalone_scale_spline:
21
+ self.spline_scaler = nn.Parameter(torch.Tensor(out_features, in_features))
22
+
23
+ self.scale_noise = scale_noise
24
+ self.scale_base = scale_base
25
+ self.scale_spline = scale_spline
26
+ self.enable_standalone_scale_spline = enable_standalone_scale_spline
27
+ self.base_activation = base_activation()
28
+ self.grid_eps = grid_eps
29
+
30
+ self.reset_parameters()
31
+
32
+ def reset_parameters(self):
33
+ nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
34
+ with torch.no_grad():
35
+ noise = ((torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2) * self.scale_noise / self.grid_size)
36
+ self.spline_weight.data.copy_((self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff(self.grid.T[self.spline_order : -self.spline_order], noise))
37
+ if self.enable_standalone_scale_spline:
38
+ nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
39
+
40
+ def b_splines(self, x: torch.Tensor):
41
+ assert x.dim() == 2 and x.size(1) == self.in_features
42
+ grid = self.grid
43
+ x = x.unsqueeze(-1)
44
+ bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
45
+ for k in range(1, self.spline_order + 1):
46
+ bases = ((x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1]) + ((grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:])
47
+ assert bases.size() == (x.size(0), self.in_features, self.grid_size + self.spline_order)
48
+ return bases.contiguous()
49
+
50
+ def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
51
+ assert x.dim() == 2 and x.size(1) == self.in_features
52
+ assert y.size() == (x.size(0), self.in_features, self.out_features)
53
+ A = self.b_splines(x).transpose(0, 1)
54
+ B = y.transpose(0, 1)
55
+ solution = torch.linalg.lstsq(A, B).solution
56
+ result = solution.permute(2, 0, 1)
57
+ assert result.size() == (self.out_features, self.in_features, self.grid_size + self.spline_order)
58
+ return result.contiguous()
59
+
60
+ @property
61
+ def scaled_spline_weight(self):
62
+ return self.spline_weight * (self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0)
63
+
64
+ def forward(self, x: torch.Tensor):
65
+ assert x.dim() == 2 and x.size(1) == self.in_features
66
+ base_output = F.linear(self.base_activation(x), self.base_weight)
67
+ spline_output = F.linear(self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1))
68
+ return base_output + spline_output
69
+
70
+ @torch.no_grad()
71
+ def update_grid(self, x: torch.Tensor, margin=0.01):
72
+ assert x.dim() == 2 and x.size(1) == self.in_features
73
+ batch = x.size(0)
74
+ splines = self.b_splines(x).permute(1, 0, 2)
75
+ orig_coeff = self.scaled_spline_weight.permute(1, 2, 0)
76
+ unreduced_spline_output = torch.bmm(splines, orig_coeff).permute(1, 0, 2)
77
+ x_sorted = torch.sort(x, dim=0)[0]
78
+ grid_adaptive = x_sorted[torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)]
79
+ uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
80
+ grid_uniform = (torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device).unsqueeze(1) * uniform_step + x_sorted[0] - margin)
81
+ grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
82
+ grid = torch.cat([grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1)], dim=0)
83
+ self.grid.copy_(grid.T)
84
+ self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
85
+
86
+ def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
87
+ l1_fake = self.spline_weight.abs().mean(-1)
88
+ regularization_loss_activation = l1_fake.sum()
89
+ p = l1_fake / regularization_loss_activation
90
+ regularization_loss_entropy = -torch.sum(p * p.log())
91
+ return regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy
requirements.txt ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.2.0
2
+ attrs==23.2.0
3
+ blinker==1.7.0
4
+ cachetools==5.3.3
5
+ certifi==2024.2.2
6
+ charset-normalizer==3.3.2
7
+ click==8.1.7
8
+ contourpy==1.2.0
9
+ cycler==0.12.1
10
+ filelock==3.13.1
11
+ fonttools==4.50.0
12
+ fsspec==2024.3.1
13
+ gitdb==4.0.11
14
+ GitPython==3.1.42
15
+ idna==3.6
16
+ Jinja2==3.1.3
17
+ jsonschema==4.21.1
18
+ jsonschema-specifications==2023.12.1
19
+ kiwisolver==1.4.5
20
+ markdown-it-py==3.0.0
21
+ MarkupSafe==2.1.5
22
+ matplotlib==3.8.3
23
+ mdurl==0.1.2
24
+ mpmath==1.3.0
25
+ networkx==3.2.1
26
+ numpy==1.26.4
27
+ nvidia-cublas-cu12==12.1.3.1
28
+ nvidia-cuda-cupti-cu12==12.1.105
29
+ nvidia-cuda-nvrtc-cu12==12.1.105
30
+ nvidia-cuda-runtime-cu12==12.1.105
31
+ nvidia-cudnn-cu12==8.9.2.26
32
+ nvidia-cufft-cu12==11.0.2.54
33
+ nvidia-curand-cu12==10.3.2.106
34
+ nvidia-cusolver-cu12==11.4.5.107
35
+ nvidia-cusparse-cu12==12.1.0.106
36
+ nvidia-nccl-cu12==2.19.3
37
+ nvidia-nvjitlink-cu12==12.4.99
38
+ nvidia-nvtx-cu12==12.1.105
39
+ opencv-python==4.9.0.80
40
+ packaging==23.2
41
+ pandas==2.2.1
42
+ pillow==10.2.0
43
+ protobuf==4.25.3
44
+ psutil==5.9.8
45
+ py-cpuinfo==9.0.0
46
+ pyarrow==15.0.2
47
+ pydeck==0.8.1b0
48
+ Pygments==2.17.2
49
+ pyparsing==3.1.2
50
+ python-dateutil==2.9.0.post0
51
+ pytz==2024.1
52
+ PyYAML==6.0.1
53
+ referencing==0.34.0
54
+ requests==2.31.0
55
+ rich==13.7.1
56
+ rpds-py==0.18.0
57
+ scipy==1.12.0
58
+ seaborn==0.13.2
59
+ six==1.16.0
60
+ smmap==5.0.1
61
+ streamlit==1.32.2
62
+ sympy==1.12
63
+ tenacity==8.2.3
64
+ thop==0.1.1.post2209072238
65
+ toml==0.10.2
66
+ toolz==0.12.1
67
+ torch==2.2.1
68
+ torchvision==0.17.1
69
+ tornado==6.4
70
+ tqdm==4.66.2
71
+ triton==2.2.0
72
+ typing_extensions==4.10.0
73
+ tzdata==2024.1
74
+ ultralytics==8.1.30
75
+ urllib3==2.2.1
76
+ watchdog==4.0.0
77
+ pafy
78
+ youtube-dl
79
+ optuna