kancilgpt / app.py
abdiharyadi's picture
feat: remove max length limit
5a389e5
import gradio as gr
from transformers import GPT2LMHeadModel
from indobenchmark import IndoNLGTokenizer
gpt_tokenizer = IndoNLGTokenizer.from_pretrained("indobenchmark/indogpt")
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
kancilgpt = GPT2LMHeadModel.from_pretrained("abdiharyadi/kancilgpt")
def generate_story():
stop = False
prompt = "<s> awal cerita | judul:"
judul = ""
isi = ""
end_part = ""
isi_not_checked = True
yield "..."
while not stop:
prompt_stop = False
while not prompt_stop:
gpt_input = gpt_tokenizer(prompt, return_tensors='pt')
gpt_out = kancilgpt.generate(
**gpt_input,
do_sample=True,
max_new_tokens=2,
pad_token_id=gpt_tokenizer.eos_token_id,
eos_token_id=gpt_tokenizer.eos_token_id
)
gpt_out = gpt_out[0]
result = gpt_tokenizer.decode(gpt_out)
splitted_result = result.split(" | ")
if len(splitted_result) <= 2:
_, judul_prompt = splitted_result
_, *judul_words = judul_prompt.split()
judul = " ".join(judul_words)
yield judul + "..."
if "." in judul:
print("Invalid judul!")
prompt = "<s> awal cerita | judul:"
continue
isi = ""
end_part = ""
if gpt_out[-1] == gpt_tokenizer.eos_token_id:
continue
else:
_, judul_prompt, isi, *end_part = splitted_result
end_part = "".join(end_part)
_, *judul_words = judul_prompt.split()
judul = " ".join(judul_words)
yield judul + "\n" + ("-" * len(judul)) + "\n" + isi + f"..."
if len(splitted_result) == 3:
if gpt_out[-1] == gpt_tokenizer.eos_token_id:
continue
elif isi_not_checked:
quote_count = 0
prev_i = 0
for i, c in enumerate(isi):
if c == "\"":
quote_count += 1
prev_i = i
if quote_count % 2 != 0:
print("Invalid isi!")
trimmed_isi = isi[:prev_i].rstrip()
prompt = f"<s> awal cerita | judul: {judul} | {trimmed_isi}"
continue
isi_not_checked = False
if gpt_out[-1] == gpt_tokenizer.eos_token_id:
prompt_stop = True
else:
prompt = result
# prompt_stop
if (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])):
print("Invalid ending! Regenerating ....")
prompt = f"<s> awal cerita | judul: {judul} | {isi} |"
continue
stop = True
total_isi = isi
print("We skip the rest of the part for debug.")
# TODO: Solve this.
# ellipsis = "..."
# while not end_part.startswith("tamat"):
# yield judul + "\n" + ("-" * len(judul)) + "\n" + total_isi + f" {ellipsis}"
# ellipsis += "."
# i = 0
# in_quote = False
# end_sentence = False
# limit = 1750
# while i < len(isi) and not (end_sentence and (not in_quote) and isi[i] == " " and (len(isi) - i) < limit):
# if isi[i] == "\"":
# in_quote = not in_quote
# if end_sentence:
# end_sentence = isi[i] not in "abcdefghijklmnopqrstuvwxyz"
# else:
# end_sentence = isi[i] in ".?!"
# i += 1
# # i == len(isi) or end_sentence or (not in_quote) or isi[i] == " "
# while i < len(isi) and not (isi[i] in "abcdefghijklmnopqrstuvwxyz\""):
# i += 1
# # i == len(isi) or isi[i] in "abcdefghijklmnopqrstuvwxyz\""
# if i == len(isi):
# raise ValueError("What???")
# next_isi = isi[i:]
# stop = False
# while not stop:
# gpt_input = gpt_tokenizer(f'<s> pertengahan cerita | judul: {judul} | {next_isi}', return_tensors='pt')
# gpt_out = kancilgpt.generate(**gpt_input, do_sample=True, max_length=512, pad_token_id=gpt_tokenizer.eos_token_id)
# result = gpt_tokenizer.decode(gpt_out[0])
# _, judul_prompt, isi, *end_part = result.split(" | ")
# end_part = "".join(end_part)
# _, *judul_words = judul_prompt.split()
# judul = " ".join(judul_words)
# if isi[len(next_isi) + 1:].strip() != "":
# print(isi[len(next_isi) + 1:])
# if "</s>" in isi or "|" in isi or (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])):
# print("Invalid output! Regenerating ....")
# continue
# quote_count = 0
# for c in isi:
# if c == "\"":
# quote_count += 1
# if quote_count % 2 != 0:
# print("Invalid output! Regenerating ....")
# continue
# stop = True
# total_isi += " " + isi[len(next_isi) + 1:]
# ellipsis = "..."
yield judul + "\n" + ("-" * len(judul)) + "\n" + total_isi + "\n\ntamat."
demo = gr.Interface(
fn=generate_story,
inputs=None,
outputs=[
gr.Textbox(label="cerita", lines=7)
]
)
demo.launch()