edbeeching commited on
Commit
b6d1901
·
1 Parent(s): 5e25317

first MVP wip

Browse files
Files changed (2) hide show
  1. app.py +250 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from dataclasses import dataclass
4
+ import os
5
+ from supabase import create_client, Client
6
+ from supabase.client import ClientOptions
7
+ from enum import Enum
8
+ from datasets import get_dataset_infos
9
+ from transformers import AutoConfig
10
+
11
+
12
+ class GenerationStatus(Enum):
13
+ PENDING = "PENDING"
14
+ RUNNING = "RUNNING"
15
+ COMPLETED = "COMPLETED"
16
+ FAILED = "FAILED"
17
+
18
+
19
+ MAX_SAMPLES = 10000 # max number of samples in the input dataset
20
+ MAX_TOKENS = 32768
21
+ MAX_MODEL_PARAMS = 20_000_000_000 # 20 billion parameters (for now)
22
+
23
+ @dataclass
24
+ class GenerationRequest:
25
+ id: str
26
+ status: GenerationStatus
27
+ input_dataset_name: str
28
+ input_dataset_config: str
29
+ input_dataset_split: str
30
+ prompt_column: str
31
+ model_name_or_path: str
32
+ model_revision: str
33
+ model_token: str | None
34
+ system_prompt: str | None
35
+ max_tokens: int
36
+ temperature: float
37
+ top_k: int
38
+ top_p: float
39
+ input_dataset_token: str | None
40
+ output_dataset_token: str
41
+ username: str
42
+ email: str
43
+
44
+
45
+ def validate_request(request: GenerationRequest):
46
+ # checks that the request is valid
47
+ # - input dataset exists and can be accessed with the provided token
48
+ try:
49
+ input_dataset_info = get_dataset_infos(request.input_dataset_name, token=request.input_dataset_token)[request.input_dataset_config]
50
+ except Exception as e:
51
+ raise Exception(f"Dataset {request.input_dataset_name} does not exist or cannot be accessed with the provided token.")
52
+
53
+ # check that the input dataset split exists
54
+ if request.input_dataset_split not in input_dataset_info.splits:
55
+ raise Exception(f"Dataset split {request.input_dataset_split} does not exist in dataset {request.input_dataset_name}. Available splits: {list(input_dataset_info.splits.keys())}")
56
+
57
+ # check that the number of samples is less than MAX_SAMPLES
58
+ if input_dataset_info.splits[request.input_dataset_split].num_samples > MAX_SAMPLES:
59
+ raise Exception(f"Dataset split {request.input_dataset_split} in dataset {request.input_dataset_name} exceeds max sample limit of {MAX_SAMPLES}.")
60
+
61
+ # check the prompt column exists in the dataset
62
+ if request.prompt_column not in input_dataset_info.features:
63
+ raise Exception(f"Prompt column {request.prompt_column} does not exist in dataset {request.input_dataset_name}. Available columns: {list(input_dataset_info.features.keys())}")
64
+
65
+
66
+ # check the models exists
67
+ try:
68
+ model_config = AutoConfig.from_pretrained(request.model_name_or_path, revision=request.model_revision, token=request.model_token)
69
+ except Exception as e:
70
+ raise Exception(f"Model {request.model_name_or_path} revision {request.model_revision} does not exist or cannot be accessed with the provided token.")
71
+
72
+ # check the model max position embeddings is greater than the requested max tokens and less than MAX_TOKENS
73
+ if model_config.max_position_embeddings < request.max_tokens:
74
+ raise Exception(f"Model {request.model_name_or_path} max position embeddings {model_config.max_position_embeddings} is less than the requested max tokens {request.max_tokens}.")
75
+ if request.max_tokens > MAX_TOKENS:
76
+ raise Exception(f"Requested max tokens {request.max_tokens} exceeds the limit of {MAX_TOKENS}.")
77
+
78
+ # check sampling parameters are valid
79
+ if request.temperature < 0.0 or request.temperature > 2.0:
80
+ raise Exception("Temperature must be between 0.0 and 2.0")
81
+ if request.top_k < 1 or request.top_k > 100:
82
+ raise Exception("Top K must be between 1 and 100")
83
+ if request.top_p < 0.0 or request.top_p > 1.0:
84
+ raise Exception("Top P must be between 0.0 and 1.0")
85
+
86
+ # check valid email address TODO: use py3-validate-email https://stackoverflow.com/questions/8022530/how-to-check-for-valid-email-address
87
+ if "@" not in request.email or "." not in request.email.split("@")[-1]:
88
+ raise Exception("Invalid email address")
89
+
90
+
91
+ def add_request_to_db(request: GenerationRequest):
92
+ url: str = os.getenv("SUPABASE_URL")
93
+ key: str = os.getenv("SUPABASE_KEY")
94
+ options: ClientOptions = {
95
+ "schema": "public"
96
+ }
97
+ supabase: Client = create_client(url, key, options)
98
+
99
+ data = {
100
+ "status": request.status.value,
101
+ "input_dataset_name": request.input_dataset_name,
102
+ "input_dataset_config": request.input_dataset_config,
103
+ "input_dataset_split": request.input_dataset_split,
104
+ "prompt_column": request.prompt_column,
105
+ "model_name_or_path": request.model_name_or_path,
106
+ "model_revision": request.model_revision,
107
+ "model_token": request.model_token,
108
+ "system_prompt": request.system_prompt,
109
+ "max_tokens": request.max_tokens,
110
+ "temperature": request.temperature,
111
+ "top_k": request.top_k,
112
+ "top_p": request.top_p,
113
+ "input_dataset_token": request.input_dataset_token,
114
+ "output_dataset_token": request.output_dataset_token,
115
+ "username": request.username,
116
+ "email": request.email
117
+ }
118
+
119
+ response = supabase.table("generation-requests").insert(data).execute()
120
+ if response.status_code != 201:
121
+ raise Exception(f"Failed to add request to database: {response.data}")
122
+
123
+ return response.data
124
+
125
+
126
+ def create_gradio_interface():
127
+ with gr.Blocks(title="Synthetic Data Generation") as interface:
128
+ with gr.Group():
129
+ with gr.Row():
130
+ gr.Markdown("# Synthetic Data Generation Request")
131
+ with gr.Row():
132
+ gr.Markdown("""
133
+ Welcome to the Synthetic Data Generation service! This tool allows you to generate synthetic data using large language models.
134
+
135
+ Generation is FREE for Hugging Face PRO users and uses idle GPUs on the HF science cluster.\n
136
+
137
+
138
+ """)
139
+ with gr.Group():
140
+ with gr.Row():
141
+ gr.Markdown("""
142
+ **How it works:**
143
+ 1. Provide an input dataset with prompts
144
+ 2. Select a language model for generation
145
+ 3. Configure generation parameters
146
+ 4. Submit your request and receive generated data
147
+ """)
148
+ gr.Markdown("""
149
+
150
+ **Requirements:**
151
+ - Input dataset must be publicly accessible or you must provide a valid HuggingFace token
152
+ - Output dataset repository must exist and you must have write access
153
+ - Model must be accessible (public or with valid token)
154
+ - Maximum 10,000 samples per dataset
155
+ - Maximum of 32k generation tokens
156
+
157
+ **Note:** Generation requests are processed asynchronously. You will be notified via email when your request is complete.
158
+ """)
159
+
160
+ with gr.Row():
161
+ with gr.Group():
162
+ gr.Markdown("## Dataset information")
163
+ with gr.Column():
164
+ with gr.Row():
165
+ input_dataset_name = gr.Textbox(label="Input Dataset Name", placeholder="e.g., simplescaling/s1K-1.1")
166
+ input_dataset_split = gr.Textbox(label="Input Dataset Split", value="train", placeholder="e.g., train, test, validation")
167
+ input_dataset_config = gr.Textbox(label="Input Dataset Config", value="default", placeholder="e.g., default, custom")
168
+ prompt_column = gr.Textbox(label="Prompt Column", placeholder="e.g., text, prompt, question")
169
+ with gr.Column():
170
+ output_dataset_name = gr.Textbox(label="Output Dataset Name", placeholder="e.g., MyOrg/my-generated-dataset")
171
+ with gr.Group():
172
+ gr.Markdown("## Model information")
173
+ with gr.Column():
174
+ with gr.Row():
175
+ model_name_or_path = gr.Textbox(label="Model Name or Path", placeholder="e.g., Qwen/Qwen3-4B-Instruct-2507")
176
+ model_revision = gr.Textbox(label="Model Revision", value="main", placeholder="e.g., main, v1.0")
177
+ model_token = gr.Textbox(label="Model Token (Optional)", type="password", placeholder="Your HF token with read/write access to the model...")
178
+ with gr.Group():
179
+ gr.Markdown("## Generation Parameters")
180
+ with gr.Row():
181
+ with gr.Column():
182
+ with gr.Row():
183
+ max_tokens = gr.Slider(label="Max Tokens", value=512, minimum=256, maximum=MAX_TOKENS, step=256)
184
+ temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, value=0.7, step=0.1)
185
+ with gr.Row():
186
+ top_k = gr.Slider(label="Top K", value=50, minimum=5, maximum=100, step=5)
187
+ top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.1)
188
+ with gr.Column():
189
+ system_prompt = gr.Textbox(label="System Prompt (Optional)", lines=3, placeholder="Optional system prompt... e.g., You are a helpful assistant.")
190
+
191
+ with gr.Group():
192
+ gr.Markdown("## User Information, for tokens refer to guide [here](https://huggingface.co/docs/hub/en/security-tokens#user-access-tokens)")
193
+ with gr.Row():
194
+ with gr.Column():
195
+ with gr.Row():
196
+ username = gr.Textbox(label="Hugging Face Username", placeholder="Your HF username")
197
+ email = gr.Textbox(label="Email", placeholder="your.email@example.com")
198
+ with gr.Row():
199
+ input_dataset_token = gr.Textbox(label="Input dataset token", type="password", placeholder="Your HF token with read access to the input dataset, leave blank if public dataset")
200
+ output_dataset_token = gr.Textbox(label="Output dataset token", type="password", placeholder="Your HF token with write access to the output dataset")
201
+
202
+ submit_btn = gr.Button("Submit Generation Request", variant="primary")
203
+ output_status = gr.Textbox(label="Status", interactive=False)
204
+
205
+ def submit_request(input_ds, input_split, prompt_col, model_name, model_rev, model_token, sys_prompt,
206
+ max_tok, temp, top_k_val, top_p_val, output_ds, user, email_addr, input_dataset_token, output_dataset_token):
207
+ try:
208
+ request = GenerationRequest(
209
+ id="", # Will be generated when adding to the database
210
+ status=GenerationStatus.PENDING,
211
+ input_dataset_name=input_ds,
212
+ input_dataset_split=input_split,
213
+ input_dataset_config=input_dataset_config,
214
+ prompt_column=prompt_col,
215
+ model_name_or_path=model_name,
216
+ model_revision=model_rev,
217
+ model_token=model_token if model_token else None,
218
+ system_prompt=sys_prompt if sys_prompt else None,
219
+ max_tokens=int(max_tok),
220
+ temperature=temp,
221
+ top_k=int(top_k_val),
222
+ top_p=top_p_val,
223
+ output_dataset_name=output_ds,
224
+ input_dataset_token=input_dataset_token if input_dataset_token else None,
225
+ output_dataset_token=output_dataset_token,
226
+ username=user,
227
+ email=email_addr
228
+ )
229
+
230
+ # check the input dataset exists and can be accessed with the provided token
231
+ validate_request(request)
232
+ add_request_to_db(request)
233
+
234
+ return "Request submitted successfully!"
235
+ except Exception as e:
236
+ return f"Error: {str(e)}"
237
+
238
+ submit_btn.click(
239
+ submit_request,
240
+ inputs=[input_dataset_name, input_dataset_split, prompt_column, model_name_or_path,
241
+ model_revision, model_token, system_prompt, max_tokens, temperature, top_k, top_p,
242
+ output_dataset_name, username, email, input_dataset_token, output_dataset_token],
243
+ outputs=output_status
244
+ )
245
+
246
+ return interface
247
+
248
+ if __name__ == "__main__":
249
+ app = create_gradio_interface()
250
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ supabase
2
+ gradio
3
+ transformers
4
+ datasets