Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
class PoetryGenerator: | |
GENRES = ('bốn chữ', 'năm chữ', 'sáu chữ', 'bảy chữ', 'tám chữ', 'lục bát', 'song thất lục bát') | |
def __init__( | |
self, | |
model_name_or_path: str = './checkpoint', | |
max_length: int = 70 | |
): | |
self.max_length = max_length | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path) | |
def generate(self, start_words: str, genre: str, n_poems: int = 1, collate: bool = False): | |
assert genre in self.GENRES, f"Expect genre in {self.GENRES}. Got {genre}." | |
tokenized = self.tokenizer( | |
self.tokenizer.bos_token + | |
genre + | |
self.tokenizer.sep_token + | |
start_words, | |
return_tensors='pt' | |
) | |
generated = [ | |
self.model.generate( | |
**tokenized, | |
do_sample=True, | |
max_length=self.max_length, | |
top_k=4, | |
num_beams=5, | |
no_repeat_ngram_size=2, | |
num_return_sequences=1 | |
)[0] | |
for _ in range(n_poems) | |
] | |
poems = [] | |
for token_ids in generated: | |
decoded = self.tokenizer.decode(token_ids) | |
poem_content = decoded.split(self.tokenizer.sep_token)[1] | |
poem_verses = poem_content.split(self.tokenizer.eos_token)[:4] | |
poem_content = '\n'.join(poem_verses) | |
poems.append(poem_content) | |
# Ugly way to show multiple poems with gradio | |
if collate: | |
for i in range(n_poems): | |
poems[i] = f'BÀI {i + 1}\n' + poems[i] | |
return '\n\n'.join(poems) | |
return poems | |
if __name__ == '__main__': | |
generator = PoetryGenerator() | |
MAX_POEMS = 5 | |
gr.Interface( | |
lambda *args: generator.generate(*args, collate=True), | |
inputs=[ | |
gr.Textbox(label="Start words"), | |
gr.Dropdown(choices=PoetryGenerator.GENRES, label="Genre"), | |
gr.Slider(1, MAX_POEMS, step=1, label="Number of poems") | |
], | |
outputs='text', | |
examples=[ | |
['thân em', 'lục bát', 2], | |
['chiều chiều', 'bảy chữ', 1] | |
] | |
).launch() | |