trycry / app.py
9Dome's picture
Update app.py
8a4038f verified
import gc
import logging
import os
import re
import torch
from cleantext import clean
import gradio as gr
from tqdm.auto import tqdm
from transformers import pipeline
logging.basicConfig(level=logging.INFO)
logging.info(f"torch version:\t{torch.__version__}")
# --- 1. ต้องประกาศชื่อ Model ไว้ตรงนี้ก่อน (ห้ามย้ายไปไว้ข้างล่าง) ---
checker_model_name = "textattack/roberta-base-CoLA"
corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
# --- 2. เช็ค Device (ป้องกัน RuntimeError เรื่อง NVIDIA) ---
device = 0 if torch.cuda.is_available() else -1
logging.info(f"Using device: {'cuda' if device == 0 else 'cpu'}")
# --- 3. สร้าง Pipeline (ดึงตัวแปรจากข้อ 1 มาใช้) ---
checker = pipeline(
"text-classification",
model=checker_model_name,
device=device,
)
corrector = pipeline(
"text2text-generation",
model=corrector_model_name,
device=device,
)
# --- ฟังก์ชันการทำงานอื่นๆ ---
def split_text(text: str) -> list:
sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text)
sentence_batches = []
temp_batch = []
for sentence in sentences:
temp_batch.append(sentence)
if (len(temp_batch) >= 2 and len(temp_batch) <= 3) or sentence == sentences[-1]:
sentence_batches.append(temp_batch)
temp_batch = []
return sentence_batches
def correct_text(text: str, separator: str = " ") -> str:
sentence_batches = split_text(text)
corrected_text = []
for batch in tqdm(sentence_batches, desc="correcting text.."):
raw_text = " ".join(batch)
results = checker(raw_text)
# ตรวจสอบคุณภาพไวยากรณ์
if results[0]["label"] != "LABEL_1" or (
results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9
):
corrected_batch = corrector(raw_text)
corrected_text.append(corrected_batch[0]["generated_text"])
else:
corrected_text.append(raw_text)
return separator.join(corrected_text)
def update(text: str):
text = clean(text[:4000], lower=False)
return correct_text(text)
# --- 4. Interface ---
with gr.Blocks() as demo:
gr.Markdown("# <center>Robust Grammar Correction</center>")
with gr.Row():
inp = gr.Textbox(label="Input", placeholder="Enter text here...")
out = gr.Textbox(label="Output", interactive=False)
btn = gr.Button("Process")
btn.click(fn=update, inputs=inp, outputs=out)
demo.launch()