File size: 5,920 Bytes
d596fb5
 
 
 
 
 
f749736
d596fb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f749736
d596fb5
 
 
 
f749736
d596fb5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from __future__ import annotations

import re
import requests
from dataclasses import dataclass

import gradio as gr
from tree_sitter import Tree, Node
from tree_sitter_languages import get_parser

def non_whitespace_len(s: str) -> int: # new len function
    return len(re.sub("\s", "", s))

def get_line_number(index: int, source_code: str) -> int:
    total_chars = 0
    for line_number, line in enumerate(source_code.splitlines(keepends=True), start=1):
        total_chars += len(line)
        if total_chars > index:
            return line_number - 1
    return line_number

@dataclass
class Span:
    # Represents a slice of a string
    start: int = 0
    end: int = 0

    def __post_init__(self):
        # If end is None, set it to start
        if self.end is None:
            self.end = self.start

    def extract(self, s: str) -> str:
        # Grab the corresponding substring of string s by bytes
        return s[self.start: self.end]

    def extract_lines(self, s: str) -> str:
        # Grab the corresponding substring of string s by lines
        return "\n".join(s.splitlines()[self.start:self.end])

    def __add__(self, other: Span | int) -> Span:
        # e.g. Span(1, 2) + Span(2, 4) = Span(1, 4) (concatenation)
        # There are no safety checks: Span(a, b) + Span(c, d) = Span(a, d)
        # and there are no requirements for b = c.
        if isinstance(other, int):
            return Span(self.start + other, self.end + other)
        elif isinstance(other, Span):
            return Span(self.start, other.end)
        else:
            raise NotImplementedError()

    def __len__(self) -> int:
        # i.e. Span(a, b) = b - a
        return self.end - self.start

def chunk_tree(
	tree: Tree,
	source_code: bytes,
	MAX_CHARS=512 * 3,
	coalesce=50 # Any chunk less than 50 characters long gets coalesced with the next chunk
) -> list[Span]:

    # 1. Recursively form chunks based on the last post (https://docs.sweep.dev/blogs/chunking-2m-files)
    def chunk_node(node: Node) -> list[Span]:
        chunks: list[Span] = []
        current_chunk: Span = Span(node.start_byte, node.start_byte)
        node_children = node.children
        for child in node_children:
            if child.end_byte - child.start_byte > MAX_CHARS:
                chunks.append(current_chunk)
                current_chunk = Span(child.end_byte, child.end_byte)
                chunks.extend(chunk_node(child))
            elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
                chunks.append(current_chunk)
                current_chunk = Span(child.start_byte, child.end_byte) 
            else:
                current_chunk += Span(child.start_byte, child.end_byte)
        chunks.append(current_chunk)
        return chunks
    chunks = chunk_node(tree.root_node)

    # 2. Filling in the gaps
    for prev, curr in zip(chunks[:-1], chunks[1:]):
        prev.end = curr.start
    curr.start = tree.root_node.end_byte

    # 3. Combining small chunks with bigger ones
    new_chunks = []
    current_chunk = Span(0, 0)
    for chunk in chunks:
        current_chunk += chunk
        if non_whitespace_len(current_chunk.extract(source_code.decode("utf-8"))) > coalesce \
            and "\n" in current_chunk.extract(source_code.decode("utf-8")):
            new_chunks.append(current_chunk)
            current_chunk = Span(chunk.end, chunk.end)
    if len(current_chunk) > 0:
        new_chunks.append(current_chunk)

    # 4. Changing line numbers
    line_chunks = [
        Span(
            get_line_number(chunk.start, source_code),
            get_line_number(chunk.end, source_code)
        )
        for chunk in new_chunks
    ]

    # 5. Eliminating empty chunks
    line_chunks = [chunk for chunk in line_chunks if len(chunk) > 0]

    return line_chunks

css = """
.code_container {
}
"""

def chunk_code(
    code: str,
    language: str,
    MAX_CHARS: int,
    coalesce: int
):
    try:
        parser = get_parser(language)
        tree = parser.parse(code.encode("utf-8"))
        chunks = chunk_tree(tree, code.encode("utf-8"), MAX_CHARS=MAX_CHARS, coalesce=coalesce)
        chunks = [chunk.extract_lines(code) for chunk in chunks]
        return "\n\n====================\n\n".join(chunks)
    except Exception as e:
        return str(e)

with gr.Blocks(css=css) as demo:
    gr.Markdown("Start typing below and the chunked output will automatically show up.")

    default_file = "https://raw.githubusercontent.com/sweepai/sweep/b267b613d4c706eaf959fe6789f11e9a856521d1/sweepai/handlers/on_check_suite.py"
    default_code = requests.get(default_file).text

    with gr.Row():
        language = gr.Dropdown(["python", "javascript", "go", "ruby", "java", "php", "c", "cpp", "rust", "haskell"], label="Language", value="python")
        max_chars = gr.Slider(100, 3000, 1500, label="Max Characters")
        coalesce = gr.Slider(0, 300, 100, label="Coalesce")
    with gr.Row():
        inp = gr.Code(placeholder="Enter the code here", label="Code to Chunk", language=language.value, lines=60, elem_classes="code_container", value=default_code)
        out = gr.Code(label="Chunked Code", language=language.value, lines=60, value=chunk_code(default_code, language.value, max_chars.value, coalesce.value))

    def update_language(inp, language, max_chars, coalesce):
        return (
            gr.update(language=language),
            gr.update(language=language, value=chunk_code(inp.value, language, max_chars, coalesce))
        )

    language.change(fn=update_language, inputs=[inp, language, max_chars, coalesce], outputs=[inp, out])
    max_chars.change(fn=chunk_code, inputs=[inp, language, max_chars, coalesce], outputs=out)
    coalesce.change(fn=chunk_code, inputs=[inp, language, max_chars, coalesce], outputs=out)
    inp.change(fn=chunk_code, inputs=[inp, language, max_chars, coalesce], outputs=out)

demo.launch()