Spaces:
Sleeping
Sleeping
File size: 5,100 Bytes
d4efbb2 63e7eb5 437fb7f 63e7eb5 d4efbb2 c6dbfda 63e7eb5 ce0ea7d 35d002b 63e7eb5 d4efbb2 63e7eb5 d4efbb2 35d002b 1556875 63e7eb5 437fb7f d4efbb2 35d002b 1556875 c6dbfda ce0ea7d 1556875 35d002b ce0ea7d 437fb7f 17757ae ce0ea7d 17757ae c6dbfda ce0ea7d c6dbfda 35d002b 991fc8d ce0ea7d c6dbfda 1556875 ce0ea7d c6dbfda d367528 2d27f56 d367528 437fb7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import gdown
from git import Repo
import gradio as gr
from huggingface_hub import snapshot_download
import os
import penman
import sys
import time
import torch
from transformers import pipeline
if not os.path.exists("amr-tst-indo"):
Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo")
sys.path.append("./amr-tst-indo")
from text_to_amr import TextToAMR
from style_detector import StyleDetector
from style_rewriting import StyleRewriting
from amr_to_text import AMRToTextWithTaufiqMethod
amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan"
snapshot_download(
repo_id=f"abdiharyadi/{amr_parsing_model_name}",
local_dir=f"./amr-tst-indo/AMRBART-id/models/{amr_parsing_model_name}",
ignore_patterns=[
"*log*",
"*checkpoint*",
]
)
t2a = TextToAMR(model_name=amr_parsing_model_name)
gdown.download(
"https://drive.google.com/uc?id=1J_6PbYsQ6Kl4Qfs1wBVwd52_r9uTpIxx",
"./model-best.pt"
)
sd = StyleDetector(
config_path="./amr-tst-indo/indonesian-aste-generative/resources/exp-v2/exp-m0.yaml",
model_path="./model-best.pt"
)
device_type = "cuda" if torch.cuda.is_available() else "cpu"
clf_pipeline = pipeline(
"text-classification",
model="abdiharyadi/roberta-base-indonesian-522M-with-sa-william-dataset",
device=device_type
)
gdown.download(
"https://drive.google.com/uc?id=15KctCcsHgTFMUh_tWNBNUiCyX56fq6p-",
"./fasttext_skipgram_indo.bin"
)
sr = StyleRewriting(
clf_pipeline=clf_pipeline,
fasttext_model_path="./fasttext_skipgram_indo.bin",
position_aware_concatenation=False,
reset_sense_strategy=False,
max_score_strategy=True,
maximize_style_words_expansion=False
)
amr_gen_model_name = "taufiq-indo-amr-generation-gold-uncased"
model_path = f"./{amr_gen_model_name}"
snapshot_download(
repo_id=f"abdiharyadi/{amr_gen_model_name}",
local_dir=model_path,
allow_patterns=[
"*checkpoint-3*"
]
)
a2t = AMRToTextWithTaufiqMethod(
model_path=os.path.join(model_path, "checkpoint-3"),
lowercase=True,
)
def run(text, source_style):
yield (
"(Memproses ...)",
"(Menunggu ...)",
"(Menunggu ...)",
"(Menunggu ...)",
"(Menunggu ...)",
)
start_time = time.time()
source_amr, *_ = t2a([text])
source_amr.metadata = {}
source_amr_display = penman.encode(source_amr)
source_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
"(Memproses ...)",
"(Menunggu ...)",
"(Menunggu ...)",
"(Menunggu ...)",
)
triplets = sd.get_triplets(text)
triplets_display = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets)
triplets_display += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
triplets_display,
"(Memproses ...)",
"(Menunggu ...)",
"(Menunggu ...)",
)
style_words = sd.get_style_words_from_triplets(triplets)
style_words_display = ", ".join(style_words)
style_words_display += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
triplets_display,
style_words_display,
"(Memproses ...)",
"(Menunggu ...)",
)
target_amr = sr(text, source_amr, source_style, style_words)
target_amr_display = penman.encode(target_amr)
target_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
triplets_display,
style_words_display,
target_amr_display,
"(Memproses ...)",
)
result, *_ = a2t([target_amr])
result += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
triplets_display,
style_words_display,
target_amr_display,
result
)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Interface(
fn=run,
inputs=[
gr.Textbox(label="Teks (Text)"),
gr.Radio(label="Gaya sumber (Source style)", choices=[
("Positif (Positive)", "LABEL_1"),
("Negatif (Negative)", "LABEL_0"),
], value="LABEL_1"),
],
outputs=[
gr.Textbox(label="Graf AMR sumber (Source AMR graph)"),
gr.Textbox(label="Triplet (Triplets)"),
gr.Textbox(label="Kata bergaya (Style words)"),
gr.Textbox(label="Graf AMR target (Target AMR graph)"),
gr.Textbox(label="Hasil (Result)"),
]
)
with gr.Column():
gr.Markdown("""
# Pengakuan
Demo ini disiapkan untuk Program Penelitian dan Pengabdian Masyarakat STEI ITB 2024.
**Tim Peneliti**:
- Masayu Leylia Khodra (masayu@staff.stei.itb.ac.id)
- M. Abdi Haryadi. H (abdiharyadi.ah@gmail.com)
""")
demo.launch()
|