File size: 4,160 Bytes
ea87393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import ast
from pathlib import Path

import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch import nn


model_id = "answerdotai/ModernBERT-base"
path = "DanGalt/modernbert-code-comrel-synthetic"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(path)
sep = "[SEP]"


def prepare_input(example):
    tokens = tokenizer(
        example["function_definition"] + sep + example["code"] + sep + example["comment"],
        truncation=True,
        max_length=1024,
        return_tensors="pt"
    )
    return tokens


def parse_text(text):
    # NOTE: Doesn't collect comments and function definitions correctly
    inputs = []
    defs = []
    tree = ast.parse(text)
    for el in tree.body:
        if isinstance(el, ast.FunctionDef):
            defs.append((el.lineno - 1, el.end_lineno - 1, el.col_offset))

    inputs = []
    lines = text.split('\n')
    for lineno, line in enumerate(lines):
        if (offset := line.find('#')) != -1:
            corresponding_def = None
            for (def_l, def_el, def_off) in defs:
                if def_l <= lineno and def_off <= offset:
                    corresponding_def = (def_l, def_el, def_off)

            comment = line[offset:]
            code = '\n'.join(lines[lineno - 4:lineno + 4])
            fdef = "None"
            if corresponding_def is not None:
                fdef = [lines[corresponding_def[0]][corresponding_def[2]:]]
                cur_lineno = corresponding_def[0]
                while cur_lineno <= corresponding_def[1]:
                    if lines[cur_lineno].find("):") != -1 or lines[cur_lineno].find("->") != -1:
                        fdef += lines[corresponding_def[0] + 1:cur_lineno + 1]
                        break
                    cur_lineno += 1

                fdef = '\n'.join(fdef).strip()

            inputs.append({
                "function_definition": fdef,
                "code": code,
                "comment": comment,
                "lineno": lineno
            })
    return inputs


def predict(inp, model=model):
    with torch.no_grad():
        out = model(**inp)
    return nn.functional.softmax(out.logits, dim=-1)[0, 1].item()


def parse_and_predict(text, thrd=0.0):
    parsed = parse_text(text)
    preds = [predict(prepare_input(p)) for p in parsed]
    result = []
    for i, p in enumerate(preds):
        if thrd > 0:
            p = thrd > p
        result.append((parsed[i]["lineno"], p))

    return result


def parse_and_predict_file(path, thrd=0.0):
    text = Path(path).open("r").read()
    return parse_and_predict(text, thrd)


def parse_and_predict_pretty_out(text, thrd=0.0):
    results = parse_and_predict(text, thrd=thrd)
    lines = text.split('\n')
    output = []
    if thrd > 0:
        for lineno, do_warn in results:
            if do_warn:
                output.append(f"The comment on line {lineno} is incorrect: '{lines[lineno]}'.")
    else:
        for lineno, p in results:
            output.append(f"The comment on line {lineno} is estimated to be correct with probability {p:.2f}: '{lines[lineno]}'.")
    return '\n'.join(output)


example_text = """a = 3
b = 2
# The code below does some calculations based on a predefined rule that is very important
c = a - b  # Calculate and store the sum of a and b in c
d = a + b  # Calculate and store the sum of a and b in d
e = c * b  # Calculate and store the product of c and d in e
print(f"Wow, maths: {[a, b, c, d, e]}")"""

gradio_app = gr.Interface(
    fn=parse_and_predict_pretty_out,
    inputs=[
        gr.Textbox(label="Input", lines=7),
        gr.Slider(value=0.8, minimum=0.0, maximum=1.0, step=0.05)],
    outputs=[gr.Textbox(label="Predictions", lines=7)],
    examples=[[example_text, 0.0], [example_text, 0.53]],
    title="Comment \"Correctness\" Classifier",
    description='Calculates probabilities for each comment in text to be "correct"/"relevant". If the threshold is 0, outputs raw predictions. Otherwise, will report only "incorrect" comments.'
)

if __name__ == "__main__":
    gradio_app.launch()