Update syntax_match.py
Browse files- syntax_match.py +73 -73
syntax_match.py
CHANGED
@@ -1,73 +1,73 @@
|
|
1 |
-
# Copyright (c) Microsoft Corporation.
|
2 |
-
# Licensed under the MIT license.
|
3 |
-
|
4 |
-
from parser import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp
|
5 |
-
from parser import (remove_comments_and_docstrings,
|
6 |
-
tree_to_token_index,
|
7 |
-
index_to_code_token,
|
8 |
-
tree_to_variable_index)
|
9 |
-
from tree_sitter import Language, Parser
|
10 |
-
|
11 |
-
dfg_function={
|
12 |
-
'python':DFG_python,
|
13 |
-
'java':DFG_java,
|
14 |
-
'ruby':DFG_ruby,
|
15 |
-
'go':DFG_go,
|
16 |
-
'php':DFG_php,
|
17 |
-
'javascript':DFG_javascript,
|
18 |
-
'c_sharp':DFG_csharp,
|
19 |
-
}
|
20 |
-
|
21 |
-
def calc_syntax_match(references, candidate, lang):
|
22 |
-
return corpus_syntax_match([references], [candidate], lang)
|
23 |
-
|
24 |
-
def corpus_syntax_match(references, candidates, lang):
|
25 |
-
JAVA_LANGUAGE = Language('parser/my-languages.so', lang)
|
26 |
-
parser = Parser()
|
27 |
-
parser.set_language(JAVA_LANGUAGE)
|
28 |
-
match_count = 0
|
29 |
-
total_count = 0
|
30 |
-
|
31 |
-
for i in range(len(candidates)):
|
32 |
-
references_sample = references[i]
|
33 |
-
candidate = candidates[i]
|
34 |
-
for reference in references_sample:
|
35 |
-
try:
|
36 |
-
candidate=remove_comments_and_docstrings(candidate,'java')
|
37 |
-
except:
|
38 |
-
pass
|
39 |
-
try:
|
40 |
-
reference=remove_comments_and_docstrings(reference,'java')
|
41 |
-
except:
|
42 |
-
pass
|
43 |
-
|
44 |
-
candidate_tree = parser.parse(bytes(candidate,'utf8')).root_node
|
45 |
-
|
46 |
-
reference_tree = parser.parse(bytes(reference,'utf8')).root_node
|
47 |
-
|
48 |
-
def get_all_sub_trees(root_node):
|
49 |
-
node_stack = []
|
50 |
-
sub_tree_sexp_list = []
|
51 |
-
depth = 1
|
52 |
-
node_stack.append([root_node, depth])
|
53 |
-
while len(node_stack) != 0:
|
54 |
-
cur_node, cur_depth = node_stack.pop()
|
55 |
-
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth])
|
56 |
-
for child_node in cur_node.children:
|
57 |
-
if len(child_node.children) != 0:
|
58 |
-
depth = cur_depth + 1
|
59 |
-
node_stack.append([child_node, depth])
|
60 |
-
return sub_tree_sexp_list
|
61 |
-
cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)]
|
62 |
-
ref_sexps = get_all_sub_trees(reference_tree)
|
63 |
-
|
64 |
-
# print(cand_sexps)
|
65 |
-
# print(ref_sexps)
|
66 |
-
|
67 |
-
for sub_tree, depth in ref_sexps:
|
68 |
-
if sub_tree in cand_sexps:
|
69 |
-
match_count += 1
|
70 |
-
total_count += len(ref_sexps)
|
71 |
-
|
72 |
-
score = match_count / total_count
|
73 |
-
return score
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
from .parser import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp
|
5 |
+
from .parser import (remove_comments_and_docstrings,
|
6 |
+
tree_to_token_index,
|
7 |
+
index_to_code_token,
|
8 |
+
tree_to_variable_index)
|
9 |
+
from tree_sitter import Language, Parser
|
10 |
+
|
11 |
+
dfg_function={
|
12 |
+
'python':DFG_python,
|
13 |
+
'java':DFG_java,
|
14 |
+
'ruby':DFG_ruby,
|
15 |
+
'go':DFG_go,
|
16 |
+
'php':DFG_php,
|
17 |
+
'javascript':DFG_javascript,
|
18 |
+
'c_sharp':DFG_csharp,
|
19 |
+
}
|
20 |
+
|
21 |
+
def calc_syntax_match(references, candidate, lang):
|
22 |
+
return corpus_syntax_match([references], [candidate], lang)
|
23 |
+
|
24 |
+
def corpus_syntax_match(references, candidates, lang):
|
25 |
+
JAVA_LANGUAGE = Language('parser/my-languages.so', lang)
|
26 |
+
parser = Parser()
|
27 |
+
parser.set_language(JAVA_LANGUAGE)
|
28 |
+
match_count = 0
|
29 |
+
total_count = 0
|
30 |
+
|
31 |
+
for i in range(len(candidates)):
|
32 |
+
references_sample = references[i]
|
33 |
+
candidate = candidates[i]
|
34 |
+
for reference in references_sample:
|
35 |
+
try:
|
36 |
+
candidate=remove_comments_and_docstrings(candidate,'java')
|
37 |
+
except:
|
38 |
+
pass
|
39 |
+
try:
|
40 |
+
reference=remove_comments_and_docstrings(reference,'java')
|
41 |
+
except:
|
42 |
+
pass
|
43 |
+
|
44 |
+
candidate_tree = parser.parse(bytes(candidate,'utf8')).root_node
|
45 |
+
|
46 |
+
reference_tree = parser.parse(bytes(reference,'utf8')).root_node
|
47 |
+
|
48 |
+
def get_all_sub_trees(root_node):
|
49 |
+
node_stack = []
|
50 |
+
sub_tree_sexp_list = []
|
51 |
+
depth = 1
|
52 |
+
node_stack.append([root_node, depth])
|
53 |
+
while len(node_stack) != 0:
|
54 |
+
cur_node, cur_depth = node_stack.pop()
|
55 |
+
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth])
|
56 |
+
for child_node in cur_node.children:
|
57 |
+
if len(child_node.children) != 0:
|
58 |
+
depth = cur_depth + 1
|
59 |
+
node_stack.append([child_node, depth])
|
60 |
+
return sub_tree_sexp_list
|
61 |
+
cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)]
|
62 |
+
ref_sexps = get_all_sub_trees(reference_tree)
|
63 |
+
|
64 |
+
# print(cand_sexps)
|
65 |
+
# print(ref_sexps)
|
66 |
+
|
67 |
+
for sub_tree, depth in ref_sexps:
|
68 |
+
if sub_tree in cand_sexps:
|
69 |
+
match_count += 1
|
70 |
+
total_count += len(ref_sexps)
|
71 |
+
|
72 |
+
score = match_count / total_count
|
73 |
+
return score
|