Spaces:
Runtime error
Runtime error
File size: 6,014 Bytes
0f09377 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from layers.swin_blocks import SwinTransformer
from utils.model_utils import *
from utils.patch import PatchEmbedding
from utils.patch import PatchExtract
from utils.patch import PatchMerging
class HybridSwinTransformer(keras.Model):
def __init__(self, model_name, **kwargs):
super().__init__(name=model_name, **kwargs)
# base models
base = keras.applications.EfficientNetB0(
include_top=False,
weights=None,
input_tensor=keras.Input((params.image_size, params.image_size, 3)),
)
# base model with compatible output which will be an input of transformer model
self.new_base = keras.Model(
[base.inputs],
[base.get_layer("block6a_expand_activation").output, base.output],
name="efficientnet",
)
# stuff of swin transformers
self.patch_extract = PatchExtract(patch_size)
self.patch_embedds = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)
self.patch_merging = PatchMerging(
(num_patch_x, num_patch_y), embed_dim=embed_dim
)
# swin blocks containers
self.swin_sequences = keras.Sequential(name="swin_blocks")
for i in range(shift_size):
self.swin_sequences.add(
SwinTransformer(
dim=embed_dim,
num_patch=(num_patch_x, num_patch_y),
num_heads=num_heads,
window_size=window_size,
shift_size=i,
num_mlp=num_mlp,
qkv_bias=qkv_bias,
dropout_rate=dropout_rate,
)
)
# swin block's head
self.swin_head = keras.Sequential(
[
layers.GlobalAveragePooling1D(),
layers.AlphaDropout(0.5),
layers.BatchNormalization(),
],
name="swin_head",
)
# base model's (cnn model) head
self.conv_head = keras.Sequential(
[
layers.GlobalAveragePooling2D(),
layers.AlphaDropout(0.5),
],
name="conv_head",
)
# classifier
self.classifier = layers.Dense(
params.class_number, activation=None, dtype="float32"
)
self.build_graph()
def call(self, inputs, training=None, **kwargs):
x, base_gcam_top = self.new_base(inputs)
x = self.patch_extract(x)
x = self.patch_embedds(x)
x = self.swin_sequences(tf.cast(x, dtype=tf.float32))
x, swin_gcam_top = self.patch_merging(x)
swin_top = self.swin_head(x)
conv_top = self.conv_head(base_gcam_top)
preds = self.classifier(tf.concat([swin_top, conv_top], axis=-1))
if training: # training phase
return preds
else: # inference phase
return preds, base_gcam_top, swin_gcam_top
def build_graph(self):
x = keras.Input(shape=(params.image_size, params.image_size, 3))
return keras.Model(inputs=[x], outputs=self.call(x))
class GradientAccumulation(HybridSwinTransformer):
"""ref: https://gist.github.com/innat/ba6740293e7b7b227829790686f2119c"""
def __init__(self, n_gradients, **kwargs):
super().__init__(**kwargs)
self.n_gradients = tf.constant(n_gradients, dtype=tf.int32)
self.n_acum_step = tf.Variable(0, dtype=tf.int32, trainable=False)
self.gradient_accumulation = [
tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False)
for v in self.trainable_variables
]
def train_step(self, data):
# track accumulation step update
self.n_acum_step.assign_add(1)
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Calculate batch gradients
gradients = tape.gradient(loss, self.trainable_variables)
# Accumulate batch gradients
for i in range(len(self.gradient_accumulation)):
self.gradient_accumulation[i].assign_add(gradients[i])
# If n_acum_step reach the n_gradients then we apply accumulated gradients to -
# update the variables otherwise do nothing
tf.cond(
tf.equal(self.n_acum_step, self.n_gradients),
self.apply_accu_gradients,
lambda: None,
)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
def apply_accu_gradients(self):
# Update weights
self.optimizer.apply_gradients(
zip(self.gradient_accumulation, self.trainable_variables)
)
# reset accumulation step
self.n_acum_step.assign(0)
for i in range(len(self.gradient_accumulation)):
self.gradient_accumulation[i].assign(
tf.zeros_like(self.trainable_variables[i], dtype=tf.float32)
)
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_pred, base_gcam_top, swin_gcam_top = self(x, training=False)
# Updates the metrics tracking the loss
self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Update the metrics.
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
|