Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -519,6 +519,7 @@ logging.basicConfig(level=logging.INFO,
|
|
519 |
])
|
520 |
logger = logging.getLogger(__name__)
|
521 |
|
|
|
522 |
# Main function to handle the translation workflow
|
523 |
def main(dataset_url, model_type, output_dataset_name, range_specification, token: gr.OAuthToken | None, profile: gr.OAuthProfile | None):
|
524 |
try:
|
@@ -527,24 +528,24 @@ def main(dataset_url, model_type, output_dataset_name, range_specification, toke
|
|
527 |
return "### You must be logged in to use this service."
|
528 |
|
529 |
if token:
|
530 |
-
|
531 |
|
532 |
# Configuration and paths
|
533 |
tokenizer_name = "facebook/wmt21-dense-24-wide-en-x"
|
534 |
model_repo_name = "cstr/wmt21ct2_int8" # Repository to download the model from
|
535 |
|
536 |
# Download the model snapshot from Hugging Face
|
537 |
-
model_path = snapshot_download(repo_id=model_repo_name, token=token)
|
538 |
-
|
539 |
|
540 |
# Load the CTranslate2 model
|
541 |
translator = ctranslate2.Translator(model_path, device="auto")
|
542 |
-
|
543 |
|
544 |
# Load the tokenizer
|
545 |
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
|
546 |
tokenizer.src_lang = "en"
|
547 |
-
|
548 |
|
549 |
# Define the task based on user input
|
550 |
task = {
|
@@ -566,17 +567,19 @@ def main(dataset_url, model_type, output_dataset_name, range_specification, toke
|
|
566 |
output_dir=".",
|
567 |
output_repo_name=output_dataset_name,
|
568 |
raw_file_path=task["raw_file"],
|
569 |
-
token=token,
|
570 |
range_specification=task["range_spec"],
|
571 |
model_type=task["model_type"],
|
572 |
translator=translator,
|
573 |
tokenizer=tokenizer,
|
574 |
)
|
575 |
-
|
|
|
576 |
else:
|
577 |
return "Login failed. Please try again."
|
578 |
except Exception as e:
|
579 |
-
|
|
|
580 |
return f"An error occurred: {e}\n\n### Logs:\n{log_stream.getvalue()}"
|
581 |
|
582 |
# Gradio interface setup
|
@@ -622,10 +625,10 @@ with gr.Blocks(theme=theme) as demo:
|
|
622 |
range_specification = gr.Textbox(label="Range Specification", lines=1, placeholder="e.g., 1-100")
|
623 |
|
624 |
with gr.Column():
|
625 |
-
output = gr.
|
626 |
|
627 |
submit_btn = gr.Button("Translate Dataset", variant="primary")
|
628 |
-
submit_btn.click(main, inputs=[dataset_url, model_type, output_dataset_name, range_specification], outputs=output)
|
629 |
|
630 |
gr.Markdown(datasets_desc)
|
631 |
|
|
|
519 |
])
|
520 |
logger = logging.getLogger(__name__)
|
521 |
|
522 |
+
# Main function to handle the translation workflow
|
523 |
# Main function to handle the translation workflow
|
524 |
def main(dataset_url, model_type, output_dataset_name, range_specification, token: gr.OAuthToken | None, profile: gr.OAuthProfile | None):
|
525 |
try:
|
|
|
528 |
return "### You must be logged in to use this service."
|
529 |
|
530 |
if token:
|
531 |
+
logger.info("Logged in to Hugging Face")
|
532 |
|
533 |
# Configuration and paths
|
534 |
tokenizer_name = "facebook/wmt21-dense-24-wide-en-x"
|
535 |
model_repo_name = "cstr/wmt21ct2_int8" # Repository to download the model from
|
536 |
|
537 |
# Download the model snapshot from Hugging Face
|
538 |
+
model_path = snapshot_download(repo_id=model_repo_name, token=token.token)
|
539 |
+
logger.info(f"Model downloaded to: {model_path}")
|
540 |
|
541 |
# Load the CTranslate2 model
|
542 |
translator = ctranslate2.Translator(model_path, device="auto")
|
543 |
+
logger.info("CTranslate2 model loaded successfully.")
|
544 |
|
545 |
# Load the tokenizer
|
546 |
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
|
547 |
tokenizer.src_lang = "en"
|
548 |
+
logger.info("Tokenizer loaded successfully.")
|
549 |
|
550 |
# Define the task based on user input
|
551 |
task = {
|
|
|
567 |
output_dir=".",
|
568 |
output_repo_name=output_dataset_name,
|
569 |
raw_file_path=task["raw_file"],
|
570 |
+
token=token.token,
|
571 |
range_specification=task["range_spec"],
|
572 |
model_type=task["model_type"],
|
573 |
translator=translator,
|
574 |
tokenizer=tokenizer,
|
575 |
)
|
576 |
+
logger.info("Dataset translation completed!")
|
577 |
+
return "Dataset translation completed!\n\n### Logs:\n" + log_stream.getvalue()
|
578 |
else:
|
579 |
return "Login failed. Please try again."
|
580 |
except Exception as e:
|
581 |
+
logger.error(f"An error occurred in the main function: {e}")
|
582 |
+
# Ensure logs are flushed and captured
|
583 |
return f"An error occurred: {e}\n\n### Logs:\n{log_stream.getvalue()}"
|
584 |
|
585 |
# Gradio interface setup
|
|
|
625 |
range_specification = gr.Textbox(label="Range Specification", lines=1, placeholder="e.g., 1-100")
|
626 |
|
627 |
with gr.Column():
|
628 |
+
output = gr.Markdown(label="Output")
|
629 |
|
630 |
submit_btn = gr.Button("Translate Dataset", variant="primary")
|
631 |
+
submit_btn.click(main, inputs=[dataset_url, model_type, output_dataset_name, range_specification, token_input], outputs=output)
|
632 |
|
633 |
gr.Markdown(datasets_desc)
|
634 |
|