File size: 4,812 Bytes
3ff9962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import ast
import importlib
import os
from typing import Optional, Sequence

class DeleteSpecificNodes(ast.NodeTransformer):

    def __init__(self, nodes_to_remove: list[ast.AST]):
        self.nodes_to_remove = nodes_to_remove

    def visit(self, node: ast.AST) -> Optional[ast.AST]:
        if node in self.nodes_to_remove:
            return None
        return super().visit(node)

def convert_to_relative_import(module_name: str, original_parent_module_name: Optional[str]) -> str:
    parts = module_name.split('.')
    if parts[-1] == original_parent_module_name:
        return '.'
    return '.' + parts[-1]

def find_module_file(module_name: str) -> str:
    if not module_name:
        raise ValueError(f'Invalid input: module_name={module_name!r}')
    module = importlib.import_module(module_name)
    module_file = module.__file__
    if module_file is None:
        raise ValueError(f'Could not find file for module: {module_name}')
    return module_file

def _flatten_import(node: ast.ImportFrom, flatten_imports_prefix: Sequence[str]) -> bool:
    """Returns True if import should be flattened.

    Checks whether the node starts the same as any of the imports in
    flatten_imports_prefix.
    """
    for import_prefix in flatten_imports_prefix:
        if node.module is not None and node.module.startswith(import_prefix):
            return True
    return False

def _remove_import(node: ast.ImportFrom, remove_imports_prefix: Sequence[str]) -> bool:
    """Returns True if import should be removed.

    Checks whether the node starts the same as any of the imports in
    remove_imports_prefix.
    """
    for import_prefix in remove_imports_prefix:
        if node.module is not None and node.module.startswith(import_prefix):
            return True
    return False

def process_file(file_path: str, folder_path: str, flatten_imports_prefix: Sequence[str], remove_imports_prefix: Sequence[str]) -> list[str]:
    with open(file_path, 'r', encoding='utf-8') as f:
        source = f.read()
    parent_module_name = None
    if os.path.basename(file_path) == '__init__.py':
        parent_module_name = os.path.basename(os.path.dirname(file_path))
    tree = ast.parse(source)
    new_files_to_process = []
    nodes_to_remove = []
    for node in ast.walk(tree):
        if isinstance(node, ast.ImportFrom) and node.module is not None and _remove_import(node, remove_imports_prefix):
            nodes_to_remove.append(node)
        elif isinstance(node, ast.ImportFrom) and node.module is not None and _flatten_import(node, flatten_imports_prefix):
            module_path = find_module_file(node.module)
            node.module = convert_to_relative_import(node.module, parent_module_name)
            new_files_to_process.append(module_path)
        elif isinstance(node, ast.ClassDef) and node.name.startswith('Composer'):
            nodes_to_remove.append(node)
        elif isinstance(node, ast.Assign) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name) and (node.targets[0].id == '__all__'):
            nodes_to_remove.append(node)
    transformer = DeleteSpecificNodes(nodes_to_remove)
    new_tree = transformer.visit(tree)
    new_filename = os.path.basename(file_path)
    if new_filename == '__init__.py':
        new_filename = file_path.split('/')[-2] + '.py'
    new_file_path = os.path.join(folder_path, new_filename)
    with open(new_file_path, 'w', encoding='utf-8') as f:
        assert new_tree is not None
        f.write(ast.unparse(new_tree))
    return new_files_to_process

def edit_files_for_hf_compatibility(folder: str, flatten_imports_prefix: Sequence[str]=('llmfoundry',), remove_imports_prefix: Sequence[str]=('composer', 'omegaconf', 'llmfoundry.metrics')) -> None:
    """Edit files to be compatible with Hugging Face Hub.

    Args:
        folder (str): The folder to process.
        flatten_imports_prefix (Sequence[str], optional): Sequence of prefixes to flatten. Defaults to ('llmfoundry',).
        remove_imports_prefix (Sequence[str], optional): Sequence of prefixes to remove. Takes precedence over flattening.
            Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics').
    """
    files_to_process = [os.path.join(folder, filename) for filename in os.listdir(folder) if filename.endswith('.py')]
    files_processed_and_queued = set(files_to_process)
    while len(files_to_process) > 0:
        to_process = files_to_process.pop()
        if os.path.isfile(to_process) and to_process.endswith('.py'):
            to_add = process_file(to_process, folder, flatten_imports_prefix, remove_imports_prefix)
            for file in to_add:
                if file not in files_processed_and_queued:
                    files_to_process.append(file)
                    files_processed_and_queued.add(file)