|
|
|
|
|
import streamlit as st |
|
import hashlib |
|
import uuid |
|
from streamlit_card import card |
|
import streamlit.components.v1 as components |
|
import time |
|
import json |
|
|
|
def generate_mock_hash(): |
|
return hashlib.sha256(str(time.time()).encode()).hexdigest() |
|
|
|
|
|
from utils import ( |
|
CLIENT_DIR, |
|
CURRENT_DIR, |
|
DEPLOYMENT_DIR, |
|
KEYS_DIR, |
|
INPUT_BROWSER_LIMIT, |
|
clean_directory, |
|
SERVER_DIR, |
|
) |
|
|
|
from concrete.ml.deployment import FHEModelClient |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
st.sidebar.title("Contact") |
|
st.sidebar.info( |
|
""" |
|
- Reda Bellafqira |
|
- Mehdi Ben Ghali |
|
- Pierre-Elisée Flory |
|
- Mohammed Lansari |
|
- Thomas Winninger |
|
""" |
|
) |
|
|
|
st.title("Secure Watermarking Service") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def todo(): |
|
st.warning("Not implemented yet", icon="⚠️") |
|
|
|
|
|
def key_gen_fn(client_id): |
|
""" |
|
Generate keys for a given user. The keys are saved in KEYS_DIR |
|
|
|
!!! needs a model in DEPLOYMENT_DIR as "client.zip" !!! |
|
Args: |
|
client_id (str): The client_id, retrieved from streamlit |
|
""" |
|
clean_directory() |
|
|
|
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{client_id}") |
|
client.load() |
|
|
|
|
|
client.generate_private_and_evaluation_keys() |
|
|
|
|
|
serialized_evaluation_keys = client.get_serialized_evaluation_keys() |
|
assert isinstance(serialized_evaluation_keys, bytes) |
|
|
|
|
|
evaluation_key_path = KEYS_DIR / f"{client_id}/evaluation_key" |
|
with evaluation_key_path.open("wb") as f: |
|
f.write(serialized_evaluation_keys) |
|
|
|
|
|
serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[ |
|
:INPUT_BROWSER_LIMIT |
|
] |
|
|
|
|
|
with st.expander("Generated keys"): |
|
st.write(f"{len(serialized_evaluation_keys) / (10**6):.2f} MB") |
|
st.code(serialized_evaluation_keys_shorten_hex) |
|
|
|
st.success("Keys have been generated!", icon="✅") |
|
|
|
|
|
def gen_trigger_set(client_id, hf_id): |
|
|
|
|
|
watermark_uuid = uuid.uuid1() |
|
hash = hashlib.sha256() |
|
hash.update(client_id + str(watermark_uuid)) |
|
client_seed = hash.digest() |
|
hash = hashlib.sha256() |
|
hash.update(hf_id + str(watermark_uuid)) |
|
hf_seed = hash.digest() |
|
|
|
trigger_set_size = 128 |
|
|
|
trigger_set_client = [ |
|
{"input": 1, "label": digit} for digit in encode_id(client_id, trigger_set_size) |
|
] |
|
|
|
todo() |
|
|
|
|
|
def encode_id(ascii_rep, size=128): |
|
"""Encode a string id to a string of bits |
|
|
|
Args: |
|
ascii_rep (_type_): The id string |
|
size (_type_): The size of the output bit string |
|
|
|
Returns: |
|
_type_: a string of bits |
|
""" |
|
return "".join([format(ord(x), "b").zfill(8) for x in client_id])[:size] |
|
|
|
|
|
def decode_id(binary_rep): |
|
"""Decode a string of bits to an ascii string |
|
|
|
Args: |
|
binary_rep (_type_): the binary string |
|
|
|
Returns: |
|
_type_: an ascii string |
|
""" |
|
|
|
|
|
binary_int = int(binary_rep, 2) |
|
|
|
byte_number = binary_int.bit_length() + 7 // 8 |
|
|
|
binary_array = binary_int.to_bytes(byte_number, "big") |
|
|
|
ascii_text = binary_array.decode() |
|
|
|
return ascii_text |
|
|
|
|
|
def compare_id(client_id, binary_triggert_set_result): |
|
"""Compares the string id with the labels of the trigger set on the tested API |
|
|
|
Args: |
|
client_id (_type_): the ascii string |
|
binary_triggert_set_result (_type_): the binary string |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
ground_truth = encode_id(client_id, 128) |
|
|
|
correct_bit = 0 |
|
for true_bit, real_bit in zip(ground_truth, binary_triggert_set_result): |
|
if true_bit != real_bit: |
|
correct_bit += 1 |
|
|
|
return correct_bit / len(binary_triggert_set_result) |
|
|
|
|
|
def watermark(model, trigger_set): |
|
"""Watermarking function |
|
|
|
Args: |
|
model (_type_): The model to watermark |
|
trigger_set (_type_): the trigger set |
|
""" |
|
todo() |
|
|
|
model_file_path = SERVER_DIR / "watermarked_model" |
|
trigger_set_file_path = SERVER_DIR / "trigger_set" |
|
|
|
|
|
model_file_path.touch() |
|
trigger_set_file_path.touch() |
|
|
|
|
|
with open(model_file_path, "rb") as model_file: |
|
st.download_button( |
|
label="Download the watermarked file", |
|
data=model_file, |
|
mime="application/octet-stream", |
|
) |
|
with open(trigger_set_file_path, "rb") as trigger_set_file: |
|
st.download_button( |
|
label="Download the triggert set", |
|
data=trigger_set_file, |
|
mime="application/octet-stream", |
|
) |
|
|
|
|
|
st.header("Client Configuration", divider=True) |
|
|
|
client_id = st.text_input("Identification string", "team-8-uuid") |
|
|
|
if st.button("Generate keys"): |
|
key_gen_fn(client_id) |
|
|
|
st.header("Model Watermarking", divider=True) |
|
|
|
encrypted_model = st.file_uploader("Upload your encrypted model") |
|
|
|
if st.button("Start Watermarking"): |
|
watermark(None, None) |
|
|
|
st.header("Watermarking Verification", divider=True) |
|
|
|
|
|
st.header("Update Blockchain", divider=True) |
|
|
|
|
|
if 'block_data' not in st.session_state: |
|
st.session_state.block_data = None |
|
|
|
|
|
if st.button("Update Blockchain"): |
|
previous_hash = generate_mock_hash() |
|
timestamp = int(time.time() * 1000) |
|
watermarked_model_hash = generate_mock_hash() |
|
trigger_set_hash = generate_mock_hash() |
|
|
|
|
|
st.session_state.block_data = { |
|
"blockNumber": 42, |
|
"previousHash": previous_hash, |
|
"timestamp": timestamp, |
|
"transactions": [ |
|
{ |
|
"type": "Watermarked Model Hash", |
|
"hash": watermarked_model_hash |
|
}, |
|
{ |
|
"type": "Trigger Set Hash", |
|
"hash": trigger_set_hash |
|
} |
|
] |
|
} |
|
|
|
st.success("Blockchain updated successfully!") |
|
|
|
|
|
if st.session_state.block_data: |
|
st.subheader("Latest Block Data (JSON)") |
|
|
|
|
|
block_json = json.dumps(st.session_state.block_data, indent=2) |
|
|
|
|
|
st.code(block_json, language='json') |
|
|
|
|