eagle0504 commited on
Commit
dffa879
·
1 Parent(s): 72f8238

cast to float

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -33,7 +33,7 @@ def create_model(input_shape):
33
  def soft_quantized_influence_measure_sigmoid(
34
  y_true: tf.Tensor, y_pred: tf.Tensor, threshold: float = 0.5
35
  ) -> tf.Tensor:
36
- # Cast y_true and y_pred to float32 to ensure compatibility with tf operations
37
  y_true = tf.cast(y_true, tf.float32)
38
  y_pred = tf.cast(y_pred, tf.float32)
39
 
@@ -56,11 +56,16 @@ def soft_quantized_influence_measure_sigmoid(
56
  true_error_loss = tf.square(error) * tf.square(n1)
57
  false_error_loss = tf.square(error) * tf.square(n2)
58
 
59
- # Apply weights to errors based on whether they are small or large, and normalize the final loss
 
 
 
 
60
  final_loss = tf.where(is_small_error, true_error_loss, false_error_loss)
61
- final_loss = (
62
- tf.reduce_mean(final_loss) / tf.square(y_std) ** 2 / len(y_true)
63
- ) # Normalize by the square of the standard deviation of true values
 
64
 
65
  return final_loss
66
 
 
33
  def soft_quantized_influence_measure_sigmoid(
34
  y_true: tf.Tensor, y_pred: tf.Tensor, threshold: float = 0.5
35
  ) -> tf.Tensor:
36
+ # Ensure y_true and y_pred are of type float32
37
  y_true = tf.cast(y_true, tf.float32)
38
  y_pred = tf.cast(y_pred, tf.float32)
39
 
 
56
  true_error_loss = tf.square(error) * tf.square(n1)
57
  false_error_loss = tf.square(error) * tf.square(n2)
58
 
59
+ # Ensure the losses are of the same type as the condition in tf.where
60
+ true_error_loss = tf.cast(true_error_loss, tf.float32)
61
+ false_error_loss = tf.cast(false_error_loss, tf.float32)
62
+
63
+ # Apply weights to errors based on whether they are small or large
64
  final_loss = tf.where(is_small_error, true_error_loss, false_error_loss)
65
+
66
+ # Normalize the final loss - ensure division is properly handled with float types
67
+ num_elements = tf.cast(tf.size(y_true), tf.float32) # Cast number of elements to float
68
+ final_loss = tf.reduce_mean(final_loss) / tf.square(y_std) ** 2 / num_elements
69
 
70
  return final_loss
71