File size: 2,959 Bytes
d4efbb2
63e7eb5
437fb7f
63e7eb5
d4efbb2
c6dbfda
63e7eb5
 
d4efbb2
 
63e7eb5
 
 
d4efbb2
63e7eb5
 
 
 
 
 
 
 
 
 
 
437fb7f
d4efbb2
 
 
 
 
 
 
 
 
c6dbfda
d4efbb2
 
 
 
c6dbfda
437fb7f
d4efbb2
 
 
 
 
c6dbfda
d4efbb2
 
 
 
 
 
 
 
c6dbfda
 
 
991fc8d
 
c6dbfda
 
 
 
 
 
 
 
 
 
cb5c6f6
 
c446280
c6dbfda
 
cb5c6f6
c6dbfda
 
cb5c6f6
c6dbfda
 
 
437fb7f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gdown
from git import Repo
import gradio as gr
from huggingface_hub import snapshot_download
import os
import penman
import sys

if not os.path.exists("amr-tst-indo"):
    Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo")
sys.path.append("./amr-tst-indo")

from text_to_amr import TextToAMR
from style_detector import StyleDetector

amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan"
snapshot_download(
    repo_id=f"abdiharyadi/{amr_parsing_model_name}",
    local_dir=f"./amr-tst-indo/AMRBART-id/models/{amr_parsing_model_name}",
    ignore_patterns=[
        "*log*",
        "*checkpoint*",
    ]
)
t2a = TextToAMR(model_name=amr_parsing_model_name)

gdown.download(
    "https://drive.google.com/uc?id=1J_6PbYsQ6Kl4Qfs1wBVwd52_r9uTpIxx",
    "./model-best.pt"
)
sd = StyleDetector(
    config_path="./amr-tst-indo/indonesian-aste-generative/resources/exp-v2/exp-m0.yaml",
    model_path="./model-best.pt"
)

def run(text, source_style):
    # source_amr, *_ = t2a([text])
    # source_amr.metadata = {}
    # source_amr_display = penman.encode(source_amr)
    source_amr_display = "(z0 / halo)"
    yield source_amr_display, "...", "...", "...", "..."

    triplet_display_dict = {"data": "..."}
    def triplets_callback(triplets: list):
        triplet_display_dict["data"] = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets)
        triplets_display = triplet_display_dict["data"]
        yield source_amr_display, triplets_display, "...", "...", "..."

    style_words = sd(text, triplets_callback=triplets_callback)
    # style_words = ["bagus", "bersih"]
    # triplets = [
    #     ("kamar", "sangat bagus", "positif"),
    #     ("kamar", "bersih", "positif")
    # ]
    # triplets_display = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets)
    triplets_display = triplet_display_dict["data"]
    style_words_display = ", ".join(style_words)
    yield source_amr_display, triplets_display, style_words_display, "...", "..."

    target_amr = penman.decode("(z0 / dunia)")
    target_amr_display = penman.encode(target_amr)
    yield source_amr_display, triplets_display, style_words_display, target_amr_display, "..."

    result = f"dunia ({text=}, {source_style=})"
    yield source_amr_display, triplets_display, style_words_display, target_amr_display, result

demo = gr.Interface(
    fn=run,
    inputs=[
        gr.Textbox(label="Teks (Text)"),
        gr.Radio(label="Gaya sumber (Source style)", choices=[
            ("Positif (Positive)", "LABEL_1"),
            ("Negatif (Negative)", "LABEL_0"),
        ], value="LABEL_1"),
    ],
    outputs=[
        gr.Textbox(label="Graf AMR sumber (Source AMR graph)"),
        gr.Textbox(label="Triplet (Triplets)"),
        gr.Textbox(label="Kata bergaya (Style words)"),
        gr.Textbox(label="Graf AMR target (Target AMR graph)"),
        gr.Textbox(label="Hasil (Result)"),
    ]
)
demo.launch()