|
import gradio as gr |
|
from PIL import Image, ImageDraw, ImageFont, ImageOps |
|
import base64 |
|
import io |
|
import json |
|
import logging |
|
import os |
|
import requests |
|
import struct |
|
import tempfile |
|
import numpy as np |
|
from cryptography.hazmat.primitives import serialization |
|
from cryptography.hazmat.primitives.asymmetric import rsa, padding |
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM |
|
from cryptography.hazmat.primitives import hashes |
|
from cryptography.exceptions import InvalidTag |
|
from gradio_client import Client |
|
from huggingface_hub import InferenceClient |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
ENDPOINTS_FILE = "endpoints.json" |
|
|
|
HEADER_BITS = 32 |
|
AES_GCM_NONCE_SIZE = 12 |
|
KEYLOCK_PRIV_KEY_PEM = os.environ.get('KEYLOCK_PRIV_KEY') |
|
PRIVATE_KEY_OBJECT = None |
|
PUBLIC_KEY_PEM_STRING = "" |
|
KEYLOCK_STATUS_MESSAGE = "" |
|
MOCK_USER_DATABASE = {"sk-12345-abcde": {"user": "demo-user", "permissions": "read"}, "sk-67890-fghij": {"user": "admin-user", "permissions": "read,write,delete"}} |
|
|
|
if not KEYLOCK_PRIV_KEY_PEM: |
|
logger.warning("No KEYLOCK_PRIV_KEY secret found. Generating a temporary key pair for this session.") |
|
temp_priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) |
|
KEYLOCK_PRIV_KEY_PEM = temp_priv_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption()).decode('utf-8') |
|
KEYLOCK_STATUS_MESSAGE = "β οΈ No secret found. Using a temporary key for this session. Keys will be lost on restart." |
|
else: |
|
logger.info("Successfully loaded private key from environment variable 'KEYLOCK_PRIV_KEY'.") |
|
KEYLOCK_STATUS_MESSAGE = "β
Loaded successfully from secrets/environment variable." |
|
|
|
try: |
|
PRIVATE_KEY_OBJECT = serialization.load_pem_private_key(KEYLOCK_PRIV_KEY_PEM.encode(), password=None) |
|
PUBLIC_KEY_PEM_STRING = PRIVATE_KEY_OBJECT.public_key().public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo).decode('utf-8') |
|
KEYLOCK_STATUS_MESSAGE += "\nβ
Public key derived successfully." |
|
except Exception as e: |
|
PRIVATE_KEY_OBJECT = None |
|
PUBLIC_KEY_PEM_STRING = "Error: Key could not be processed." |
|
KEYLOCK_STATUS_MESSAGE += f"\nβ Failed to parse key: {e}" |
|
|
|
def _parse_secret_data(secret_data_str: str) -> dict: |
|
stripped_input = secret_data_str.strip() |
|
try: |
|
data_dict = json.loads(stripped_input) |
|
if isinstance(data_dict, dict): return data_dict |
|
except json.JSONDecodeError: |
|
pass |
|
data_dict = {} |
|
for line in stripped_input.splitlines(): |
|
line = line.strip() |
|
if not line or line.startswith('#'): continue |
|
separator = ':' if ':' in line else '=' |
|
if separator not in line: continue |
|
parts = line.split(separator, 1) |
|
if len(parts) == 2: |
|
key = parts[0].strip().strip("'\"") |
|
value = parts[1].strip().strip(",").strip().strip("'\"") |
|
if key: data_dict[key] = value |
|
return data_dict |
|
|
|
def prepare_base_image(uploaded_image: Image.Image | None, progress) -> Image.Image: |
|
size = 600 |
|
if uploaded_image: |
|
progress(0, desc="β
Using uploaded image...") |
|
return ImageOps.fit(uploaded_image, (size, size), Image.Resampling.LANCZOS) |
|
try: |
|
progress(0, desc="β³ Fetching default background...") |
|
response = requests.get("https://images.unsplash.com/photo-1506318137071-a8e063b4bec0?q=80&w=1200&auto=format=fit=crop", timeout=10) |
|
response.raise_for_status() |
|
img = Image.open(io.BytesIO(response.content)).convert("RGB") |
|
return ImageOps.fit(img, (size, size), Image.Resampling.LANCZOS) |
|
except Exception as e: |
|
logger.warning(f"Default image fetch failed: {e}. Falling back to AI.") |
|
try: |
|
progress(0, desc="β³ Generating image with SDXL-Lightning...") |
|
client = InferenceClient() |
|
image_bytes = client.text_to_image("A stunning view of a distant galaxy, nebulae, and constellations, digital art, vibrant colors", model="sd-community/sdxl-lightning") |
|
return ImageOps.fit(Image.open(io.BytesIO(image_bytes)).convert("RGB"), (size, size), Image.Resampling.LANCZOS) |
|
except Exception as e: |
|
raise gr.Error(f"All image sources failed. AI error: {e}") |
|
|
|
def create_encrypted_image(payload_dict: dict, public_key_pem: str, base_image: Image.Image, overlay_option: str, server_url: str) -> Image.Image: |
|
json_bytes = json.dumps(payload_dict).encode('utf-8') |
|
public_key = serialization.load_pem_public_key(public_key_pem.encode('utf-8')) |
|
aes_key, nonce = os.urandom(32), os.urandom(12) |
|
ciphertext = AESGCM(aes_key).encrypt(nonce, json_bytes, None) |
|
rsa_encrypted_key = public_key.encrypt(aes_key, padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None)) |
|
encrypted_payload = struct.pack('>I', len(rsa_encrypted_key)) + rsa_encrypted_key + nonce + ciphertext |
|
img = base_image.copy().convert("RGB") |
|
width, height = img.size |
|
draw = ImageDraw.Draw(img, "RGBA") |
|
try: |
|
font_bold = ImageFont.truetype("DejaVuSans-Bold.ttf", 30); font_regular = ImageFont.truetype("DejaVuSans.ttf", 15); font_small = ImageFont.truetype("DejaVuSans.ttf", 12) |
|
except IOError: |
|
font_bold = ImageFont.load_default(size=28); font_regular = ImageFont.load_default(size=14); font_small = ImageFont.load_default(size=12) |
|
overlay_color, title_color, key_color, value_color = (15, 23, 42, 190), (226, 232, 240), (148, 163, 184), (241, 245, 249) |
|
draw.rectangle([0, 20, width, 100], fill=overlay_color) |
|
draw.text((width / 2, 45), "KeyLock Secure Data", fill=title_color, font=font_bold, anchor="ms") |
|
draw.text((width / 2, 75), server_url.replace("https://", ""), fill=key_color, font=font_small, anchor="ms") |
|
if overlay_option != "None": |
|
lines = list(payload_dict.keys()) if overlay_option == "Keys Only" else [f"{k}: {v}" for k, v in payload_dict.items()] |
|
line_heights = [draw.textbbox((0, 0), line, font=font_regular)[3] for line in lines] |
|
box_y0 = height - (sum(line_heights) + (len(lines) - 1) * 6 + 30) - 20 |
|
draw.rectangle([20, box_y0, width - 20, height - 20], fill=overlay_color) |
|
current_y = box_y0 + 15 |
|
for i, (key, value) in enumerate(payload_dict.items()): |
|
if overlay_option == "Keys Only": |
|
draw.text((35, current_y), key, fill=key_color, font=font_regular) |
|
else: |
|
key_text = f"{key}:"; draw.text((35, current_y), key_text, fill=key_color, font=font_regular) |
|
key_bbox = draw.textbbox((35, current_y), key_text, font=font_regular) |
|
draw.text((key_bbox[2] + 8, current_y), str(value), fill=value_color, font=font_regular) |
|
current_y += line_heights[i] + 6 |
|
pixel_data = np.array(img).ravel() |
|
binary_payload = ''.join(format(b, '08b') for b in struct.pack('>I', len(encrypted_payload)) + encrypted_payload) |
|
if len(binary_payload) > pixel_data.size: raise ValueError("Payload is too large for the image.") |
|
pixel_data[:len(binary_payload)] = (pixel_data[:len(binary_payload)] & 0xFE) | np.array(list(binary_payload), dtype=np.uint8) |
|
return Image.fromarray(pixel_data.reshape((height, width, 3)), 'RGB') |
|
|
|
def api_get_info(): |
|
return {"name": "Embedded KeyLock Server", "version": "2.1", "documentation": "This server can generate and authenticate KeyLock images.", "required_payload_keys": [{"key_name": "API_KEY", "description": "Your unique API Key.", "example": "sk-12345-abcde"}, {"key_name": "USER", "description": "The user ID for the key.", "example": "demo-user"}]} |
|
|
|
def api_get_public_key(): |
|
return PUBLIC_KEY_PEM_STRING |
|
|
|
def api_decode_and_auth(image_base64_string: str) -> dict: |
|
if not PRIVATE_KEY_OBJECT: raise gr.Error("Server is not configured with a private key.") |
|
try: |
|
pixel_data = np.array(Image.open(io.BytesIO(base64.b64decode(image_base64_string))).convert("RGB")).ravel() |
|
header_binary_string = "".join(str(p & 1) for p in pixel_data[:HEADER_BITS]) |
|
data_length = int(header_binary_string, 2) |
|
data_binary_string = "".join(str(p & 1) for p in pixel_data[HEADER_BITS:HEADER_BITS + data_length * 8]) |
|
crypto_payload = int(data_binary_string, 2).to_bytes(data_length, byteorder='big') |
|
offset = 4; encrypted_aes_key_len = struct.unpack('>I', crypto_payload[:offset])[0] |
|
encrypted_aes_key = crypto_payload[offset:offset + encrypted_aes_key_len]; offset += encrypted_aes_key_len |
|
nonce = crypto_payload[offset:offset + AES_GCM_NONCE_SIZE]; offset += AES_GCM_NONCE_SIZE |
|
ciphertext_with_tag = crypto_payload[offset:] |
|
recovered_aes_key = PRIVATE_KEY_OBJECT.decrypt(encrypted_aes_key, padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None)) |
|
decrypted_payload = json.loads(AESGCM(recovered_aes_key).decrypt(nonce, ciphertext_with_tag, None).decode('utf-8')) |
|
db_entry = MOCK_USER_DATABASE.get(decrypted_payload.get('API_KEY')) |
|
if db_entry and db_entry.get("user") == decrypted_payload.get('USER'): |
|
return {"status": "Success", "message": f"User '{decrypted_payload.get('USER')}' authenticated.", "details": decrypted_payload} |
|
else: |
|
return {"status": "Failed", "message": "Invalid credentials.", "details": decrypted_payload} |
|
except Exception as e: |
|
return {"status": "Error", "message": f"Decryption/Processing Failed: {e}", "details": {}} |
|
|
|
def load_endpoints(): |
|
try: |
|
if os.path.exists(ENDPOINTS_FILE): |
|
with open(ENDPOINTS_FILE, 'r') as f: |
|
return json.load(f) |
|
except (FileNotFoundError, json.JSONDecodeError): |
|
pass |
|
return [] |
|
def save_endpoints(endpoints_list): |
|
with open(ENDPOINTS_FILE, 'w') as f: |
|
json.dump(endpoints_list, f, indent=2) |
|
|
|
theme = gr.themes.Soft(primary_hue="sky", secondary_hue="blue", neutral_hue="slate") |
|
with gr.Blocks(theme=theme, title="KeyLock Showcase") as demo: |
|
all_servers_state = gr.State(value=load_endpoints()) |
|
active_server_state = gr.State({}) |
|
|
|
gr.Markdown("# π KeyLock Showcase") |
|
gr.Markdown("A comprehensive toolkit for generating and testing KeyLock authentication images against live servers.") |
|
gr.Markdown("This servers address: https://agents-mcp-hackathon-keylock-auth-system.hf.space") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Client Operations"): |
|
gr.Markdown("### 1. Connect to a Target Server") |
|
with gr.Row(): |
|
saved_servers_dropdown = gr.Dropdown(label="Load Saved Server", interactive=True) |
|
with gr.Column(): |
|
server_url_input = gr.Textbox(label="Or Add New Server by URL", placeholder="https://your-server.hf.space") |
|
connect_button = gr.Button("Connect New Server", variant="primary") |
|
client_status_display = gr.Markdown("**Status:** Not Connected") |
|
|
|
with gr.Accordion("2. Create an Encrypted Image for the Connected Server", open=False) as client_generate_accordion: |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
client_payload_input = gr.Textbox(label="Secret Data (key:value or JSON)", lines=5) |
|
client_overlay_radio = gr.Radio(label="Show Labels on Image", choices=["Keys and Values", "Keys Only", "None"], value="Keys and Values") |
|
client_base_image_input = gr.Image(label="Optional Base Image", type="pil", height=200) |
|
client_generate_button = gr.Button("Create Image", variant="secondary") |
|
with gr.Column(scale=3): |
|
client_generated_image_preview = gr.Image(label="Generated Image Preview", interactive=False) |
|
client_generated_file_output = gr.File(label="Download Uncorrupted PNG", interactive=False, file_count="single") |
|
|
|
with gr.Accordion("3. Test an Existing Image with the Connected Server", open=False) as client_test_accordion: |
|
client_test_image_input = gr.Image(type="filepath", label="Upload Encrypted Image") |
|
client_auth_status_display = gr.Markdown(visible=False) |
|
client_auth_result_output = gr.JSON(label="Server Authentication Response") |
|
client_test_image_input.change(lambda: (gr.update(visible=False), None), outputs=[client_auth_status_display, client_auth_result_output]) |
|
|
|
with gr.TabItem("Server Showcase & Admin"): |
|
gr.Markdown("## Embedded Server Details") |
|
gr.Markdown("This Gradio app is also running its own KeyLock server. You can use its details to test the client.") |
|
gr.Textbox(label="Embedded Server Status", value=KEYLOCK_STATUS_MESSAGE, interactive=False, lines=3) |
|
gr.Code(label="Embedded Server Public Key", value=PUBLIC_KEY_PEM_STRING, language="python") |
|
gr.JSON(label="Embedded Server Required Payload", value={k['key_name']: k['example'] for k in api_get_info()["required_payload_keys"]}) |
|
|
|
with gr.Accordion("Generate Image with Embedded Server", open=False): |
|
server_payload_input = gr.JSON(label="Payload to Encrypt", value={k['key_name']: k['example'] for k in api_get_info()["required_payload_keys"]}) |
|
server_generate_button = gr.Button("Generate Image", variant="secondary") |
|
server_generated_file_output = gr.File(label="Download Uncorrupted PNG", interactive=False, file_count="single") |
|
|
|
with gr.Accordion("Admin: Generate New Key Pair", open=False): |
|
gen_keys_button = gr.Button("βοΈ Generate New 2048-bit Key Pair") |
|
with gr.Row(): |
|
output_private_key = gr.Textbox(lines=8, label="Generated Private Key", interactive=False, show_copy_button=True) |
|
output_public_key = gr.Textbox(lines=8, label="Generated Public Key", interactive=False, show_copy_button=True) |
|
|
|
def initialize_ui(all_servers_list): |
|
return gr.update(choices=[s['name'] for s in all_servers_list] if all_servers_list else []) |
|
|
|
def process_server_connection(server_data, all_servers): |
|
placeholder = "\n".join([f"{k['key_name']}: {k['example']}" for k in server_data['info'].get('required_payload_keys', [])]) |
|
status_md = f"**Status:** β
Connected to **{server_data['name']}**" |
|
return {active_server_state: server_data, client_status_display: status_md, client_payload_input: gr.update(placeholder=placeholder), all_servers_state: all_servers} |
|
|
|
def load_server_from_dropdown(server_name, all_servers): |
|
server_from_list = next((s for s in all_servers if s['name'] == server_name), None) |
|
if server_from_list: |
|
active_server_data = {'name': server_from_list['name'], 'url': server_from_list['link'], 'pubkey': server_from_list['public_key'], 'info': server_from_list.get('info', {})} |
|
return process_server_connection(active_server_data, all_servers) |
|
return {} |
|
|
|
def add_new_server_from_url(url, all_servers): |
|
if not url: raise gr.Error("Please provide a server URL.") |
|
url = url.strip().rstrip('/') |
|
|
|
existing_server = next((s for s in all_servers if s['link'] == url), None) |
|
if existing_server: |
|
gr.Info(f"Server already exists. Loading '{existing_server['name']}'.") |
|
updates = load_server_from_dropdown(existing_server['name'], all_servers) |
|
updates[saved_servers_dropdown] = gr.update(value=existing_server['name']) |
|
return updates |
|
|
|
try: |
|
client = Client(url, verbose=False) |
|
info = client.predict(api_name="/keylock-info") |
|
pubkey = client.predict(api_name="/keylock-pub") |
|
server_for_list = {'name': info.get('name', url), 'link': url, 'public_key': pubkey, 'info': info} |
|
all_servers.append(server_for_list) |
|
save_endpoints(all_servers) |
|
gr.Info(f"Successfully added and saved '{server_for_list['name']}'!") |
|
server_for_state = {'name': info.get('name', url), 'url': url, 'pubkey': pubkey, 'info': info} |
|
updates = process_server_connection(server_for_state, all_servers) |
|
updates[saved_servers_dropdown] = gr.update(choices=[s['name'] for s in all_servers], value=server_for_state['name']) |
|
return updates |
|
except Exception as e: |
|
gr.Error(f"Connection Failed: {e}") |
|
return { |
|
active_server_state: gr.update(), |
|
client_status_display: gr.update(), |
|
client_payload_input: gr.update(), |
|
all_servers_state: gr.update(), |
|
saved_servers_dropdown: gr.update() |
|
} |
|
|
|
def client_generate_image_wrapper(active_server, payload_str, overlay, base_img, progress=gr.Progress(track_tqdm=True)): |
|
if not active_server: raise gr.Error("Not connected to a server.") |
|
payload_dict = _parse_secret_data(payload_str) |
|
if not payload_dict: raise gr.Error("Invalid payload format. Please provide key:value pairs or a valid JSON object.") |
|
base_image = prepare_base_image(base_img, progress) |
|
img = create_encrypted_image(payload_dict, active_server['pubkey'], base_image, overlay, active_server['url']) |
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: |
|
img.save(f.name, "PNG", compress_level=1) |
|
return f.name, f.name |
|
|
|
def client_authenticate_wrapper(active_server, image_path): |
|
if not active_server: raise gr.Error("Not connected to a server.") |
|
if not image_path: raise gr.Error("Please upload an image.") |
|
try: |
|
with open(image_path, "rb") as f: b64_img = base64.b64encode(f.read()).decode('utf-8') |
|
client = Client(active_server['url']) |
|
response = client.predict(b64_img, api_name="/keylock-auth") |
|
status_md = "### β
Authentication Successful" if response.get("status") == "Success" else "### β Authentication Failed" if response.get("status") == "Failed" else f"### β οΈ Server Error: {response.get('message')}" |
|
return gr.update(value=status_md, visible=True), response |
|
except Exception as e: |
|
gr.Error(f"Authentication request failed: {e}") |
|
return gr.update(value=f"### β οΈ Request Error: {e}", visible=True), None |
|
|
|
def server_generate_image_wrapper(payload): |
|
img = create_encrypted_image(payload, PUBLIC_KEY_PEM_STRING, prepare_base_image(None, gr.Progress(track_tqdm=True)), "Keys and Values", "Embedded Server") |
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: |
|
img.save(f.name, "PNG", compress_level=1) |
|
return f.name |
|
|
|
def generate_pem_keys(): |
|
pk = rsa.generate_private_key(public_exponent=65537, key_size=2048) |
|
priv = pk.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption()).decode() |
|
pub = pk.public_key().public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo).decode() |
|
return priv, pub |
|
|
|
demo.load(initialize_ui, inputs=all_servers_state, outputs=saved_servers_dropdown) |
|
saved_servers_dropdown.change(load_server_from_dropdown, inputs=[saved_servers_dropdown, all_servers_state], outputs=[active_server_state, client_status_display, client_payload_input, all_servers_state]) |
|
connect_button.click(add_new_server_from_url, inputs=[server_url_input, all_servers_state], outputs=[active_server_state, client_status_display, client_payload_input, all_servers_state, saved_servers_dropdown]) |
|
client_generate_button.click(client_generate_image_wrapper, inputs=[active_server_state, client_payload_input, client_overlay_radio, client_base_image_input], outputs=[client_generated_image_preview, client_generated_file_output]) |
|
client_test_image_input.upload(client_authenticate_wrapper, inputs=[active_server_state, client_test_image_input], outputs=[client_auth_status_display, client_auth_result_output]) |
|
server_generate_button.click(server_generate_image_wrapper, inputs=[server_payload_input], outputs=[server_generated_file_output]) |
|
gen_keys_button.click(generate_pem_keys, outputs=[output_private_key, output_public_key]) |
|
|
|
with gr.Row(visible=False): |
|
gr.Interface(fn=api_get_info, inputs=None, outputs=gr.JSON(), api_name="keylock-info") |
|
gr.Interface(fn=api_get_public_key, inputs=None, outputs=gr.Textbox(), api_name="keylock-pub") |
|
gr.Interface(fn=api_decode_and_auth, inputs=gr.Textbox(), outputs=gr.JSON(), api_name="keylock-auth") |
|
|
|
if __name__ == "__main__": |
|
demo.launch(mcp_server=True) |