Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| # CRITICAL: Redirect cache to temporary storage | |
| os.environ['TORCH_HOME'] = '/tmp/torch_cache' | |
| os.environ['HUB_DIR'] = '/tmp/torch_hub' | |
| os.environ['TMPDIR'] = '/tmp' | |
| torch.hub.set_dir('/tmp/torch_hub') | |
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| T5Tokenizer, | |
| T5ForConditionalGeneration | |
| ) | |
| import re | |
| HF_USERNAME = "Tin113" | |
| # ----------------------------------------- | |
| BART_MODEL_REPO = f"{HF_USERNAME}/bart_model" | |
| VIT5_MODEL_REPO = f"{HF_USERNAME}/vit5_model" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Thiết bị sử dụng: {DEVICE}") | |
| # Tải các model | |
| try: | |
| print(f"Đang tải model BART từ {BART_MODEL_REPO}...") | |
| tokenizer_bart = AutoTokenizer.from_pretrained(BART_MODEL_REPO) | |
| model_bart = AutoModelForSeq2SeqLM.from_pretrained(BART_MODEL_REPO).to(DEVICE) | |
| model_bart.eval() | |
| print("Tải model BART thành công.") | |
| except Exception as e: | |
| print(f"Lỗi khi tải model BART: {e}") | |
| model_bart, tokenizer_bart = None, None | |
| try: | |
| print(f"Đang tải model ViT5 từ {VIT5_MODEL_REPO}...") | |
| tokenizer_vit5 = T5Tokenizer.from_pretrained(VIT5_MODEL_REPO) | |
| model_vit5 = T5ForConditionalGeneration.from_pretrained(VIT5_MODEL_REPO).to(DEVICE) | |
| model_vit5.eval() | |
| print("Tải model ViT5 thành công.") | |
| except Exception as e: | |
| print(f"Lỗi khi tải model ViT5: {e}") | |
| model_vit5, tokenizer_vit5 = None, None | |
| def clean_text(text): | |
| if not isinstance(text, str): return "" | |
| return re.sub(r'\s+', ' ', text).strip() | |
| def correct_grammar(sentence, model_choice): | |
| if not sentence.strip(): return "Vui lòng nhập một câu." | |
| model, tokenizer, prefix = None, None, "" | |
| if model_choice == "BARTpho-syllable": | |
| if model_bart: | |
| model, tokenizer, prefix = model_bart, tokenizer_bart, "Fix: " | |
| else: | |
| return "Lỗi: Model BART không khả dụng. Vui lòng kiểm tra lại Space." | |
| elif model_choice == "ViT5-base": | |
| if model_vit5: | |
| model, tokenizer, prefix = model_vit5, tokenizer_vit5, "sửa lỗi: " | |
| else: | |
| return "Lỗi: Model ViT5 không khả dụng. Vui lòng kiểm tra lại Space." | |
| input_text = prefix + sentence | |
| input_ids = tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True, padding=True).input_ids.to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model.generate(input_ids, max_length=276, num_beams=2, early_stopping=True, repetition_penalty=1.05, no_repeat_ngram_size=2) | |
| return clean_text(tokenizer.decode(outputs[0], skip_special_tokens=True)) | |
| description = """ | |
| Demo sửa lỗi chính tả tiếng Việt sử dụng hai model: BARTpho-syllable và ViT5-base. | |
| 1. Nhập câu lỗi vào ô bên dưới. | |
| 2. Chọn model bạn muốn dùng. | |
| 3. Nhấn "Submit" để xem kết quả. | |
| """ | |
| demo = gr.Interface( | |
| fn=correct_grammar, | |
| inputs=[ | |
| gr.Textbox(lines=5, label="Nhập câu tiếng Việt bị lỗi"), | |
| gr.Radio(choices=["BARTpho-syllable", "ViT5-base"], value="ViT5-base", label="Chọn Model") | |
| ], | |
| outputs=gr.Textbox(label="Câu đã được sửa"), | |
| title="Sửa lỗi chính tả Tiếng Việt", | |
| description=description, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |