Wendyy commited on
Commit
624b44d
1 Parent(s): 2602495

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py CHANGED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
2
+ from peft import PeftModel
3
+ import torch
4
+ import gradio as gr
5
+ import os
6
+ import re
7
+
8
+ class ChineseCharacterStop(StoppingCriteria):
9
+ def __init__(self, chars: list[str]):
10
+ self.chars = [
11
+ tokenizer(i, add_special_tokens=False, return_tensors='pt').input_ids
12
+ for i in chars
13
+ ]
14
+ # for chars, tokens in zip(chars, self.chars):
15
+ # print(f"'{chars}':{tokens}")
16
+
17
+ def __call__(self, input_ids: torch.LongTensor,
18
+ scores: torch.FloatTensor, **kwargs) -> bool:
19
+ for c in self.chars:
20
+ c = c.to(input_ids.device)
21
+ match = torch.eq(input_ids[..., -c.shape[1]:], c)
22
+ if torch.any(torch.all(match, dim=1)):
23
+ return True
24
+ return False
25
+
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Wenzhong-GPT2-110M")
28
+ tokenizer.pad_token = tokenizer.eos_token
29
+ gpt2_model = AutoModelForCausalLM.from_pretrained("IDEA-CCNL/Wenzhong-GPT2-110M")
30
+ model = PeftModel.from_pretrained(gpt2_model, 'checkpoint_lora_v4.1')
31
+
32
+
33
+ def cang_tou(tou: str):
34
+ poem_now = "写一首唐诗:"
35
+ for c in tou:
36
+ poem_now += c
37
+ print(poem_now)
38
+ inputs = tokenizer(poem_now, return_tensors='pt')
39
+ outputs = model.generate(
40
+ **inputs,
41
+ return_dict_in_generate=True,
42
+ max_length=150,
43
+ do_sample=True,
44
+ top_p=0.4,
45
+ num_beams=1,
46
+ num_return_sequences=1,
47
+ stopping_criteria=[ChineseCharacterStop(['。', ','])],
48
+ pad_token_id=tokenizer.pad_token_id
49
+ )
50
+ poem_now = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
51
+ print(poem_now)
52
+ return poem_now[6:]
53
+
54
+
55
+ def prompt_gen(prompt):
56
+ inputs = tokenizer(prompt, return_tensors='pt')
57
+ outputs = model.generate(
58
+ **inputs,
59
+ return_dict_in_generate=True,
60
+ max_length=200,
61
+ do_sample=True,
62
+ top_p=0.8,
63
+ num_beams=5,
64
+ num_return_sequences=3,
65
+ # stopping_criteria=[ChineseCharacterStop(['。', ',', ''])],
66
+ pad_token_id=tokenizer.pad_token_id
67
+ )
68
+ res = ''
69
+ for line in tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True):
70
+ line = line[len(prompt):]
71
+ res = res+line+'\n'
72
+ return res
73
+
74
+ css = """
75
+ #col-container {max-width: 510px; margin-left: auto; margin-right: auto;}
76
+ a {text-decoration-line: underline; font-weight: 600;}
77
+ .animate-spin {
78
+ animation: spin 1s linear infinite;
79
+ }
80
+ """
81
+
82
+ with gr.Blocks(css=css) as demo:
83
+ with gr.Column(elem_id="col-container"):
84
+ gr.Markdown(
85
+ """
86
+ <h1 style="text-align: center;">✨古诗生成</h1>
87
+ <p style="text-align: center;">
88
+ 根据输入的提示生成古诗、藏头诗<br />
89
+ </p>
90
+ """
91
+ )
92
+ with gr.Tab("提示"):
93
+ prompt_in = gr.Textbox(label="Prompt", placeholder="写一首关于思乡的古诗:", elem_id="prompt-in")
94
+ #neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in")
95
+ #inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False)
96
+ submit_btn = gr.Button("Submit")
97
+ poetry_result = gr.Textbox(label="Output", elem_id="poetry-output")
98
+
99
+ submit_btn.click(fn=prompt_gen,
100
+ inputs=[prompt_in],
101
+ outputs=[poetry_result])
102
+
103
+ with gr.Tab("藏头诗"):
104
+ tou_in = gr.Textbox(label="Prompt", placeholder="一见如故", elem_id="tou-in")
105
+ #neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in")
106
+ #inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False)
107
+ submit_btn = gr.Button("Submit")
108
+ cangtou_result = gr.Textbox(label="Output", elem_id="cangtou-output")
109
+ submit_btn.click(fn=cang_tou,
110
+ inputs=[tou_in],
111
+ outputs=[cangtou_result])
112
+
113
+
114
+
115
+ demo.queue(max_size=12).launch()