Vipitis commited on
Commit
46dd33b
1 Parent(s): 283b861

fix iframe, add generation

Browse files
Files changed (1) hide show
  1. app.py +78 -12
app.py CHANGED
@@ -141,9 +141,10 @@ In the near future there will be some buttons and sliders to generate variations
141
  If I find an efficient way, the shaders might run in real time and be interactive.
142
 
143
  ## TODO:
144
- - [~] use embedded Shadertoy for reference/attribution (seems to not on change??)
145
  - [] working render implementation on CPU only space (use the browser for WebGPU?)
146
- - [] generate variations of return statements (ShaderEval task1)
 
147
  - [] generate whole functions
148
  - [] generate whole shaders (via prompts?)
149
  """
@@ -170,30 +171,95 @@ def grab_sample(sample_idx):
170
  sample_title = sample_pass["title"]
171
  sample_auhtor = sample_pass["author"]
172
  source_iframe = construct_embed(sample_source)
 
173
  return sample_pass, sample_code, source_iframe #, sample_title, sample_auhtor
174
 
175
- def _make_pipeline(model_cp):
 
 
176
  tokenizer = AutoTokenizer.from_pretrained(model_cp, trust_remote_code=True)
177
  model = AutoModelForCausalLM.from_pretrained(model_cp, trust_remote_code=True)
178
- return pipeline("text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  def construct_embed(source_url):
181
- return f'<iframe width="640" height="360" frameborder="0" src="{source_url}?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>'
 
182
 
183
  with gr.Blocks() as site:
184
  text_md = gr.Markdown(text)
185
- model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys", label="Model Checkpoint", interactive=True)
186
- sample_idx = gr.Slider(minimum=0, maximum=num_samples, value=5, label="pick sample from dataset", step=1.0)
187
- # run_button = gr.Button(label="generate code")
188
- render_button = gr.Button("render frame0",label="render frame")
189
- time_slider = gr.Slider(minimum=0, maximum=10, value=0, label="time (update on release)", step=0.02)
190
  #output = gr.Textbox(label="Output")
191
  rendered_frame = gr.Image(shape=(512, 420), label=f"rendered frame preview")
192
  # info_md = gr.Markdown(value="code_source", label="source URL for this shader", interactive=False)
193
- source_embed = gr.HTML('<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/XtcSDf?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>')
194
- sample_code = gr.Code(label="Sample Code", language=None, readonly=True, lines=20)
195
  sample_pass = gr.State(value={})
 
 
 
 
196
  sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, source_embed])
 
 
197
  # sample_idx.release(fn=construct_embed, inputs=[sample_idx], outputs=[source_embed]) #twice to make have different outputs?
198
  time_slider.release(fn=lambda code, time: asyncio.run(get_image(code, time)), inputs=[sample_code, time_slider], outputs=rendered_frame)
199
  render_button.click(fn=lambda code: asyncio.run(get_image(code)), inputs=[sample_code], outputs=rendered_frame)
 
141
  If I find an efficient way, the shaders might run in real time and be interactive.
142
 
143
  ## TODO:
144
+ - [x] use embedded Shadertoy for reference/attribution (done, but some errors)
145
  - [] working render implementation on CPU only space (use the browser for WebGPU?)
146
+ - [~] generate variations of return statements [ShaderEval task1](https://huggingface.co/spaces/Vipitis/ShaderEval) (missing all of the generation parameters)
147
+ - [] generation history stating which function and orig/generated returns. (use State ??).
148
  - [] generate whole functions
149
  - [] generate whole shaders (via prompts?)
150
  """
 
171
  sample_title = sample_pass["title"]
172
  sample_auhtor = sample_pass["author"]
173
  source_iframe = construct_embed(sample_source)
174
+ print(f"{source_iframe=}")
175
  return sample_pass, sample_code, source_iframe #, sample_title, sample_auhtor
176
 
177
+ PIPE = None #gloabl var in CAPS indicates constant? so why are we changing it?
178
+
179
+ def _make_pipeline(model_cp = "gpt2"): #bad default model for testing
180
  tokenizer = AutoTokenizer.from_pretrained(model_cp, trust_remote_code=True)
181
  model = AutoModelForCausalLM.from_pretrained(model_cp, trust_remote_code=True)
182
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True)
183
+ PIPE = pipe # set the global?
184
+ print(f"loaded model {model_cp} as a pipline")
185
+ return pipe
186
+
187
+ def process_retn(retn):
188
+ return retn.split(";")[0].strip()
189
+
190
+ def get_full_replacement(orig_code, retn_start_idx, retn_end_idx, prediction) -> str:
191
+ """
192
+ Batches the generated return statement into the code and returns the full altered code.
193
+ """
194
+ print(f"{orig_code[retn_start_idx:retn_end_idx]=}")
195
+ generated = process_retn(prediction)
196
+ print(f"{generated=}")
197
+ variation = orig_code[:retn_start_idx] + generated + orig_code[retn_end_idx:]
198
+ return variation
199
+
200
+ def alter_return(orig_code, func_idx=0, pipeline=PIPE): #default pipeline can't be passed as gloabl?
201
+ """
202
+ Replaces the return statement of a function with a generated one.
203
+ Args:
204
+ orig_code (str): The original code.
205
+ func_idx (int): The index of the function to replace the return statement of.
206
+ pipeline (Pipeline): The pipeline to use for generation.
207
+ Returns:
208
+ str: The altered code.
209
+ """
210
+ if pipeline is None:
211
+ print("no pipeline found, loading default one")
212
+ pipeline = _make_pipeline()
213
+
214
+ retrns = []
215
+ retrn_start_idx = orig_code.find("return")
216
+ while retrn_start_idx != -1:
217
+ retrn_end_idx = orig_code.find(";", retrn_start_idx)
218
+ retrns.append((retrn_start_idx, retrn_end_idx))
219
+ retrn_start_idx = orig_code.find("return", retrn_end_idx)
220
+ num_returns = len(retrns)
221
+ if num_returns == 0:
222
+ print("no return statement found, returning original code")
223
+ return orig_code
224
+ func_idx = int(max(0, min(func_idx, num_returns - 1))) #clamp to valid range, cast to int as a bodge.
225
+ retrn_start_idx, retrn_end_idx = retrns[func_idx]
226
+ model_context = orig_code[:retrn_start_idx] #TODO: maximal context?
227
+ model_inp = model_context + "return"
228
+ new_toks = (retrn_end_idx - retrn_start_idx) * 2 #TODO: approximation, we do have early stopping? maybe also use a number instead?
229
+ pipe_generation = pipeline(model_inp, max_new_tokens=new_toks, return_full_text=False)[0]["generated_text"] #pipeline kwargs are missing?!
230
+ altered_code = get_full_replacement(orig_code, retrn_start_idx+7, retrn_end_idx, pipe_generation)
231
+
232
+ return altered_code
233
+
234
+ def add_history(func_id, orig_rtn, gened_rtn, history):
235
+ # is this a list? or a JSON dict?
236
+ history[func_id] = (orig_rtn, gened_rtn)
237
+ return history, history
238
 
239
  def construct_embed(source_url):
240
+ shader_id = source_url.split("/")[-1]
241
+ return f'<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/{shader_id}?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>'
242
 
243
  with gr.Blocks() as site:
244
  text_md = gr.Markdown(text)
245
+ model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True)
246
+ sample_idx = gr.Slider(minimum=0, maximum=num_samples, value=3211, label="pick sample from dataset", step=1.0)
247
+ run_button = gr.Button("generate a alternate return statement for one function", label="generate code")
248
+ render_button = gr.Button("render frame0 (can carsh the sapce on invalid shadercode)",label="render frame")
249
+ time_slider = gr.Slider(minimum=0, maximum=10, value=0, label="time (update on release, also used to pick other functions as a bodge)", step=0.02)
250
  #output = gr.Textbox(label="Output")
251
  rendered_frame = gr.Image(shape=(512, 420), label=f"rendered frame preview")
252
  # info_md = gr.Markdown(value="code_source", label="source URL for this shader", interactive=False)
253
+ source_embed = gr.HTML('<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/WsBcWV?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>', label="How this shader originally renders")
254
+ sample_code = gr.Code(label="Current Code (will update changes you generate)", language=None, readonly=True, lines=20)
255
  sample_pass = gr.State(value={})
256
+ pipe = gr.State(value=PIPE)
257
+ # hist_state = gr.State(Value={})
258
+ # history_table = gr.JSON()
259
+ model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe])
260
  sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, source_embed])
261
+ run_button.click(fn=alter_return, inputs=[sample_code, time_slider, pipe], outputs=[sample_code])
262
+ # run_button.click(fn=add_history, inputs=[time_slider, sample_pass, sample_code, hist_state], outputs=[history_table, hist_state])
263
  # sample_idx.release(fn=construct_embed, inputs=[sample_idx], outputs=[source_embed]) #twice to make have different outputs?
264
  time_slider.release(fn=lambda code, time: asyncio.run(get_image(code, time)), inputs=[sample_code, time_slider], outputs=rendered_frame)
265
  render_button.click(fn=lambda code: asyncio.run(get_image(code)), inputs=[sample_code], outputs=rendered_frame)