| | import os |
| |
|
| | import libcst as cst |
| |
|
| |
|
| | |
| | |
| | EXCLUDED_EXTERNAL_FILES = { |
| | "habana": [{"name": "modeling_all_models", "type": "modeling"}], |
| | } |
| |
|
| |
|
| | def convert_relative_import_to_absolute( |
| | import_node: cst.ImportFrom, |
| | file_path: str, |
| | package_name: str | None = "transformers", |
| | ) -> cst.ImportFrom: |
| | """ |
| | Convert a relative libcst.ImportFrom node into an absolute one, |
| | using the file path and package name. |
| | |
| | Args: |
| | import_node: A relative import node (e.g. `from ..utils import helper`) |
| | file_path: Path to the file containing the import (can be absolute or relative) |
| | package_name: The top-level package name (e.g. 'myproject') |
| | |
| | Returns: |
| | A new ImportFrom node with the absolute import path |
| | """ |
| | if not (import_node.relative and len(import_node.relative) > 0): |
| | return import_node |
| |
|
| | file_path = os.path.abspath(file_path) |
| | rel_level = len(import_node.relative) |
| |
|
| | |
| | file_path_no_ext = file_path.removesuffix(".py") |
| | file_parts = file_path_no_ext.split(os.path.sep) |
| |
|
| | |
| | if package_name not in file_parts: |
| | raise ValueError(f"Package name '{package_name}' not found in file path '{file_path}'") |
| |
|
| | |
| | pkg_index = file_parts.index(package_name) |
| | module_parts = file_parts[pkg_index + 1 :] |
| | if len(module_parts) < rel_level: |
| | raise ValueError(f"Relative import level ({rel_level}) goes beyond package root.") |
| |
|
| | base_parts = module_parts[:-rel_level] |
| |
|
| | |
| | def flatten_module(module: cst.BaseExpression | None) -> list[str]: |
| | if not module: |
| | return [] |
| | if isinstance(module, cst.Name): |
| | return [module.value] |
| | elif isinstance(module, cst.Attribute): |
| | parts = [] |
| | while isinstance(module, cst.Attribute): |
| | parts.insert(0, module.attr.value) |
| | module = module.value |
| | if isinstance(module, cst.Name): |
| | parts.insert(0, module.value) |
| | return parts |
| | return [] |
| |
|
| | import_parts = flatten_module(import_node.module) |
| |
|
| | |
| | full_parts = [package_name] + base_parts + import_parts |
| |
|
| | |
| | if package_name != "transformers" and file_parts[pkg_index - 1] != "src": |
| | full_parts = [file_parts[pkg_index - 1]] + full_parts |
| |
|
| | |
| | dotted_module: cst.BaseExpression | None = None |
| | for part in full_parts: |
| | name = cst.Name(part) |
| | dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name) |
| |
|
| | |
| | return import_node.with_changes(module=dotted_module, relative=[]) |
| |
|
| |
|
| | def convert_to_relative_import(import_node: cst.ImportFrom, file_path: str, package_name: str) -> cst.ImportFrom: |
| | """ |
| | Convert an absolute import to a relative one if it belongs to `package_name`. |
| | |
| | Parameters: |
| | - node: The ImportFrom node to possibly transform. |
| | - file_path: Absolute path to the file containing the import (e.g., '/path/to/mypackage/foo/bar.py'). |
| | - package_name: The top-level package name (e.g., 'mypackage'). |
| | |
| | Returns: |
| | - A possibly modified ImportFrom node. |
| | """ |
| | if import_node.relative: |
| | return import_node |
| |
|
| | |
| | def get_module_name(module): |
| | if isinstance(module, cst.Name): |
| | return module.value, [module.value] |
| | elif isinstance(module, cst.Attribute): |
| | parts = [] |
| | while isinstance(module, cst.Attribute): |
| | parts.append(module.attr.value) |
| | module = module.value |
| | if isinstance(module, cst.Name): |
| | parts.append(module.value) |
| | parts.reverse() |
| | return ".".join(parts), parts |
| | return "", None |
| |
|
| | module_name, submodule_list = get_module_name(import_node.module) |
| |
|
| | |
| | if ( |
| | not (module_name.startswith(package_name + ".") or module_name.startswith("optimum." + package_name + ".")) |
| | and module_name != package_name |
| | ): |
| | return import_node |
| |
|
| | |
| | norm_file_path = os.path.normpath(file_path) |
| | parts = norm_file_path.split(os.sep) |
| |
|
| | try: |
| | pkg_index = parts.index(package_name) |
| | except ValueError: |
| | |
| | return import_node |
| |
|
| | |
| | depth = len(parts) - pkg_index - 1 |
| | for i, submodule in enumerate(parts[pkg_index + 1 :]): |
| | if submodule == submodule_list[2 + i]: |
| | depth -= 1 |
| | else: |
| | break |
| |
|
| | |
| | relative = [cst.Dot()] * depth if depth > 0 else [cst.Dot()] |
| |
|
| | |
| | if module_name.startswith("optimum." + package_name + "."): |
| | stripped_name = module_name[len("optimum." + package_name) :].lstrip(".") |
| | else: |
| | stripped_name = module_name[len(package_name) :].lstrip(".") |
| |
|
| | |
| | if stripped_name == "": |
| | new_module = None |
| | else: |
| | name_parts = stripped_name.split(".")[i:] |
| | new_module = cst.Name(name_parts[0]) |
| | for part in name_parts[1:]: |
| | new_module = cst.Attribute(value=new_module, attr=cst.Name(part)) |
| |
|
| | return import_node.with_changes(module=new_module, relative=relative) |
| |
|
| |
|
| | class AbsoluteImportTransformer(cst.CSTTransformer): |
| | def __init__(self, relative_path: str, source_library: str): |
| | super().__init__() |
| | self.relative_path = relative_path |
| | self.source_library = source_library |
| |
|
| | def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom: |
| | return convert_relative_import_to_absolute( |
| | import_node=updated_node, file_path=self.relative_path, package_name=self.source_library |
| | ) |
| |
|
| |
|
| | class RelativeImportTransformer(cst.CSTTransformer): |
| | def __init__(self, relative_path: str, source_library: str): |
| | super().__init__() |
| | self.relative_path = relative_path |
| | self.source_library = source_library |
| |
|
| | def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom: |
| | return convert_to_relative_import(updated_node, self.relative_path, self.source_library) |
| |
|