HybridModel-GradCAM / models /hybrid_model.py
innat
init
0f09377
raw
history blame
6.01 kB
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from layers.swin_blocks import SwinTransformer
from utils.model_utils import *
from utils.patch import PatchEmbedding
from utils.patch import PatchExtract
from utils.patch import PatchMerging
class HybridSwinTransformer(keras.Model):
def __init__(self, model_name, **kwargs):
super().__init__(name=model_name, **kwargs)
# base models
base = keras.applications.EfficientNetB0(
include_top=False,
weights=None,
input_tensor=keras.Input((params.image_size, params.image_size, 3)),
)
# base model with compatible output which will be an input of transformer model
self.new_base = keras.Model(
[base.inputs],
[base.get_layer("block6a_expand_activation").output, base.output],
name="efficientnet",
)
# stuff of swin transformers
self.patch_extract = PatchExtract(patch_size)
self.patch_embedds = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)
self.patch_merging = PatchMerging(
(num_patch_x, num_patch_y), embed_dim=embed_dim
)
# swin blocks containers
self.swin_sequences = keras.Sequential(name="swin_blocks")
for i in range(shift_size):
self.swin_sequences.add(
SwinTransformer(
dim=embed_dim,
num_patch=(num_patch_x, num_patch_y),
num_heads=num_heads,
window_size=window_size,
shift_size=i,
num_mlp=num_mlp,
qkv_bias=qkv_bias,
dropout_rate=dropout_rate,
)
)
# swin block's head
self.swin_head = keras.Sequential(
[
layers.GlobalAveragePooling1D(),
layers.AlphaDropout(0.5),
layers.BatchNormalization(),
],
name="swin_head",
)
# base model's (cnn model) head
self.conv_head = keras.Sequential(
[
layers.GlobalAveragePooling2D(),
layers.AlphaDropout(0.5),
],
name="conv_head",
)
# classifier
self.classifier = layers.Dense(
params.class_number, activation=None, dtype="float32"
)
self.build_graph()
def call(self, inputs, training=None, **kwargs):
x, base_gcam_top = self.new_base(inputs)
x = self.patch_extract(x)
x = self.patch_embedds(x)
x = self.swin_sequences(tf.cast(x, dtype=tf.float32))
x, swin_gcam_top = self.patch_merging(x)
swin_top = self.swin_head(x)
conv_top = self.conv_head(base_gcam_top)
preds = self.classifier(tf.concat([swin_top, conv_top], axis=-1))
if training: # training phase
return preds
else: # inference phase
return preds, base_gcam_top, swin_gcam_top
def build_graph(self):
x = keras.Input(shape=(params.image_size, params.image_size, 3))
return keras.Model(inputs=[x], outputs=self.call(x))
class GradientAccumulation(HybridSwinTransformer):
"""ref: https://gist.github.com/innat/ba6740293e7b7b227829790686f2119c"""
def __init__(self, n_gradients, **kwargs):
super().__init__(**kwargs)
self.n_gradients = tf.constant(n_gradients, dtype=tf.int32)
self.n_acum_step = tf.Variable(0, dtype=tf.int32, trainable=False)
self.gradient_accumulation = [
tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False)
for v in self.trainable_variables
]
def train_step(self, data):
# track accumulation step update
self.n_acum_step.assign_add(1)
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Calculate batch gradients
gradients = tape.gradient(loss, self.trainable_variables)
# Accumulate batch gradients
for i in range(len(self.gradient_accumulation)):
self.gradient_accumulation[i].assign_add(gradients[i])
# If n_acum_step reach the n_gradients then we apply accumulated gradients to -
# update the variables otherwise do nothing
tf.cond(
tf.equal(self.n_acum_step, self.n_gradients),
self.apply_accu_gradients,
lambda: None,
)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
def apply_accu_gradients(self):
# Update weights
self.optimizer.apply_gradients(
zip(self.gradient_accumulation, self.trainable_variables)
)
# reset accumulation step
self.n_acum_step.assign(0)
for i in range(len(self.gradient_accumulation)):
self.gradient_accumulation[i].assign(
tf.zeros_like(self.trainable_variables[i], dtype=tf.float32)
)
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_pred, base_gcam_top, swin_gcam_top = self(x, training=False)
# Updates the metrics tracking the loss
self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Update the metrics.
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}