p1atdev commited on
Commit
617aba5
1 Parent(s): f9f5197

feat: create demo app

Browse files
Files changed (5) hide show
  1. .gitmodules +3 -0
  2. README.md +3 -3
  3. RetNet +1 -0
  4. app.py +157 -0
  5. requirements.txt +1 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "RetNet"]
2
+ path = RetNet
3
+ url = https://github.com/syncdoth/RetNet.git
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
  title: LightNovel Intro RetNet 400M Demo
3
- emoji: 🌍
4
  colorFrom: gray
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.47.1
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  ---
12
 
 
1
  ---
2
  title: LightNovel Intro RetNet 400M Demo
3
+ emoji: 🖋️
4
  colorFrom: gray
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.47.1
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
  ---
12
 
RetNet ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit a4253a7bc16519459320c140a7e3d14b5f017b32
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from RetNet.retnet.modeling_retnet import RetNetForCausalLM
4
+ from transformers import AutoTokenizer
5
+
6
+ import gradio as gr
7
+
8
+ MODEL_NAME = "p1atdev/LightNovel-Intro-RetNet-400M"
9
+
10
+ DEFAULT_INPUT_TEXT = "目が覚めると、"
11
+
12
+ EXAMPLE_INPUTS = [
13
+ DEFAULT_INPUT_TEXT,
14
+ "冒険者ギルドには",
15
+ "真っ白い部屋の中、そこには",
16
+ "20XX年、",
17
+ "「なんだって!?」",
18
+ "どうやらトラックにはねられ、俺は",
19
+ ]
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
22
+ model = RetNetForCausalLM.from_pretrained(MODEL_NAME)
23
+ model.eval()
24
+
25
+
26
+ @torch.no_grad()
27
+ def generate(
28
+ input_text,
29
+ max_new_tokens=128,
30
+ do_sample=True,
31
+ temperature=1.0,
32
+ top_p=0.95,
33
+ top_k=20,
34
+ ):
35
+ if input_text.strip() == "":
36
+ return ""
37
+
38
+ inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
39
+ generated = model.custom_generate(
40
+ **inputs,
41
+ parallel_compute_prompt=True,
42
+ max_new_tokens=max_new_tokens,
43
+ do_sample=do_sample,
44
+ temperature=temperature,
45
+ top_p=top_p,
46
+ top_k=top_k,
47
+ )
48
+ return tokenizer.batch_decode(generated)[0]
49
+
50
+
51
+ def continue_generate(
52
+ input_text,
53
+ *args,
54
+ ):
55
+ return input_text, generate(input_text, *args)
56
+
57
+
58
+ with gr.Blocks() as demo:
59
+ gr.Markdown(
60
+ """\
61
+ # LightNovel-Intro-RetNet-400M-Demo
62
+
63
+ ライトノベルの冒頭だけを学習した 400M パラメータの RetNet モデルのデモです。
64
+
65
+ ### 参考:
66
+
67
+ - https://github.com/syncdoth/RetNet
68
+ """
69
+ )
70
+
71
+ input_text = gr.Textbox(
72
+ value=DEFAULT_INPUT_TEXT,
73
+ placeholder="私の名前は...",
74
+ lines=2,
75
+ )
76
+ output_text = gr.Textbox(
77
+ value="",
78
+ placeholder="ここに出力が表示されます...",
79
+ lines=8,
80
+ interactive=False,
81
+ )
82
+
83
+ with gr.Row():
84
+ generate_btn = gr.Button("Generate ✒️", variant="primary")
85
+ continue_btn = gr.Button("Continue ➡️", variant="secondary")
86
+ clear_btn = gr.ClearButton(
87
+ value="Clear 🧹",
88
+ components=[input_text, output_text],
89
+ )
90
+
91
+ with gr.Accordion("Advanced settings", open=False):
92
+ max_tokens = gr.Slider(
93
+ label="Max tokens",
94
+ minimum=8,
95
+ maximum=512,
96
+ value=128,
97
+ step=4,
98
+ )
99
+ do_sample = gr.Checkbox(
100
+ label="Do sample",
101
+ value=True,
102
+ )
103
+ temperature = gr.Slider(
104
+ label="Temperature",
105
+ minimum=0,
106
+ maximum=2,
107
+ value=1,
108
+ step=0.05,
109
+ )
110
+ top_p = gr.Slider(
111
+ label="Top p",
112
+ minimum=0,
113
+ maximum=1,
114
+ value=0.95,
115
+ step=0.05,
116
+ )
117
+ top_k = gr.Slider(
118
+ label="Top k",
119
+ minimum=0,
120
+ maximum=100,
121
+ value=20,
122
+ step=1,
123
+ )
124
+
125
+ gr.Examples(
126
+ examples=EXAMPLE_INPUTS,
127
+ inputs=input_text,
128
+ )
129
+
130
+ generate_btn.click(
131
+ fn=generate,
132
+ inputs=[
133
+ input_text,
134
+ max_tokens,
135
+ do_sample,
136
+ temperature,
137
+ top_p,
138
+ top_k,
139
+ ],
140
+ outputs=output_text,
141
+ queue=False,
142
+ )
143
+ continue_btn.click(
144
+ fn=continue_generate,
145
+ inputs=[
146
+ input_text,
147
+ max_tokens,
148
+ do_sample,
149
+ temperature,
150
+ top_p,
151
+ top_k,
152
+ ],
153
+ outputs=[input_text, output_text],
154
+ queue=False,
155
+ )
156
+
157
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ transformers==4.34.0