tnk2908 commited on
Commit
c231729
1 Parent(s): 0852a55

Add parameters to RestAPI

Browse files
Files changed (3) hide show
  1. api.py +3 -1
  2. schemes.py +5 -1
  3. stegno.py +3 -2
api.py CHANGED
@@ -45,14 +45,16 @@ async def encrypt_api(
45
  seed_scheme=body.seed_scheme,
46
  window_length=body.window_length,
47
  private_key=body.private_key,
 
48
  max_new_tokens_ratio=body.max_new_tokens_ratio,
 
49
  num_beams=body.num_beams,
50
  repetition_penalty=body.repetition_penalty,
51
  )
52
  return {
53
  "texts": texts,
54
  "msgs_rates": msgs_rates,
55
- "tokens_info": tokens_infos,
56
  }
57
 
58
 
 
45
  seed_scheme=body.seed_scheme,
46
  window_length=body.window_length,
47
  private_key=body.private_key,
48
+ min_new_tokens_ratio=body.min_new_tokens_ratio,
49
  max_new_tokens_ratio=body.max_new_tokens_ratio,
50
+ do_sample=body.do_sample,
51
  num_beams=body.num_beams,
52
  repetition_penalty=body.repetition_penalty,
53
  )
54
  return {
55
  "texts": texts,
56
  "msgs_rates": msgs_rates,
57
+ "tokens_infos": tokens_infos,
58
  }
59
 
60
 
schemes.py CHANGED
@@ -49,7 +49,7 @@ class EncryptionBody(BaseModel):
49
  title="Private key used to compute the seed for PRF",
50
  ge=0,
51
  )
52
- max_new_tokens_ratio: float = Field(
53
  default=GlobalConfig.get("encrypt.default", "min_new_tokens_ratio"),
54
  title="Min length of generated text compared to the minimum length required to hide the message",
55
  ge=1,
@@ -64,6 +64,10 @@ class EncryptionBody(BaseModel):
64
  title="Number of beams used in beam search",
65
  ge=1,
66
  )
 
 
 
 
67
 
68
  repetition_penalty: float = Field(
69
  default=GlobalConfig.get("encrypt.default", "repetition_penalty"),
 
49
  title="Private key used to compute the seed for PRF",
50
  ge=0,
51
  )
52
+ min_new_tokens_ratio: float = Field(
53
  default=GlobalConfig.get("encrypt.default", "min_new_tokens_ratio"),
54
  title="Min length of generated text compared to the minimum length required to hide the message",
55
  ge=1,
 
64
  title="Number of beams used in beam search",
65
  ge=1,
66
  )
67
+ do_sample: bool = Field(
68
+ default=GlobalConfig.get("encrypt.default", "do_sample"),
69
+ title="Whether to use greedy or sampling generating"
70
+ )
71
 
72
  repetition_penalty: float = Field(
73
  default=GlobalConfig.get("encrypt.default", "repetition_penalty"),
stegno.py CHANGED
@@ -78,18 +78,19 @@ def generate(
78
  salt_key=salt_key,
79
  private_key=private_key,
80
  )
81
- min_length = (
82
  prompt_size
83
  + start_pos
84
  + logits_processor.get_message_len() * min_new_tokens_ratio
85
  )
86
- max_length = (
87
  prompt_size
88
  + start_pos
89
  + logits_processor.get_message_len() * max_new_tokens_ratio
90
  )
91
  max_length = min(max_length, tokenizer.model_max_length)
92
  min_length = min(min_length, max_length)
 
93
  output_tokens = model.generate(
94
  **tokenized_input,
95
  logits_processor=transformers.LogitsProcessorList([logits_processor]),
 
78
  salt_key=salt_key,
79
  private_key=private_key,
80
  )
81
+ min_length = int(
82
  prompt_size
83
  + start_pos
84
  + logits_processor.get_message_len() * min_new_tokens_ratio
85
  )
86
+ max_length = int(
87
  prompt_size
88
  + start_pos
89
  + logits_processor.get_message_len() * max_new_tokens_ratio
90
  )
91
  max_length = min(max_length, tokenizer.model_max_length)
92
  min_length = min(min_length, max_length)
93
+
94
  output_tokens = model.generate(
95
  **tokenized_input,
96
  logits_processor=transformers.LogitsProcessorList([logits_processor]),