|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from itertools import chain |
|
import gradio as gr |
|
import torch |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(device) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall") |
|
model = AutoModelForCausalLM.from_pretrained("uer/gpt2-chinese-cluecorpussmall").to(device) |
|
|
|
def generate_text(prompt,length=500): |
|
inputs = tokenizer(prompt,add_special_tokens=False, return_tensors="pt").to(device) |
|
|
|
txt = tokenizer.decode(model.generate(inputs["input_ids"], |
|
max_length=length, |
|
num_beams=2, |
|
no_repeat_ngram_size=2, |
|
early_stopping=True, |
|
pad_token_id = 0 |
|
)[0]) |
|
|
|
|
|
replacements = { |
|
'[': "", |
|
']': "", |
|
'S': "", |
|
'E': "", |
|
'P': "", |
|
'U': "", |
|
'N': "", |
|
'K': "" |
|
} |
|
|
|
|
|
new_text = ''.join(chain.from_iterable(replacements.get(word, [word]) for word in txt)) |
|
|
|
|
|
return new_text |
|
|
|
with gr.Blocks() as web: |
|
gr.Markdown("<h1><center>Andrew Lim Chinese stories </center></h1>") |
|
gr.Markdown("""<h2><center>让人工智能讲故事:<br><br> |
|
<img src=https://images.unsplash.com/photo-1550450339-e7a4787a2074?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1252&q=80></center></h2>""") |
|
gr.Markdown("""<center>******</center>""") |
|
|
|
|
|
input_text = gr.Textbox(label="故事的开始", lines=6) |
|
buton = gr.Button("Submit ") |
|
output_text = gr.Textbox(lines=6, label="人工智能讲一个故事 :") |
|
buton.click(generate_text, inputs=[input_text], outputs=output_text) |
|
|
|
web.launch() |