import gradio as gr from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer import datasets import asyncio import numpy as np def make_script(shader_code): # code copied and fixed(escaping single quotes to double quotes!!!) from https://webglfundamentals.org/webgl/webgl-shadertoy.html script = (""" WebGL - Shadertoy
blank canvas here indicates that some of the shadertoy specific functions are not yet supported with this implementation (like #define I believe). you can always copy and paste the code into a shadertoy.com window to try. """) return script def make_iframe(shader_code): #keep a single function? script = make_script(shader_code) return f"""""" intro_text = """ # Welcome to the interactive shadercoding demo. This gives you access to a filtered version of the [Shadertoys](https://huggingface.co/datasets/Vipitis/Shadertoys) dataset, only shaders that consist of a single pass are available. And then lets you use code generation models to make alterations to part of the shadercode. ## How To Use: 1. Load any Model for [`text-generation`](https://huggingface.co/models?pipeline_tag=text-generation) and hit ENTER. 2. Use the slider to sample a shader from the dataset. - The original shader will be embedding on the left, click on title to get to the source. - The shadercode will be displayed on the right, this is interactive. - A preview of the currently displayed shadercode will be displayed on the lower left. (hover to advance time) 3. use the dropdown to select a function to modify. 4. press either button to make modifications to that function 5. you can also edit the code manually. """ outro_text =""" ## Models to try (look at [ShaderEval](https://huggingface.co/spaces/Vipitis/ShaderEval) for an indication of how helpful they will be): - [gpt2](https://huggingface.co/gpt2) baseline for language models, really struggles with shadercode. - [bigscience/bloom-1b1](https://huggingface.co/bigscience/bloom-1b1) a newer and larger freely available model. Does understand a big of code. - [codeparrot/codeparrot-small](https://huggingface.co/codeparrot/codeparrot-small) a model trained on code, but not on shadercode. Manages to graps the patterns. - [salesforce/codegen-2B-multi](https://huggingface.co/salesforce/codegen-2B-multi) a larger model that indicates some potential. - [bigcode/santacoder](https://huggingface.co/bigcode/santacoder) a model trained on subset of [TheStack](https://huggingface.co/datasets/bigcode/the-stack), struggles with shadercode. - [Vipitis/santacoder-finetuned-the-stack-glsl](https://huggingface.co/Vipitis/santacoder-finetuned-the-stack-glsl) fine-tuned by me on the glsl subset of [TheStack](https://huggingface.co/datasets/bigcode/the-stack), is an improvement. - [Vipitis/santacoder-finetuned-Shadertoys](https://huggingface.co/Vipitis/santacoder-finetuned-Shadertoys) fine-tuned by me on whole shaders from [Shadertoys](https://huggingface.co/datasets/Vipitis/Shadertoys). Does overfit quite a bit with greedy decoding. - [Vipitis/santacoder-finetuned-Shadertoys-fine](https://huggingface.co/Vipitis/santacoder-finetuned-Shadertoys-fine) fine-tuned by me just functions from [Shadertoys-fine](https://huggingface.co/datasets/Vipitis/Shadertoys-fine). Memorizes the exact function about half the time. - [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) a very large model which I haven't tried yet. - **any other model you want to** ## TODO (feel free to contribute with a [Pull-Request](https://huggingface.co/Vipitis/santacoder-finetuned-the-stack-glsl/discussions?status=open&type=pull_request)): - [x] use embedded Shadertoy for reference/attribution (done, but some errors) - [~] working render implementation on CPU only space (as webgl via webglfundamentals, ccs needs fixing for iframe (or hijack Shadertoy iframe)) - [~] generate variations of return statements [ShaderEval task1](https://huggingface.co/spaces/Vipitis/ShaderEval) (needs to be reworked using the other parts) - [x] generate whole functions (seems to work quite well) - [] dropdown for model selection (from curated list or all supported models?) - [] generation history stating which function and orig/generated returns. (use State ??). do it as comments in the code? - [] display errros/issues to the user (raise gr.Error could be one idea, but highlighting in the code would be awesome) - [] generate whole shaders (via prompts guidance, recursive from errors) - [] accordion with generation parameters (as pipeline_kwargs?) look up starcoder playround and take "inspiration" from there - [] support FIM task for better model context - [] gradio examples ### Notes: - this is meant as a resource to show code generation for a "creative" task. - the goal is not to not replace shader artists, but aims to be an assistant instead. - the space still lacks quite a lot of features, but will continue to evolve. - this demo can be useful to sannity check evaluation results, where the academic numbers are made. - If you create a remix with these tools, please attribute the original creator of your starting point when sharing the results. (And perhaps share in the [discussion tab](https://huggingface.co/Vipitis/santacoder-finetuned-the-stack-glsl/discussions?status=open&type=discussion) too) """ passes_dataset = datasets.load_dataset("Vipitis/Shadertoys") single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions. all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]]) num_samples = len(all_single_passes) import tree_sitter from tree_sitter import Language, Parser Language.build_library("./build/my-languages.so", ['tree-sitter-glsl']) GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl') parser = Parser() parser.set_language(GLSL_LANGUAGE) def grab_sample(sample_idx): sample_pass = all_single_passes[sample_idx] sample_code = sample_pass["code"] sample_source = sample_pass["source"] sample_title = sample_pass["title"] sample_auhtor = sample_pass["author"] source_iframe = construct_embed(sample_source) print(f"{source_iframe=}") # sample_funcs = _parse_functions(sample_code) # funcs = _parse_functions(sample_code) # func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)] # print(f"updating drop down to:{func_identifiers}") return sample_pass, sample_code, source_iframe, funcs#, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor def _parse_functions(in_code): """ returns all functions in the code as their actual nodes. """ tree = parser.parse(bytes(in_code, "utf8")) funcs = [n for n in tree.root_node.children if n.type == "function_definition"] return funcs PIPE = None def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing tokenizer = AutoTokenizer.from_pretrained(model_cp, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_cp, trust_remote_code=True) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True) PIPE = pipe # set the global? print(f"loaded model {model_cp} as a pipline") return pipe def process_retn(retn): return retn.split(";")[0].strip() def get_full_replacement(orig_code, retn_start_idx, retn_end_idx, prediction) -> str: """ Batches the generated return statement into the code and returns the full altered code. """ print(f"{orig_code[retn_start_idx:retn_end_idx]=}") generated = process_retn(prediction) print(f"{generated=}") variation = orig_code[:retn_start_idx] + generated + orig_code[retn_end_idx:] return variation def alter_return(orig_code, func_idx="0:", pipeline=PIPE): #default pipeline can't be passed as gloabl? """ Replaces the return statement of a function with a generated one. Args: orig_code (str): The original code. func_idx (int): The index of the function to replace the return statement of. pipeline (Pipeline): The pipeline to use for generation. Returns: str: The altered code. """ if pipeline is None: print("no pipeline found, loading default one") pipeline = _make_pipeline() print(f"{func_idx=}") func_idx = int(func_idx.split(":")[0].strip()) retrns = [] retrn_start_idx = orig_code.find("return") while retrn_start_idx != -1: retrn_end_idx = orig_code.find(";", retrn_start_idx) retrns.append((retrn_start_idx, retrn_end_idx)) retrn_start_idx = orig_code.find("return", retrn_end_idx) num_returns = len(retrns) if num_returns == 0: print("no return statement found, returning original code") return orig_code func_idx = int(max(0, min(func_idx, num_returns - 1))) #clamp to valid range, cast to int as a bodge. retrn_start_idx, retrn_end_idx = retrns[func_idx] model_context = orig_code[:retrn_start_idx] #TODO: maximal context? model_inp = model_context + "return" new_toks = (retrn_end_idx - retrn_start_idx) * 2 #TODO: approximation, we do have early stopping? maybe also use a number instead? pipe_generation = pipeline(model_inp, max_new_tokens=new_toks, return_full_text=False)[0]["generated_text"] #pipeline kwargs are missing?! altered_code = get_full_replacement(orig_code, retrn_start_idx+7, retrn_end_idx, pipe_generation) return altered_code def _line_chr2char(text, line_idx, chr_idx): """ returns the character index at the given line and character index. """ lines = text.split("\n") char_idx = 0 for i in range(line_idx): char_idx += len(lines[i]) + 1 char_idx += chr_idx return char_idx def alter_body(old_code, func_id: str, funcs_list: list, pipeline=PIPE): """ Replaces the body of a function with a generated one. Args: old_code (str): The original code. func_node (Node): The node of the function to replace the body of. pipeline (Pipeline): The pipeline to use for generation. Returns: str: The altered code. """ print(f"{func_id=}") func_id = int(func_id.split(":")[0].strip()) #undo their string casting? func_node = funcs_list[func_id] print(f"using for generation: {func_node=}") print(f"{pipeline=}") # check if default even loaded if pipeline is None: print("no pipeline found, loading default one") pipeline = _make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine") func_start_idx = _line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1]) identifier_str = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() body_node = func_node.child_by_field_name("body") body_start_idx = _line_chr2char(old_code, body_node.start_point[0], body_node.start_point[1]) body_end_idx = _line_chr2char(old_code, body_node.end_point[0], body_node.end_point[1]) print(f"{old_code[body_start_idx:body_end_idx]=}") model_context = identifier_str # just this num_new_tokens = max(160,(body_end_idx - body_start_idx) + 10) #TODO: approximation, we do have early stopping? maybe also use a number instead? HARD MAX for performance limits. print(f"generating up to {num_new_tokens} after {model_context!r}") generation = pipeline(model_context, max_new_tokens=num_new_tokens, return_full_text=False)[0]["generated_text"] print(f"{generation=}") id_with_generation = identifier_str + generation print(f"{id_with_generation=}") try: #strip the body first_gened_func = _parse_functions(id_with_generation)[0] # truncate generation to a single function? except IndexError: print("generation wasn't a full function.") altered_code = old_code[:body_start_idx] + generation + "//the generation didn't complete the function!\n" + old_code[body_end_idx:] #needs a newline to break out of the comment. return altered_code, pipeline # raise gr.Error(f"didn't generate a full function: {generation!r}]") print(f"{first_gened_func=}") generated_body = first_gened_func.child_by_field_name("body").text.decode() print(f"{generated_body=}") altered_code = old_code[:body_start_idx] + generated_body + old_code[body_end_idx:] return altered_code, pipeline def add_history(func_id, orig_rtn, gened_rtn, history): # is this a list? or a JSON dict? history[func_id] = (orig_rtn, gened_rtn) return history, history def list_dropdown(in_code): #only used for auto update, not on sample pick? funcs = _parse_functions(in_code) # print(f"updating drop down to:{func_identifiers=}") func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)] # funcs = [n for n in funcs] #wrapped as set to avoid json issues? print(f"updating drop down to:{func_identifiers}") return funcs, gr.Dropdown.update(choices=func_identifiers) def construct_embed(source_url): shader_id = source_url.split("/")[-1] return f'' with gr.Blocks() as site: top_md = gr.Markdown(intro_text) model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True) sample_idx = gr.Slider(minimum=0, maximum=num_samples, value=3211, label="pick sample from dataset", step=1.0) func_dropdown = gr.Dropdown(label="chose a function to modify") #breaks if I add a string in before that? with gr.Row(): gen_return_button = gr.Button("generate a alternate return statement", label="generate return") gen_func_button = gr.Button("generate an alternate function body", label="generate function") # update_funcs_button = gr.Button("update functions", label="update functions") with gr.Row(): with gr.Column(): source_embed = gr.HTML('', label="How this shader originally renders") our_embed = gr.HTML(label="glsl render of the current code") sample_code = gr.Code("// touch the slider to select a shader", label="Current Code (will update changes you generate)", language=None) bot_md = gr.Markdown(outro_text) sample_pass = gr.State(value={}) pipe = gr.State(value=PIPE) pipe.value=_make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine") # set a default like this? funcs = gr.State(value=[]) # hist_state = gr.State(Value={}) # history_table = gr.JSON() model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe]) # how can we trigger this on load? sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, source_embed]) gen_return_button.click(fn=alter_return, inputs=[sample_code, func_dropdown, pipe], outputs=[sample_code]) gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, pipe], outputs=[sample_code, pipe]) sample_code.change(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]) # to update this after generation, so spans aren't messed up sample_code.change(fn=make_iframe, inputs=[sample_code], outputs=[our_embed]) #twice could cause issues, find better ways. site.launch()