cast to float
Browse files
app.py
CHANGED
@@ -33,7 +33,7 @@ def create_model(input_shape):
|
|
33 |
def soft_quantized_influence_measure_sigmoid(
|
34 |
y_true: tf.Tensor, y_pred: tf.Tensor, threshold: float = 0.5
|
35 |
) -> tf.Tensor:
|
36 |
-
#
|
37 |
y_true = tf.cast(y_true, tf.float32)
|
38 |
y_pred = tf.cast(y_pred, tf.float32)
|
39 |
|
@@ -56,11 +56,16 @@ def soft_quantized_influence_measure_sigmoid(
|
|
56 |
true_error_loss = tf.square(error) * tf.square(n1)
|
57 |
false_error_loss = tf.square(error) * tf.square(n2)
|
58 |
|
59 |
-
#
|
|
|
|
|
|
|
|
|
60 |
final_loss = tf.where(is_small_error, true_error_loss, false_error_loss)
|
61 |
-
|
62 |
-
|
63 |
-
) #
|
|
|
64 |
|
65 |
return final_loss
|
66 |
|
|
|
33 |
def soft_quantized_influence_measure_sigmoid(
|
34 |
y_true: tf.Tensor, y_pred: tf.Tensor, threshold: float = 0.5
|
35 |
) -> tf.Tensor:
|
36 |
+
# Ensure y_true and y_pred are of type float32
|
37 |
y_true = tf.cast(y_true, tf.float32)
|
38 |
y_pred = tf.cast(y_pred, tf.float32)
|
39 |
|
|
|
56 |
true_error_loss = tf.square(error) * tf.square(n1)
|
57 |
false_error_loss = tf.square(error) * tf.square(n2)
|
58 |
|
59 |
+
# Ensure the losses are of the same type as the condition in tf.where
|
60 |
+
true_error_loss = tf.cast(true_error_loss, tf.float32)
|
61 |
+
false_error_loss = tf.cast(false_error_loss, tf.float32)
|
62 |
+
|
63 |
+
# Apply weights to errors based on whether they are small or large
|
64 |
final_loss = tf.where(is_small_error, true_error_loss, false_error_loss)
|
65 |
+
|
66 |
+
# Normalize the final loss - ensure division is properly handled with float types
|
67 |
+
num_elements = tf.cast(tf.size(y_true), tf.float32) # Cast number of elements to float
|
68 |
+
final_loss = tf.reduce_mean(final_loss) / tf.square(y_std) ** 2 / num_elements
|
69 |
|
70 |
return final_loss
|
71 |
|