abdiharyadi's picture
style: use Column for grouping the object
24c312c
raw
history blame
5.81 kB
import gdown
from git import Repo
import gradio as gr
from huggingface_hub import snapshot_download
import os
import penman
import sys
import time
import torch
from transformers import pipeline
if not os.path.exists("amr-tst-indo"):
Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo")
sys.path.append("./amr-tst-indo")
from text_to_amr import TextToAMR
from style_detector import StyleDetector
from style_rewriting import StyleRewriting
from amr_to_text import AMRToTextWithTaufiqMethod
amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan"
snapshot_download(
repo_id=f"abdiharyadi/{amr_parsing_model_name}",
local_dir=f"./amr-tst-indo/AMRBART-id/models/{amr_parsing_model_name}",
ignore_patterns=[
"*log*",
"*checkpoint*",
]
)
t2a = TextToAMR(model_name=amr_parsing_model_name)
gdown.download(
"https://drive.google.com/uc?id=1J_6PbYsQ6Kl4Qfs1wBVwd52_r9uTpIxx",
"./model-best.pt"
)
sd = StyleDetector(
config_path="./amr-tst-indo/indonesian-aste-generative/resources/exp-v2/exp-m0.yaml",
model_path="./model-best.pt"
)
device_type = "cuda" if torch.cuda.is_available() else "cpu"
clf_pipeline = pipeline(
"text-classification",
model="abdiharyadi/roberta-base-indonesian-522M-with-sa-william-dataset",
device=device_type
)
gdown.download(
"https://drive.google.com/uc?id=15KctCcsHgTFMUh_tWNBNUiCyX56fq6p-",
"./fasttext_skipgram_indo.bin"
)
sr = StyleRewriting(
clf_pipeline=clf_pipeline,
fasttext_model_path="./fasttext_skipgram_indo.bin",
position_aware_concatenation=False,
reset_sense_strategy=False,
max_score_strategy=True,
maximize_style_words_expansion=False
)
amr_gen_model_name = "taufiq-indo-amr-generation-gold-uncased"
model_path = f"./{amr_gen_model_name}"
snapshot_download(
repo_id=f"abdiharyadi/{amr_gen_model_name}",
local_dir=model_path,
allow_patterns=[
"*checkpoint-3*"
]
)
a2t = AMRToTextWithTaufiqMethod(
model_path=os.path.join(model_path, "checkpoint-3"),
lowercase=True,
)
def run(text, source_style):
yield (
"(Memproses ...)",
"(Menunggu ...)",
"(Menunggu ...)",
"(Menunggu ...)",
"(Menunggu ...)",
)
start_time = time.time()
source_amr, *_ = t2a([text])
source_amr.metadata = {}
source_amr_display = penman.encode(source_amr)
source_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
"(Memproses ...)",
"(Menunggu ...)",
"(Menunggu ...)",
"(Menunggu ...)",
)
triplets = sd.get_triplets(text)
triplets_display = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets)
triplets_display += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
triplets_display,
"(Memproses ...)",
"(Menunggu ...)",
"(Menunggu ...)",
)
style_words = sd.get_style_words_from_triplets(triplets)
style_words_display = ", ".join(style_words)
style_words_display += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
triplets_display,
style_words_display,
"(Memproses ...)",
"(Menunggu ...)",
)
target_amr = sr(text, source_amr, source_style, style_words)
target_amr_display = penman.encode(target_amr)
target_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
triplets_display,
style_words_display,
target_amr_display,
"(Memproses ...)",
)
result, *_ = a2t([target_amr])
result += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
triplets_display,
style_words_display,
target_amr_display,
result
)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_textbox = gr.Textbox(label="Teks (Text)")
style_choices = gr.Radio(
label="Gaya sumber (Source style)",
choices=[
("Positif (Positive)", "LABEL_1"),
("Negatif (Negative)", "LABEL_0"),
],
value="LABEL_1"
)
submit_btn = gr.Button("Submit")
with gr.Column():
with gr.Row():
src_amr_graph_output = gr.Textbox(
label="Graf AMR sumber (Source AMR graph)",
min_width=320,
)
triplets_output = gr.Textbox(
label="Triplet (Triplets)",
min_width=320,
)
with gr.Row():
style_words_output = gr.Textbox(
label="Kata bergaya (Style words)",
min_width=320,
)
tgt_amr_graph_output = gr.Textbox(
label="Graf AMR target (Target AMR graph)",
min_width=320,
)
result_output = gr.Textbox(label="Hasil (Result)")
with gr.Column():
gr.Markdown("""
# Pengakuan
Demo ini disiapkan untuk Program Penelitian dan Pengabdian Masyarakat STEI ITB 2024.
**Tim Peneliti**:
- Masayu Leylia Khodra (masayu@staff.stei.itb.ac.id)
- M. Abdi Haryadi. H (abdiharyadi.ah@gmail.com)
""")
submit_btn.click(
run,
[input_textbox, style_choices],
[src_amr_graph_output, triplets_output, style_words_output,
tgt_amr_graph_output, result_output]
)
demo.launch()