Vipitis commited on
Commit
db24268
1 Parent(s): 9f8916a

improved model context

Browse files
Files changed (1) hide show
  1. app.py +39 -7
app.py CHANGED
@@ -272,10 +272,12 @@ outro_text ="""
272
  - [] generate whole shaders (via prompts guidance, recursive from errors)
273
  - [x] accordion with generation parameters (as pipeline_kwargs?) look up starcoder playround and take "inspiration" from there (implemented for both buttons, untested)
274
  - [] support FIM task for better model context
275
- - [~] include some context for prompt (title, comments before a functions) - now works with the first comment inside a function body (has to be first)
276
  - [] gradio examples
277
  - [] use GPU if available, respect memory restrictions.
278
- - [~] stream model generation (maybe in a new window?) - WIP for body gen right now -> janky solution works.
 
 
279
 
280
  ### Notes:
281
  - this is meant as a resource to show code generation for a "creative" task.
@@ -295,6 +297,7 @@ new_shadertoy_code = """void mainImage( out vec4 fragColor, in vec2 fragCoord )
295
 
296
  passes_dataset = datasets.load_dataset("Vipitis/Shadertoys")
297
  single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions.
 
298
  all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]])
299
  num_samples = len(all_single_passes)
300
 
@@ -448,6 +451,34 @@ def _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_pe
448
  gen_kwargs["repetition_penalty"] = repetition_penalty
449
  return gen_kwargs
450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  def alter_body(old_code, func_id, funcs_list: list, temperature, max_new_tokens, top_p, repetition_penalty, pipeline=PIPE):
452
  """
453
  Replaces the body of a function with a generated one.
@@ -483,11 +514,12 @@ def alter_body(old_code, func_id, funcs_list: list, temperature, max_new_tokens,
483
  print(f"{old_code[body_start_idx:body_end_idx]=}")
484
  model_context = identifier_str # base case
485
  # add any comments at the beginning of the function to the model_context
486
- second_child = func_node.child_by_field_name("body").children[1] #might error out?
487
- if second_child.type == "comment":
488
- # print(second_child.text.decode())
489
- model_context += " { \n " + second_child.text.decode()
490
- print(f"{model_context=}")
 
491
  # generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
492
  generation = _run_generation(model_context, pipeline, generation_kwargs)
493
  for i in generation:
 
272
  - [] generate whole shaders (via prompts guidance, recursive from errors)
273
  - [x] accordion with generation parameters (as pipeline_kwargs?) look up starcoder playround and take "inspiration" from there (implemented for both buttons, untested)
274
  - [] support FIM task for better model context
275
+ - [x] include some context for prompt (title, comments before a functions) - now takes all comments directly before a function as well as all comments at the beginning inside a function.
276
  - [] gradio examples
277
  - [] use GPU if available, respect memory restrictions.
278
+ - [x] stream model generation (maybe in a new window?) - janky solution and only sometimes hangs up
279
+ - [] 2nd iFrame needs a lot of fixing (I am not a web developer, need help)
280
+ - [] (optional) filtering the dataset by license?
281
 
282
  ### Notes:
283
  - this is meant as a resource to show code generation for a "creative" task.
 
297
 
298
  passes_dataset = datasets.load_dataset("Vipitis/Shadertoys")
299
  single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions.
300
+ # single_passes = single_passes.filter(lambda x: x["license"] not in "copyright") #to avoid any "do not display this" license?
301
  all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]])
302
  num_samples = len(all_single_passes)
303
 
 
451
  gen_kwargs["repetition_penalty"] = repetition_penalty
452
  return gen_kwargs
453
 
454
+ def _grab_before_comments(func_node):
455
+ """
456
+ returns the comments that happen just before a function node
457
+ """
458
+ precomment = ""
459
+ last_comment_line = 0
460
+ for node in func_node.parent.children: #could you optimize where to iterated from? directon?
461
+ if node.start_point[0] != last_comment_line + 1:
462
+ precomment = ""
463
+ if node.type == "comment":
464
+ precomment += node.text.decode() + "\n"
465
+ last_comment_line = node.start_point[0]
466
+ elif node == func_node:
467
+ return precomment
468
+ return precomment
469
+
470
+ def _get_docstrings(func_node):
471
+ """
472
+ returns the docstring of a function node
473
+ """
474
+ docstring = ""
475
+ for node in func_node.child_by_field_name("body").children[1:]:
476
+ if node.type == "comment":
477
+ docstring += node.text.decode() + "\n"
478
+ else:
479
+ return docstring
480
+ return docstring
481
+
482
  def alter_body(old_code, func_id, funcs_list: list, temperature, max_new_tokens, top_p, repetition_penalty, pipeline=PIPE):
483
  """
484
  Replaces the body of a function with a generated one.
 
514
  print(f"{old_code[body_start_idx:body_end_idx]=}")
515
  model_context = identifier_str # base case
516
  # add any comments at the beginning of the function to the model_context
517
+ # second_child = func_node.child_by_field_name("body").children[1] #might error out?
518
+ docstring = _get_docstrings(func_node) #might be empty?
519
+ if docstring:
520
+ model_context = model_context + "\n{\n" + docstring + "\n"
521
+ model_context = _grab_before_comments(func_node) + model_context
522
+ print(f"{model_context=}")
523
  # generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
524
  generation = _run_generation(model_context, pipeline, generation_kwargs)
525
  for i in generation: