Spaces:
Sleeping
Sleeping
import os | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
import gradio as gr | |
from transformers import pipeline | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from Ashaar.utils import get_output_df, get_highlighted_patterns_html | |
from Ashaar.bait_analysis import BaitAnalysis | |
from langs import * | |
import sys | |
import json | |
import argparse | |
arg_parser = argparse.ArgumentParser() | |
arg_parser.add_argument('--lang', type = str, default = 'ar') | |
args = arg_parser.parse_args() | |
lang = args.lang | |
if lang == 'ar': | |
TITLE = TITLE_ar | |
DESCRIPTION = DESCRIPTION_ar | |
textbox_trg_text = textbox_trg_text_ar | |
textbox_inp_text = textbox_inp_text_ar | |
btn_trg_text = btn_trg_text_ar | |
btn_inp_text = btn_inp_text_ar | |
css = """ #textbox{ direction: RTL;}""" | |
else: | |
TITLE = TITLE_en | |
DESCRIPTION = DESCRIPTION_en | |
textbox_trg_text = textbox_trg_text_en | |
textbox_inp_text = textbox_inp_text_en | |
btn_trg_text = btn_trg_text_en | |
btn_inp_text = btn_inp_text_en | |
css = "" | |
gpt_tokenizer = AutoTokenizer.from_pretrained('arbml/ashaar_tokenizer') | |
model = AutoModelForCausalLM.from_pretrained('arbml/Ashaar_model') | |
theme_to_token = json.load(open("extra/theme_tokens.json", "r")) | |
token_to_theme = {t:m for m,t in theme_to_token.items()} | |
meter_to_token = json.load(open("extra/meter_tokens.json", "r")) | |
token_to_meter = {t:m for m,t in meter_to_token.items()} | |
analysis = BaitAnalysis() | |
meter, theme, qafiyah = "", "", "" | |
def analyze(poem): | |
global meter,theme,qafiyah, generate_btn | |
shatrs = poem.split("\n") | |
baits = [' # '.join(shatrs[2*i:2*i+2]) for i in range(len(shatrs)//2)] | |
output = analysis.analyze(baits,override_tashkeel=True) | |
meter = output['meter'] | |
qafiyah = output['qafiyah'][0] | |
theme = output['theme'][-1] | |
df = get_output_df(output) | |
return get_highlighted_patterns_html(df), gr.Button.update(interactive=True) | |
def generate(inputs, top_p = 3): | |
baits = inputs.split('\n') | |
if len(baits) % 2 !=0: | |
baits = baits[:-1] | |
poem = ' '.join(['<|bsep|> '+baits[i]+' <|vsep|> '+baits[i+1]+' </|bsep|>' for i in range(0, len(baits), 2)]) | |
prompt = f""" | |
{meter_to_token[meter]} {qafiyah} {theme_to_token[theme]} | |
<|psep|> | |
{poem} | |
""".strip() | |
print(prompt) | |
encoded_input = gpt_tokenizer(prompt, return_tensors='pt') | |
output = model.generate(**encoded_input, max_length = 512, top_p = 3, do_sample=True) | |
result = "" | |
prev_token = "" | |
line_cnts = 0 | |
for i, beam in enumerate(output[:, len(encoded_input.input_ids[0]):]): | |
if line_cnts >= 10: | |
break | |
for token in beam: | |
if line_cnts >= 10: | |
break | |
decoded = gpt_tokenizer.decode(token) | |
if 'meter' in decoded or 'theme' in decoded: | |
break | |
if decoded in ["<|vsep|>", "</|bsep|>"]: | |
result += "\n" | |
line_cnts+=1 | |
elif decoded in ['<|bsep|>', '<|psep|>', '</|psep|>']: | |
pass | |
else: | |
result += decoded | |
prev_token = decoded | |
else: | |
break | |
# return theme+" "+ f"ู ู ุจุญุฑ {meter} ู ุน ูุงููุฉ ุจุญุฑ ({qafiyah})" + "\n" +result | |
return result, gr.Button.update(interactive=False) | |
examples = [ | |
[ | |
"""ุงูููุจ ุฃุนูู ูุง ุนุฐูู ุจุฏุงุฆู | |
ูุฃุญู ู ูู ุจุฌููู ูุจู ุงุฆู""" | |
], | |
[ | |
"""ุฑู ุชู ุงููุคุงุฏู ู ููุญุฉ ุนุฐุฑุงุกู | |
ุจุณูุงู ู ูุญุธู ู ุง ููููู ุฏูุงุกู""" | |
], | |
[ | |
"""ุฃุฐูููู ุงูุญูุฑูุตู ูุงูุทููู ูุนู ุงูุฑูููุงุจูุง | |
ูููุฏ ููุนูู ุงูููุฑูู ูุ ุฅุฐุง ุงุณุชูุฑูุงุจูุง""" | |
] | |
] | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML(TITLE) | |
gr.HTML(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
textbox_output = gr.Textbox(lines=10, label=textbox_trg_text, elem_id="textbox") | |
with gr.Column(): | |
inputs = gr.Textbox(lines=10, label=textbox_inp_text, elem_id="textbox") | |
with gr.Row(): | |
with gr.Column(): | |
if lang == 'ar': | |
trg_btn = gr.Button(btn_trg_text, interactive=False) | |
else: | |
trg_btn = gr.Button(btn_trg_text) | |
with gr.Column(): | |
if lang == 'ar': | |
inp_btn = gr.Button(btn_inp_text) | |
else: | |
inp_btn = gr.Button(btn_inp_text, interactive = False) | |
with gr.Row(): | |
html_output = gr.HTML() | |
if lang == 'en': | |
gr.Examples(examples, textbox_output) | |
inp_btn.click(generate, inputs = textbox_output, outputs=[inputs, inp_btn]) | |
trg_btn.click(analyze, inputs = textbox_output, outputs=[html_output,inp_btn]) | |
else: | |
gr.Examples(examples, inputs) | |
trg_btn.click(generate, inputs = inputs, outputs=[textbox_output, trg_btn]) | |
inp_btn.click(analyze, inputs = inputs, outputs=[html_output,trg_btn] ) | |
# demo.launch(server_name = '0.0.0.0', share=True) | |
demo.launch() |