|
|
|
|
|
import numpy as np |
|
import hashlib |
|
import random |
|
import secrets |
|
|
|
|
|
def embed_watermark_LSB(model, watermark_data): |
|
""" |
|
Embeds a watermark into the provided model using Least Significant Bit (LSB) technique. |
|
|
|
Arguments: |
|
model : object |
|
The machine learning model object (e.g., TensorFlow/Keras model). |
|
watermark_data : str |
|
The watermark data to be embedded into the model. |
|
|
|
Returns: |
|
model : object |
|
The model with the embedded watermark. |
|
""" |
|
|
|
|
|
watermark_bytes = watermark_data.encode('utf-8') |
|
|
|
|
|
total_capacity = sum([np.prod(w.shape) for w in model.get_weights()]) |
|
required_capacity = len(watermark_bytes) * 8 |
|
if required_capacity > total_capacity: |
|
raise ValueError("Watermark size exceeds model capacity") |
|
|
|
|
|
flattened_weights = np.concatenate([w.flatten() for w in model.get_weights()]) |
|
|
|
|
|
watermark_bits = ''.join(format(byte, '08b') for byte in watermark_bytes) |
|
watermark_bits += '1' |
|
for i, bit in enumerate(watermark_bits): |
|
flattened_weights[i] = (flattened_weights[i] & ~1) | int(bit) |
|
|
|
|
|
updated_weights = np.split(flattened_weights, [np.prod(w.shape) for w in model.get_weights()]) |
|
model.set_weights([w.reshape(s) for w, s in zip(updated_weights, [w.shape for w in model.get_weights()])]) |
|
|
|
return model |
|
|
|
|
|
def detect_watermark_LSB(model): |
|
""" |
|
Detects and extracts the watermark from the provided model using Least Significant Bit (LSB) technique. |
|
|
|
Arguments: |
|
model : object |
|
The machine learning model object (e.g., TensorFlow/Keras model). |
|
|
|
Returns: |
|
detected_watermark : str or None |
|
Extracted watermark if detected, else None. |
|
""" |
|
|
|
|
|
flattened_weights = np.concatenate([w.flatten() for w in model.get_weights()]) |
|
|
|
|
|
watermark_bits = '' |
|
stop_bit = '1' |
|
for bit in flattened_weights: |
|
bit = int(bit) & 1 |
|
watermark_bits += str(bit) |
|
if watermark_bits.endswith(stop_bit): |
|
break |
|
|
|
|
|
watermark_bytes = [int(watermark_bits[i:i+8], 2) for i in range(0, len(watermark_bits), 8)] |
|
detected_watermark = bytearray(watermark_bytes).decode('utf-8') |
|
|
|
return detected_watermark |