import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import gradio as gr from transformers import pipeline from transformers import AutoTokenizer, AutoModelForCausalLM from Ashaar.utils import get_output_df, get_highlighted_patterns_html from Ashaar.bait_analysis import BaitAnalysis from langs import * import sys import json import argparse arg_parser = argparse.ArgumentParser() arg_parser.add_argument('--lang', type = str, default = 'ar') args = arg_parser.parse_args() lang = args.lang if lang == 'ar': TITLE = TITLE_ar DESCRIPTION = DESCRIPTION_ar textbox_trg_text = textbox_trg_text_ar textbox_inp_text = textbox_inp_text_ar btn_trg_text = btn_trg_text_ar btn_inp_text = btn_inp_text_ar css = """ #textbox{ direction: RTL;}""" else: TITLE = TITLE_en DESCRIPTION = DESCRIPTION_en textbox_trg_text = textbox_trg_text_en textbox_inp_text = textbox_inp_text_en btn_trg_text = btn_trg_text_en btn_inp_text = btn_inp_text_en css = "" gpt_tokenizer = AutoTokenizer.from_pretrained('arbml/ashaar_tokenizer') model = AutoModelForCausalLM.from_pretrained('arbml/Ashaar_model') theme_to_token = json.load(open("extra/theme_tokens.json", "r")) token_to_theme = {t:m for m,t in theme_to_token.items()} meter_to_token = json.load(open("extra/meter_tokens.json", "r")) token_to_meter = {t:m for m,t in meter_to_token.items()} analysis = BaitAnalysis() meter, theme, qafiyah = "", "", "" def analyze(poem): global meter,theme,qafiyah, generate_btn shatrs = poem.split("\n") baits = [' # '.join(shatrs[2*i:2*i+2]) for i in range(len(shatrs)//2)] output = analysis.analyze(baits,override_tashkeel=True) meter = output['meter'] qafiyah = output['qafiyah'][0] theme = output['theme'][-1] df = get_output_df(output) return get_highlighted_patterns_html(df), gr.Button.update(interactive=True) def generate(inputs, top_p = 3): baits = inputs.split('\n') if len(baits) % 2 !=0: baits = baits[:-1] poem = ' '.join(['<|bsep|> '+baits[i]+' <|vsep|> '+baits[i+1]+' ' for i in range(0, len(baits), 2)]) prompt = f""" {meter_to_token[meter]} {qafiyah} {theme_to_token[theme]} <|psep|> {poem} """.strip() print(prompt) encoded_input = gpt_tokenizer(prompt, return_tensors='pt') output = model.generate(**encoded_input, max_length = 512, top_p = 3, do_sample=True) result = "" prev_token = "" line_cnts = 0 for i, beam in enumerate(output[:, len(encoded_input.input_ids[0]):]): if line_cnts >= 10: break for token in beam: if line_cnts >= 10: break decoded = gpt_tokenizer.decode(token) if 'meter' in decoded or 'theme' in decoded: break if decoded in ["<|vsep|>", ""]: result += "\n" line_cnts+=1 elif decoded in ['<|bsep|>', '<|psep|>', '']: pass else: result += decoded prev_token = decoded else: break # return theme+" "+ f"من بحر {meter} مع قافية بحر ({qafiyah})" + "\n" +result return result, gr.Button.update(interactive=False) examples = [ [ """القلب أعلم يا عذول بدائه وأحق منك بجفنه وبمائه""" ], [ """رمتِ الفؤادَ مليحة عذراءُ بسهامِ لحظٍ ما لهنَّ دواءُ""" ], [ """أذَلَّ الحِرْصُ والطَّمَعُ الرِّقابَا وقَد يَعفو الكَريمُ، إذا استَرَابَا""" ] ] with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: with gr.Row(): with gr.Column(): gr.HTML(TITLE) gr.HTML(DESCRIPTION) with gr.Row(): with gr.Column(): textbox_output = gr.Textbox(lines=10, label=textbox_trg_text, elem_id="textbox") with gr.Column(): inputs = gr.Textbox(lines=10, label=textbox_inp_text, elem_id="textbox") with gr.Row(): with gr.Column(): if lang == 'ar': trg_btn = gr.Button(btn_trg_text, interactive=False) else: trg_btn = gr.Button(btn_trg_text) with gr.Column(): if lang == 'ar': inp_btn = gr.Button(btn_inp_text) else: inp_btn = gr.Button(btn_inp_text, interactive = False) with gr.Row(): html_output = gr.HTML() if lang == 'en': gr.Examples(examples, textbox_output) inp_btn.click(generate, inputs = textbox_output, outputs=[inputs, inp_btn]) trg_btn.click(analyze, inputs = textbox_output, outputs=[html_output,inp_btn]) else: gr.Examples(examples, inputs) trg_btn.click(generate, inputs = inputs, outputs=[textbox_output, trg_btn]) inp_btn.click(analyze, inputs = inputs, outputs=[html_output,trg_btn] ) # demo.launch(server_name = '0.0.0.0', share=True) demo.launch()