Spaces:
Sleeping
Sleeping
import gdown | |
from git import Repo | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import os | |
import penman | |
import sys | |
if not os.path.exists("amr-tst-indo"): | |
Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo") | |
sys.path.append("./amr-tst-indo") | |
from text_to_amr import TextToAMR | |
from style_detector import StyleDetector | |
amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan" | |
snapshot_download( | |
repo_id=f"abdiharyadi/{amr_parsing_model_name}", | |
local_dir=f"./amr-tst-indo/AMRBART-id/models/{amr_parsing_model_name}", | |
ignore_patterns=[ | |
"*log*", | |
"*checkpoint*", | |
] | |
) | |
t2a = TextToAMR(model_name=amr_parsing_model_name) | |
gdown.download( | |
"https://drive.google.com/uc?id=1J_6PbYsQ6Kl4Qfs1wBVwd52_r9uTpIxx", | |
"./model-best.pt" | |
) | |
sd = StyleDetector( | |
config_path="./amr-tst-indo/indonesian-aste-generative/resources/exp-v2/exp-m0.yaml", | |
model_path="./model-best.pt" | |
) | |
def run(text, source_style): | |
# source_amr, *_ = t2a([text]) | |
# source_amr.metadata = {} | |
# source_amr_display = penman.encode(source_amr) | |
source_amr_display = "(z0 / halo)" | |
yield source_amr_display, "...", "...", "...", "..." | |
triplet_display_dict = {"data": "..."} | |
def triplets_callback(triplets: list): | |
triplet_display_dict["data"] = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets) | |
triplets_display = triplet_display_dict["data"] | |
yield source_amr_display, triplets_display, "...", "...", "..." | |
style_words = sd(text, triplets_callback=triplets_callback) | |
# style_words = ["bagus", "bersih"] | |
# triplets = [ | |
# ("kamar", "sangat bagus", "positif"), | |
# ("kamar", "bersih", "positif") | |
# ] | |
# triplets_display = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets) | |
triplets_display = triplet_display_dict["data"] | |
style_words_display = ", ".join(style_words) | |
yield source_amr_display, triplets_display, style_words_display, "...", "..." | |
target_amr = penman.decode("(z0 / dunia)") | |
target_amr_display = penman.encode(target_amr) | |
yield source_amr_display, triplets_display, style_words_display, target_amr_display, "..." | |
result = f"dunia ({text=}, {source_style=})" | |
yield source_amr_display, triplets_display, style_words_display, target_amr_display, result | |
demo = gr.Interface( | |
fn=run, | |
inputs=[ | |
gr.Textbox(label="Teks (Text)"), | |
gr.Radio(label="Gaya sumber (Source style)", choices=[ | |
("Positif (Positive)", "LABEL_1"), | |
("Negatif (Negative)", "LABEL_0"), | |
], value="LABEL_1"), | |
], | |
outputs=[ | |
gr.Textbox(label="Graf AMR sumber (Source AMR graph)"), | |
gr.Textbox(label="Triplet (Triplets)"), | |
gr.Textbox(label="Kata bergaya (Style words)"), | |
gr.Textbox(label="Graf AMR target (Target AMR graph)"), | |
gr.Textbox(label="Hasil (Result)"), | |
] | |
) | |
demo.launch() | |