abdiharyadi's picture
feat: integrate StyleDetector, disable TextToAMR for fast test
d4efbb2
raw
history blame
2.96 kB
import gdown
from git import Repo
import gradio as gr
from huggingface_hub import snapshot_download
import os
import penman
import sys
if not os.path.exists("amr-tst-indo"):
Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo")
sys.path.append("./amr-tst-indo")
from text_to_amr import TextToAMR
from style_detector import StyleDetector
amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan"
snapshot_download(
repo_id=f"abdiharyadi/{amr_parsing_model_name}",
local_dir=f"./amr-tst-indo/AMRBART-id/models/{amr_parsing_model_name}",
ignore_patterns=[
"*log*",
"*checkpoint*",
]
)
t2a = TextToAMR(model_name=amr_parsing_model_name)
gdown.download(
"https://drive.google.com/uc?id=1J_6PbYsQ6Kl4Qfs1wBVwd52_r9uTpIxx",
"./model-best.pt"
)
sd = StyleDetector(
config_path="./amr-tst-indo/indonesian-aste-generative/resources/exp-v2/exp-m0.yaml",
model_path="./model-best.pt"
)
def run(text, source_style):
# source_amr, *_ = t2a([text])
# source_amr.metadata = {}
# source_amr_display = penman.encode(source_amr)
source_amr_display = "(z0 / halo)"
yield source_amr_display, "...", "...", "...", "..."
triplet_display_dict = {"data": "..."}
def triplets_callback(triplets: list):
triplet_display_dict["data"] = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets)
triplets_display = triplet_display_dict["data"]
yield source_amr_display, triplets_display, "...", "...", "..."
style_words = sd(text, triplets_callback=triplets_callback)
# style_words = ["bagus", "bersih"]
# triplets = [
# ("kamar", "sangat bagus", "positif"),
# ("kamar", "bersih", "positif")
# ]
# triplets_display = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets)
triplets_display = triplet_display_dict["data"]
style_words_display = ", ".join(style_words)
yield source_amr_display, triplets_display, style_words_display, "...", "..."
target_amr = penman.decode("(z0 / dunia)")
target_amr_display = penman.encode(target_amr)
yield source_amr_display, triplets_display, style_words_display, target_amr_display, "..."
result = f"dunia ({text=}, {source_style=})"
yield source_amr_display, triplets_display, style_words_display, target_amr_display, result
demo = gr.Interface(
fn=run,
inputs=[
gr.Textbox(label="Teks (Text)"),
gr.Radio(label="Gaya sumber (Source style)", choices=[
("Positif (Positive)", "LABEL_1"),
("Negatif (Negative)", "LABEL_0"),
], value="LABEL_1"),
],
outputs=[
gr.Textbox(label="Graf AMR sumber (Source AMR graph)"),
gr.Textbox(label="Triplet (Triplets)"),
gr.Textbox(label="Kata bergaya (Style words)"),
gr.Textbox(label="Graf AMR target (Target AMR graph)"),
gr.Textbox(label="Hasil (Result)"),
]
)
demo.launch()