| | """
|
| | TensorFlow Decoder using stegastamp_pretrained model
|
| | This model contains BOTH encoder and decoder in one checkpoint
|
| | """
|
| | import tensorflow as tf
|
| | import numpy as np
|
| | from PIL import Image as PILImage
|
| |
|
| |
|
| | class TFStegaStampDecoderPretrained:
|
| | """
|
| | TensorFlow-based StegaStamp decoder using the pretrained model
|
| | This is the CORRECT decoder for extracting 100-bit secrets
|
| | """
|
| |
|
| | def __init__(self, model_path='saved_models/stegastamp_pretrained'):
|
| | """
|
| | Load the TensorFlow pretrained model (contains both encoder and decoder)
|
| |
|
| | Args:
|
| | model_path: Path to the pretrained saved model
|
| | """
|
| | print(f"Loading TensorFlow pretrained model from: {model_path}")
|
| |
|
| |
|
| | loaded = tf.saved_model.load(model_path)
|
| |
|
| |
|
| | self._infer_fn = loaded.signatures['serving_default']
|
| |
|
| |
|
| | input_signature = self._infer_fn.structured_input_signature[1]
|
| | output_signature = self._infer_fn.structured_outputs
|
| |
|
| | self._input_names = list(input_signature.keys())
|
| | self._output_names = list(output_signature.keys())
|
| |
|
| | print(f" Input tensors: {self._input_names}")
|
| | print(f" Output tensors: {self._output_names}")
|
| | print("TensorFlow decoder loaded successfully!")
|
| |
|
| | def decode(self, image):
|
| | """
|
| | Decode a 100-bit secret from an image
|
| |
|
| | Args:
|
| | image: numpy array of shape (H, W, 3) or (B, H, W, 3)
|
| | OR PIL Image
|
| | Values should be in [0, 1] range for numpy, [0, 255] for PIL
|
| |
|
| | Returns:
|
| | numpy array of shape (100,) containing decoded bits (0 or 1)
|
| | """
|
| |
|
| | if hasattr(image, 'mode'):
|
| | image = np.array(image).astype(np.float32) / 255.0
|
| |
|
| |
|
| | if image.ndim == 3:
|
| | image = np.expand_dims(image, axis=0)
|
| |
|
| |
|
| | if image.shape[-1] != 3:
|
| |
|
| | image = np.transpose(image, (0, 2, 3, 1))
|
| |
|
| |
|
| | if image.shape[1:3] != (400, 400):
|
| |
|
| | image_pil = PILImage.fromarray((image[0] * 255).astype(np.uint8))
|
| | image_pil = image_pil.resize((400, 400))
|
| | image = np.array(image_pil).astype(np.float32) / 255.0
|
| | image = np.expand_dims(image, axis=0)
|
| |
|
| |
|
| | image = image.astype(np.float32)
|
| |
|
| |
|
| |
|
| | image_tensor = tf.convert_to_tensor(image)
|
| |
|
| |
|
| | if 'image' in self._input_names:
|
| | result = self._infer_fn(image=image_tensor)
|
| | else:
|
| |
|
| | result = self._infer_fn(**{self._input_names[0]: image_tensor})
|
| |
|
| |
|
| | if 'decoded' in self._output_names:
|
| | decoded = result['decoded'].numpy()
|
| | else:
|
| |
|
| | decoded = list(result.values())[0].numpy()
|
| |
|
| |
|
| | secret = decoded[0]
|
| |
|
| |
|
| | bits = (secret > 0.5).astype(np.float32)
|
| |
|
| | return bits
|
| |
|
| | def __call__(self, image):
|
| | """Make the decoder callable"""
|
| | return self.decode(image)
|
| |
|
| | def close(self):
|
| | """Cleanup resources"""
|
| |
|
| | pass
|
| |
|
| | def __del__(self):
|
| | """Cleanup on deletion"""
|
| | try:
|
| | self.close()
|
| | except:
|
| | pass
|
| |
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | print("\n" + "="*80)
|
| | print("Testing TF Decoder with Pretrained Model")
|
| | print("="*80 + "\n")
|
| |
|
| | try:
|
| |
|
| | decoder = TFStegaStampDecoderPretrained()
|
| |
|
| |
|
| | test_image = np.ones((400, 400, 3), dtype=np.float32) * 0.5
|
| |
|
| | print("\nTesting decoding...")
|
| | decoded = decoder.decode(test_image)
|
| |
|
| | if decoded is not None:
|
| | print(f"✓ Decoded {len(decoded)} bits")
|
| | print(f" Sample values: {decoded[:20]}")
|
| | print(f" Mean: {decoded.mean():.3f}")
|
| | else:
|
| | print("❌ Decoding failed")
|
| |
|
| | decoder.close()
|
| | print("\n✓ Test complete!")
|
| |
|
| | except Exception as e:
|
| | print(f"\n❌ Error: {e}")
|
| | import traceback
|
| | traceback.print_exc() |