Spaces:
Running
Running
import gdown | |
from git import Repo | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import os | |
import penman | |
import sys | |
import time | |
import torch | |
from transformers import pipeline | |
if not os.path.exists("amr-tst-indo"): | |
Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo") | |
sys.path.append("./amr-tst-indo") | |
from text_to_amr import TextToAMR | |
from style_detector import StyleDetector | |
from style_rewriting import StyleRewriting | |
from amr_to_text import AMRToTextWithTaufiqMethod | |
amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan" | |
snapshot_download( | |
repo_id=f"abdiharyadi/{amr_parsing_model_name}", | |
local_dir=f"./amr-tst-indo/AMRBART-id/models/{amr_parsing_model_name}", | |
ignore_patterns=[ | |
"*log*", | |
"*checkpoint*", | |
] | |
) | |
t2a = TextToAMR(model_name=amr_parsing_model_name) | |
gdown.download( | |
"https://drive.google.com/uc?id=1J_6PbYsQ6Kl4Qfs1wBVwd52_r9uTpIxx", | |
"./model-best.pt" | |
) | |
sd = StyleDetector( | |
config_path="./amr-tst-indo/indonesian-aste-generative/resources/exp-v2/exp-m0.yaml", | |
model_path="./model-best.pt" | |
) | |
device_type = "cuda" if torch.cuda.is_available() else "cpu" | |
clf_pipeline = pipeline( | |
"text-classification", | |
model="abdiharyadi/roberta-base-indonesian-522M-with-sa-william-dataset", | |
device=device_type | |
) | |
gdown.download( | |
"https://drive.google.com/uc?id=15KctCcsHgTFMUh_tWNBNUiCyX56fq6p-", | |
"./fasttext_skipgram_indo.bin" | |
) | |
sr = StyleRewriting( | |
clf_pipeline=clf_pipeline, | |
fasttext_model_path="./fasttext_skipgram_indo.bin", | |
position_aware_concatenation=False, | |
reset_sense_strategy=False, | |
max_score_strategy=True, | |
maximize_style_words_expansion=False | |
) | |
amr_gen_model_name = "taufiq-indo-amr-generation-gold-uncased" | |
model_path = f"./{amr_gen_model_name}" | |
snapshot_download( | |
repo_id=f"abdiharyadi/{amr_gen_model_name}", | |
local_dir=model_path, | |
allow_patterns=[ | |
"*checkpoint-3*" | |
] | |
) | |
a2t = AMRToTextWithTaufiqMethod( | |
model_path=os.path.join(model_path, "checkpoint-3"), | |
lowercase=True, | |
) | |
def run(text, source_style): | |
yield ( | |
"(Memproses ...)", | |
"(Menunggu ...)", | |
"(Menunggu ...)", | |
"(Menunggu ...)", | |
"(Menunggu ...)", | |
) | |
start_time = time.time() | |
source_amr, *_ = t2a([text]) | |
source_amr.metadata = {} | |
source_amr_display = penman.encode(source_amr) | |
source_amr_display += f"\n\n({time.time() - start_time:.2f} s)" | |
yield ( | |
source_amr_display, | |
"(Memproses ...)", | |
"(Menunggu ...)", | |
"(Menunggu ...)", | |
"(Menunggu ...)", | |
) | |
triplets = sd.get_triplets(text) | |
triplets_display = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets) | |
triplets_display += f"\n\n({time.time() - start_time:.2f} s)" | |
yield ( | |
source_amr_display, | |
triplets_display, | |
"(Memproses ...)", | |
"(Menunggu ...)", | |
"(Menunggu ...)", | |
) | |
style_words = sd.get_style_words_from_triplets(triplets) | |
style_words_display = ", ".join(style_words) | |
style_words_display += f"\n\n({time.time() - start_time:.2f} s)" | |
yield ( | |
source_amr_display, | |
triplets_display, | |
style_words_display, | |
"(Memproses ...)", | |
"(Menunggu ...)", | |
) | |
target_amr = sr(text, source_amr, source_style, style_words) | |
target_amr_display = penman.encode(target_amr) | |
target_amr_display += f"\n\n({time.time() - start_time:.2f} s)" | |
yield ( | |
source_amr_display, | |
triplets_display, | |
style_words_display, | |
target_amr_display, | |
"(Memproses ...)", | |
) | |
result, *_ = a2t([target_amr]) | |
result += f"\n\n({time.time() - start_time:.2f} s)" | |
yield ( | |
source_amr_display, | |
triplets_display, | |
style_words_display, | |
target_amr_display, | |
result | |
) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_textbox = gr.Textbox(label="Teks (Text)") | |
style_choices = gr.Radio( | |
label="Gaya sumber (Source style)", | |
choices=[ | |
("Positif (Positive)", "LABEL_1"), | |
("Negatif (Negative)", "LABEL_0"), | |
], | |
value="LABEL_1" | |
) | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
with gr.Row(): | |
src_amr_graph_output = gr.Textbox( | |
label="Graf AMR sumber (Source AMR graph)", | |
min_width=320, | |
) | |
triplets_output = gr.Textbox( | |
label="Triplet (Triplets)", | |
min_width=320, | |
) | |
with gr.Row(): | |
style_words_output = gr.Textbox( | |
label="Kata bergaya (Style words)", | |
min_width=320, | |
) | |
tgt_amr_graph_output = gr.Textbox( | |
label="Graf AMR target (Target AMR graph)", | |
min_width=320, | |
) | |
result_output = gr.Textbox(label="Hasil (Result)") | |
with gr.Column(): | |
gr.Markdown(""" | |
# Pengakuan | |
Demo ini disiapkan untuk Program Penelitian dan Pengabdian Masyarakat STEI ITB 2024. | |
**Tim Peneliti**: | |
- Masayu Leylia Khodra (masayu@staff.stei.itb.ac.id) | |
- M. Abdi Haryadi. H (abdiharyadi.ah@gmail.com) | |
""") | |
submit_btn.click( | |
run, | |
[input_textbox, style_choices], | |
[src_amr_graph_output, triplets_output, style_words_output, | |
tgt_amr_graph_output, result_output] | |
) | |
demo.launch() | |