sharktide commited on
Commit
50bea56
·
verified ·
1 Parent(s): 5d0187f

Update custom_objects.py

Browse files
Files changed (1) hide show
  1. custom_objects.py +54 -0
custom_objects.py CHANGED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, models
3
+ from tensorflow.keras.saving import register_keras_serializable
4
+
5
+ @register_keras_serializable()
6
+ class StressAmplifier(tf.keras.layers.Layer):
7
+ def __init__(self, **kwargs):
8
+ super().__init__(**kwargs)
9
+
10
+ def call(self, inputs):
11
+ stress = inputs[:, 2]
12
+ slip = inputs[:, 4]
13
+ stress_boost = tf.sigmoid((stress - 400) * 0.01)
14
+ slip_boost = tf.sigmoid((slip - 8) * 0.5)
15
+ modulation = 1.0 + 0.4 * stress_boost * slip_boost
16
+ return tf.expand_dims(modulation, axis=-1)
17
+
18
+ @register_keras_serializable()
19
+ class DepthSuppressor(tf.keras.layers.Layer):
20
+ def __init__(self, **kwargs):
21
+ super().__init__(**kwargs)
22
+
23
+ def call(self, inputs):
24
+ depth = inputs[:, 3]
25
+ suppression = tf.sigmoid((depth - 25) * 0.15)
26
+ modulation = 1.0 - 0.3 * suppression
27
+ return tf.expand_dims(modulation, axis=-1)
28
+
29
+ @register_keras_serializable()
30
+ class DisplacementActivator(tf.keras.layers.Layer):
31
+ def __init__(self, **kwargs):
32
+ super().__init__(**kwargs)
33
+
34
+ def call(self, inputs):
35
+ displacement = inputs[:, 1]
36
+ activation = tf.sigmoid((displacement - 30) * 0.08)
37
+ modulation = 1.0 + 0.3 * activation
38
+ return tf.expand_dims(modulation, axis=-1)
39
+
40
+ @register_keras_serializable()
41
+ class SoftScale(tf.keras.layers.Layer):
42
+ def __init__(self, factor=0.25, **kwargs):
43
+ super().__init__(**kwargs)
44
+ self.factor = factor
45
+
46
+ def call(self, inputs):
47
+ return 1.0 + self.factor * tf.tanh(inputs - 1.0)
48
+
49
+ CUSTOM_OBJECTS = {
50
+ 'StressAmplifier': StressAmplifier,
51
+ 'DepthSuppressor': DepthSuppressor,
52
+ 'DisplacementActivator': DisplacementActivator,
53
+ 'SoftScale': SoftScale
54
+ }