Vipitis commited on
Commit
829134c
1 Parent(s): 198a9e6

fix recursive call in get_root

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. utils/tree_utils.py +5 -4
app.py CHANGED
@@ -235,9 +235,9 @@ def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2,
235
  func_start_idx = line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1])
236
  identifier_str = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx?
237
  body_node = func_node.child_by_field_name("body")
238
- # body_start_idx, body_end_idx = node_str_idx(body_node) #can cause index error, needs better testing!
239
- body_start_idx = line_chr2char(old_code, body_node.start_point[0], body_node.start_point[1])
240
- body_end_idx = line_chr2char(old_code, body_node.end_point[0], body_node.end_point[1])
241
  print(f"{old_code[body_start_idx:body_end_idx]=}")
242
  model_context = identifier_str # base case
243
  # add any comments at the beginning of the function to the model_context
 
235
  func_start_idx = line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1])
236
  identifier_str = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx?
237
  body_node = func_node.child_by_field_name("body")
238
+ body_start_idx, body_end_idx = node_str_idx(body_node) #can cause index error, needs better testing!
239
+ # body_start_idx = line_chr2char(old_code, body_node.start_point[0], body_node.start_point[1])
240
+ # body_end_idx = line_chr2char(old_code, body_node.end_point[0], body_node.end_point[1])
241
  print(f"{old_code[body_start_idx:body_end_idx]=}")
242
  model_context = identifier_str # base case
243
  # add any comments at the beginning of the function to the model_context
utils/tree_utils.py CHANGED
@@ -18,11 +18,12 @@ def replace_function(old_func_node, new_func_node):
18
 
19
  def get_root(node):
20
  """
21
- returns the root node of a node
22
  """
23
- while node.parent != None:
24
- node = node.parent
25
- return node.parent #still need to call parent here
 
26
 
27
  def node_str_idx(node):
28
  """
 
18
 
19
  def get_root(node):
20
  """
21
+ returns the root node the tree of the given node (recursively)
22
  """
23
+ if node.parent is None:
24
+ return node
25
+ else:
26
+ return get_root(node.parent)
27
 
28
  def node_str_idx(node):
29
  """