Spaces:
Runtime error
Runtime error
refactor tree_utils
Browse files- app.py +11 -65
- tree_utils.py +59 -0
app.py
CHANGED
@@ -6,6 +6,9 @@ import numpy as np
|
|
6 |
import torch
|
7 |
from threading import Thread
|
8 |
|
|
|
|
|
|
|
9 |
def make_script(shader_code):
|
10 |
# code copied and fixed(escaping single quotes to double quotes!!!) from https://webglfundamentals.org/webgl/webgl-shadertoy.html
|
11 |
script = ("""
|
@@ -295,18 +298,6 @@ new_shadertoy_code = """void mainImage( out vec4 fragColor, in vec2 fragCoord )
|
|
295 |
fragColor = vec4(col,1.0);
|
296 |
}"""
|
297 |
|
298 |
-
passes_dataset = datasets.load_dataset("Vipitis/Shadertoys")
|
299 |
-
single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions.
|
300 |
-
# single_passes = single_passes.filter(lambda x: x["license"] not in "copyright") #to avoid any "do not display this" license?
|
301 |
-
all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]])
|
302 |
-
num_samples = len(all_single_passes)
|
303 |
-
|
304 |
-
import tree_sitter
|
305 |
-
from tree_sitter import Language, Parser
|
306 |
-
Language.build_library("./build/my-languages.so", ['tree-sitter-glsl'])
|
307 |
-
GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl')
|
308 |
-
parser = Parser()
|
309 |
-
parser.set_language(GLSL_LANGUAGE)
|
310 |
|
311 |
def grab_sample(sample_idx):
|
312 |
sample_pass = all_single_passes[sample_idx]
|
@@ -322,19 +313,6 @@ def grab_sample(sample_idx):
|
|
322 |
# print(f"updating drop down to:{func_identifiers}")
|
323 |
return sample_pass, sample_code, sample_title, source_iframe, funcs#, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor
|
324 |
|
325 |
-
|
326 |
-
def _parse_functions(in_code):
|
327 |
-
"""
|
328 |
-
returns all functions in the code as their actual nodes.
|
329 |
-
includes any comment made directly after the function definition or diretly after #copilot trigger
|
330 |
-
"""
|
331 |
-
tree = parser.parse(bytes(in_code, "utf8"))
|
332 |
-
funcs = [n for n in tree.root_node.children if n.type == "function_definition"]
|
333 |
-
|
334 |
-
return funcs
|
335 |
-
|
336 |
-
PIPE = None
|
337 |
-
|
338 |
def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing
|
339 |
# if torch.cuda.is_available():
|
340 |
# device = "cuda"
|
@@ -436,16 +414,6 @@ def alter_return(orig_code, func_idx, temperature, max_new_tokens, top_p, repeti
|
|
436 |
|
437 |
return altered_code
|
438 |
|
439 |
-
def _line_chr2char(text, line_idx, chr_idx):
|
440 |
-
"""
|
441 |
-
returns the character index at the given line and character index.
|
442 |
-
"""
|
443 |
-
lines = text.split("\n")
|
444 |
-
char_idx = 0
|
445 |
-
for i in range(line_idx):
|
446 |
-
char_idx += len(lines[i]) + 1
|
447 |
-
char_idx += chr_idx
|
448 |
-
return char_idx
|
449 |
|
450 |
def _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty):
|
451 |
gen_kwargs = {}
|
@@ -455,34 +423,6 @@ def _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_pe
|
|
455 |
gen_kwargs["repetition_penalty"] = repetition_penalty
|
456 |
return gen_kwargs
|
457 |
|
458 |
-
def _grab_before_comments(func_node):
|
459 |
-
"""
|
460 |
-
returns the comments that happen just before a function node
|
461 |
-
"""
|
462 |
-
precomment = ""
|
463 |
-
last_comment_line = 0
|
464 |
-
for node in func_node.parent.children: #could you optimize where to iterated from? directon?
|
465 |
-
if node.start_point[0] != last_comment_line + 1:
|
466 |
-
precomment = ""
|
467 |
-
if node.type == "comment":
|
468 |
-
precomment += node.text.decode() + "\n"
|
469 |
-
last_comment_line = node.start_point[0]
|
470 |
-
elif node == func_node:
|
471 |
-
return precomment
|
472 |
-
return precomment
|
473 |
-
|
474 |
-
def _get_docstrings(func_node):
|
475 |
-
"""
|
476 |
-
returns the docstring of a function node
|
477 |
-
"""
|
478 |
-
docstring = ""
|
479 |
-
for node in func_node.child_by_field_name("body").children:
|
480 |
-
if node.type == "comment" or node.type == "{":
|
481 |
-
docstring += node.text.decode() + "\n"
|
482 |
-
else:
|
483 |
-
return docstring
|
484 |
-
return docstring
|
485 |
-
|
486 |
def alter_body(old_code, func_id, funcs_list: list, prompt, temperature, max_new_tokens, top_p, repetition_penalty, pipeline=PIPE):
|
487 |
"""
|
488 |
Replaces the body of a function with a generated one.
|
@@ -581,7 +521,7 @@ def construct_embed(source_url):
|
|
581 |
with gr.Blocks() as site:
|
582 |
top_md = gr.Markdown(intro_text)
|
583 |
model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True)
|
584 |
-
sample_idx = gr.Slider(minimum=0, maximum=
|
585 |
func_dropdown = gr.Dropdown(value=["0: edit the Code (or load a shader) to update this dropdown"], label="chose a function to modify") #breaks if I add a string in before that? #TODO: use type="index" to get int - always gives None?
|
586 |
prompt_text = gr.Textbox(value="the title used by the model has generation hint", label="prompt text", info="leave blank to skip", interactive=True)
|
587 |
with gr.Accordion("Advanced settings", open=False): # from: https://huggingface.co/spaces/bigcode/bigcode-playground/blob/main/app.py
|
@@ -644,7 +584,7 @@ with gr.Blocks() as site:
|
|
644 |
|
645 |
model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe]) # how can we trigger this on load?
|
646 |
sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, prompt_text, source_embed]) #funcs here?
|
647 |
-
gen_return_button.click(fn=alter_return, inputs=[sample_code, func_dropdown, pipe], outputs=[sample_code])
|
648 |
gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, prompt_text, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code, pipe]).then(
|
649 |
fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]
|
650 |
)
|
@@ -652,5 +592,11 @@ with gr.Blocks() as site:
|
|
652 |
fn=make_iframe, inputs=[sample_code], outputs=[our_embed])
|
653 |
|
654 |
if __name__ == "__main__": #works on huggingface?
|
|
|
|
|
|
|
|
|
|
|
|
|
655 |
site.queue()
|
656 |
site.launch()
|
|
|
6 |
import torch
|
7 |
from threading import Thread
|
8 |
|
9 |
+
from tree_utils import _parse_functions, _get_docstrings, _grab_before_comments, _line_chr2char
|
10 |
+
PIPE = None
|
11 |
+
|
12 |
def make_script(shader_code):
|
13 |
# code copied and fixed(escaping single quotes to double quotes!!!) from https://webglfundamentals.org/webgl/webgl-shadertoy.html
|
14 |
script = ("""
|
|
|
298 |
fragColor = vec4(col,1.0);
|
299 |
}"""
|
300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
def grab_sample(sample_idx):
|
303 |
sample_pass = all_single_passes[sample_idx]
|
|
|
313 |
# print(f"updating drop down to:{func_identifiers}")
|
314 |
return sample_pass, sample_code, sample_title, source_iframe, funcs#, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor
|
315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing
|
317 |
# if torch.cuda.is_available():
|
318 |
# device = "cuda"
|
|
|
414 |
|
415 |
return altered_code
|
416 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
|
418 |
def _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty):
|
419 |
gen_kwargs = {}
|
|
|
423 |
gen_kwargs["repetition_penalty"] = repetition_penalty
|
424 |
return gen_kwargs
|
425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
def alter_body(old_code, func_id, funcs_list: list, prompt, temperature, max_new_tokens, top_p, repetition_penalty, pipeline=PIPE):
|
427 |
"""
|
428 |
Replaces the body of a function with a generated one.
|
|
|
521 |
with gr.Blocks() as site:
|
522 |
top_md = gr.Markdown(intro_text)
|
523 |
model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True)
|
524 |
+
sample_idx = gr.Slider(minimum=0, maximum=10513, value=3211, label="pick sample from dataset", step=1.0)
|
525 |
func_dropdown = gr.Dropdown(value=["0: edit the Code (or load a shader) to update this dropdown"], label="chose a function to modify") #breaks if I add a string in before that? #TODO: use type="index" to get int - always gives None?
|
526 |
prompt_text = gr.Textbox(value="the title used by the model has generation hint", label="prompt text", info="leave blank to skip", interactive=True)
|
527 |
with gr.Accordion("Advanced settings", open=False): # from: https://huggingface.co/spaces/bigcode/bigcode-playground/blob/main/app.py
|
|
|
584 |
|
585 |
model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe]) # how can we trigger this on load?
|
586 |
sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, prompt_text, source_embed]) #funcs here?
|
587 |
+
gen_return_button.click(fn=alter_return, inputs=[sample_code, func_dropdown, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code])
|
588 |
gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, prompt_text, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code, pipe]).then(
|
589 |
fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]
|
590 |
)
|
|
|
592 |
fn=make_iframe, inputs=[sample_code], outputs=[our_embed])
|
593 |
|
594 |
if __name__ == "__main__": #works on huggingface?
|
595 |
+
passes_dataset = datasets.load_dataset("Vipitis/Shadertoys")
|
596 |
+
single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions.
|
597 |
+
# single_passes = single_passes.filter(lambda x: x["license"] not in "copyright") #to avoid any "do not display this" license?
|
598 |
+
all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]])
|
599 |
+
num_samples = len(all_single_passes)
|
600 |
+
|
601 |
site.queue()
|
602 |
site.launch()
|
tree_utils.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tree_sitter
|
2 |
+
from tree_sitter import Language, Parser
|
3 |
+
|
4 |
+
Language.build_library("./build/my-languages.so", ['tree-sitter-glsl'])
|
5 |
+
GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl')
|
6 |
+
parser = Parser()
|
7 |
+
parser.set_language(GLSL_LANGUAGE)
|
8 |
+
|
9 |
+
|
10 |
+
def _parse_functions(in_code):
|
11 |
+
"""
|
12 |
+
returns all functions in the code as their actual nodes.
|
13 |
+
includes any comment made directly after the function definition or diretly after #copilot trigger
|
14 |
+
"""
|
15 |
+
tree = parser.parse(bytes(in_code, "utf8"))
|
16 |
+
funcs = [n for n in tree.root_node.children if n.type == "function_definition"]
|
17 |
+
|
18 |
+
return funcs
|
19 |
+
|
20 |
+
|
21 |
+
def _get_docstrings(func_node):
|
22 |
+
"""
|
23 |
+
returns the docstring of a function node
|
24 |
+
"""
|
25 |
+
docstring = ""
|
26 |
+
for node in func_node.child_by_field_name("body").children:
|
27 |
+
if node.type == "comment" or node.type == "{":
|
28 |
+
docstring += node.text.decode() + "\n"
|
29 |
+
else:
|
30 |
+
return docstring
|
31 |
+
return docstring
|
32 |
+
|
33 |
+
|
34 |
+
def _grab_before_comments(func_node):
|
35 |
+
"""
|
36 |
+
returns the comments that happen just before a function node
|
37 |
+
"""
|
38 |
+
precomment = ""
|
39 |
+
last_comment_line = 0
|
40 |
+
for node in func_node.parent.children: #could you optimize where to iterated from? directon?
|
41 |
+
if node.start_point[0] != last_comment_line + 1:
|
42 |
+
precomment = ""
|
43 |
+
if node.type == "comment":
|
44 |
+
precomment += node.text.decode() + "\n"
|
45 |
+
last_comment_line = node.start_point[0]
|
46 |
+
elif node == func_node:
|
47 |
+
return precomment
|
48 |
+
return precomment
|
49 |
+
|
50 |
+
def _line_chr2char(text, line_idx, chr_idx):
|
51 |
+
"""
|
52 |
+
returns the character index at the given line and character index.
|
53 |
+
"""
|
54 |
+
lines = text.split("\n")
|
55 |
+
char_idx = 0
|
56 |
+
for i in range(line_idx):
|
57 |
+
char_idx += len(lines[i]) + 1
|
58 |
+
char_idx += chr_idx
|
59 |
+
return char_idx
|