import gradio as gr from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer from wgpu.utils.shadertoy import * from wgpu.gui.offscreen import WgpuCanvas as OffscreenCanvas, run as run_offscreen import wgpu import time import ctypes import datasets from PIL import Image import asyncio import numpy as np # reimplement the Shadertoy class with offscreen canvas! class ShadertoyCustom(Shadertoy): def __init__(self, shader_code, resolution=(800, 450), canvas_class=WgpuCanvas, run_fn=run): self._canvas_class = canvas_class self._fun_fn = run_fn super().__init__(shader_code, resolution) self._uniform_data = UniformArray( ("mouse", "f", 4), ("resolution", "f", 3), ("time", "f", 1), ("time_delta", "f", 1), ("frame", "I", 1), ) self._shader_code = shader_code self._uniform_data["resolution"] = resolution + (1,) self._prepare_render() self._bind_events() def _prepare_render(self): import wgpu.backends.rs # noqa self._canvas = self._canvas_class(title="Shadertoy", size=self.resolution, max_fps=60) adapter = wgpu.request_adapter( canvas=self._canvas, power_preference="high-performance" ) self._device = adapter.request_device() self._present_context = self._canvas.get_context() # We use "bgra8unorm" not "bgra8unorm-srgb" here because we want to let the shader fully control the color-space. self._present_context.configure( device=self._device, format=wgpu.TextureFormat.bgra8unorm ) shader_type = self.shader_type if shader_type == "glsl": vertex_shader_code = vertex_code_glsl frag_shader_code = ( builtin_variables_glsl + self.shader_code + fragment_code_glsl ) elif shader_type == "wgsl": vertex_shader_code = vertex_code_wgsl frag_shader_code = ( builtin_variables_wgsl + self.shader_code + fragment_code_wgsl ) vertex_shader_program = self._device.create_shader_module( label="triangle_vert", code=vertex_shader_code ) frag_shader_program = self._device.create_shader_module( label="triangle_frag", code=frag_shader_code ) self._uniform_buffer = self._device.create_buffer( size=self._uniform_data.nbytes, usage=wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST, ) bind_group_layout = self._device.create_bind_group_layout( entries=binding_layout ) self._bind_group = self._device.create_bind_group( layout=bind_group_layout, entries=[ { "binding": 0, "resource": { "buffer": self._uniform_buffer, "offset": 0, "size": self._uniform_data.nbytes, }, }, ], ) self._render_pipeline = self._device.create_render_pipeline( layout=self._device.create_pipeline_layout( bind_group_layouts=[bind_group_layout] ), vertex={ "module": vertex_shader_program, "entry_point": "main", "buffers": [], }, primitive={ "topology": wgpu.PrimitiveTopology.triangle_list, "front_face": wgpu.FrontFace.ccw, "cull_mode": wgpu.CullMode.none, }, depth_stencil=None, multisample=None, fragment={ "module": frag_shader_program, "entry_point": "main", "targets": [ { "format": wgpu.TextureFormat.bgra8unorm, "blend": { "color": ( wgpu.BlendFactor.one, wgpu.BlendFactor.zero, wgpu.BlendOperation.add, ), "alpha": ( wgpu.BlendFactor.one, wgpu.BlendFactor.zero, wgpu.BlendOperation.add, ), }, }, ], }, ) def show(self, time: float = 0.0): self._canvas.request_draw(self._draw_frame) self._fun_fn() text = """ # Welcome to the interactive shadercoding demo. ## (WIP), you can try and explore the dataset a bit right now. (frames are rendered on the fly, not part of the dataset(yet)) This gives you access to a filtered version of the [Shadertoys](https://huggingface.co/datasets/Vipitis/Shadertoys) dataset, only shaders that const of a single pass (and have at least one fuction with a return statement) are available. In the near future there will be some buttons and sliders to generate variations of the shadercode itself, and hence get some different images. If I find an efficient way, the shaders might run in real time and be interactive. """ passes_dataset = datasets.load_dataset("Vipitis/Shadertoys") single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1 and x["code"].count("return") >= 1) #filter easier than having a custom loader script? all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]]) num_samples = len(all_single_passes) async def get_image(code, time= 0.0, resolution=(512, 420)): shader = ShadertoyCustom(code, resolution, OffscreenCanvas, run_offscreen) #pass offscreen canvas here. shader._uniform_data["time"] = time #set any time you want shader._canvas.request_draw(shader._draw_frame) # frame = shader._canvas.snapshot().data frame = np.asarray(shader._canvas.draw()) img = Image.fromarray(frame) # remove transparent pixels img = img.convert('RGB') return img def grab_sample(sample_idx): sample_pass = all_single_passes[sample_idx] sample_code = sample_pass["code"] sample_source = sample_pass["source"] sample_title = sample_pass["title"] sample_auhtor = sample_pass["author"] return sample_code, sample_source #, sample_title, sample_auhtor def _make_pipeline(model_cp): tokenizer = AutoTokenizer.from_pretrained(model_cp, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_cp, trust_remote_code=True) return pipeline("text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True) with gr.Blocks() as site: text_md = gr.Markdown(text) model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys", label="Model Checkpoint", interactive=True) sample_idx = gr.Slider(minimum=0, maximum=num_samples, value=5, label="pick sample from dataset", step=1.0) # run_button = gr.Button(label="generate code") render_button = gr.Button("render frame0",label="render frame") time_slider = gr.Slider(minimum=0, maximum=10, value=0, label="time (update on release)", step=0.02) #output = gr.Textbox(label="Output") rendered_frame = gr.Image(shape=(512, 420), label=f"rendered frame preview") info_md = gr.Markdown(value="code_source", label="source URL for this shader", interactive=False) sample_code = gr.Code(label="Sample Code", language=None, readonly=True, lines=20) sample_pass = gr.State(value=None) sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_code, info_md]) time_slider.release(fn=lambda code, time: asyncio.run(get_image(code, time)), inputs=[sample_code, time_slider], outputs=rendered_frame) render_button.click(fn=lambda code: asyncio.run(get_image(code)), inputs=[sample_code], outputs=rendered_frame) # run_button.click(fn=print, inputs=[model_cp, sample_idx], outputs=output) site.launch()