|
import torch |
|
import gradio as gr |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
tokenizer = T5Tokenizer.from_pretrained("VietAI/vit5-base") |
|
model = T5ForConditionalGeneration.from_pretrained("Libosa2707/vietnamese-poem-t5") |
|
|
|
|
|
def generate_poem(input_text): |
|
|
|
input_text = input_text.strip() |
|
input_text = input_text.lower() |
|
|
|
|
|
min_length = 50 |
|
max_length = 512 |
|
rep_penalty = 1.2 |
|
temp = 0.7 |
|
top_k = 50 |
|
top_p = 0.92 |
|
no_repeat_ngram_size = 2 |
|
|
|
|
|
input_ids = tokenizer( |
|
input_text, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=42, |
|
).input_ids.to(model.device) |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
output = model.generate( |
|
do_sample=True, |
|
input_ids=input_ids, |
|
min_length=min_length, |
|
max_length=max_length, |
|
top_p=top_p, |
|
top_k=top_k, |
|
temperature=temp, |
|
repetition_penalty=rep_penalty, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
num_return_sequences=1, |
|
) |
|
|
|
|
|
gen = tokenizer.decode( |
|
output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False |
|
) |
|
sentences = gen.split("<unk>") |
|
gen_poem = "\n".join(sentences).replace("<pad>", "").replace("</s>", "") |
|
gen_poem = gen_poem.strip() |
|
|
|
|
|
pretty_text = "" |
|
for line in gen_poem.split("\n"): |
|
line = line.strip() |
|
if not line: |
|
continue |
|
line = line[0].upper() + line[1:] |
|
pretty_text += line + "\n" |
|
|
|
|
|
return pretty_text |
|
|
|
|
|
generate_poem_interface = gr.Interface( |
|
title="Làm thơ theo yêu cầu", |
|
fn=generate_poem, |
|
inputs=[ |
|
gr.components.Textbox( |
|
lines=1, |
|
placeholder="Làm thơ với thể thơ tám chữ và tiêu đề mùa xuân nho nhỏ", |
|
label="Yêu cầu về thể thơ và tiêu đề", |
|
), |
|
], |
|
outputs="text", |
|
) |
|
|