File size: 4,286 Bytes
2d141af
 
 
0907806
2d141af
 
 
 
27e63ab
 
 
 
 
 
 
 
 
 
 
 
829134c
27e63ab
829134c
 
 
 
27e63ab
 
 
 
 
 
abed9bd
 
 
 
27e63ab
 
 
 
 
 
 
2d141af
3f8d823
2d141af
 
 
 
 
 
 
 
 
 
3f8d823
2d141af
 
 
 
50c1955
 
 
 
 
 
 
 
 
 
2d141af
 
abed9bd
50c1955
 
 
 
 
 
abed9bd
50c1955
 
abed9bd
 
2d141af
3f8d823
2d141af
 
 
 
 
 
 
 
 
 
 
 
 
abed9bd
2d141af
50c1955
 
 
 
 
 
 
3f8d823
2d141af
abed9bd
2d141af
 
 
 
 
abed9bd
 
 
 
2d141af
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import tree_sitter
from tree_sitter import Language, Parser

Language.build_library("./build/my-languages.so", ['./tree-sitter-glsl'])
GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl')
parser = Parser()
parser.set_language(GLSL_LANGUAGE)

def replace_function(old_func_node, new_func_node):
    """
    replaces the old function node with the new function node
    """
    tree = give_tree(old_func_node)
    old_func_start, old_func_end = node_str_idx(old_func_node)
    # new_func_start, new_func_end = node_str_idx(new_func_node)
    new_code = tree.text[:old_func_start] + new_func_node.text + tree.text[old_func_end:]
    return new_code

def get_root(node):
    """
    returns the root node the tree of the given node (recursively)
    """
    if node.parent is None:
        return node
    else:
        return get_root(node.parent)

def node_str_idx(node):
    """
    returns the character index of start and end of a node
    """
    whole_text = get_root(node).text.decode()
    # start_idx = line_chr2char(whole_text, node.start_point[0], node.start_point[1])
    # end_idx = line_chr2char(whole_text, node.end_point[0], node.end_point[1])
    start_idx = node.start_byte #actual numbers?
    end_idx = node.end_byte
    return start_idx, end_idx

def give_tree(func_node):
    """
    return the tree where this function node is in
    """
    return parser.parse(func_node.parent.text) #really no better way?

def parse_functions(in_code):
    """
    returns all functions in the code as their actual nodes.
    includes any comment made directly after the function definition or diretly after #copilot trigger
    """
    tree = parser.parse(bytes(in_code, "utf8"))
    funcs = [n for n in tree.root_node.children if n.type == "function_definition"]

    return funcs


def get_docstrings(func_node):
    """
    returns the docstring of a function node
    """
    docstring = ""
    for node in func_node.children:
        if node.type == "comment": #comment in like the declarator
            docstring += node.text.decode()
        elif node.type == "compound_statement": #body below here
            for body_node in node.children:
                if body_node.type == "comment" or body_node.type == "{":
                    docstring += " " * body_node.start_point[1] #add in indentation
                    docstring += body_node.text.decode() + "\n" 
                else:
                    return docstring
    return docstring

def full_func_head(func_node) -> str:
    """
    returns function head including docstrings before any real body code
    """
    cursor = func_node.child_by_field_name("body").walk()
    cursor.goto_first_child()
    while cursor.node.type == "comment" or cursor.node.type == "{":
        last_char = cursor.node.end_byte
        cursor.goto_next_sibling()
    end = cursor.node.start_point
    # return "\n".join(func_node.text.decode().split("\n")[:(end[0]-func_node.start_point[0])])[:-(last_char)-1]
    return func_node.text.decode()[:(last_char - func_node.start_byte)]

def grab_before_comments(func_node):
    """
    returns the comments that happen just before a function node
    """
    precomment = ""
    last_comment_line = 0
    for node in func_node.parent.children: #could you optimize where to iterated from? directon?
        if node.start_point[0] != last_comment_line + 1:
            precomment = ""
        if node.type == "comment":
            precomment += node.text.decode() + "\n"
            last_comment_line = node.start_point[0]
        elif node == func_node:
            return precomment
    return precomment 

def has_docstrings(func_node):
    """
    returns whether a function node has a docstring 
    """
    return get_docstrings(func_node).strip() != "{" or grab_before_comments(func_node) != ""


def line_chr2char(text, line_idx, chr_idx):
    """
    ## just use strat_byte and end_byte instead!
    returns the character index at the given line and character index.
    """
    lines = text.split("\n")
    char_idx = 0
    for i in range(line_idx):
        try:
            char_idx += len(lines[i]) + 1
        except IndexError as e:
            raise IndexError(f"{i=} of {line_idx=} does not exist in {text=}") from e
    char_idx += chr_idx
    return char_idx