Vipitis commited on
Commit
50c1955
1 Parent(s): 0a4dd43

improved docstring extraction

Browse files
Files changed (3) hide show
  1. utils/__init__.py +2 -2
  2. utils/generation.py +2 -6
  3. utils/tree_utils.py +27 -5
utils/__init__.py CHANGED
@@ -1,8 +1,8 @@
1
- from .tree_utils import (parse_functions, get_docstrings, grab_before_comments, line_chr2char, replace_function, get_root, node_str_idx, give_tree)
2
  from .html_utils import (make_iframe, make_script, construct_embed)
3
  from .generation import (combine_generation_kwargs, stream_generation, construct_model_context)
4
 
5
- tree_funcs = ["parse_functions", "get_docstrings", "grab_before_comments", "line_chr2char", "replace_function", "get_root", "node_str_idx", "give_tree"]
6
  html_funcs = ["make_iframe", "make_script", "construct_embed"]
7
  gen_funcs = ["combine_generation_kwargs", "stream_generation", "construct_model_context"]
8
 
 
1
+ from .tree_utils import (parse_functions, get_docstrings, grab_before_comments, line_chr2char, replace_function, get_root, node_str_idx, give_tree, full_func_head, has_docstrings)
2
  from .html_utils import (make_iframe, make_script, construct_embed)
3
  from .generation import (combine_generation_kwargs, stream_generation, construct_model_context)
4
 
5
+ tree_funcs = ["parse_functions", "get_docstrings", "grab_before_comments", "line_chr2char", "replace_function", "get_root", "node_str_idx", "give_tree", "full_func_head", "has_docstrings"]
6
  html_funcs = ["make_iframe", "make_script", "construct_embed"]
7
  gen_funcs = ["combine_generation_kwargs", "stream_generation", "construct_model_context"]
8
 
utils/generation.py CHANGED
@@ -1,6 +1,6 @@
1
  from transformers import TextIteratorStreamer
2
  from threading import Thread
3
- from .tree_utils import get_docstrings, grab_before_comments
4
 
5
  def combine_generation_kwargs(temperature=2.0, max_new_tokens=512, top_p=0.95, repetition_penalty=1.2):
6
  """
@@ -47,11 +47,7 @@ def construct_model_context(func_node, prompt="") -> str:
47
  """
48
  Constructs the model context from a function node.
49
  """
50
- model_context = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx?
51
- docstring = get_docstrings(func_node) #might be empty?
52
- if docstring:
53
- model_context = model_context + "\n" + docstring
54
- model_context = grab_before_comments(func_node) + model_context #prepend comments
55
  if prompt != "":
56
  model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title
57
  model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint
 
1
  from transformers import TextIteratorStreamer
2
  from threading import Thread
3
+ from .tree_utils import full_func_head, grab_before_comments
4
 
5
  def combine_generation_kwargs(temperature=2.0, max_new_tokens=512, top_p=0.95, repetition_penalty=1.2):
6
  """
 
47
  """
48
  Constructs the model context from a function node.
49
  """
50
+ model_context = grab_before_comments(func_node) + full_func_head(func_node) # (identifier + docstrings)
 
 
 
 
51
  if prompt != "":
52
  model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title
53
  model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint
utils/tree_utils.py CHANGED
@@ -56,13 +56,28 @@ def get_docstrings(func_node):
56
  returns the docstring of a function node
57
  """
58
  docstring = ""
59
- for node in func_node.child_by_field_name("body").children:
60
- if node.type == "comment" or node.type == "{":
61
- docstring += node.text.decode() + "\n"
62
- else:
63
- return docstring
 
 
 
 
 
64
  return docstring
65
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def grab_before_comments(func_node):
68
  """
@@ -80,6 +95,13 @@ def grab_before_comments(func_node):
80
  return precomment
81
  return precomment
82
 
 
 
 
 
 
 
 
83
  def line_chr2char(text, line_idx, chr_idx):
84
  """
85
  returns the character index at the given line and character index.
 
56
  returns the docstring of a function node
57
  """
58
  docstring = ""
59
+ for node in func_node.children:
60
+ if node.type == "comment": #comment in like the declarator
61
+ docstring += node.text.decode()
62
+ elif node.type == "compound_statement": #body below here
63
+ for body_node in node.children:
64
+ if body_node.type == "comment" or body_node.type == "{":
65
+ docstring += " " * body_node.start_point[1] #add in indentation
66
+ docstring += body_node.text.decode() + "\n"
67
+ else:
68
+ return docstring
69
  return docstring
70
 
71
+ def full_func_head(func_node):
72
+ """
73
+ returns function head including docstrings before any real body code
74
+ """
75
+ cursor = func_node.child_by_field_name("body").walk()
76
+ cursor.goto_first_child()
77
+ while cursor.node.type == "comment" or cursor.node.type == "{":
78
+ cursor.goto_next_sibling()
79
+ end = cursor.node.start_point
80
+ return "\n".join(func_node.text.decode().split("\n")[:(end[0]-func_node.start_point[0])])
81
 
82
  def grab_before_comments(func_node):
83
  """
 
95
  return precomment
96
  return precomment
97
 
98
+ def has_docstrings(func_node):
99
+ """
100
+ returns whether a function node has a docstring
101
+ """
102
+ return get_docstrings(func_node).strip() != "{" or grab_before_comments(func_node) != ""
103
+
104
+
105
  def line_chr2char(text, line_idx, chr_idx):
106
  """
107
  returns the character index at the given line and character index.