Spaces:
Sleeping
Sleeping
Add parameters to RestAPI
Browse files- api.py +3 -1
- schemes.py +5 -1
- stegno.py +3 -2
api.py
CHANGED
@@ -45,14 +45,16 @@ async def encrypt_api(
|
|
45 |
seed_scheme=body.seed_scheme,
|
46 |
window_length=body.window_length,
|
47 |
private_key=body.private_key,
|
|
|
48 |
max_new_tokens_ratio=body.max_new_tokens_ratio,
|
|
|
49 |
num_beams=body.num_beams,
|
50 |
repetition_penalty=body.repetition_penalty,
|
51 |
)
|
52 |
return {
|
53 |
"texts": texts,
|
54 |
"msgs_rates": msgs_rates,
|
55 |
-
"
|
56 |
}
|
57 |
|
58 |
|
|
|
45 |
seed_scheme=body.seed_scheme,
|
46 |
window_length=body.window_length,
|
47 |
private_key=body.private_key,
|
48 |
+
min_new_tokens_ratio=body.min_new_tokens_ratio,
|
49 |
max_new_tokens_ratio=body.max_new_tokens_ratio,
|
50 |
+
do_sample=body.do_sample,
|
51 |
num_beams=body.num_beams,
|
52 |
repetition_penalty=body.repetition_penalty,
|
53 |
)
|
54 |
return {
|
55 |
"texts": texts,
|
56 |
"msgs_rates": msgs_rates,
|
57 |
+
"tokens_infos": tokens_infos,
|
58 |
}
|
59 |
|
60 |
|
schemes.py
CHANGED
@@ -49,7 +49,7 @@ class EncryptionBody(BaseModel):
|
|
49 |
title="Private key used to compute the seed for PRF",
|
50 |
ge=0,
|
51 |
)
|
52 |
-
|
53 |
default=GlobalConfig.get("encrypt.default", "min_new_tokens_ratio"),
|
54 |
title="Min length of generated text compared to the minimum length required to hide the message",
|
55 |
ge=1,
|
@@ -64,6 +64,10 @@ class EncryptionBody(BaseModel):
|
|
64 |
title="Number of beams used in beam search",
|
65 |
ge=1,
|
66 |
)
|
|
|
|
|
|
|
|
|
67 |
|
68 |
repetition_penalty: float = Field(
|
69 |
default=GlobalConfig.get("encrypt.default", "repetition_penalty"),
|
|
|
49 |
title="Private key used to compute the seed for PRF",
|
50 |
ge=0,
|
51 |
)
|
52 |
+
min_new_tokens_ratio: float = Field(
|
53 |
default=GlobalConfig.get("encrypt.default", "min_new_tokens_ratio"),
|
54 |
title="Min length of generated text compared to the minimum length required to hide the message",
|
55 |
ge=1,
|
|
|
64 |
title="Number of beams used in beam search",
|
65 |
ge=1,
|
66 |
)
|
67 |
+
do_sample: bool = Field(
|
68 |
+
default=GlobalConfig.get("encrypt.default", "do_sample"),
|
69 |
+
title="Whether to use greedy or sampling generating"
|
70 |
+
)
|
71 |
|
72 |
repetition_penalty: float = Field(
|
73 |
default=GlobalConfig.get("encrypt.default", "repetition_penalty"),
|
stegno.py
CHANGED
@@ -78,18 +78,19 @@ def generate(
|
|
78 |
salt_key=salt_key,
|
79 |
private_key=private_key,
|
80 |
)
|
81 |
-
min_length = (
|
82 |
prompt_size
|
83 |
+ start_pos
|
84 |
+ logits_processor.get_message_len() * min_new_tokens_ratio
|
85 |
)
|
86 |
-
max_length = (
|
87 |
prompt_size
|
88 |
+ start_pos
|
89 |
+ logits_processor.get_message_len() * max_new_tokens_ratio
|
90 |
)
|
91 |
max_length = min(max_length, tokenizer.model_max_length)
|
92 |
min_length = min(min_length, max_length)
|
|
|
93 |
output_tokens = model.generate(
|
94 |
**tokenized_input,
|
95 |
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|
|
|
78 |
salt_key=salt_key,
|
79 |
private_key=private_key,
|
80 |
)
|
81 |
+
min_length = int(
|
82 |
prompt_size
|
83 |
+ start_pos
|
84 |
+ logits_processor.get_message_len() * min_new_tokens_ratio
|
85 |
)
|
86 |
+
max_length = int(
|
87 |
prompt_size
|
88 |
+ start_pos
|
89 |
+ logits_processor.get_message_len() * max_new_tokens_ratio
|
90 |
)
|
91 |
max_length = min(max_length, tokenizer.model_max_length)
|
92 |
min_length = min(min_length, max_length)
|
93 |
+
|
94 |
output_tokens = model.generate(
|
95 |
**tokenized_input,
|
96 |
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|