Vipitis commited on
Commit
fee32de
1 Parent(s): 0680c21

generate function body button

Browse files
app.py CHANGED
@@ -142,18 +142,30 @@ If I find an efficient way, the shaders might run in real time and be interactiv
142
 
143
  ## TODO:
144
  - [x] use embedded Shadertoy for reference/attribution (done, but some errors)
145
- - [] working render implementation on CPU only space (use the browser for WebGPU?)
146
  - [~] generate variations of return statements [ShaderEval task1](https://huggingface.co/spaces/Vipitis/ShaderEval) (missing all of the generation parameters)
147
  - [] generation history stating which function and orig/generated returns. (use State ??). do it as comments in the code?
148
- - [] generate whole functions
149
  - [] generate whole shaders (via prompts?)
150
  """
151
  passes_dataset = datasets.load_dataset("Vipitis/Shadertoys")
152
- single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1 and x["code"].count("return") >= 1) #filter easier than having a custom loader script?
153
  all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]])
154
  num_samples = len(all_single_passes)
155
 
 
 
 
 
 
 
 
 
156
  async def get_image(code, time= 0.0, resolution=(512, 420)):
 
 
 
 
157
  shader = ShadertoyCustom(code, resolution, OffscreenCanvas, run_offscreen) #pass offscreen canvas here.
158
  shader._uniform_data["time"] = time #set any time you want
159
  shader._canvas.request_draw(shader._draw_frame)
@@ -172,11 +184,25 @@ def grab_sample(sample_idx):
172
  sample_auhtor = sample_pass["author"]
173
  source_iframe = construct_embed(sample_source)
174
  print(f"{source_iframe=}")
175
- return sample_pass, sample_code, source_iframe #, sample_title, sample_auhtor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- PIPE = None #gloabl var in CAPS indicates constant? so why are we changing it?
178
 
179
- def _make_pipeline(model_cp = "gpt2"): #bad default model for testing
180
  tokenizer = AutoTokenizer.from_pretrained(model_cp, trust_remote_code=True)
181
  model = AutoModelForCausalLM.from_pretrained(model_cp, trust_remote_code=True)
182
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True)
@@ -184,6 +210,7 @@ def _make_pipeline(model_cp = "gpt2"): #bad default model for testing
184
  print(f"loaded model {model_cp} as a pipline")
185
  return pipe
186
 
 
187
  def process_retn(retn):
188
  return retn.split(";")[0].strip()
189
 
@@ -231,11 +258,67 @@ def alter_return(orig_code, func_idx=0, pipeline=PIPE): #default pipeline can't
231
 
232
  return altered_code
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  def add_history(func_id, orig_rtn, gened_rtn, history):
235
  # is this a list? or a JSON dict?
236
  history[func_id] = (orig_rtn, gened_rtn)
237
  return history, history
238
 
 
 
 
 
 
 
 
 
 
239
  def construct_embed(source_url):
240
  shader_id = source_url.split("/")[-1]
241
  return f'<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/{shader_id}?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>'
@@ -244,23 +327,30 @@ with gr.Blocks() as site:
244
  text_md = gr.Markdown(text)
245
  model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True)
246
  sample_idx = gr.Slider(minimum=0, maximum=num_samples, value=3211, label="pick sample from dataset", step=1.0)
247
- run_button = gr.Button("generate a alternate return statement for one function", label="generate code")
 
 
 
 
248
  render_button = gr.Button("render frame0 (can carsh the sapce on invalid shadercode)",label="render frame")
249
  time_slider = gr.Slider(minimum=0, maximum=10, value=0, label="time (update on release, also used to pick other functions as a bodge)", step=0.02)
250
  with gr.Row():
251
  with gr.Column():
252
  source_embed = gr.HTML('<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/WsBcWV?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>', label="How this shader originally renders")
253
- rendered_frame = gr.Image(shape=(512, 420), label=f"rendered frame preview")
254
- sample_code = gr.Code(label="Current Code (will update changes you generate)", language=None, readonly=True, lines=20)
255
 
256
  sample_pass = gr.State(value={})
257
  pipe = gr.State(value=PIPE)
 
258
  # hist_state = gr.State(Value={})
259
  # history_table = gr.JSON()
260
 
261
  model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe])
262
- sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, source_embed])
263
- run_button.click(fn=alter_return, inputs=[sample_code, time_slider, pipe], outputs=[sample_code])
 
 
264
  # run_button.click(fn=add_history, inputs=[time_slider, sample_pass, sample_code, hist_state], outputs=[history_table, hist_state])
265
  # sample_idx.release(fn=construct_embed, inputs=[sample_idx], outputs=[source_embed]) #twice to make have different outputs?
266
  time_slider.release(fn=lambda code, time: asyncio.run(get_image(code, time)), inputs=[sample_code, time_slider], outputs=rendered_frame)
 
142
 
143
  ## TODO:
144
  - [x] use embedded Shadertoy for reference/attribution (done, but some errors)
145
+ - [] working render implementation on CPU only space (use the browser for WebGPU?, maybe via an iFrame too?)
146
  - [~] generate variations of return statements [ShaderEval task1](https://huggingface.co/spaces/Vipitis/ShaderEval) (missing all of the generation parameters)
147
  - [] generation history stating which function and orig/generated returns. (use State ??). do it as comments in the code?
148
+ - [x?] generate whole functions (only works once)
149
  - [] generate whole shaders (via prompts?)
150
  """
151
  passes_dataset = datasets.load_dataset("Vipitis/Shadertoys")
152
+ single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions.
153
  all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]])
154
  num_samples = len(all_single_passes)
155
 
156
+ import tree_sitter
157
+ from tree_sitter import Language, Parser
158
+ Language.build_library("build/my-languages.so", ['tree-sitter-glsl'])
159
+ GLSL_LANGUAGE = Language('build/my-languages.so', 'glsl')
160
+ parser = Parser()
161
+ parser.set_language(GLSL_LANGUAGE)
162
+
163
+
164
  async def get_image(code, time= 0.0, resolution=(512, 420)):
165
+ tree = parser.parse(bytes(code, "utf8"))
166
+ if tree.root_node.has_error:
167
+ print("ERROR in the tree, aborting.")
168
+ return None
169
  shader = ShadertoyCustom(code, resolution, OffscreenCanvas, run_offscreen) #pass offscreen canvas here.
170
  shader._uniform_data["time"] = time #set any time you want
171
  shader._canvas.request_draw(shader._draw_frame)
 
184
  sample_auhtor = sample_pass["author"]
185
  source_iframe = construct_embed(sample_source)
186
  print(f"{source_iframe=}")
187
+ sample_funcs = _parse_functions(sample_code)
188
+ funcs = _parse_functions(sample_code)
189
+ func_identifiers = [(idx,n.child_by_field_name("declarator").text.decode()) for idx, n in enumerate(funcs)]
190
+ print(f"updating drop down to:{func_identifiers}")
191
+ return sample_pass, sample_code, source_iframe, funcs, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor
192
+
193
+
194
+ def _parse_functions(in_code):
195
+ """
196
+ returns all functions in the code as their actual nodes.
197
+
198
+ """
199
+ tree = parser.parse(bytes(in_code, "utf8"))
200
+ funcs = [n for n in tree.root_node.children if n.type == "function_definition"]
201
+ return funcs
202
 
203
+ PIPE = None
204
 
205
+ def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing
206
  tokenizer = AutoTokenizer.from_pretrained(model_cp, trust_remote_code=True)
207
  model = AutoModelForCausalLM.from_pretrained(model_cp, trust_remote_code=True)
208
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True)
 
210
  print(f"loaded model {model_cp} as a pipline")
211
  return pipe
212
 
213
+
214
  def process_retn(retn):
215
  return retn.split(";")[0].strip()
216
 
 
258
 
259
  return altered_code
260
 
261
+ def _line_chr2char(text, line_idx, chr_idx):
262
+ """
263
+ returns the character index at the given line and character index.
264
+ """
265
+ lines = text.split("\n")
266
+ char_idx = 0
267
+ for i in range(line_idx):
268
+ char_idx += len(lines[i]) + 1
269
+ char_idx += chr_idx
270
+ return char_idx
271
+
272
+ def alter_body(old_code, func_id, funcs_list, pipeline=PIPE):
273
+ """
274
+ Replaces the body of a function with a generated one.
275
+ Args:
276
+ old_code (str): The original code.
277
+ func_node (Node): The node of the function to replace the body of.
278
+ pipeline (Pipeline): The pipeline to use for generation.
279
+ Returns:
280
+ str: The altered code.
281
+ """
282
+ print(f"{func_id=}")
283
+ func_id = int(func_id.split(",")[0]) #undo their string casting?
284
+ func_node = funcs_list[func_id]
285
+ print(f"using for generation: {func_node=}")
286
+
287
+
288
+
289
+ if pipeline is None:
290
+ print("no pipeline found, loading default one")
291
+ pipeline = _make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine")
292
+
293
+ func_start_idx = _line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1])
294
+ body_node = func_node.child_by_field_name("body")
295
+ body_start_idx = _line_chr2char(old_code, body_node.start_point[0], body_node.start_point[1])
296
+ body_end_idx = _line_chr2char(old_code, body_node.end_point[0], body_node.end_point[1])
297
+ print(f"{old_code[body_start_idx:body_end_idx]=}")
298
+ model_context = old_code[:body_start_idx]
299
+ generation = pipeline(model_context, max_new_tokens=(body_end_idx - body_start_idx)*2, return_full_text=False)[0]["generated_text"]
300
+ print(f"{generation=}")
301
+ first_gened_func = _parse_functions(old_code[func_start_idx:body_start_idx] + generation)[0] # truncate generation to a single function?
302
+ # strip just the body.
303
+ generated_body = first_gened_func.child_by_field_name("body").text.decode()
304
+ print(f"{generated_body=}")
305
+ altered_code = old_code[:body_start_idx] + generated_body + old_code[body_end_idx:]
306
+ return altered_code
307
+
308
  def add_history(func_id, orig_rtn, gened_rtn, history):
309
  # is this a list? or a JSON dict?
310
  history[func_id] = (orig_rtn, gened_rtn)
311
  return history, history
312
 
313
+ def list_dropdown(in_code):
314
+ funcs = _parse_functions(in_code)
315
+
316
+ # print(f"updating drop down to:{func_identifiers=}")
317
+ func_identifiers = [(idx,n.child_by_field_name("declarator").text.decode()) for idx, n in enumerate(funcs)]
318
+ # funcs = [n for n in funcs] #wrapped as set to avoid json issues?
319
+ print(f"updating drop down to:{func_identifiers}")
320
+ return funcs, gr.Dropdown.update(choices=func_identifiers)
321
+
322
  def construct_embed(source_url):
323
  shader_id = source_url.split("/")[-1]
324
  return f'<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/{shader_id}?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>'
 
327
  text_md = gr.Markdown(text)
328
  model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True)
329
  sample_idx = gr.Slider(minimum=0, maximum=num_samples, value=3211, label="pick sample from dataset", step=1.0)
330
+ func_dropdown = gr.Dropdown(label="chose a function to modify") #breaks if I add a string in before that?
331
+ with gr.Row():
332
+ gen_return_button = gr.Button("generate a alternate return statement", label="generate return")
333
+ gen_func_button = gr.Button("generate an alternate function body", label="generate function")
334
+ # update_funcs_button = gr.Button("update functions", label="update functions")
335
  render_button = gr.Button("render frame0 (can carsh the sapce on invalid shadercode)",label="render frame")
336
  time_slider = gr.Slider(minimum=0, maximum=10, value=0, label="time (update on release, also used to pick other functions as a bodge)", step=0.02)
337
  with gr.Row():
338
  with gr.Column():
339
  source_embed = gr.HTML('<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/WsBcWV?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>', label="How this shader originally renders")
340
+ rendered_frame = gr.Image(shape=(512, 420), label=f"rendered frame preview", type="pil") #colors are messed up?
341
+ sample_code = gr.Code(label="Current Code (will update changes you generate)", language=None)
342
 
343
  sample_pass = gr.State(value={})
344
  pipe = gr.State(value=PIPE)
345
+ funcs = gr.State(value=[])
346
  # hist_state = gr.State(Value={})
347
  # history_table = gr.JSON()
348
 
349
  model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe])
350
+ sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, source_embed, funcs, func_dropdown])
351
+ # sample_idx.release(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]) #use multiple event handles to call other functions! seems to not work really well. always messes up
352
+ gen_return_button.click(fn=alter_return, inputs=[sample_code, time_slider, pipe], outputs=[sample_code])
353
+ gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, pipe], outputs=[sample_code])
354
  # run_button.click(fn=add_history, inputs=[time_slider, sample_pass, sample_code, hist_state], outputs=[history_table, hist_state])
355
  # sample_idx.release(fn=construct_embed, inputs=[sample_idx], outputs=[source_embed]) #twice to make have different outputs?
356
  time_slider.release(fn=lambda code, time: asyncio.run(get_image(code, time)), inputs=[sample_code, time_slider], outputs=rendered_frame)
build/my-languages.exp ADDED
Binary file (744 Bytes). View file
 
build/my-languages.lib ADDED
Binary file (1.8 kB). View file
 
build/my-languages.so ADDED
Binary file (631 kB). View file
 
requirements.txt CHANGED
@@ -7,4 +7,5 @@ pillow
7
  gradio
8
  numpy
9
  glfw
10
- jupylet
 
 
7
  gradio
8
  numpy
9
  glfw
10
+ jupylet
11
+ tree-sitter