|
from __future__ import annotations |
|
|
|
import re |
|
import requests |
|
from dataclasses import dataclass |
|
|
|
import gradio as gr |
|
from tree_sitter import Tree, Node |
|
from tree_sitter_languages import get_parser |
|
|
|
def non_whitespace_len(s: str) -> int: |
|
return len(re.sub("\s", "", s)) |
|
|
|
def get_line_number(index: int, source_code: str) -> int: |
|
total_chars = 0 |
|
for line_number, line in enumerate(source_code.splitlines(keepends=True), start=1): |
|
total_chars += len(line) |
|
if total_chars > index: |
|
return line_number - 1 |
|
return line_number |
|
|
|
@dataclass |
|
class Span: |
|
|
|
start: int = 0 |
|
end: int = 0 |
|
|
|
def __post_init__(self): |
|
|
|
if self.end is None: |
|
self.end = self.start |
|
|
|
def extract(self, s: str) -> str: |
|
|
|
return s[self.start: self.end] |
|
|
|
def extract_lines(self, s: str) -> str: |
|
|
|
return "\n".join(s.splitlines()[self.start:self.end]) |
|
|
|
def __add__(self, other: Span | int) -> Span: |
|
|
|
|
|
|
|
if isinstance(other, int): |
|
return Span(self.start + other, self.end + other) |
|
elif isinstance(other, Span): |
|
return Span(self.start, other.end) |
|
else: |
|
raise NotImplementedError() |
|
|
|
def __len__(self) -> int: |
|
|
|
return self.end - self.start |
|
|
|
def chunk_tree( |
|
tree: Tree, |
|
source_code: bytes, |
|
MAX_CHARS=512 * 3, |
|
coalesce=50 |
|
) -> list[Span]: |
|
|
|
|
|
def chunk_node(node: Node) -> list[Span]: |
|
chunks: list[Span] = [] |
|
current_chunk: Span = Span(node.start_byte, node.start_byte) |
|
node_children = node.children |
|
for child in node_children: |
|
if child.end_byte - child.start_byte > MAX_CHARS: |
|
chunks.append(current_chunk) |
|
current_chunk = Span(child.end_byte, child.end_byte) |
|
chunks.extend(chunk_node(child)) |
|
elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS: |
|
chunks.append(current_chunk) |
|
current_chunk = Span(child.start_byte, child.end_byte) |
|
else: |
|
current_chunk += Span(child.start_byte, child.end_byte) |
|
chunks.append(current_chunk) |
|
return chunks |
|
chunks = chunk_node(tree.root_node) |
|
|
|
|
|
for prev, curr in zip(chunks[:-1], chunks[1:]): |
|
prev.end = curr.start |
|
curr.start = tree.root_node.end_byte |
|
|
|
|
|
new_chunks = [] |
|
current_chunk = Span(0, 0) |
|
for chunk in chunks: |
|
current_chunk += chunk |
|
if non_whitespace_len(current_chunk.extract(source_code.decode("utf-8"))) > coalesce \ |
|
and "\n" in current_chunk.extract(source_code.decode("utf-8")): |
|
new_chunks.append(current_chunk) |
|
current_chunk = Span(chunk.end, chunk.end) |
|
if len(current_chunk) > 0: |
|
new_chunks.append(current_chunk) |
|
|
|
|
|
line_chunks = [ |
|
Span( |
|
get_line_number(chunk.start, source_code), |
|
get_line_number(chunk.end, source_code) |
|
) |
|
for chunk in new_chunks |
|
] |
|
|
|
|
|
line_chunks = [chunk for chunk in line_chunks if len(chunk) > 0] |
|
|
|
return line_chunks |
|
|
|
css = """ |
|
.code_container { |
|
} |
|
""" |
|
|
|
def chunk_code( |
|
code: str, |
|
language: str, |
|
MAX_CHARS: int, |
|
coalesce: int |
|
): |
|
try: |
|
parser = get_parser(language) |
|
tree = parser.parse(code.encode("utf-8")) |
|
chunks = chunk_tree(tree, code.encode("utf-8"), MAX_CHARS=MAX_CHARS, coalesce=coalesce) |
|
chunks = [chunk.extract_lines(code) for chunk in chunks] |
|
return "\n\n====================\n\n".join(chunks) |
|
except Exception as e: |
|
return str(e) |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown("## Code Chunking Demo") |
|
gr.Markdown("Start typing below and the chunked output will automatically show up. Checkout how this algorithm works at https://docs.sweep.dev/blogs/chunking-2m-files and https://docs.sweep.dev/blogs/chunking-improvements. We also have interactive notebooks at ") |
|
|
|
default_file = "https://raw.githubusercontent.com/sweepai/sweep/b267b613d4c706eaf959fe6789f11e9a856521d1/sweepai/handlers/on_check_suite.py" |
|
default_code = requests.get(default_file).text |
|
|
|
with gr.Row(): |
|
language = gr.Dropdown(["python", "javascript", "go", "ruby", "java", "php", "c", "cpp", "rust", "haskell"], label="Language", value="python") |
|
max_chars = gr.Slider(1, 3000, 1500, label="Max Characters") |
|
coalesce = gr.Slider(0, 300, 100, label="Coalesce") |
|
with gr.Row(): |
|
inp = gr.Code(placeholder="Enter the code here", label="Code to Chunk", language=language.value, lines=60, elem_classes="code_container", value=default_code) |
|
out = gr.Code(label="Chunked Code", language=language.value, lines=60, value=chunk_code(default_code, language.value, max_chars.value, coalesce.value)) |
|
|
|
def update_language(inp, language, max_chars, coalesce): |
|
return ( |
|
gr.update(language=language), |
|
gr.update(language=language, value=chunk_code(inp.value, language, max_chars, coalesce)) |
|
) |
|
|
|
language.change(fn=update_language, inputs=[inp, language, max_chars, coalesce], outputs=[inp, out]) |
|
max_chars.change(fn=chunk_code, inputs=[inp, language, max_chars, coalesce], outputs=out) |
|
coalesce.change(fn=chunk_code, inputs=[inp, language, max_chars, coalesce], outputs=out) |
|
inp.change(fn=chunk_code, inputs=[inp, language, max_chars, coalesce], outputs=out) |
|
|
|
demo.launch() |
|
|