rynmurdock commited on
Commit
b670aa6
1 Parent(s): 7836f28
Files changed (1) hide show
  1. safety_checker_improved.py +4 -3
safety_checker_improved.py CHANGED
@@ -11,9 +11,10 @@ import tensorflow as tf
11
  from tensorflow.keras import mixed_precision
12
  physical_devices = tf.config.list_physical_devices('GPU')
13
 
14
- tf.config.experimental.set_memory_growth(
15
- physical_devices[0], True
16
- )
 
17
 
18
  model = tf.keras.models.load_model('nsfweffnetv2-b02-3epochs.h5',custom_objects={"KerasLayer":hub.KerasLayer})
19
  # "The image classifier had been trained on 682550 images from the 5 classes "Drawing" (39026), "Hentai" (28134), "Neutral" (369507), "Porn" (207969) & "Sexy" (37914).
 
11
  from tensorflow.keras import mixed_precision
12
  physical_devices = tf.config.list_physical_devices('GPU')
13
 
14
+ if len(physical_devices) > 0:
15
+ tf.config.experimental.set_memory_growth(
16
+ physical_devices[0], True
17
+ )
18
 
19
  model = tf.keras.models.load_model('nsfweffnetv2-b02-3epochs.h5',custom_objects={"KerasLayer":hub.KerasLayer})
20
  # "The image classifier had been trained on 682550 images from the 5 classes "Drawing" (39026), "Hentai" (28134), "Neutral" (369507), "Porn" (207969) & "Sexy" (37914).