| | |
| | |
| | import struct |
| | import uuid |
| |
|
| | import numpy as np |
| | from torch import nn |
| | import ezkl |
| | import os |
| | import json |
| | import torch |
| | import base64 |
| | from concrete.ml.deployment import FHEModelServer |
| | from fastapi import FastAPI |
| | from pydantic import BaseModel |
| |
|
| | app = FastAPI() |
| |
|
| | evaluation_key = None |
| |
|
| |
|
| | |
| | class AIModel(nn.Module): |
| | def __init__(self): |
| | super(AIModel, self).__init__() |
| |
|
| | |
| | self.fhe_model = FHEModelServer("deployment/sentiment_fhe_model") |
| |
|
| | def forward(self, x): |
| | print(f"forward input: {x}") |
| |
|
| | |
| | x = x[0] |
| | _encrypted_encoding = x.numpy().tobytes() |
| | prediction = self.fhe_model.run(_encrypted_encoding, evaluation_key) |
| | print(f"forward prediction hex: {prediction.hex()}") |
| |
|
| | byte_tensor = torch.tensor(list(prediction), dtype=torch.uint8) |
| | print(f"tensor_output: {byte_tensor}") |
| |
|
| | return byte_tensor |
| |
|
| |
|
| | class ZKProofRequest(BaseModel): |
| | encrypted_encoding: str |
| | evaluation_key: str |
| |
|
| |
|
| | circuit = AIModel() |
| |
|
| |
|
| | @app.post("/get_zk_proof") |
| | async def get_zk_proof(request: ZKProofRequest): |
| | request.encrypted_encoding = base64.b64decode(request.encrypted_encoding) |
| | request.evaluation_key = base64.b64decode(request.evaluation_key) |
| |
|
| | global evaluation_key |
| | evaluation_key = request.evaluation_key |
| |
|
| | folder_path = f"zkml_encrypted/{str(uuid.uuid4())}" |
| | if not os.path.exists(folder_path): |
| | os.makedirs(folder_path) |
| |
|
| | model_path = os.path.join(f'{folder_path}/network.onnx') |
| | compiled_model_path = os.path.join(f'{folder_path}/network.compiled') |
| | pk_path = os.path.join(f'{folder_path}/test.pk') |
| | vk_path = os.path.join(f'{folder_path}/test.vk') |
| | settings_path = os.path.join(f'{folder_path}/settings.json') |
| |
|
| | witness_path = os.path.join(f'{folder_path}/witness.json') |
| | input_data_path = os.path.join(f'{folder_path}/input.json') |
| | srs_path = os.path.join(f'{folder_path}/kzg14.srs') |
| | output_path = os.path.join(f'{folder_path}/output.json') |
| |
|
| | |
| | x = torch.tensor(list([request.encrypted_encoding]), dtype=torch.uint8) |
| |
|
| | |
| | circuit.eval() |
| |
|
| | |
| | with torch.no_grad(): |
| | output = circuit(x) |
| | |
| | output_data = output.detach().numpy().tolist() |
| | with open(output_path, 'w') as f: |
| | json.dump(output_data, f) |
| |
|
| | print("start") |
| | |
| | torch.onnx.export(circuit, |
| | x, |
| | model_path, |
| | export_params=True, |
| | opset_version=10, |
| | do_constant_folding=True, |
| | input_names=['input'], |
| | output_names=['output'], |
| | dynamic_axes={'input': {0: 'batch_size'}, |
| | 'output': {0: 'batch_size'}}) |
| | print("end") |
| |
|
| | data = dict(input_data=x.tolist()) |
| |
|
| | |
| | json.dump(data, open(input_data_path, 'w')) |
| |
|
| | py_run_args = ezkl.PyRunArgs() |
| | py_run_args.input_visibility = "public" |
| | py_run_args.output_visibility = "public" |
| | py_run_args.param_visibility = "fixed" |
| |
|
| | res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args) |
| | assert res is True |
| |
|
| | cal_path = os.path.join(f"{folder_path}/calibration.json") |
| |
|
| | |
| | json.dump(data, open(cal_path, 'w')) |
| |
|
| | await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources") |
| |
|
| | res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path) |
| | assert res is True |
| |
|
| | |
| | res = await ezkl.get_srs(settings_path, srs_path=srs_path) |
| | assert res is True |
| |
|
| | |
| |
|
| | res = await ezkl.gen_witness(input_data_path, compiled_model_path, witness_path) |
| | assert os.path.isfile(witness_path) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | res = ezkl.setup( |
| | compiled_model_path, |
| | vk_path, |
| | pk_path, |
| | srs_path |
| | ) |
| |
|
| | assert res is True |
| | assert os.path.isfile(vk_path) |
| | assert os.path.isfile(pk_path) |
| | assert os.path.isfile(settings_path) |
| |
|
| | |
| | proof_path = os.path.join(f'{folder_path}/test.pf') |
| | res = ezkl.prove( |
| | witness_path, |
| | compiled_model_path, |
| | pk_path, |
| | proof_path, |
| | "single", |
| | srs_path |
| | ) |
| | assert os.path.isfile(proof_path) |
| |
|
| | |
| | res = ezkl.verify( |
| | proof_path, |
| | settings_path, |
| | vk_path, |
| | srs_path |
| | ) |
| | assert res is True |
| | print("verified on local") |
| |
|
| | |
| | verify_sol_code_path = os.path.join(f'{folder_path}/verify.sol') |
| | verify_sol_abi_path = os.path.join(f'{folder_path}/verify.abi') |
| | res = await ezkl.create_evm_verifier( |
| | vk_path, |
| | settings_path, |
| | verify_sol_code_path, |
| | verify_sol_abi_path, |
| | srs_path |
| | ) |
| | assert res is True |
| | verify_contract_addr_file = f"{folder_path}/addr.txt" |
| | rpc_url = "http://103.231.86.33:10219" |
| | await ezkl.deploy_evm( |
| | addr_path=verify_contract_addr_file, |
| | rpc_url=rpc_url, |
| | sol_code_path=verify_sol_code_path |
| | ) |
| | if os.path.exists(verify_contract_addr_file): |
| | with open(verify_contract_addr_file, 'r') as file: |
| | verify_contract_addr = file.read() |
| | else: |
| | print(f"error: File {verify_contract_addr_file} does not exist.") |
| | return {"error": "Contract address file not found"} |
| | |
| | res = await ezkl.verify_evm( |
| | addr_verifier=verify_contract_addr, |
| | proof_path=proof_path, |
| | rpc_url=rpc_url |
| | ) |
| | assert res is True |
| | print("verified on chain") |
| |
|
| | |
| | with open(proof_path, 'rb') as f: |
| | proof_content = base64.b64encode(f.read()).decode('utf-8') |
| |
|
| | return {"output": output_data, "proof": proof_content, "verify_contract_addr": verify_contract_addr} |
| |
|