File size: 4,466 Bytes
6e8c033
2d141af
 
 
6e8c033
 
 
2d141af
27e63ab
 
 
 
 
 
 
c6524f1
27e63ab
 
 
 
829134c
27e63ab
829134c
 
 
 
27e63ab
 
 
 
 
 
abed9bd
 
 
 
27e63ab
 
 
 
 
 
 
2d141af
3f8d823
2d141af
 
 
 
c6524f1
2d141af
 
 
 
 
3f8d823
2d141af
 
 
 
50c1955
 
 
 
 
 
 
 
 
 
2d141af
 
abed9bd
50c1955
 
 
 
 
 
abed9bd
50c1955
 
abed9bd
c6524f1
2d141af
3f8d823
2d141af
 
 
 
 
491ed03
2d141af
 
 
 
61d090a
 
2d141af
 
 
80cd783
 
491ed03
 
2d141af
50c1955
 
 
 
491ed03
50c1955
 
3f8d823
2d141af
abed9bd
2d141af
 
 
 
 
abed9bd
 
 
 
2d141af
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import tree_sitter_glsl as tsglsl
import tree_sitter
from tree_sitter import Language, Parser

GLSL_LANGUAGE = Language(tsglsl.language())

parser = Parser(GLSL_LANGUAGE)

def replace_function(old_func_node, new_func_node):
    """
    replaces the old function node with the new function node
    """
    tree = give_tree(old_func_node)
    old_func_start, old_func_end = node_str_idx(old_func_node)
    # new_func_start, new_func_end = node_str_idx(new_func_node)
    new_code = tree.text[:old_func_start].decode() + new_func_node.text.decode() + tree.text[old_func_end:].decode()
    return new_code

def get_root(node):
    """
    returns the root node the tree of the given node (recursively)
    """
    if node.parent is None:
        return node
    else:
        return get_root(node.parent)

def node_str_idx(node):
    """
    returns the character index of start and end of a node
    """
    whole_text = get_root(node).text.decode()
    # start_idx = line_chr2char(whole_text, node.start_point[0], node.start_point[1])
    # end_idx = line_chr2char(whole_text, node.end_point[0], node.end_point[1])
    start_idx = node.start_byte #actual numbers?
    end_idx = node.end_byte
    return start_idx, end_idx

def give_tree(func_node):
    """
    return the tree where this function node is in
    """
    return parser.parse(func_node.parent.text) #really no better way?

def parse_functions(in_code):
    """
    returns all functions in the code as their actual nodes.
    includes any comment made directly after the function definition or diretly after #copilot trigger
    """
    tree = parser.parse(bytes(in_code, encoding="utf-8"))
    funcs = [n for n in tree.root_node.children if n.type == "function_definition"]

    return funcs


def get_docstrings(func_node):
    """
    returns the docstring of a function node
    """
    docstring = ""
    for node in func_node.children:
        if node.type == "comment": #comment in like the declarator
            docstring += node.text.decode()
        elif node.type == "compound_statement": #body below here
            for body_node in node.children:
                if body_node.type == "comment" or body_node.type == "{":
                    docstring += " " * body_node.start_point[1] #add in indentation
                    docstring += body_node.text.decode() + "\n" 
                else:
                    return docstring
    return docstring

def full_func_head(func_node) -> str:
    """
    returns function head including docstrings before any real body code
    """
    cursor = func_node.child_by_field_name("body").walk()
    cursor.goto_first_child()
    while cursor.node.type == "comment" or cursor.node.type == "{":
        last_char = cursor.node.end_byte
        cursor.goto_next_sibling()
    end = cursor.node.start_point
    # return "\n".join(func_node.text.decode().split("\n")[:(end[0]-func_node.start_point[0])])[:-(last_char)-1]
    return func_node.text[:(last_char - func_node.start_byte)].decode()

def grab_before_comments(func_node):
    """
    returns the comments that happen just before a function node
    """
    precomment = ""
    last_comment_line = 0
    start_byte = func_node.start_byte
    for node in func_node.parent.children: #could you optimize where to iterated from? directon?
        if node.start_point[0] != last_comment_line + 1:
            precomment = ""
        if node.type == "comment":
            if precomment == "":
                start_byte = node.start_byte
            precomment += node.text.decode() + "\n"
            last_comment_line = node.start_point[0]
        elif node == func_node:
            if precomment == "":
                start_byte = node.start_byte
            return precomment, start_byte
    return precomment, start_byte

def has_docstrings(func_node):
    """
    returns whether a function node has a docstring 
    """
    return get_docstrings(func_node).strip() != "{" or grab_before_comments(func_node)[0] != ""


def line_chr2char(text, line_idx, chr_idx):
    """
    ## just use strat_byte and end_byte instead!
    returns the character index at the given line and character index.
    """
    lines = text.split("\n")
    char_idx = 0
    for i in range(line_idx):
        try:
            char_idx += len(lines[i]) + 1
        except IndexError as e:
            raise IndexError(f"{i=} of {line_idx=} does not exist in {text=}") from e
    char_idx += chr_idx
    return char_idx