Tin113's picture
Update app.py
4e5872e verified
import os
import torch
# CRITICAL: Redirect cache to temporary storage
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
os.environ['HUB_DIR'] = '/tmp/torch_hub'
os.environ['TMPDIR'] = '/tmp'
torch.hub.set_dir('/tmp/torch_hub')
import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
T5Tokenizer,
T5ForConditionalGeneration
)
import re
HF_USERNAME = "Tin113"
# -----------------------------------------
BART_MODEL_REPO = f"{HF_USERNAME}/bart_model"
VIT5_MODEL_REPO = f"{HF_USERNAME}/vit5_model"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Thiết bị sử dụng: {DEVICE}")
# Tải các model
try:
print(f"Đang tải model BART từ {BART_MODEL_REPO}...")
tokenizer_bart = AutoTokenizer.from_pretrained(BART_MODEL_REPO)
model_bart = AutoModelForSeq2SeqLM.from_pretrained(BART_MODEL_REPO).to(DEVICE)
model_bart.eval()
print("Tải model BART thành công.")
except Exception as e:
print(f"Lỗi khi tải model BART: {e}")
model_bart, tokenizer_bart = None, None
try:
print(f"Đang tải model ViT5 từ {VIT5_MODEL_REPO}...")
tokenizer_vit5 = T5Tokenizer.from_pretrained(VIT5_MODEL_REPO)
model_vit5 = T5ForConditionalGeneration.from_pretrained(VIT5_MODEL_REPO).to(DEVICE)
model_vit5.eval()
print("Tải model ViT5 thành công.")
except Exception as e:
print(f"Lỗi khi tải model ViT5: {e}")
model_vit5, tokenizer_vit5 = None, None
def clean_text(text):
if not isinstance(text, str): return ""
return re.sub(r'\s+', ' ', text).strip()
def correct_grammar(sentence, model_choice):
if not sentence.strip(): return "Vui lòng nhập một câu."
model, tokenizer, prefix = None, None, ""
if model_choice == "BARTpho-syllable":
if model_bart:
model, tokenizer, prefix = model_bart, tokenizer_bart, "Fix: "
else:
return "Lỗi: Model BART không khả dụng. Vui lòng kiểm tra lại Space."
elif model_choice == "ViT5-base":
if model_vit5:
model, tokenizer, prefix = model_vit5, tokenizer_vit5, "sửa lỗi: "
else:
return "Lỗi: Model ViT5 không khả dụng. Vui lòng kiểm tra lại Space."
input_text = prefix + sentence
input_ids = tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True, padding=True).input_ids.to(DEVICE)
with torch.no_grad():
outputs = model.generate(input_ids, max_length=276, num_beams=2, early_stopping=True, repetition_penalty=1.05, no_repeat_ngram_size=2)
return clean_text(tokenizer.decode(outputs[0], skip_special_tokens=True))
description = """
Demo sửa lỗi chính tả tiếng Việt sử dụng hai model: BARTpho-syllable và ViT5-base.
1. Nhập câu lỗi vào ô bên dưới.
2. Chọn model bạn muốn dùng.
3. Nhấn "Submit" để xem kết quả.
"""
demo = gr.Interface(
fn=correct_grammar,
inputs=[
gr.Textbox(lines=5, label="Nhập câu tiếng Việt bị lỗi"),
gr.Radio(choices=["BARTpho-syllable", "ViT5-base"], value="ViT5-base", label="Chọn Model")
],
outputs=gr.Textbox(label="Câu đã được sửa"),
title="Sửa lỗi chính tả Tiếng Việt",
description=description,
)
if __name__ == "__main__":
demo.launch()