mamiksik commited on
Commit
83b862c
1 Parent(s): 6af1ce6

Add t5predictor to app.py

Browse files
Files changed (2) hide show
  1. app.py +89 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import T5ForConditionalGeneration, RobertaTokenizer
6
+
7
+
8
+ tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01")
9
+ model = T5ForConditionalGeneration.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01")
10
+
11
+
12
+ def parse_files(accumulator: list[str], patch: str):
13
+ lines = patch.splitlines()
14
+
15
+ filename_before = None
16
+ for line in lines:
17
+ if line.startswith("index") or line.startswith("diff"):
18
+ continue
19
+ if line.startswith("---"):
20
+ filename_before = line.split(" ", 1)[1][1:]
21
+ continue
22
+
23
+ if line.startswith("+++"):
24
+ filename_after = line.split(" ", 1)[1][1:]
25
+
26
+ if filename_before == filename_after:
27
+ accumulator.append(f"<ide><path>{filename_before}")
28
+ else:
29
+ accumulator.append(f"<add><path>{filename_after}")
30
+ accumulator.append(f"<del><path>{filename_before}")
31
+ continue
32
+
33
+ line = re.sub("@@[^@@]*@@", "", line)
34
+ if len(line) == 0:
35
+ continue
36
+
37
+ if line[0] == "+":
38
+ line = line.replace("+", "<add>", 1)
39
+ elif line[0] == "-":
40
+ line = line.replace("-", "<del>", 1)
41
+ else:
42
+ line = f"<ide>{line}"
43
+
44
+ accumulator.append(line)
45
+
46
+ return accumulator
47
+
48
+
49
+ def predict(patch, max_length, min_length, num_beams, prediction_count):
50
+ accumulator = []
51
+ parse_files(accumulator, patch)
52
+ input_text = '\n'.join(accumulator)
53
+
54
+ with torch.no_grad():
55
+ token_count = tokenizer(input_text, return_tensors="pt").input_ids.shape[1]
56
+
57
+ input_ids = tokenizer(
58
+ input_text,
59
+ truncation=True,
60
+ padding=True,
61
+ return_tensors="pt",
62
+ ).input_ids
63
+
64
+ outputs = model.generate(
65
+ input_ids,
66
+ max_length=max_length,
67
+ min_length=min_length,
68
+ num_beams=num_beams,
69
+ num_return_sequences=prediction_count,
70
+ )
71
+
72
+ result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
73
+ return token_count, '\n'.join(accumulator), {k: 0 for k in result}
74
+
75
+
76
+ iface = gr.Interface(fn=predict, inputs=[
77
+ gr.Textbox(label="Patch (as generated by git diff)"),
78
+ gr.Slider(1, 128, value=20, label="Max message length"),
79
+ gr.Slider(1, 128, value=5, label="Min message length"),
80
+ gr.Slider(1, 10, value=7, label="Number of beams"),
81
+ gr.Slider(1, 15, value=5, label="Number of predictions"),
82
+ ], outputs=[
83
+ gr.Textbox(label="Token count"),
84
+ gr.Textbox(label="Parsed patch"),
85
+ gr.Label(label="Predictions")
86
+ ])
87
+
88
+ if __name__ == "__main__":
89
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio~=3.16.2
2
+ transformers~=4.25.1
3
+ torch~=1.13.1