wayandadang commited on
Commit
37e9e4f
1 Parent(s): 5d50376
Files changed (6) hide show
  1. .gitignore +148 -0
  2. ROC Curve.jpg +0 -0
  3. Report Training.jpg +0 -0
  4. app.py +133 -0
  5. kan_linear.py +91 -0
  6. requirements.txt +79 -0
.gitignore ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Local folder
10
+ local_folder
11
+ project_demo
12
+ project_demo/
13
+ logs/
14
+ local_folder/
15
+ /demo.py
16
+ demo.py
17
+ runs/
18
+
19
+ # Large folders
20
+ weights/
21
+ videos/
22
+ images/
23
+
24
+
25
+ # Distribution / packaging
26
+ .Python
27
+ build/
28
+ develop-eggs/
29
+ dist/
30
+ downloads/
31
+ eggs/
32
+ .eggs/
33
+ lib/
34
+ lib64/
35
+ parts/
36
+ sdist/
37
+ var/
38
+ wheels/
39
+ pip-wheel-metadata/
40
+ share/python-wheels/
41
+ *.egg-info/
42
+ .installed.cfg
43
+ *.egg
44
+ MANIFEST
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ .vscode
53
+
54
+ # Installer logs
55
+ pip-log.txt
56
+ pip-delete-this-directory.txt
57
+
58
+ # Unit test / coverage reports
59
+ htmlcov/
60
+ .tox/
61
+ .nox/
62
+ .coverage
63
+ .coverage.*
64
+ .cache
65
+ nosetests.xml
66
+ coverage.xml
67
+ *.cover
68
+ *.py,cover
69
+ .hypothesis/
70
+ .pytest_cache/
71
+
72
+ # Translations
73
+ *.mo
74
+ *.pot
75
+
76
+ # Django stuff:
77
+ *.log
78
+ local_settings.py
79
+ db.sqlite3
80
+ db.sqlite3-journal
81
+
82
+ # Flask stuff:
83
+ instance/
84
+ .webassets-cache
85
+
86
+ # Scrapy stuff:
87
+ .scrapy
88
+
89
+ # Sphinx documentation
90
+ docs/_build/
91
+
92
+ # PyBuilder
93
+ target/
94
+
95
+ # Jupyter Notebook
96
+ .ipynb_checkpoints
97
+
98
+ # IPython
99
+ profile_default/
100
+ ipython_config.py
101
+
102
+ # pyenv
103
+ .python-version
104
+
105
+ # pipenv
106
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
107
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
108
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
109
+ # install all needed dependencies.
110
+ #Pipfile.lock
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ venv_/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
ROC Curve.jpg ADDED
Report Training.jpg ADDED
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms, models
5
+ from PIL import Image, UnidentifiedImageError
6
+ import streamlit as st
7
+ import numpy as np
8
+ import requests
9
+ from io import BytesIO
10
+ from kan_linear import KANLinear
11
+ import logging
12
+
13
+ # Setup logging
14
+ logging.basicConfig(level=logging.INFO)
15
+
16
+ # Definisikan model ResNet Anda
17
+ class CustomResNetKAN(nn.Module):
18
+ def __init__(self, num_classes=1): # Set num_classes to 1 for binary classification
19
+ super(CustomResNetKAN, self).__init__()
20
+ self.model = models.resnet50(pretrained=False)
21
+ self.model.fc = KANLinear(self.model.fc.in_features, num_classes)
22
+
23
+ def forward(self, x):
24
+ return self.model(x)
25
+
26
+ def load_model(weights_path, device):
27
+ model = CustomResNetKAN().to(device)
28
+ state_dict = torch.load(weights_path, map_location=device)
29
+
30
+ # Remove 'module.' prefix from keys
31
+ new_state_dict = {}
32
+ for k, v in state_dict.items():
33
+ if k.startswith('module.'):
34
+ new_state_dict[k[len('module.'):]] = v
35
+ else:
36
+ new_state_dict[k] = v
37
+
38
+ model.load_state_dict(new_state_dict)
39
+ model.eval()
40
+ return model
41
+
42
+ class CustomImageLoadingError(Exception):
43
+ """Custom exception for image loading errors"""
44
+ pass
45
+
46
+ def load_image_from_url(url):
47
+ try:
48
+ logging.info(f"Loading image from URL: {url}")
49
+ response = requests.get(url)
50
+ response.raise_for_status() # Check if the request was successful
51
+
52
+ content_type = response.headers['Content-Type']
53
+ logging.info(f"Content-Type: {content_type}")
54
+
55
+ # Check if the content type is an image
56
+ if 'image' not in content_type:
57
+ raise CustomImageLoadingError(f"URL does not point to an image: {content_type}")
58
+
59
+ img = Image.open(BytesIO(response.content)).convert('RGB')
60
+ logging.info("Image successfully loaded and converted to RGB")
61
+ return img
62
+ except requests.HTTPError as e:
63
+ logging.error(f"HTTPError while loading image: {e}")
64
+ raise CustomImageLoadingError(f"Error loading image from URL: {e}")
65
+ except UnidentifiedImageError as e:
66
+ logging.error(f"UnidentifiedImageError while loading image: {e}")
67
+ raise CustomImageLoadingError(f"Cannot identify image file: {e}")
68
+ except requests.RequestException as e:
69
+ logging.error(f"RequestException while loading image: {e}")
70
+ raise CustomImageLoadingError(f"Error loading image from URL: {e}")
71
+ except Exception as e:
72
+ logging.error(f"Unexpected error while loading image: {e}")
73
+ raise CustomImageLoadingError(f"Error loading image from URL: {e}")
74
+
75
+ def preprocess_image(image):
76
+ transform = transforms.Compose([
77
+ transforms.Resize((224, 224)),
78
+ transforms.ToTensor()
79
+ ])
80
+ return transform(image).unsqueeze(0)
81
+
82
+ # Streamlit app
83
+ st.title("Cat and Dog Classification with ResNet-KAN")
84
+
85
+ st.sidebar.title("Upload Images")
86
+ uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
87
+ image_url = st.sidebar.text_input("Or enter image URL...")
88
+
89
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
90
+ model = load_model('weights/best_model_resnet50_KAN.pth', device)
91
+
92
+ img = None
93
+
94
+ if uploaded_file is not None:
95
+ logging.info("Image uploaded via file uploader")
96
+ img = Image.open(uploaded_file).convert('RGB')
97
+ elif image_url:
98
+ try:
99
+ img = load_image_from_url(image_url)
100
+ except CustomImageLoadingError as e:
101
+ st.sidebar.error(str(e))
102
+ except Exception as e:
103
+ st.sidebar.error(f"Unexpected error: {e}")
104
+
105
+ st.sidebar.write("-----")
106
+
107
+ # Define your information for the footer
108
+ name = "Wayan Dadang"
109
+
110
+ st.sidebar.write("Follow me on:")
111
+ # Create a footer section with links and copyright information
112
+ st.sidebar.markdown(f"""
113
+ [LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/)
114
+ [GitHub](https://github.com/Wayan123)
115
+ [Resume](https://wayan123.github.io/)
116
+ © {name} - {2024}
117
+ """, unsafe_allow_html=True)
118
+
119
+ if img is not None:
120
+ st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
121
+ if st.button('Predict'):
122
+ img_tensor = preprocess_image(img).to(device)
123
+
124
+ with torch.no_grad():
125
+ output = model(img_tensor)
126
+ prob = torch.sigmoid(output).item()
127
+
128
+ st.write(f"Prediction: {prob:.4f}")
129
+
130
+ if prob < 0.5:
131
+ st.write("This image is classified as a Cat.")
132
+ else:
133
+ st.write("This image is classified as a Dog.")
kan_linear.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class KANLinear(nn.Module):
7
+ def __init__(self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=nn.SiLU, grid_eps=0.02, grid_range=[-1, 1]):
8
+ super(KANLinear, self).__init__()
9
+ self.in_features = in_features
10
+ self.out_features = out_features
11
+ self.grid_size = grid_size
12
+ self.spline_order = spline_order
13
+
14
+ h = (grid_range[1] - grid_range[0]) / grid_size
15
+ grid = ((torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0]).expand(in_features, -1).contiguous())
16
+ self.register_buffer("grid", grid)
17
+
18
+ self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
19
+ self.spline_weight = nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))
20
+ if enable_standalone_scale_spline:
21
+ self.spline_scaler = nn.Parameter(torch.Tensor(out_features, in_features))
22
+
23
+ self.scale_noise = scale_noise
24
+ self.scale_base = scale_base
25
+ self.scale_spline = scale_spline
26
+ self.enable_standalone_scale_spline = enable_standalone_scale_spline
27
+ self.base_activation = base_activation()
28
+ self.grid_eps = grid_eps
29
+
30
+ self.reset_parameters()
31
+
32
+ def reset_parameters(self):
33
+ nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
34
+ with torch.no_grad():
35
+ noise = ((torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2) * self.scale_noise / self.grid_size)
36
+ self.spline_weight.data.copy_((self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff(self.grid.T[self.spline_order : -self.spline_order], noise))
37
+ if self.enable_standalone_scale_spline:
38
+ nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
39
+
40
+ def b_splines(self, x: torch.Tensor):
41
+ assert x.dim() == 2 and x.size(1) == self.in_features
42
+ grid = self.grid
43
+ x = x.unsqueeze(-1)
44
+ bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
45
+ for k in range(1, self.spline_order + 1):
46
+ bases = ((x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1]) + ((grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:])
47
+ assert bases.size() == (x.size(0), self.in_features, self.grid_size + self.spline_order)
48
+ return bases.contiguous()
49
+
50
+ def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
51
+ assert x.dim() == 2 and x.size(1) == self.in_features
52
+ assert y.size() == (x.size(0), self.in_features, self.out_features)
53
+ A = self.b_splines(x).transpose(0, 1)
54
+ B = y.transpose(0, 1)
55
+ solution = torch.linalg.lstsq(A, B).solution
56
+ result = solution.permute(2, 0, 1)
57
+ assert result.size() == (self.out_features, self.in_features, self.grid_size + self.spline_order)
58
+ return result.contiguous()
59
+
60
+ @property
61
+ def scaled_spline_weight(self):
62
+ return self.spline_weight * (self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0)
63
+
64
+ def forward(self, x: torch.Tensor):
65
+ assert x.dim() == 2 and x.size(1) == self.in_features
66
+ base_output = F.linear(self.base_activation(x), self.base_weight)
67
+ spline_output = F.linear(self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1))
68
+ return base_output + spline_output
69
+
70
+ @torch.no_grad()
71
+ def update_grid(self, x: torch.Tensor, margin=0.01):
72
+ assert x.dim() == 2 and x.size(1) == self.in_features
73
+ batch = x.size(0)
74
+ splines = self.b_splines(x).permute(1, 0, 2)
75
+ orig_coeff = self.scaled_spline_weight.permute(1, 2, 0)
76
+ unreduced_spline_output = torch.bmm(splines, orig_coeff).permute(1, 0, 2)
77
+ x_sorted = torch.sort(x, dim=0)[0]
78
+ grid_adaptive = x_sorted[torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)]
79
+ uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
80
+ grid_uniform = (torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device).unsqueeze(1) * uniform_step + x_sorted[0] - margin)
81
+ grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
82
+ grid = torch.cat([grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1)], dim=0)
83
+ self.grid.copy_(grid.T)
84
+ self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
85
+
86
+ def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
87
+ l1_fake = self.spline_weight.abs().mean(-1)
88
+ regularization_loss_activation = l1_fake.sum()
89
+ p = l1_fake / regularization_loss_activation
90
+ regularization_loss_entropy = -torch.sum(p * p.log())
91
+ return regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy
requirements.txt ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.2.0
2
+ attrs==23.2.0
3
+ blinker==1.7.0
4
+ cachetools==5.3.3
5
+ certifi==2024.2.2
6
+ charset-normalizer==3.3.2
7
+ click==8.1.7
8
+ contourpy==1.2.0
9
+ cycler==0.12.1
10
+ filelock==3.13.1
11
+ fonttools==4.50.0
12
+ fsspec==2024.3.1
13
+ gitdb==4.0.11
14
+ GitPython==3.1.42
15
+ idna==3.6
16
+ Jinja2==3.1.3
17
+ jsonschema==4.21.1
18
+ jsonschema-specifications==2023.12.1
19
+ kiwisolver==1.4.5
20
+ markdown-it-py==3.0.0
21
+ MarkupSafe==2.1.5
22
+ matplotlib==3.8.3
23
+ mdurl==0.1.2
24
+ mpmath==1.3.0
25
+ networkx==3.2.1
26
+ numpy==1.26.4
27
+ nvidia-cublas-cu12==12.1.3.1
28
+ nvidia-cuda-cupti-cu12==12.1.105
29
+ nvidia-cuda-nvrtc-cu12==12.1.105
30
+ nvidia-cuda-runtime-cu12==12.1.105
31
+ nvidia-cudnn-cu12==8.9.2.26
32
+ nvidia-cufft-cu12==11.0.2.54
33
+ nvidia-curand-cu12==10.3.2.106
34
+ nvidia-cusolver-cu12==11.4.5.107
35
+ nvidia-cusparse-cu12==12.1.0.106
36
+ nvidia-nccl-cu12==2.19.3
37
+ nvidia-nvjitlink-cu12==12.4.99
38
+ nvidia-nvtx-cu12==12.1.105
39
+ opencv-python==4.9.0.80
40
+ packaging==23.2
41
+ pandas==2.2.1
42
+ pillow==10.2.0
43
+ protobuf==4.25.3
44
+ psutil==5.9.8
45
+ py-cpuinfo==9.0.0
46
+ pyarrow==15.0.2
47
+ pydeck==0.8.1b0
48
+ Pygments==2.17.2
49
+ pyparsing==3.1.2
50
+ python-dateutil==2.9.0.post0
51
+ pytz==2024.1
52
+ PyYAML==6.0.1
53
+ referencing==0.34.0
54
+ requests==2.31.0
55
+ rich==13.7.1
56
+ rpds-py==0.18.0
57
+ scipy==1.12.0
58
+ seaborn==0.13.2
59
+ six==1.16.0
60
+ smmap==5.0.1
61
+ streamlit==1.32.2
62
+ sympy==1.12
63
+ tenacity==8.2.3
64
+ thop==0.1.1.post2209072238
65
+ toml==0.10.2
66
+ toolz==0.12.1
67
+ torch==2.2.1
68
+ torchvision==0.17.1
69
+ tornado==6.4
70
+ tqdm==4.66.2
71
+ triton==2.2.0
72
+ typing_extensions==4.10.0
73
+ tzdata==2024.1
74
+ ultralytics==8.1.30
75
+ urllib3==2.2.1
76
+ watchdog==4.0.0
77
+ pafy
78
+ youtube-dl
79
+ optuna