AraPoet / app.py
Badr AlKhamissi
updated defaults
c476a27
# coding=utf8
import json
import torch
import gradio as gr
import pyarabic.araby as araby
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
feature_names = [
"Title",
"Meter",
"Theme",
"Name",
"Era",
"Country",
"Type"
]
with open("./poet_names.json", 'r', encoding="utf-8") as fin:
poet_names = json.load(fin)
def normalize_text(text):
text = araby.strip_tatweel(text)
return text
def generate_poem(country, era, meter, theme, lang_type, poet, num_lines, num_poems, title):
num_poems = int(num_poems)
prompt = title
prompt = normalize_text(prompt)
features = [prompt, meter, theme, poet, era, country, lang_type]
prompt = ""
for name, feat in zip(feature_names, features):
prompt += f"{name}: {feat}; "
prompt += f"Length: {num_lines}; Poem:"
num_beams = 5
top_k = 50
top_p = 0.9
r_penalty = 5.
input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
print(f"> Running: {prompt} | {num_poems} Poems")
outputs = model.generate(input_ids=input_ids,
min_length=32,
max_length=256,
do_sample=True,
top_k=top_k,
top_p=top_p,
repetition_penalty=r_penalty,
num_beams=num_beams,
num_return_sequences=num_poems,
early_stopping=True
)
poems = []
print(f"> # of Outputs: {len(outputs)}")
for output in outputs:
raw = tokenizer.decode(output)
raw = raw.replace("<pad>", "").replace("</s>", "")
print("="*100)
print(raw)
print("="*100)
poems += ['\n'.join(raw.split("<s>"))]
return "\n\n".join(poems)
meters = ['البسيط', 'التفعيله', 'الحداء', 'الخفيف', 'الدوبيت', 'الرجز', 'الرمل', 'السريع', 'السلسلة', 'الصخري', 'الطويل', 'الكامل', 'الكان كان', 'اللويحاني', 'المتدارك', 'المتقارب', 'المجتث', 'المديد', 'المسحوب', 'المضارع', 'المقتضب', 'المنسرح', 'المواليا', 'الموشح', 'الهجيني', 'الهزج', 'الوافر', 'بحر أحذ الكامل', 'بحر أحذ المديد', 'بحر أحذ الوافر', 'بحر البسيط', 'بحر التفعيله', 'بحر الخبب', 'بحر الخفيف', 'بحر الدوبيت', 'بحر الرجز', 'بحر الرمل', 'بحر السريع', 'بحر السلسلة', 'بحر الطويل', 'بحر القوما', 'بحر الكامل', 'بحر الكامل المقطوع', 'بحر المتدارك', 'بحر المتدارك المنهوك', 'بحر المتقارب', 'بحر المجتث', 'بحر المديد', 'بحر المضارع', 'بحر المقتضب', 'بحر المنسرح', 'بحر المواليا', 'بحر الهزج', 'بحر الوافر', 'بحر تفعيلة الرجز', 'بحر تفعيلة الرمل', 'بحر تفعيلة الكامل', 'بحر تفعيلة المتقارب', 'بحر مجزوء البسيط', 'بحر مجزوء الخفيف', 'بحر مجزوء الدوبيت', 'بحر مجزوء الرجز', 'بحر مجزوء الرمل', 'بحر مجزوء الرمل ', 'بحر مجزوء السريع', 'بحر مجزوء الطويل', 'بحر مجزوء الكامل', 'بحر مجزوء المتدارك', 'بحر مجزوء المتقارب', 'بحر مجزوء المجتث', 'بحر مجزوء المديد', 'بحر مجزوء المنسرح', 'بحر مجزوء المواليا', 'بحر مجزوء الهزج', 'بحر مجزوء الوافر', 'بحر مجزوء موشح', 'بحر مخلع البسيط', 'بحر مخلع الرجز', 'بحر مخلع الرمل', 'بحر مخلع السريع', 'بحر مخلع الكامل', 'بحر مخلع موشح', 'بحر مربع البسيط', 'بحر مربع الرجز', 'بحر مشطور الرجز', 'بحر مشطور السريع', 'بحر مشطور الطويل', 'بحر منهوك البسيط', 'بحر منهوك الرجز', 'بحر منهوك الكامل', 'بحر منهوك المنسرح', 'بحر موشح', 'بسيط', 'زجل', 'شعر التفعيلة', 'شعر حر', 'عامي', 'عدة أبحر', 'عموديه', 'مجزوء الخفيف', 'نثريه', 'None']
themes = ['قصيدة اعتذار', 'قصيدة الاناشيد', 'قصيدة المعلقات', 'قصيدة حزينه', 'قصيدة دينية', 'قصيدة ذم', 'قصيدة رثاء', 'قصيدة رومنسيه', 'قصيدة سياسية', 'قصيدة شوق', 'قصيدة عامه', 'قصيدة عتاب', 'قصيدة غزل', 'قصيدة فراق', 'قصيدة قصيره', 'قصيدة مدح', 'قصيدة هجاء', 'قصيدة وطنيه', 'None']
language_types = ['شعبي', 'عامي', 'فصحى', 'فصيح', '-', 'None']
poet_era = ['العصر الأموي', 'العصر الأندلسي', 'العصر الأيوبي', 'العصر الإسلامي', 'العصر الجاهلي', 'العصر الحديث', 'العصر العباسي', 'العصر العثماني', 'العصر الفاطمي', 'العصر المملوكي', 'المخضرمين', 'المغرب والأندلس', 'عصر بين الدولتين', 'قبل الإسلام', 'None']
countries = ['الأردن', 'الإمارات', 'البحرين', 'الجزائر', 'السعودية', 'السنغال', 'السودان', 'الصومال', 'العراق', 'الكويت', 'المغرب', 'اليمن', 'تونس', 'سوريا', 'سورية', 'عمان', 'فلسطين', 'قطر', 'لبنان', 'ليبيا', 'مصر', 'موريتانيا', 'None']
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("bkhmsi/arapoet-mt5", use_auth_token="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")
model: AutoModelForSeq2SeqLM = AutoModelForSeq2SeqLM.from_pretrained("bkhmsi/arapoet-mt5", use_auth_token="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")
model.eval()
title = ""
with gr.Blocks(title=title) as demo:
inputs = []
gr.Markdown(
"""
# AraPoet: Controlled Arabic Poetry Generation
The model hosted here is a finetuned version of [mT5-large](https://huggingface.co/google/mt5-large) (∼ 1.2B parameters) on the largest repository of Arabic poems, the [ashaar](https://huggingface.co/datasets/arbml/ashaar) dataset.
The model can be conditioned on a set of attributes to control the style of the generated poem.
Namely: the poet name, country, era, meter, theme, language type, title and the length of the poem.
You can start by clicking on one of the examples below or try your own input.
"""
)
with gr.Row():
inputs += [gr.Dropdown(countries, label="Country", value="مصر")]
inputs += [gr.Dropdown(poet_era, label="Era", value="العصر الحديث")]
with gr.Row():
inputs += [gr.Dropdown(meters, label="Meter", value="بحر السريع")]
inputs += [gr.Dropdown(themes, label="Theme", value="قصيدة رومنسيه")]
with gr.Row():
inputs += [gr.Dropdown(language_types, label="Language Type", value="فصحى")]
inputs += [gr.Dropdown(poet_names, label="Poet", value="أحمد شوقي")]
with gr.Row():
inputs += [gr.Slider(2, 20, value=6, step=1, label="Number of Lines")]
inputs += [gr.Slider(1, 4, value=1, step=1, label="Number of Samples")]
with gr.Row():
inputs += [gr.Textbox(label="Title", value="إثن عنان القلب واسلم به")]
btn = gr.Button("Generate")
examples = gr.Examples(examples="./examples", inputs=inputs)
btn.click(generate_poem, inputs, gr.TextArea(label="Generation"))
gr.Markdown(
"""
Checkout our [AraPoet Preprint](https://github.com/BKHMSI/BKHMSI.github.io/blob/master/archive/resources/AraPoet.pdf) for more details about the model.
"""
)
demo.launch()