Zamark / app.py
Sckathach's picture
mb
d4d76e3
raw
history blame
6.9 kB
import streamlit as st
import hashlib
import uuid
from streamlit_card import card
import streamlit.components.v1 as components
import time
import json
def generate_mock_hash():
return hashlib.sha256(str(time.time()).encode()).hexdigest()
from utils import (
CLIENT_DIR,
CURRENT_DIR,
DEPLOYMENT_DIR,
KEYS_DIR,
INPUT_BROWSER_LIMIT,
clean_directory,
SERVER_DIR,
)
from concrete.ml.deployment import FHEModelClient
st.set_page_config(layout="wide")
st.sidebar.title("Contact")
st.sidebar.info(
"""
- Reda Bellafqira
- Mehdi Ben Ghali
- Pierre-Elisée Flory
- Mohammed Lansari
- Thomas Winninger
"""
)
st.title("Secure Watermarking Service")
# st.image(
# "llm_watermarking.png",
# caption="A Watermark for Large Language Models (https://doi.org/10.48550/arXiv.2301.10226)",
# )
def todo():
st.warning("Not implemented yet", icon="⚠️")
def key_gen_fn(client_id):
"""
Generate keys for a given user. The keys are saved in KEYS_DIR
!!! needs a model in DEPLOYMENT_DIR as "client.zip" !!!
Args:
client_id (str): The client_id, retrieved from streamlit
"""
clean_directory()
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{client_id}")
client.load()
# Creates the private and evaluation keys on the client side
client.generate_private_and_evaluation_keys()
# Get the serialized evaluation keys
serialized_evaluation_keys = client.get_serialized_evaluation_keys()
assert isinstance(serialized_evaluation_keys, bytes)
# Save the evaluation key
evaluation_key_path = KEYS_DIR / f"{client_id}/evaluation_key"
with evaluation_key_path.open("wb") as f:
f.write(serialized_evaluation_keys)
# show bit of key
serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[
:INPUT_BROWSER_LIMIT
]
# shpw len of key
# f"{len(serialized_evaluation_keys) / (10**6):.2f} MB"
with st.expander("Generated keys"):
st.write(f"{len(serialized_evaluation_keys) / (10**6):.2f} MB")
st.code(serialized_evaluation_keys_shorten_hex)
st.success("Keys have been generated!", icon="✅")
def gen_trigger_set(client_id, hf_id):
# input : random images seeded by client_id
# labels : binary array of the id
watermark_uuid = uuid.uuid1()
hash = hashlib.sha256()
hash.update(client_id + str(watermark_uuid))
client_seed = hash.digest()
hash = hashlib.sha256()
hash.update(hf_id + str(watermark_uuid))
hf_seed = hash.digest()
trigger_set_size = 128
trigger_set_client = [
{"input": 1, "label": digit} for digit in encode_id(client_id, trigger_set_size)
]
todo()
def encode_id(ascii_rep, size=128):
"""Encode a string id to a string of bits
Args:
ascii_rep (_type_): The id string
size (_type_): The size of the output bit string
Returns:
_type_: a string of bits
"""
return "".join([format(ord(x), "b").zfill(8) for x in client_id])[:size]
def decode_id(binary_rep):
"""Decode a string of bits to an ascii string
Args:
binary_rep (_type_): the binary string
Returns:
_type_: an ascii string
"""
# Initializing a binary string in the form of
# 0 and 1, with base of 2
binary_int = int(binary_rep, 2)
# Getting the byte number
byte_number = binary_int.bit_length() + 7 // 8
# Getting an array of bytes
binary_array = binary_int.to_bytes(byte_number, "big")
# Converting the array into ASCII text
ascii_text = binary_array.decode()
# Getting the ASCII value
return ascii_text
def compare_id(client_id, binary_triggert_set_result):
"""Compares the string id with the labels of the trigger set on the tested API
Args:
client_id (_type_): the ascii string
binary_triggert_set_result (_type_): the binary string
Returns:
_type_: _description_
"""
ground_truth = encode_id(client_id, 128)
correct_bit = 0
for true_bit, real_bit in zip(ground_truth, binary_triggert_set_result):
if true_bit != real_bit:
correct_bit += 1
return correct_bit / len(binary_triggert_set_result)
def watermark(model, trigger_set):
"""Watermarking function
Args:
model (_type_): The model to watermark
trigger_set (_type_): the trigger set
"""
todo()
model_file_path = SERVER_DIR / "watermarked_model"
trigger_set_file_path = SERVER_DIR / "trigger_set"
# TODO: remove once model correctly watermarked
model_file_path.touch()
trigger_set_file_path.touch()
# Once the model is watermarked and dumped to files (model + trigger set), the user can download them
with open(model_file_path, "rb") as model_file:
st.download_button(
label="Download the watermarked file",
data=model_file,
mime="application/octet-stream",
)
with open(trigger_set_file_path, "rb") as trigger_set_file:
st.download_button(
label="Download the triggert set",
data=trigger_set_file,
mime="application/octet-stream",
)
st.header("Client Configuration", divider=True)
client_id = st.text_input("Identification string", "team-8-uuid")
if st.button("Generate keys"):
key_gen_fn(client_id)
st.header("Model Watermarking", divider=True)
encrypted_model = st.file_uploader("Upload your encrypted model")
if st.button("Start Watermarking"):
watermark(None, None)
st.header("Watermarking Verification", divider=True)
st.header("Update Blockchain", divider=True)
# Initialize session state to store the block data
if 'block_data' not in st.session_state:
st.session_state.block_data = None
# Button to update the blockchain
if st.button("Update Blockchain"):
previous_hash = generate_mock_hash()
timestamp = int(time.time() * 1000) # Current timestamp in milliseconds
watermarked_model_hash = generate_mock_hash()
trigger_set_hash = generate_mock_hash()
# Create the block data structure
st.session_state.block_data = {
"blockNumber": 42,
"previousHash": previous_hash,
"timestamp": timestamp,
"transactions": [
{
"type": "Watermarked Model Hash",
"hash": watermarked_model_hash
},
{
"type": "Trigger Set Hash",
"hash": trigger_set_hash
}
]
}
st.success("Blockchain updated successfully!")
# Display the JSON if block_data exists
if st.session_state.block_data:
st.subheader("Latest Block Data (JSON)")
# Convert the data to a formatted JSON string
block_json = json.dumps(st.session_state.block_data, indent=2)
# Display the JSON
st.code(block_json, language='json')