| |
| |
| import struct |
| import uuid |
|
|
| import numpy as np |
| from torch import nn |
| import ezkl |
| import os |
| import json |
| import torch |
| import base64 |
| from concrete.ml.deployment import FHEModelServer |
| from fastapi import FastAPI |
| from pydantic import BaseModel |
|
|
| from config import rpc_url, private_key |
|
|
| app = FastAPI() |
|
|
| evaluation_key = None |
|
|
|
|
| |
| class AIModel(nn.Module): |
| def __init__(self): |
| super(AIModel, self).__init__() |
|
|
| |
| self.fhe_model = FHEModelServer("../deployment/sentiment_fhe_model") |
|
|
| def forward(self, x): |
| |
|
|
| |
| x = x[0] |
| _encrypted_encoding = x.numpy().tobytes() |
| prediction = self.fhe_model.run(_encrypted_encoding, evaluation_key) |
| |
|
|
| byte_tensor = torch.tensor(list(prediction), dtype=torch.uint8) |
| |
|
|
| return byte_tensor |
|
|
|
|
| class ZKProofRequest(BaseModel): |
| encrypted_encoding: str |
| evaluation_key: str |
|
|
|
|
| circuit = AIModel() |
|
|
|
|
| @app.post("/get_zk_proof") |
| async def get_zk_proof(request: ZKProofRequest): |
| request.encrypted_encoding = base64.b64decode(request.encrypted_encoding) |
| request.evaluation_key = base64.b64decode(request.evaluation_key) |
|
|
| global evaluation_key |
| evaluation_key = request.evaluation_key |
|
|
| folder_path = f"zkml_encrypted/{str(uuid.uuid4())}" |
| if not os.path.exists(folder_path): |
| os.makedirs(folder_path) |
|
|
| model_path = os.path.join(f'{folder_path}/network.onnx') |
| compiled_model_path = os.path.join(f'{folder_path}/network.compiled') |
| pk_path = os.path.join(f'{folder_path}/test.pk') |
| vk_path = os.path.join(f'{folder_path}/test.vk') |
| settings_path = os.path.join(f'{folder_path}/settings.json') |
|
|
| witness_path = os.path.join(f'{folder_path}/witness.json') |
| input_data_path = os.path.join(f'{folder_path}/input.json') |
| srs_path = os.path.join(f'{folder_path}/kzg14.srs') |
| output_path = os.path.join(f'{folder_path}/output.json') |
|
|
| |
| x = torch.tensor(list([request.encrypted_encoding]), dtype=torch.uint8) |
|
|
| |
| circuit.eval() |
|
|
| |
| with torch.no_grad(): |
| output = circuit(x) |
| |
| output_data = output.detach().numpy().tolist() |
| with open(output_path, 'w') as f: |
| json.dump(output_data, f) |
|
|
| |
| torch.onnx.export(circuit, |
| x, |
| model_path, |
| export_params=True, |
| opset_version=10, |
| do_constant_folding=True, |
| input_names=['input'], |
| output_names=['output'], |
| dynamic_axes={'input': {0: 'batch_size'}, |
| 'output': {0: 'batch_size'}}) |
|
|
| data = dict(input_data=x.tolist()) |
|
|
| |
| json.dump(data, open(input_data_path, 'w')) |
|
|
| py_run_args = ezkl.PyRunArgs() |
| py_run_args.input_visibility = "public" |
| py_run_args.output_visibility = "public" |
| py_run_args.param_visibility = "fixed" |
|
|
| res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args) |
| assert res is True |
|
|
| cal_path = os.path.join(f"{folder_path}/calibration.json") |
|
|
| |
| json.dump(data, open(cal_path, 'w')) |
|
|
| await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources") |
|
|
| res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path) |
| assert res is True |
|
|
| |
| res = await ezkl.get_srs(settings_path, srs_path=srs_path) |
| assert res is True |
|
|
| |
|
|
| res = await ezkl.gen_witness(input_data_path, compiled_model_path, witness_path) |
| assert os.path.isfile(witness_path) |
|
|
| |
| |
| |
| |
|
|
| res = ezkl.setup( |
| compiled_model_path, |
| vk_path, |
| pk_path, |
| srs_path |
| ) |
|
|
| assert res is True |
| assert os.path.isfile(vk_path) |
| assert os.path.isfile(pk_path) |
| assert os.path.isfile(settings_path) |
|
|
| |
| proof_path = os.path.join(f'{folder_path}/test.pf') |
| res = ezkl.prove( |
| witness_path, |
| compiled_model_path, |
| pk_path, |
| proof_path, |
| "single", |
| srs_path |
| ) |
| assert os.path.isfile(proof_path) |
|
|
| |
| res = ezkl.verify( |
| proof_path, |
| settings_path, |
| vk_path, |
| srs_path |
| ) |
| assert res is True |
| print("verified on local") |
|
|
| |
| verify_sol_code_path = os.path.join(f'{folder_path}/verify.sol') |
| verify_sol_abi_path = os.path.join(f'{folder_path}/verify.abi') |
| res = await ezkl.create_evm_verifier( |
| vk_path, |
| settings_path, |
| verify_sol_code_path, |
| verify_sol_abi_path, |
| srs_path |
| ) |
| assert res is True |
| verify_contract_addr_file = f"{folder_path}/addr.txt" |
| await ezkl.deploy_evm( |
| addr_path=verify_contract_addr_file, |
| rpc_url=rpc_url, |
| private_key=private_key, |
| sol_code_path=verify_sol_code_path |
| ) |
| if os.path.exists(verify_contract_addr_file): |
| with open(verify_contract_addr_file, 'r') as file: |
| verify_contract_addr = file.read() |
| else: |
| print(f"error: File {verify_contract_addr_file} does not exist.") |
| return {"error": "Contract address file not found"} |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| with open(proof_path, 'rb') as f: |
| proof_content = base64.b64encode(f.read()).decode('utf-8') |
|
|
| return {"output": array_to_hex_string(output_data)[:1000], |
| "output_path": output_path, |
| "proof": proof_content[:500], |
| "proof_path": proof_path, |
| "verify_contract_addr": verify_contract_addr} |
|
|
|
|
| def array_to_hex_string(array): |
| hex_string = ''.join(format(num, '02x') for num in array) |
| return hex_string |
|
|