|
import gradio as gr |
|
import torch |
|
from transformers import ( |
|
MT5ForConditionalGeneration, |
|
MT5TokenizerFast, |
|
pipeline, |
|
) |
|
|
|
MODEL_ID = "tacab/mt5-beero_somali" |
|
|
|
|
|
tokenizer = MT5TokenizerFast.from_pretrained(MODEL_ID) |
|
model = MT5ForConditionalGeneration.from_pretrained(MODEL_ID) |
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
qa = pipeline( |
|
"text2text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device=device, |
|
max_new_tokens=200, |
|
min_length=100, |
|
num_beams=4, |
|
length_penalty=0.7, |
|
no_repeat_ngram_size=3, |
|
early_stopping=False, |
|
) |
|
|
|
def answer(question: str) -> str: |
|
prompt = f"Su'aal: {question}" |
|
out = qa(prompt) |
|
return out[0]["generated_text"] |
|
|
|
demo = gr.Interface( |
|
fn=answer, |
|
inputs=gr.Textbox(lines=2, placeholder="Gali su'aashaada...", label="Su'aal"), |
|
outputs=gr.Textbox(label="Jawaab"), |
|
title="Beero Somali Q&A", |
|
description="Su'aal–Jawaab module ku saleysan tacab/mt5-beero_somali", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|