cstr commited on
Commit
0861973
1 Parent(s): 923de84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -519,6 +519,7 @@ logging.basicConfig(level=logging.INFO,
519
  ])
520
  logger = logging.getLogger(__name__)
521
 
 
522
  # Main function to handle the translation workflow
523
  def main(dataset_url, model_type, output_dataset_name, range_specification, token: gr.OAuthToken | None, profile: gr.OAuthProfile | None):
524
  try:
@@ -527,24 +528,24 @@ def main(dataset_url, model_type, output_dataset_name, range_specification, toke
527
  return "### You must be logged in to use this service."
528
 
529
  if token:
530
- logging.info("Logged in to Hugging Face")
531
 
532
  # Configuration and paths
533
  tokenizer_name = "facebook/wmt21-dense-24-wide-en-x"
534
  model_repo_name = "cstr/wmt21ct2_int8" # Repository to download the model from
535
 
536
  # Download the model snapshot from Hugging Face
537
- model_path = snapshot_download(repo_id=model_repo_name, token=token)
538
- logging.info(f"Model downloaded to: {model_path}")
539
 
540
  # Load the CTranslate2 model
541
  translator = ctranslate2.Translator(model_path, device="auto")
542
- logging.info("CTranslate2 model loaded successfully.")
543
 
544
  # Load the tokenizer
545
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
546
  tokenizer.src_lang = "en"
547
- logging.info("Tokenizer loaded successfully.")
548
 
549
  # Define the task based on user input
550
  task = {
@@ -566,17 +567,19 @@ def main(dataset_url, model_type, output_dataset_name, range_specification, toke
566
  output_dir=".",
567
  output_repo_name=output_dataset_name,
568
  raw_file_path=task["raw_file"],
569
- token=token,
570
  range_specification=task["range_spec"],
571
  model_type=task["model_type"],
572
  translator=translator,
573
  tokenizer=tokenizer,
574
  )
575
- return "Dataset translation completed!"
 
576
  else:
577
  return "Login failed. Please try again."
578
  except Exception as e:
579
- logging.error(f"An error occurred in the main function: {e}")
 
580
  return f"An error occurred: {e}\n\n### Logs:\n{log_stream.getvalue()}"
581
 
582
  # Gradio interface setup
@@ -622,10 +625,10 @@ with gr.Blocks(theme=theme) as demo:
622
  range_specification = gr.Textbox(label="Range Specification", lines=1, placeholder="e.g., 1-100")
623
 
624
  with gr.Column():
625
- output = gr.Textbox(label="Output", lines=1)
626
 
627
  submit_btn = gr.Button("Translate Dataset", variant="primary")
628
- submit_btn.click(main, inputs=[dataset_url, model_type, output_dataset_name, range_specification], outputs=output)
629
 
630
  gr.Markdown(datasets_desc)
631
 
 
519
  ])
520
  logger = logging.getLogger(__name__)
521
 
522
+ # Main function to handle the translation workflow
523
  # Main function to handle the translation workflow
524
  def main(dataset_url, model_type, output_dataset_name, range_specification, token: gr.OAuthToken | None, profile: gr.OAuthProfile | None):
525
  try:
 
528
  return "### You must be logged in to use this service."
529
 
530
  if token:
531
+ logger.info("Logged in to Hugging Face")
532
 
533
  # Configuration and paths
534
  tokenizer_name = "facebook/wmt21-dense-24-wide-en-x"
535
  model_repo_name = "cstr/wmt21ct2_int8" # Repository to download the model from
536
 
537
  # Download the model snapshot from Hugging Face
538
+ model_path = snapshot_download(repo_id=model_repo_name, token=token.token)
539
+ logger.info(f"Model downloaded to: {model_path}")
540
 
541
  # Load the CTranslate2 model
542
  translator = ctranslate2.Translator(model_path, device="auto")
543
+ logger.info("CTranslate2 model loaded successfully.")
544
 
545
  # Load the tokenizer
546
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
547
  tokenizer.src_lang = "en"
548
+ logger.info("Tokenizer loaded successfully.")
549
 
550
  # Define the task based on user input
551
  task = {
 
567
  output_dir=".",
568
  output_repo_name=output_dataset_name,
569
  raw_file_path=task["raw_file"],
570
+ token=token.token,
571
  range_specification=task["range_spec"],
572
  model_type=task["model_type"],
573
  translator=translator,
574
  tokenizer=tokenizer,
575
  )
576
+ logger.info("Dataset translation completed!")
577
+ return "Dataset translation completed!\n\n### Logs:\n" + log_stream.getvalue()
578
  else:
579
  return "Login failed. Please try again."
580
  except Exception as e:
581
+ logger.error(f"An error occurred in the main function: {e}")
582
+ # Ensure logs are flushed and captured
583
  return f"An error occurred: {e}\n\n### Logs:\n{log_stream.getvalue()}"
584
 
585
  # Gradio interface setup
 
625
  range_specification = gr.Textbox(label="Range Specification", lines=1, placeholder="e.g., 1-100")
626
 
627
  with gr.Column():
628
+ output = gr.Markdown(label="Output")
629
 
630
  submit_btn = gr.Button("Translate Dataset", variant="primary")
631
+ submit_btn.click(main, inputs=[dataset_url, model_type, output_dataset_name, range_specification, token_input], outputs=output)
632
 
633
  gr.Markdown(datasets_desc)
634