innat commited on
Commit
0f09377
1 Parent(s): f1deb8a
.gitignore ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # Pycharm
132
+ .idea/
README.md CHANGED
@@ -1,12 +1,16 @@
1
  ---
2
- title: HybridModel GradCAM
3
- emoji:
4
  colorFrom: purple
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.0.19
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
+ title: Demo
3
+ emoji: 🔥
4
  colorFrom: purple
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.0.15
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ ## Visual Interpretation of a Hybrid Model
13
+
14
+ Building a hybrid model with *EfficientNet* and *Swin Transformer*, we have tried to inspect the visual interpretations of a CNN and Transformer blocks of a hybrid model (CNN + Swin Transformer) with the GradCAM technique. As a result, it appears that the transformer blocks are capable of globally refining feature activation across the relevant object, as opposed to the CNN, which is more focused on operating locally. However, the approach that will be shown here, is experimental. The workflow probably can generate a more meaningful modeling approach. The model is trained on [tf_flowers](https://www.tensorflow.org/datasets/catalog/tf_flowers) dataset, a multi-class classification problem.
15
+
16
+ ![]('./Presentation2.png')
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gdown
4
+ import gradio as gr
5
+ import tensorflow as tf
6
+
7
+ from config import Parameters
8
+ from models.hybrid_model import GradientAccumulation
9
+ from utils.model_utils import *
10
+ from utils.viz_utils import make_gradcam_heatmap
11
+ from utils.viz_utils import save_and_display_gradcam
12
+
13
+ image_size = Parameters().image_size
14
+ str_labels = [
15
+ "daisy",
16
+ "dandelion",
17
+ "roses",
18
+ "sunflowers",
19
+ "tulips",
20
+ ]
21
+
22
+
23
+ def get_model():
24
+ """Get the model."""
25
+ model = GradientAccumulation(
26
+ n_gradients=params.num_grad_accumulation, model_name="HybridModel"
27
+ )
28
+ _ = model(tf.ones((1, params.image_size, params.image_size, 3)))[0].shape
29
+ return model
30
+
31
+
32
+ def get_model_weight(model_id):
33
+ """Get the trained weights."""
34
+ if not os.path.exists("model.h5"):
35
+ model_weight = gdown.download(id=model_id, quiet=False)
36
+ else:
37
+ model_weight = "model.h5"
38
+ return model_weight
39
+
40
+
41
+ def load_model(model_id):
42
+ """Load trained model."""
43
+ weight = get_model_weight(model_id)
44
+ model = get_model()
45
+ model.load_weights(weight)
46
+ return model
47
+
48
+
49
+ def image_process(image):
50
+ """Image preprocess for model input."""
51
+ image = tf.cast(image, dtype=tf.float32)
52
+ original_shape = image.shape
53
+ image = tf.image.resize(image, [image_size, image_size])
54
+ image = image[tf.newaxis, ...]
55
+ return image, original_shape
56
+
57
+
58
+ def predict_fn(image):
59
+ """A predict function that will be invoked by gradio."""
60
+ loaded_model = load_model(model_id="1y6tseN0194T6d-4iIh5wo7RL9ttQERe0")
61
+ loaded_image, original_shape = image_process(image)
62
+
63
+ heatmap_a, heatmap_b, preds = make_gradcam_heatmap(loaded_image, loaded_model)
64
+ int_label = tf.argmax(preds, axis=-1).numpy()[0]
65
+ str_label = str_labels[int_label]
66
+
67
+ overaly_a = save_and_display_gradcam(
68
+ loaded_image[0], heatmap_a, image_shape=original_shape[:2]
69
+ )
70
+ overlay_b = save_and_display_gradcam(
71
+ loaded_image[0], heatmap_b, image_shape=original_shape[:2]
72
+ )
73
+
74
+ return [f"Predicted: {str_label}", overaly_a, overlay_b]
75
+
76
+
77
+ iface = gr.Interface(
78
+ fn=predict_fn,
79
+ inputs=gr.inputs.Image(label="Input Image"),
80
+ outputs=[
81
+ gr.outputs.Label(label="Prediction"),
82
+ gr.inputs.Image(label="CNN GradCAM"),
83
+ gr.inputs.Image(label="Transformer GradCAM"),
84
+ ],
85
+ title="Hybrid EfficientNet Swin Transformer Demo",
86
+ description="The model is trained on tf_flowers dataset.",
87
+ examples=[
88
+ ["examples/dandelion.jpg"],
89
+ ["examples/sunflower.jpg"],
90
+ ["examples/tulip.jpg"],
91
+ ["examples/daisy.jpg"],
92
+ ["examples/rose.jpg"],
93
+ ],
94
+ )
95
+ iface.launch()
config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+
4
+
5
+ class Parameters:
6
+ # data level
7
+ image_count = 3670
8
+ image_size = 384
9
+ batch_size = 12
10
+ num_grad_accumulation = 8
11
+ label_smooth = 0.05
12
+ class_number = 5
13
+ val_split = 0.2
14
+ autotune = tf.data.AUTOTUNE
15
+
16
+ # hparams
17
+ epochs = 10
18
+ lr_sched = "cosine_restart"
19
+ lr_base = 0.016
20
+ lr_min = 0
21
+ lr_decay_epoch = 2.4
22
+ lr_warmup_epoch = 5
23
+ lr_decay_factor = 0.97
24
+
25
+ scaled_lr = lr_base * (batch_size / 256.0)
26
+ scaled_lr_min = lr_min * (batch_size / 256.0)
27
+ num_validation_sample = int(image_count * val_split)
28
+ num_training_sample = image_count - num_validation_sample
29
+ train_step = int(np.ceil(num_training_sample / float(batch_size)))
30
+ total_steps = train_step * epochs
examples/daisy.jpg ADDED
examples/dandelion.jpg ADDED
examples/rose.jpg ADDED
examples/sunflower.jpg ADDED
examples/tulip.jpg ADDED
layers/__init__.py ADDED
File without changes
layers/swin_blocks.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from jax import numpy as jnp
3
+ except ModuleNotFoundError:
4
+ # jax doesn't support windows os yet.
5
+ import numpy as jnp
6
+
7
+ import tensorflow as tf
8
+ from tensorflow import keras
9
+ from tensorflow.keras import layers
10
+
11
+ from layers.window_attention import WindowAttention
12
+ from utils.drop_path import DropPath
13
+ from utils.swin_window import window_partition
14
+ from utils.swin_window import window_reverse
15
+
16
+
17
+ class SwinTransformer(layers.Layer):
18
+ def __init__(
19
+ self,
20
+ dim,
21
+ num_patch,
22
+ num_heads,
23
+ window_size=7,
24
+ shift_size=0,
25
+ num_mlp=1024,
26
+ qkv_bias=True,
27
+ dropout_rate=0.0,
28
+ **kwargs,
29
+ ):
30
+ super(SwinTransformer, self).__init__(**kwargs)
31
+
32
+ self.dim = dim
33
+ self.num_patch = num_patch
34
+ self.num_heads = num_heads
35
+ self.window_size = window_size
36
+ self.shift_size = shift_size
37
+ self.num_mlp = num_mlp
38
+
39
+ self.norm1 = layers.LayerNormalization(epsilon=1e-5)
40
+ self.attn = WindowAttention(
41
+ dim,
42
+ window_size=(self.window_size, self.window_size),
43
+ num_heads=num_heads,
44
+ qkv_bias=qkv_bias,
45
+ dropout_rate=dropout_rate,
46
+ )
47
+ self.drop_path = DropPath(dropout_rate) if dropout_rate > 0.0 else tf.identity
48
+ self.norm2 = layers.LayerNormalization(epsilon=1e-5)
49
+
50
+ self.mlp = keras.Sequential(
51
+ [
52
+ layers.Dense(num_mlp),
53
+ layers.Activation(keras.activations.gelu),
54
+ layers.Dropout(dropout_rate),
55
+ layers.Dense(dim),
56
+ layers.Dropout(dropout_rate),
57
+ ]
58
+ )
59
+
60
+ if min(self.num_patch) < self.window_size:
61
+ self.shift_size = 0
62
+ self.window_size = min(self.num_patch)
63
+
64
+ def build(self, input_shape):
65
+ if self.shift_size == 0:
66
+ self.attn_mask = None
67
+ else:
68
+ height, width = self.num_patch
69
+ h_slices = (
70
+ slice(0, -self.window_size),
71
+ slice(-self.window_size, -self.shift_size),
72
+ slice(-self.shift_size, None),
73
+ )
74
+ w_slices = (
75
+ slice(0, -self.window_size),
76
+ slice(-self.window_size, -self.shift_size),
77
+ slice(-self.shift_size, None),
78
+ )
79
+ mask_array = jnp.zeros((1, height, width, 1))
80
+ count = 0
81
+ for h in h_slices:
82
+ for w in w_slices:
83
+ mask_array[:, h, w, :] = count
84
+ count += 1
85
+ mask_array = tf.convert_to_tensor(mask_array)
86
+
87
+ # mask array to windows
88
+ mask_windows = window_partition(mask_array, self.window_size)
89
+ mask_windows = tf.reshape(
90
+ mask_windows, shape=[-1, self.window_size * self.window_size]
91
+ )
92
+ attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(
93
+ mask_windows, axis=2
94
+ )
95
+ attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
96
+ attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
97
+ self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False)
98
+
99
+ def call(self, x):
100
+ height, width = self.num_patch
101
+ _, num_patches_before, channels = x.shape
102
+ x_skip = x
103
+ x = self.norm1(x)
104
+ x = tf.reshape(x, shape=(-1, height, width, channels))
105
+ if self.shift_size > 0:
106
+ shifted_x = tf.roll(
107
+ x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
108
+ )
109
+ else:
110
+ shifted_x = x
111
+
112
+ x_windows = window_partition(shifted_x, self.window_size)
113
+ x_windows = tf.reshape(
114
+ x_windows, shape=(-1, self.window_size * self.window_size, channels)
115
+ )
116
+ attn_windows = self.attn(x_windows, mask=self.attn_mask)
117
+
118
+ attn_windows = tf.reshape(
119
+ attn_windows, shape=(-1, self.window_size, self.window_size, channels)
120
+ )
121
+ shifted_x = window_reverse(
122
+ attn_windows, self.window_size, height, width, channels
123
+ )
124
+ if self.shift_size > 0:
125
+ x = tf.roll(
126
+ shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
127
+ )
128
+ else:
129
+ x = shifted_x
130
+
131
+ x = tf.reshape(x, shape=(-1, height * width, channels))
132
+ x = self.drop_path(x)
133
+ x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32)
134
+ x_skip = x
135
+ x = self.norm2(x)
136
+ x = self.mlp(x)
137
+ x = self.drop_path(x)
138
+ x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32)
139
+ return x
layers/window_attention.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers
3
+
4
+
5
+ class WindowAttention(layers.Layer):
6
+ def __init__(
7
+ self,
8
+ dim,
9
+ window_size,
10
+ num_heads,
11
+ qkv_bias=True,
12
+ dropout_rate=0.0,
13
+ return_attention_scores=False,
14
+ **kwargs,
15
+ ):
16
+ super().__init__(**kwargs)
17
+ self.dim = dim
18
+ self.window_size = window_size
19
+ self.num_heads = num_heads
20
+ self.scale = (dim // num_heads) ** -0.5
21
+ self.return_attention_scores = return_attention_scores
22
+ self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)
23
+ self.dropout = layers.Dropout(dropout_rate)
24
+ self.proj = layers.Dense(dim)
25
+
26
+ def build(self, input_shape):
27
+ self.relative_position_bias_table = self.add_weight(
28
+ shape=(
29
+ (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
30
+ self.num_heads,
31
+ ),
32
+ initializer="zeros",
33
+ trainable=True,
34
+ name="relative_position_bias_table",
35
+ )
36
+
37
+ self.relative_position_index = self.get_relative_position_index(
38
+ self.window_size[0], self.window_size[1]
39
+ )
40
+ super().build(input_shape)
41
+
42
+ def get_relative_position_index(self, window_height, window_width):
43
+ x_x, y_y = tf.meshgrid(range(window_height), range(window_width))
44
+ coords = tf.stack([y_y, x_x], axis=0)
45
+ coords_flatten = tf.reshape(coords, [2, -1])
46
+
47
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
48
+ relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0])
49
+
50
+ x_x = (relative_coords[:, :, 0] + window_height - 1) * (2 * window_width - 1)
51
+ y_y = relative_coords[:, :, 1] + window_width - 1
52
+ relative_coords = tf.stack([x_x, y_y], axis=-1)
53
+
54
+ return tf.reduce_sum(relative_coords, axis=-1)
55
+
56
+ def call(self, x, mask=None):
57
+ _, size, channels = x.shape
58
+ head_dim = channels // self.num_heads
59
+ x_qkv = self.qkv(x)
60
+ x_qkv = tf.reshape(x_qkv, shape=(-1, size, 3, self.num_heads, head_dim))
61
+ x_qkv = tf.transpose(x_qkv, perm=(2, 0, 3, 1, 4))
62
+ q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
63
+ q = q * self.scale
64
+ k = tf.transpose(k, perm=(0, 1, 3, 2))
65
+ attn = q @ k
66
+
67
+ relative_position_bias = tf.gather(
68
+ self.relative_position_bias_table,
69
+ self.relative_position_index,
70
+ axis=0,
71
+ )
72
+ relative_position_bias = tf.transpose(relative_position_bias, [2, 0, 1])
73
+ attn = attn + tf.expand_dims(relative_position_bias, axis=0)
74
+
75
+ if mask is not None:
76
+ nW = mask.get_shape()[0]
77
+ mask_float = tf.cast(
78
+ tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32
79
+ )
80
+ attn = (
81
+ tf.reshape(attn, shape=(-1, nW, self.num_heads, size, size))
82
+ + mask_float
83
+ )
84
+ attn = tf.reshape(attn, shape=(-1, self.num_heads, size, size))
85
+ attn = tf.nn.softmax(attn, axis=-1)
86
+ else:
87
+ attn = tf.nn.softmax(attn, axis=-1)
88
+ attn = self.dropout(attn)
89
+
90
+ x_qkv = attn @ v
91
+ x_qkv = tf.transpose(x_qkv, perm=(0, 2, 1, 3))
92
+ x_qkv = tf.reshape(x_qkv, shape=(-1, size, channels))
93
+ x_qkv = self.proj(x_qkv)
94
+ x_qkv = self.dropout(x_qkv)
95
+
96
+ if self.return_attention_scores:
97
+ return x_qkv, attn
98
+ else:
99
+ return x_qkv
100
+
101
+ def get_config(self):
102
+ config = super().get_config()
103
+ config.update(
104
+ {
105
+ "dim": self.dim,
106
+ "window_size": self.window_size,
107
+ "num_heads": self.num_heads,
108
+ "scale": self.scale,
109
+ }
110
+ )
111
+ return config
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
models/hybrid_model.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow import keras
3
+ from tensorflow.keras import layers
4
+
5
+ from layers.swin_blocks import SwinTransformer
6
+ from utils.model_utils import *
7
+ from utils.patch import PatchEmbedding
8
+ from utils.patch import PatchExtract
9
+ from utils.patch import PatchMerging
10
+
11
+
12
+ class HybridSwinTransformer(keras.Model):
13
+ def __init__(self, model_name, **kwargs):
14
+ super().__init__(name=model_name, **kwargs)
15
+ # base models
16
+ base = keras.applications.EfficientNetB0(
17
+ include_top=False,
18
+ weights=None,
19
+ input_tensor=keras.Input((params.image_size, params.image_size, 3)),
20
+ )
21
+
22
+ # base model with compatible output which will be an input of transformer model
23
+ self.new_base = keras.Model(
24
+ [base.inputs],
25
+ [base.get_layer("block6a_expand_activation").output, base.output],
26
+ name="efficientnet",
27
+ )
28
+
29
+ # stuff of swin transformers
30
+ self.patch_extract = PatchExtract(patch_size)
31
+ self.patch_embedds = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)
32
+ self.patch_merging = PatchMerging(
33
+ (num_patch_x, num_patch_y), embed_dim=embed_dim
34
+ )
35
+
36
+ # swin blocks containers
37
+ self.swin_sequences = keras.Sequential(name="swin_blocks")
38
+ for i in range(shift_size):
39
+ self.swin_sequences.add(
40
+ SwinTransformer(
41
+ dim=embed_dim,
42
+ num_patch=(num_patch_x, num_patch_y),
43
+ num_heads=num_heads,
44
+ window_size=window_size,
45
+ shift_size=i,
46
+ num_mlp=num_mlp,
47
+ qkv_bias=qkv_bias,
48
+ dropout_rate=dropout_rate,
49
+ )
50
+ )
51
+
52
+ # swin block's head
53
+ self.swin_head = keras.Sequential(
54
+ [
55
+ layers.GlobalAveragePooling1D(),
56
+ layers.AlphaDropout(0.5),
57
+ layers.BatchNormalization(),
58
+ ],
59
+ name="swin_head",
60
+ )
61
+
62
+ # base model's (cnn model) head
63
+ self.conv_head = keras.Sequential(
64
+ [
65
+ layers.GlobalAveragePooling2D(),
66
+ layers.AlphaDropout(0.5),
67
+ ],
68
+ name="conv_head",
69
+ )
70
+
71
+ # classifier
72
+ self.classifier = layers.Dense(
73
+ params.class_number, activation=None, dtype="float32"
74
+ )
75
+ self.build_graph()
76
+
77
+ def call(self, inputs, training=None, **kwargs):
78
+ x, base_gcam_top = self.new_base(inputs)
79
+ x = self.patch_extract(x)
80
+ x = self.patch_embedds(x)
81
+ x = self.swin_sequences(tf.cast(x, dtype=tf.float32))
82
+ x, swin_gcam_top = self.patch_merging(x)
83
+
84
+ swin_top = self.swin_head(x)
85
+ conv_top = self.conv_head(base_gcam_top)
86
+ preds = self.classifier(tf.concat([swin_top, conv_top], axis=-1))
87
+
88
+ if training: # training phase
89
+ return preds
90
+ else: # inference phase
91
+ return preds, base_gcam_top, swin_gcam_top
92
+
93
+ def build_graph(self):
94
+ x = keras.Input(shape=(params.image_size, params.image_size, 3))
95
+ return keras.Model(inputs=[x], outputs=self.call(x))
96
+
97
+
98
+ class GradientAccumulation(HybridSwinTransformer):
99
+ """ref: https://gist.github.com/innat/ba6740293e7b7b227829790686f2119c"""
100
+
101
+ def __init__(self, n_gradients, **kwargs):
102
+ super().__init__(**kwargs)
103
+ self.n_gradients = tf.constant(n_gradients, dtype=tf.int32)
104
+ self.n_acum_step = tf.Variable(0, dtype=tf.int32, trainable=False)
105
+ self.gradient_accumulation = [
106
+ tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False)
107
+ for v in self.trainable_variables
108
+ ]
109
+
110
+ def train_step(self, data):
111
+ # track accumulation step update
112
+ self.n_acum_step.assign_add(1)
113
+
114
+ # Unpack the data. Its structure depends on your model and
115
+ # on what you pass to `fit()`.
116
+ x, y = data
117
+
118
+ with tf.GradientTape() as tape:
119
+ y_pred = self(x, training=True) # Forward pass
120
+ loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
121
+
122
+ # Calculate batch gradients
123
+ gradients = tape.gradient(loss, self.trainable_variables)
124
+
125
+ # Accumulate batch gradients
126
+ for i in range(len(self.gradient_accumulation)):
127
+ self.gradient_accumulation[i].assign_add(gradients[i])
128
+
129
+ # If n_acum_step reach the n_gradients then we apply accumulated gradients to -
130
+ # update the variables otherwise do nothing
131
+ tf.cond(
132
+ tf.equal(self.n_acum_step, self.n_gradients),
133
+ self.apply_accu_gradients,
134
+ lambda: None,
135
+ )
136
+
137
+ # Return a dict mapping metric names to current value.
138
+ # Note that it will include the loss (tracked in self.metrics).
139
+ self.compiled_metrics.update_state(y, y_pred)
140
+ return {m.name: m.result() for m in self.metrics}
141
+
142
+ def apply_accu_gradients(self):
143
+ # Update weights
144
+ self.optimizer.apply_gradients(
145
+ zip(self.gradient_accumulation, self.trainable_variables)
146
+ )
147
+
148
+ # reset accumulation step
149
+ self.n_acum_step.assign(0)
150
+ for i in range(len(self.gradient_accumulation)):
151
+ self.gradient_accumulation[i].assign(
152
+ tf.zeros_like(self.trainable_variables[i], dtype=tf.float32)
153
+ )
154
+
155
+ def test_step(self, data):
156
+ # Unpack the data
157
+ x, y = data
158
+
159
+ # Compute predictions
160
+ y_pred, base_gcam_top, swin_gcam_top = self(x, training=False)
161
+
162
+ # Updates the metrics tracking the loss
163
+ self.compiled_loss(y, y_pred, regularization_losses=self.losses)
164
+
165
+ # Update the metrics.
166
+ self.compiled_metrics.update_state(y, y_pred)
167
+
168
+ # Return a dict mapping metric names to current value.
169
+ # Note that it will include the loss (tracked in self.metrics).
170
+ return {m.name: m.result() for m in self.metrics}
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ tensorflow==2.6.4
2
+ jax==0.3.13
3
+ jaxlib
4
+ numpy
5
+ matplotlib==3.5.2
6
+ gradio==3.0.15
7
+ gdown==4.4.0
utils/__init__.py ADDED
File without changes
utils/drop_path.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import backend
3
+ from tensorflow.keras import layers
4
+
5
+
6
+ class DropPath(layers.Layer):
7
+ def __init__(self, drop_prob=None, **kwargs):
8
+ super(DropPath, self).__init__(**kwargs)
9
+ self.drop_prob = drop_prob
10
+
11
+ def call(self, inputs, training=None):
12
+ if self.drop_prob == 0.0 or not training:
13
+ return inputs
14
+ else:
15
+ batch_size = tf.shape(inputs)[0]
16
+ keep_prob = 1 - self.drop_prob
17
+ path_mask_shape = (batch_size,) + (1,) * (len(tf.shape(inputs)) - 1)
18
+ path_mask = tf.floor(backend.random_bernoulli(path_mask_shape, p=keep_prob))
19
+ outputs = (
20
+ tf.math.divide(tf.cast(inputs, dtype=tf.float32), keep_prob) * path_mask
21
+ )
22
+ return outputs
23
+
24
+ def get_config(self):
25
+ config = super().get_config()
26
+ config.update(
27
+ {
28
+ "drop_prob": self.drop_prob,
29
+ }
30
+ )
31
+ return config
utils/model_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+
4
+
5
+ class Parameters:
6
+ # data level
7
+ image_count = 3670
8
+ image_size = 384
9
+ batch_size = 12
10
+ num_grad_accumulation = 8
11
+ class_number = 5
12
+ val_split = 0.2
13
+ autotune = tf.data.AUTOTUNE
14
+
15
+ # hparams
16
+ epochs = 10
17
+ lr_sched = "cosine_restart"
18
+ lr_base = 0.016
19
+ lr_min = 0
20
+ lr_decay_epoch = 2.4
21
+ lr_warmup_epoch = 5
22
+ lr_decay_factor = 0.97
23
+
24
+ scaled_lr = lr_base * (batch_size / 256.0)
25
+ scaled_lr_min = lr_min * (batch_size / 256.0)
26
+ num_validation_sample = int(image_count * val_split)
27
+ num_training_sample = image_count - num_validation_sample
28
+ train_step = int(np.ceil(num_training_sample / float(batch_size)))
29
+ total_steps = train_step * epochs
30
+
31
+
32
+ params = Parameters()
33
+
34
+
35
+ patch_size = (2, 2) # 4-by-4 sized patches
36
+ dropout_rate = 0.5 # Dropout rate
37
+ num_heads = 8 # Attention heads
38
+ embed_dim = 64 # Embedding dimension
39
+ num_mlp = 128 # MLP layer size
40
+ qkv_bias = True # Convert embedded patches to query, key, and values with a learnable additive value
41
+ window_size = 2 # Size of attention window
42
+ shift_size = 1 # Size of shifting window
43
+ image_dimension = 24 # Initial image size / Input size of the transformer model
44
+
45
+ num_patch_x = image_dimension // patch_size[0]
46
+ num_patch_y = image_dimension // patch_size[1]
utils/patch.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers
3
+
4
+
5
+ class PatchExtract(layers.Layer):
6
+ def __init__(self, patch_size, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.patch_size_x = patch_size[0]
9
+ self.patch_size_y = patch_size[0]
10
+
11
+ def call(self, images):
12
+ batch_size = tf.shape(images)[0]
13
+ patches = tf.image.extract_patches(
14
+ images=images,
15
+ sizes=(1, self.patch_size_x, self.patch_size_y, 1),
16
+ strides=(1, self.patch_size_x, self.patch_size_y, 1),
17
+ rates=(1, 1, 1, 1),
18
+ padding="VALID",
19
+ )
20
+ patch_dim = patches.shape[-1]
21
+ patch_num = patches.shape[1]
22
+ return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))
23
+
24
+ def get_config(self):
25
+ config = super().get_config()
26
+ config.update(
27
+ {
28
+ "patch_size_y": self.patch_size_y,
29
+ "patch_size_x": self.patch_size_x,
30
+ }
31
+ )
32
+ return config
33
+
34
+
35
+ class PatchEmbedding(layers.Layer):
36
+ def __init__(self, num_patch, embed_dim, **kwargs):
37
+ super().__init__(**kwargs)
38
+ self.num_patch = num_patch
39
+ self.proj = layers.Dense(embed_dim)
40
+ self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)
41
+
42
+ def call(self, patch):
43
+ pos = tf.range(start=0, limit=self.num_patch, delta=1)
44
+ return self.proj(patch) + self.pos_embed(pos)
45
+
46
+ def get_config(self):
47
+ config = super().get_config()
48
+ config.update(
49
+ {
50
+ "num_patch": self.num_patch,
51
+ }
52
+ )
53
+ return config
54
+
55
+
56
+ class PatchMerging(layers.Layer):
57
+ def __init__(self, num_patch, embed_dim):
58
+ super().__init__()
59
+ self.num_patch = num_patch
60
+ self.embed_dim = embed_dim
61
+ self.linear_trans = layers.Dense(2 * embed_dim, use_bias=False)
62
+
63
+ def call(self, x):
64
+ height, width = self.num_patch
65
+ _, _, C = x.get_shape().as_list()
66
+ x = tf.reshape(x, shape=(-1, height, width, C))
67
+ feat_maps = x
68
+
69
+ x0 = x[:, 0::2, 0::2, :]
70
+ x1 = x[:, 1::2, 0::2, :]
71
+ x2 = x[:, 0::2, 1::2, :]
72
+ x3 = x[:, 1::2, 1::2, :]
73
+ x = tf.concat((x0, x1, x2, x3), axis=-1)
74
+ x = tf.reshape(x, shape=(-1, (height // 2) * (width // 2), 4 * C))
75
+ return self.linear_trans(x), feat_maps
76
+
77
+ def get_config(self):
78
+ config = super().get_config()
79
+ config.update({"num_patch": self.num_patch, "embed_dim": self.embed_dim})
80
+ return config
utils/swin_window.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+
4
+ def window_partition(x, window_size):
5
+ _, height, width, channels = x.shape
6
+ patch_num_y = height // window_size
7
+ patch_num_x = width // window_size
8
+ x = tf.reshape(
9
+ x, shape=(-1, patch_num_y, window_size, patch_num_x, window_size, channels)
10
+ )
11
+ x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
12
+ windows = tf.reshape(x, shape=(-1, window_size, window_size, channels))
13
+ return windows
14
+
15
+
16
+ def window_reverse(windows, window_size, height, width, channels):
17
+ patch_num_y = height // window_size
18
+ patch_num_x = width // window_size
19
+ x = tf.reshape(
20
+ windows,
21
+ shape=(-1, patch_num_y, patch_num_x, window_size, window_size, channels),
22
+ )
23
+ x = tf.transpose(x, perm=(0, 1, 3, 2, 4, 5))
24
+ x = tf.reshape(x, shape=(-1, height, width, channels))
25
+ return x
utils/viz_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.cm as cm
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow import keras
5
+
6
+
7
+ def make_gradcam_heatmap(img_array, grad_model, pred_index=None):
8
+ with tf.GradientTape(persistent=True) as tape:
9
+ preds, base_top, swin_top = grad_model(img_array)
10
+ if pred_index is None:
11
+ pred_index = tf.argmax(preds[0])
12
+ class_channel = preds[:, pred_index]
13
+
14
+ grads = tape.gradient(class_channel, base_top)
15
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
16
+ base_top = base_top[0]
17
+ heatmap_a = base_top @ pooled_grads[..., tf.newaxis]
18
+ heatmap_a = tf.squeeze(heatmap_a)
19
+ heatmap_a = tf.maximum(heatmap_a, 0) / tf.math.reduce_max(heatmap_a)
20
+ heatmap_a = heatmap_a.numpy()
21
+
22
+ grads = tape.gradient(class_channel, swin_top)
23
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
24
+ swin_top = swin_top[0]
25
+ heatmap_b = swin_top @ pooled_grads[..., tf.newaxis]
26
+ heatmap_b = tf.squeeze(heatmap_b)
27
+ heatmap_b = tf.maximum(heatmap_b, 0) / tf.math.reduce_max(heatmap_b)
28
+ heatmap_b = heatmap_b.numpy()
29
+ return heatmap_a, heatmap_b, preds
30
+
31
+
32
+ def save_and_display_gradcam(
33
+ img,
34
+ heatmap,
35
+ target=None,
36
+ pred=None,
37
+ cam_path="cam.jpg",
38
+ cmap="jet", # inferno, viridis
39
+ alpha=0.6,
40
+ plot=None,
41
+ image_shape=None,
42
+ ):
43
+ # Rescale heatmap to a range 0-255
44
+ heatmap = np.uint8(255 * heatmap)
45
+
46
+ # Use jet colormap to colorize heatmap
47
+ jet = cm.get_cmap(cmap)
48
+
49
+ # Use RGB values of the colormap
50
+ jet_colors = jet(np.arange(256))[:, :3]
51
+ jet_heatmap = jet_colors[heatmap]
52
+
53
+ # Create an image with RGB colorized heatmap
54
+ jet_heatmap = keras.utils.array_to_img(jet_heatmap)
55
+ jet_heatmap = jet_heatmap.resize((img.shape[0], img.shape[1]))
56
+ jet_heatmap = keras.utils.img_to_array(jet_heatmap)
57
+
58
+ # Superimpose the heatmap on original image
59
+ superimposed_img = img + jet_heatmap * alpha
60
+ superimposed_img = keras.utils.array_to_img(superimposed_img)
61
+
62
+ size_w, size_h = image_shape[:2]
63
+ superimposed_img = superimposed_img.resize((size_h, size_w))
64
+ return superimposed_img