p1atdev commited on
Commit
f8eb38f
1 Parent(s): 47d556d

feat: gradio interface

Browse files
Files changed (3) hide show
  1. README.md +3 -3
  2. app.py +191 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
  title: SDPrompt RetNet 300M Demo
3
  emoji: ⚡
4
- colorFrom: gray
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.1.2
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  ---
12
 
1
  ---
2
  title: SDPrompt RetNet 300M Demo
3
  emoji: ⚡
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.1.2
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
  ---
12
 
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+
3
+ import torch
4
+
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
6
+
7
+ import gradio as gr
8
+
9
+ MODEL_NAME = "isek-ai/SDPrompt-RetNet-300M"
10
+
11
+ DEFAULT_INPUT_TEXT = "1girl,"
12
+
13
+ EXAMPLE_INPUTS = [DEFAULT_INPUT_TEXT, "oil painting of", "high quality photo of"]
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
17
+ model.eval()
18
+
19
+ # streamer = TextStreamer(
20
+ # tokenizer,
21
+ # skip_prompt=False,
22
+ # skip_special_tokens=True,
23
+ # )
24
+
25
+
26
+ @torch.no_grad()
27
+ def generate(
28
+ input_text,
29
+ max_new_tokens=128,
30
+ do_sample=True,
31
+ temperature=1.0,
32
+ top_p=0.95,
33
+ top_k=20,
34
+ # no_repeat_ngram_size=3,
35
+ repetition_penalty=1.2,
36
+ num_beams=1,
37
+ ):
38
+ if input_text.strip() == "":
39
+ return ""
40
+
41
+ inputs = tokenizer(
42
+ f"<s>{input_text}", return_tensors="pt", add_special_tokens=False
43
+ )["input_ids"]
44
+
45
+ generated = model.generate(
46
+ inputs,
47
+ max_new_tokens=max_new_tokens,
48
+ do_sample=do_sample,
49
+ temperature=temperature,
50
+ top_p=top_p,
51
+ top_k=top_k,
52
+ # no_repeat_ngram_size=no_repeat_ngram_size,
53
+ repetition_penalty=repetition_penalty,
54
+ num_beams=num_beams,
55
+ # streamer=streamer,
56
+ )
57
+
58
+ return tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
59
+
60
+
61
+ def continue_generate(
62
+ input_text,
63
+ *args,
64
+ ):
65
+ return input_text, generate(input_text, *args)
66
+
67
+
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown(
70
+ """\
71
+ # SDPrompt-RetNet-300M-Demo
72
+ A RetNet model trained with Stable Diffusion prompts and Danbooru tags.
73
+
74
+ Model: https://huggingface.co/isek-ai/SDPrompt-RetNet-300M
75
+
76
+ ### Reference:
77
+ - https://github.com/syncdoth/RetNet
78
+ """
79
+ )
80
+
81
+ input_text = gr.Textbox(
82
+ label="Input text",
83
+ value=DEFAULT_INPUT_TEXT,
84
+ placeholder="beautiful photo of ...",
85
+ lines=2,
86
+ )
87
+ output_text = gr.Textbox(
88
+ label="Output text",
89
+ value="",
90
+ placeholder="Output will appear here...",
91
+ lines=8,
92
+ interactive=False,
93
+ )
94
+
95
+ with gr.Row():
96
+ generate_btn = gr.Button("Generate ✒️", variant="primary")
97
+ continue_btn = gr.Button("Continue ➡️", variant="secondary")
98
+ clear_btn = gr.ClearButton(
99
+ value="Clear 🧹",
100
+ components=[input_text, output_text],
101
+ )
102
+
103
+ with gr.Accordion("Advanced settings", open=False):
104
+ max_tokens = gr.Slider(
105
+ label="Max tokens",
106
+ minimum=8,
107
+ maximum=512,
108
+ value=75,
109
+ step=4,
110
+ )
111
+ do_sample = gr.Checkbox(
112
+ label="Do sample",
113
+ value=True,
114
+ )
115
+ temperature = gr.Slider(
116
+ label="Temperature",
117
+ minimum=0,
118
+ maximum=1,
119
+ value=0.9,
120
+ step=0.05,
121
+ )
122
+ top_p = gr.Slider(
123
+ label="Top p",
124
+ minimum=0,
125
+ maximum=1,
126
+ value=0.95,
127
+ step=0.05,
128
+ )
129
+ top_k = gr.Slider(
130
+ label="Top k",
131
+ minimum=0,
132
+ maximum=100,
133
+ value=50,
134
+ step=1,
135
+ )
136
+ repetition_penalty = gr.Slider(
137
+ label="Repetition penalty",
138
+ minimum=0,
139
+ maximum=2,
140
+ value=1,
141
+ step=0.1,
142
+ )
143
+ num_beams = gr.Slider(
144
+ label="Num beams",
145
+ minimum=1,
146
+ maximum=10,
147
+ value=4,
148
+ step=1,
149
+ )
150
+
151
+ gr.Examples(
152
+ examples=EXAMPLE_INPUTS,
153
+ inputs=input_text,
154
+ )
155
+
156
+ generate_btn.click(
157
+ fn=generate,
158
+ inputs=[
159
+ input_text,
160
+ max_tokens,
161
+ do_sample,
162
+ temperature,
163
+ top_p,
164
+ top_k,
165
+ repetition_penalty,
166
+ num_beams,
167
+ ],
168
+ outputs=output_text,
169
+ queue=False,
170
+ )
171
+ continue_btn.click(
172
+ fn=continue_generate,
173
+ inputs=[
174
+ output_text,
175
+ max_tokens,
176
+ do_sample,
177
+ temperature,
178
+ top_p,
179
+ top_k,
180
+ repetition_penalty,
181
+ num_beams,
182
+ ],
183
+ outputs=[input_text, output_text],
184
+ queue=False,
185
+ )
186
+
187
+ demo.queue()
188
+ demo.launch(
189
+ debug=True,
190
+ show_error=True,
191
+ )
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ torch
2
+ transformers==4.34.0
3
+ numpy
4
+ timm
5
+ safetensors