D2F-eval / sanitize.py
Bailan-Alex's picture
Upload folder using huggingface_hub
2f3e169 verified
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# Modified from Dream repos: https://github.com/HKUNLP/Dream
"""Post-processing LLM-generated Python code implemented using tree-sitter."""
import os
import sys
import pathlib
ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.extend([os.path.dirname(ROOT), os.path.dirname(os.path.dirname(ROOT))])
import ast
import traceback
from typing import Dict, List, Optional, Set, Tuple
def refine_text(text: str) -> str:
text = text.replace("\t", " ")
text = text.replace("\r\n", "\n").replace("\r", "\n")
return text.strip() + "\n"
def syntax_check(code, verbose = False):
try:
ast.parse(code)
return True
except (SyntaxError, MemoryError):
if verbose:
traceback.print_exc()
return False
def extract_longest_valid_code(text: str) -> str:
lines = text.splitlines()
if len(lines) > 100:
lines = lines[:100]
max_valid_lines = 0
max_valid_snippet = ""
for i in range(len(lines)):
for j in range(i, len(lines)):
current_snippet = "\n".join(lines[i:j+1])
if syntax_check(current_snippet):
valid_line_count = sum(1 for line in lines[i:j+1] if line.strip())
if valid_line_count > max_valid_lines:
max_valid_lines = valid_line_count
max_valid_snippet = current_snippet
return max_valid_snippet
def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]:
name2deps = {}
for name, node in nodes:
deps = set()
stack = [node]
while stack:
current = stack.pop()
for child in ast.iter_child_nodes(current):
if isinstance(child, ast.Name):
deps.add(child.id)
elif isinstance(child, ast.Attribute):
deps.add(child.attr)
else:
stack.append(child)
name2deps[name] = deps
return name2deps
def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]:
visited = set()
to_visit = [entrypoint]
while to_visit:
current = to_visit.pop(0)
if current not in visited:
visited.add(current)
to_visit.extend(call_graph.get(current, set()) - visited)
return visited
def get_definition_name(node: ast.AST) -> Optional[str]:
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
return node.name
elif isinstance(node, ast.Assign):
targets = node.targets
if targets and isinstance(targets[0], ast.Name):
return targets[0].id
return None
def has_return_statement(node: ast.AST) -> bool:
return any(isinstance(n, ast.Return) for n in ast.walk(node))
def sanitize(text: str, entrypoint: Optional[str] = None) -> str:
text = refine_text(text)
# text = python_extract(text)
code = extract_longest_valid_code(text)
tree = ast.parse(code)
definitions = {}
imports = []
for node in tree.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
imports.append(node)
elif isinstance(node, ast.ClassDef):
name = node.name
definitions[name] = ('class', node)
elif isinstance(node, ast.FunctionDef):
name = node.name
if has_return_statement(node):
definitions[name] = ('function', node)
elif isinstance(node, ast.Assign):
name = get_definition_name(node)
if name:
definitions[name] = ('variable', node)
if entrypoint:
name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()])
reachable = get_function_dependency(entrypoint, name2deps)
sanitized_output = []
for node in imports:
sanitized_output.append(ast.unparse(node))
for name, (_, node) in definitions.items():
if not entrypoint or name in reachable:
sanitized_output.append(ast.unparse(node))
return "\n".join(sanitized_output)