U3reb / app.py
moThecarpenter80's picture
Update app.py
429ace2 verified
import streamlit as st
from transformers import MarianTokenizer, MarianMTModel , BertTokenizer, AutoModelForSeq2SeqLM, pipeline
from ar_corrector.corrector import Corrector
import mishkal.tashkeel
from arabert.preprocess import ArabertPreprocessor
# Initialize Mishkal vocalizer
vocalizer = mishkal.tashkeel.TashkeelClass()
# Initialize Marian tokenizer and model for translation
mname = "marefa-nlp/marefa-mt-en-ar"
tokenizer = MarianTokenizer.from_pretrained(mname)
model = MarianMTModel.from_pretrained(mname)
# Initialize BERT tokenizer and model for summarization
model_name = "malmarjeh/mbert2mbert-arabic-text-summarization"
preprocessor = ArabertPreprocessor(model_name="")
tokenizer_summarization = BertTokenizer.from_pretrained(model_name)
model_summarization = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipeline_summarization = pipeline("text2text-generation", model=model_summarization, tokenizer=tokenizer_summarization)
corr = Corrector()
def main():
st.title("U3reb Demo")
# Text Input
input_text = st.text_area("Enter Arabic Text:")
# Tokenization
st.subheader("Tokenization (Mishkal)")
if input_text:
text_mishkal = vocalizer.tashkeel(input_text)
st.write("Tokenized Text (with diacritics):", text_mishkal)
# Translation
st.subheader("Translation")
if input_text:
translated_tokens = model.generate(**tokenizer.prepare_seq2seq_batch([input_text], return_tensors="pt"))
translated_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated_tokens]
st.write("Translated Text:", translated_text)
# Arabic Text Correction
st.subheader("Arabic Text Correction (ar_correct)")
if input_text:
corrected_text = corr.contextual_correct(input_text)
st.write("Corrected Text:", corrected_text)
# Text Summarization
st.subheader("Text Summarization")
if input_text:
preprocessed_text = preprocessor.preprocess(input_text)
result = pipeline_summarization(preprocessed_text,
pad_token_id=tokenizer_summarization.eos_token_id,
num_beams=3,
repetition_penalty=3.0,
max_length=200,
length_penalty=1.0,
no_repeat_ngram_size=3)[0]['generated_text']
st.write("Summarized Text:", result)
if __name__ == "__main__":
main()