Mohaddz commited on
Commit
08b5ccb
·
verified ·
1 Parent(s): 3bc3003

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -0
app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RND1 Diffusion Model Demo for Hugging Face Spaces with ZeroGPU
4
+ """
5
+
6
+ import torch
7
+ import gradio as gr
8
+ import spaces
9
+ import random
10
+ import numpy as np
11
+ from transformers import AutoTokenizer
12
+
13
+ # Global model and tokenizer
14
+ model = None
15
+ tokenizer = None
16
+ device = "cuda"
17
+
18
+
19
+ def set_seed(seed: int):
20
+ """Set random seed for reproducibility."""
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ if torch.cuda.is_available():
25
+ torch.cuda.manual_seed_all(seed)
26
+
27
+
28
+ def load_model():
29
+ """Load model and tokenizer (called once at startup)."""
30
+ global model, tokenizer
31
+
32
+ from rnd.configuration_rnd import RND1Config
33
+ from rnd.modeling_rnd import RND1LM
34
+
35
+ model_path = "radicalnumerics/RND1-Base-0910"
36
+
37
+ print("Loading tokenizer...")
38
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
39
+
40
+ print("Loading model...")
41
+ cfg = RND1Config.from_pretrained(model_path)
42
+ cfg.model_type = "rnd1"
43
+ cfg.attn_implementation = "sdpa"
44
+ cfg.moe_backend = "hf"
45
+
46
+ model = RND1LM.from_pretrained(
47
+ model_path,
48
+ config=cfg,
49
+ torch_dtype=torch.bfloat16,
50
+ device_map="auto",
51
+ trust_remote_code=True,
52
+ use_safetensors=True,
53
+ low_cpu_mem_usage=True,
54
+ )
55
+ model.eval()
56
+ print("Model loaded successfully!")
57
+
58
+
59
+ @spaces.GPU(duration=120) # Request GPU for up to 120 seconds
60
+ def generate_text(
61
+ prompt: str,
62
+ mode: str,
63
+ num_steps: int,
64
+ max_new_tokens: int,
65
+ temperature: float,
66
+ top_k: int,
67
+ top_p: float,
68
+ seed: int,
69
+ progress=gr.Progress()
70
+ ):
71
+ """
72
+ Generate text using RND1 diffusion model.
73
+
74
+ Args:
75
+ prompt: Input text prompt
76
+ mode: Generation mode ('task' or 'completion')
77
+ num_steps: Number of diffusion steps
78
+ max_new_tokens: Maximum tokens to generate
79
+ temperature: Sampling temperature
80
+ top_k: Top-k filtering (0 to disable)
81
+ top_p: Top-p nucleus filtering (0 to disable)
82
+ seed: Random seed
83
+ progress: Gradio progress tracker
84
+ """
85
+ if not prompt.strip():
86
+ return "⚠️ Please enter a prompt."
87
+
88
+ progress(0, desc="Setting seed...")
89
+ set_seed(seed)
90
+
91
+ progress(0.1, desc="Preparing prompt...")
92
+
93
+ # Format prompt based on mode
94
+ if mode == "task":
95
+ if not prompt.strip().startswith("Question:"):
96
+ formatted_prompt = f"Question: {prompt}\n"
97
+ else:
98
+ formatted_prompt = prompt
99
+ else:
100
+ formatted_prompt = prompt
101
+
102
+ # Tokenize
103
+ progress(0.2, desc="Tokenizing...")
104
+ inputs = tokenizer(formatted_prompt, return_tensors="pt")
105
+ input_ids = inputs.input_ids.to(device)
106
+ attention_mask = inputs.attention_mask.to(device) if 'attention_mask' in inputs else None
107
+
108
+ # Prepare generation config
109
+ from rnd.generation_config import RND1GenerationConfig
110
+
111
+ greedy = (temperature == 1.0)
112
+ gen_config = RND1GenerationConfig(
113
+ max_new_tokens=max_new_tokens,
114
+ num_diffusion_steps=num_steps,
115
+ mask_token_id=151669,
116
+ temperature=temperature if not greedy else 1.0,
117
+ top_k=top_k if top_k > 0 else None,
118
+ top_p=top_p if top_p > 0 else None,
119
+ greedy=greedy,
120
+ eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else 151645,
121
+ pad_token_id=tokenizer.pad_token_id,
122
+ bos_token_id=tokenizer.bos_token_id,
123
+ )
124
+
125
+ # Generate
126
+ progress(0.3, desc=f"Generating ({num_steps} diffusion steps)...")
127
+
128
+ generator = torch.Generator(device=device)
129
+ generator.manual_seed(seed)
130
+
131
+ with torch.no_grad():
132
+ output = model.generate(
133
+ inputs=input_ids,
134
+ generation_config=gen_config,
135
+ generator=generator,
136
+ )
137
+
138
+ progress(0.9, desc="Decoding...")
139
+
140
+ # Decode generated tokens
141
+ generated_tokens = output[0][len(input_ids[0]):]
142
+ generation = tokenizer.decode(
143
+ generated_tokens.tolist(),
144
+ skip_special_tokens=True
145
+ )
146
+
147
+ progress(1.0, desc="Complete!")
148
+
149
+ return generation
150
+
151
+
152
+ # Create Gradio interface
153
+ def create_interface():
154
+ with gr.Blocks(title="RND1 Diffusion Language Model", theme=gr.themes.Soft()) as demo:
155
+ gr.Markdown("""
156
+ # 🌊 RND1 Diffusion Language Model
157
+
158
+ Generate text using a diffusion-based language model. The model uses iterative denoising
159
+ to progressively refine masked tokens into coherent text.
160
+
161
+ **Note:** First generation may take longer as the model loads.
162
+ """)
163
+
164
+ with gr.Row():
165
+ with gr.Column(scale=1):
166
+ prompt = gr.Textbox(
167
+ label="Prompt",
168
+ placeholder="Enter your prompt here...",
169
+ lines=4,
170
+ value="Write a Python function that finds the longest common subsequence of two strings."
171
+ )
172
+
173
+ mode = gr.Radio(
174
+ choices=["task", "completion"],
175
+ value="task",
176
+ label="Generation Mode",
177
+ info="Task: Q&A format for instructions | Completion: Continue the text"
178
+ )
179
+
180
+ with gr.Accordion("Generation Settings", open=True):
181
+ num_steps = gr.Slider(
182
+ minimum=16,
183
+ maximum=512,
184
+ value=256,
185
+ step=16,
186
+ label="Diffusion Steps",
187
+ info="More steps = better quality but slower"
188
+ )
189
+
190
+ max_new_tokens = gr.Slider(
191
+ minimum=32,
192
+ maximum=512,
193
+ value=256,
194
+ step=32,
195
+ label="Max New Tokens"
196
+ )
197
+
198
+ with gr.Accordion("Sampling Parameters", open=False):
199
+ temperature = gr.Slider(
200
+ minimum=0.1,
201
+ maximum=2.0,
202
+ value=1.0,
203
+ step=0.1,
204
+ label="Temperature",
205
+ info="1.0 = greedy/deterministic"
206
+ )
207
+
208
+ top_k = gr.Slider(
209
+ minimum=0,
210
+ maximum=100,
211
+ value=0,
212
+ step=1,
213
+ label="Top-K",
214
+ info="0 to disable"
215
+ )
216
+
217
+ top_p = gr.Slider(
218
+ minimum=0.0,
219
+ maximum=1.0,
220
+ value=0.0,
221
+ step=0.05,
222
+ label="Top-P (Nucleus)",
223
+ info="0 to disable"
224
+ )
225
+
226
+ seed = gr.Slider(
227
+ minimum=0,
228
+ maximum=100000,
229
+ value=12345,
230
+ step=1,
231
+ label="Random Seed"
232
+ )
233
+
234
+ generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
235
+
236
+ with gr.Column(scale=1):
237
+ output = gr.Textbox(
238
+ label="Generated Text",
239
+ lines=20,
240
+ show_copy_button=True
241
+ )
242
+
243
+ gr.Markdown("""
244
+ ### Examples
245
+ Try these prompts to see what the model can do!
246
+ """)
247
+
248
+ gr.Examples(
249
+ examples=[
250
+ ["Write a Python function that finds the longest common subsequence of two strings.", "task", 256, 256, 1.0, 0, 0.0, 12345],
251
+ ["Explain the concept of recursion with a simple example.", "task", 256, 200, 1.0, 0, 0.0, 42],
252
+ ["The key to understanding quantum computing lies in", "completion", 256, 256, 1.0, 0, 0.0, 9876],
253
+ ["Once upon a time in a distant galaxy,", "completion", 256, 300, 1.0, 0, 0.0, 7777],
254
+ ],
255
+ inputs=[prompt, mode, num_steps, max_new_tokens, temperature, top_k, top_p, seed],
256
+ outputs=output,
257
+ fn=generate_text,
258
+ cache_examples=False,
259
+ )
260
+
261
+ generate_btn.click(
262
+ fn=generate_text,
263
+ inputs=[prompt, mode, num_steps, max_new_tokens, temperature, top_k, top_p, seed],
264
+ outputs=output,
265
+ )
266
+
267
+ return demo
268
+
269
+
270
+ if __name__ == "__main__":
271
+ # Load model at startup
272
+ load_model()
273
+
274
+ # Launch Gradio interface
275
+ demo = create_interface()
276
+ demo.queue(max_size=10) # Enable queue for ZeroGPU
277
+ demo.launch()