jwkirchenbauer commited on
Commit
5b3e92c
1 Parent(s): c3f4b90

API timeout incr to 60 sec

Browse files
Files changed (1) hide show
  1. demo_watermark.py +19 -10
demo_watermark.py CHANGED
@@ -210,12 +210,13 @@ def load_model(args):
210
 
211
 
212
  from text_generation import InferenceAPIClient
 
213
  def generate_with_api(prompt, args):
214
  hf_api_key = os.environ.get("HF_API_KEY")
215
  if hf_api_key is None:
216
  raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
217
 
218
- client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key)
219
 
220
  assert args.n_beams == 1, "HF API models do not support beam search."
221
  generation_params = {
@@ -226,14 +227,22 @@ def generate_with_api(prompt, args):
226
  generation_params["temperature"] = args.sampling_temp
227
  generation_params["seed"] = args.generation_seed
228
 
229
- generation_params["watermark"] = False
230
- output = client.generate(prompt, **generation_params)
231
- output_text_without_watermark = output.generated_text
232
-
233
- generation_params["watermark"] = True
234
- output = client.generate(prompt, **generation_params)
235
- output_text_with_watermark = output.generated_text
236
-
 
 
 
 
 
 
 
 
237
  return (output_text_without_watermark,
238
  output_text_with_watermark)
239
 
@@ -737,7 +746,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
737
  generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
738
  detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
739
  model_selector.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
740
- # When the parameters change, display the update and fire detection, since some detection params dont change the model output.
741
  delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
742
  gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
743
  gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])
 
210
 
211
 
212
  from text_generation import InferenceAPIClient
213
+ from requests.exceptions import ReadTimeout
214
  def generate_with_api(prompt, args):
215
  hf_api_key = os.environ.get("HF_API_KEY")
216
  if hf_api_key is None:
217
  raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
218
 
219
+ client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60)
220
 
221
  assert args.n_beams == 1, "HF API models do not support beam search."
222
  generation_params = {
 
227
  generation_params["temperature"] = args.sampling_temp
228
  generation_params["seed"] = args.generation_seed
229
 
230
+ timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
231
+ try:
232
+ generation_params["watermark"] = False
233
+ output = client.generate(prompt, **generation_params)
234
+ output_text_without_watermark = output.generated_text
235
+ except ReadTimeout as e:
236
+ print(e)
237
+ output_text_without_watermark = timeout_msg
238
+ try:
239
+ generation_params["watermark"] = True
240
+ output = client.generate(prompt, **generation_params)
241
+ output_text_with_watermark = output.generated_text
242
+ except ReadTimeout as e:
243
+ print(e)
244
+ output_text_with_watermark = timeout_msg
245
+
246
  return (output_text_without_watermark,
247
  output_text_with_watermark)
248
 
 
746
  generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
747
  detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
748
  model_selector.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
749
+ # When the parameters change, display the update and also fire detection, since some detection params dont change the model output.
750
  delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
751
  gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
752
  gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])