Spaces:
Running
Running
File size: 5,006 Bytes
1f125f1 247b3e4 1f125f1 441b276 247b3e4 1f125f1 441b276 1f125f1 441b276 247b3e4 441b276 247b3e4 1f125f1 441b276 247b3e4 1f125f1 0186ed1 441b276 1f125f1 247b3e4 1f125f1 cc8b2eb 1f125f1 c231729 1f125f1 c231729 1f125f1 ee83d59 441b276 0186ed1 c231729 0186ed1 441b276 247b3e4 1f125f1 441b276 1f125f1 441b276 1f125f1 247b3e4 1f125f1 cc8b2eb 1f125f1 52c67ef 1f125f1 0186ed1 1f125f1 ee83d59 1f125f1 ee83d59 1f125f1 ee83d59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import base64
import json
import torch
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
import uvicorn
from stegno import generate, decrypt
from utils import load_model
from seed_scheme_factory import SeedSchemeFactory
from model_factory import ModelFactory
from global_config import GlobalConfig
from schemes import DecryptionBody, EncryptionBody
app = FastAPI()
with open("resources/examples.json", "r") as f:
examples = json.load(f)
@app.post(
"/encrypt",
responses={
200: {
"content": {
"application/json": {"example": examples["encrypt"]["response"]}
}
}
},
)
async def encrypt_api(
body: EncryptionBody,
):
byte_msg = base64.b64decode(body.msg)
model, tokenizer = ModelFactory.load_model(body.gen_model)
texts, msgs_rates, tokens_infos = generate(
tokenizer=tokenizer,
model=model,
prompt=body.prompt,
msg=byte_msg,
start_pos_p=[body.start_pos],
delta=body.delta,
msg_base=body.msg_base,
seed_scheme=body.seed_scheme,
window_length=body.window_length,
private_key=body.private_key,
min_new_tokens_ratio=body.min_new_tokens_ratio,
max_new_tokens_ratio=body.max_new_tokens_ratio,
do_sample=body.do_sample,
num_beams=body.num_beams,
repetition_penalty=body.repetition_penalty,
)
return {
"texts": texts,
"msgs_rates": msgs_rates,
"tokens_infos": tokens_infos,
}
@app.post(
"/decrypt",
responses={
200: {
"content": {
"application/json": {"example": examples["decrypt"]["response"]}
}
}
},
)
async def decrypt_api(body: DecryptionBody):
model, tokenizer = ModelFactory.load_model(body.gen_model)
msgs = decrypt(
tokenizer=tokenizer,
device=model.device,
text=body.text,
msg_base=body.msg_base,
seed_scheme=body.seed_scheme,
window_length=body.window_length,
private_key=body.private_key,
)
msg_b64 = {}
for i, s_msg in enumerate(msgs):
msg_b64[i] = []
for msg in s_msg:
msg_b64[i].append(base64.b64encode(msg))
return msg_b64
@app.get(
"/configs",
responses={
200: {
"content": {
"application/json": {"example": examples["configs"]["response"]}
},
}
},
)
async def default_config():
configs = {
"default": {
"encrypt": {
"gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
"start_pos": GlobalConfig.get("encrypt.default", "start_pos"),
"delta": GlobalConfig.get("encrypt.default", "delta"),
"msg_base": GlobalConfig.get("encrypt.default", "msg_base"),
"seed_scheme": GlobalConfig.get(
"encrypt.default", "seed_scheme"
),
"window_length": GlobalConfig.get(
"encrypt.default", "window_length"
),
"private_key": GlobalConfig.get(
"encrypt.default", "private_key"
),
"min_new_tokens_ratio": GlobalConfig.get(
"encrypt.default", "min_new_tokens_ratio"
),
"max_new_tokens_ratio": GlobalConfig.get(
"encrypt.default", "max_new_tokens_ratio"
),
"do_sample": GlobalConfig.get("encrypt.default", "do_sample"),
"num_beams": GlobalConfig.get("encrypt.default", "num_beams"),
"repetition_penalty": GlobalConfig.get(
"encrypt.default", "repetition_penalty"
),
},
"decrypt": {
"gen_model": GlobalConfig.get("encrypt.default", "gen_model"),
"msg_base": GlobalConfig.get("encrypt.default", "msg_base"),
"seed_scheme": GlobalConfig.get(
"encrypt.default", "seed_scheme"
),
"window_length": GlobalConfig.get(
"encrypt.default", "window_length"
),
"private_key": GlobalConfig.get(
"encrypt.default", "private_key"
),
},
},
"seed_schemes": SeedSchemeFactory.get_schemes_name(),
"models": ModelFactory.get_models_names(),
}
return configs
if __name__ == "__main__":
# The following are mainly used to satisfy the linter
host = GlobalConfig.get("server", "host")
host = str(host) if host is not None else "0.0.0.0"
port = GlobalConfig.get("server", "port")
port = int(port) if port is not None else 8000
workers = GlobalConfig.get("server", "workers")
workers = int(workers) if workers is not None else 1
uvicorn.run("api:app", host=host, port=port, workers=workers)
|