import gdown from git import Repo import gradio as gr from huggingface_hub import snapshot_download import os import penman import sys if not os.path.exists("amr-tst-indo"): Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo") sys.path.append("./amr-tst-indo") from text_to_amr import TextToAMR from style_detector import StyleDetector amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan" snapshot_download( repo_id=f"abdiharyadi/{amr_parsing_model_name}", local_dir=f"./amr-tst-indo/AMRBART-id/models/{amr_parsing_model_name}", ignore_patterns=[ "*log*", "*checkpoint*", ] ) t2a = TextToAMR(model_name=amr_parsing_model_name) gdown.download( "https://drive.google.com/uc?id=1J_6PbYsQ6Kl4Qfs1wBVwd52_r9uTpIxx", "./model-best.pt" ) sd = StyleDetector( config_path="./amr-tst-indo/indonesian-aste-generative/resources/exp-v2/exp-m0.yaml", model_path="./model-best.pt" ) def run(text, source_style): # source_amr, *_ = t2a([text]) # source_amr.metadata = {} # source_amr_display = penman.encode(source_amr) source_amr_display = "(z0 / halo)" yield source_amr_display, "...", "...", "...", "..." triplet_display_dict = {"data": "..."} def triplets_callback(triplets: list): triplet_display_dict["data"] = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets) triplets_display = triplet_display_dict["data"] yield source_amr_display, triplets_display, "...", "...", "..." style_words = sd(text, triplets_callback=triplets_callback) # style_words = ["bagus", "bersih"] # triplets = [ # ("kamar", "sangat bagus", "positif"), # ("kamar", "bersih", "positif") # ] # triplets_display = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets) triplets_display = triplet_display_dict["data"] style_words_display = ", ".join(style_words) yield source_amr_display, triplets_display, style_words_display, "...", "..." target_amr = penman.decode("(z0 / dunia)") target_amr_display = penman.encode(target_amr) yield source_amr_display, triplets_display, style_words_display, target_amr_display, "..." result = f"dunia ({text=}, {source_style=})" yield source_amr_display, triplets_display, style_words_display, target_amr_display, result demo = gr.Interface( fn=run, inputs=[ gr.Textbox(label="Teks (Text)"), gr.Radio(label="Gaya sumber (Source style)", choices=[ ("Positif (Positive)", "LABEL_1"), ("Negatif (Negative)", "LABEL_0"), ], value="LABEL_1"), ], outputs=[ gr.Textbox(label="Graf AMR sumber (Source AMR graph)"), gr.Textbox(label="Triplet (Triplets)"), gr.Textbox(label="Kata bergaya (Style words)"), gr.Textbox(label="Graf AMR target (Target AMR graph)"), gr.Textbox(label="Hasil (Result)"), ] ) demo.launch()