import argparse import re from pathlib import Path import nbformat import nbconvert from traitlets.config import Config # Notebooks that are excluded from the CI tests EXCLUDED_NOTEBOOKS = ["data-preparation-ct-scan.ipynb", "pytorch-monai-training.ipynb"] DEVICE_WIDGET = "device = widgets.Dropdown(" def disable_gradio_debug(nb, notebook_path): found = False for cell in nb["cells"]: if "gradio" in cell["source"] and "debug" in cell["source"]: found = True cell["source"] = cell["source"].replace("debug=True", "debug=False") if found: print(f"Disabled gradio debug mode for {notebook_path}") return nb def disable_skip_ext(nb, notebook_path, test_device=""): found = False skip_for_device = None if test_device else False for cell in nb["cells"]: if test_device is not None and skip_for_device is None: if ( 'skip_for_device = "{}" in device.value'.format(test_device) in cell["source"] and "to_quantize = widgets.Checkbox(value=not skip_for_device" in cell["source"] ): skip_for_device = True if "%%skip" in cell["source"]: found = True if not skip_for_device: cell["source"] = re.sub(r"%%skip.*.\n", "\n", cell["source"]) else: cell["source"] = '"""\n' + cell["source"] + '\n"""' if found: print(f"Disabled skip extension mode for {notebook_path}") return nb def remove_ov_install(cell): updated_lines = [] def has_additional_deps(str_part): if "%pip" in str_part: return False if "install" in str_part: return False if str_part.startswith("-"): return False if str_part.startswith("https://"): return False return True lines = cell["source"].split("\n") for line in lines: if "openvino" in line: updated_line_content = [] empty = True package_found = False for part in line.split(" "): if "openvino-dev" in part: package_found = True continue if "openvino-nightly" in part: package_found = True continue if "openvino-tokenizers" in part: package_found = True continue if "openvino>" in part or "openvino=" in part or "openvino" == part: package_found = True continue if empty: empty = not has_additional_deps(part) updated_line_content.append(part) if package_found: if not empty: updated_line = " ".join(updated_line_content) if line.startswith(" "): for token in line: if token != " ": break # keep indention updated_line = " " + updated_line updated_lines.append(updated_line + "\n# " + line) else: updated_lines.append(line) else: updated_lines.append(line) cell["source"] = "\n".join(updated_lines) def patch_notebooks(notebooks_dir, test_device="", skip_ov_install=False): """ Patch notebooks in notebooks directory with replacement values found in notebook metadata to speed up test execution. This function is specific for the OpenVINO notebooks Github Actions CI. For example: change nr of epochs from 15 to 1 in tensorflow-training-openvino-nncf.ipynb by adding {"test_replace": {"epochs = 15": "epochs = 1"} to the cell metadata of the cell that contains `epochs = 15` :param notebooks_dir: Directory that contains the notebook subdirectories. For example: openvino_notebooks/notebooks """ nb_convert_config = Config() nb_convert_config.NotebookExporter.preprocessors = ["nbconvert.preprocessors.ClearOutputPreprocessor"] output_remover = nbconvert.NotebookExporter(nb_convert_config) for notebookfile in Path(notebooks_dir).glob("**/*.ipynb"): if not str(notebookfile.name).startswith("test_") and notebookfile.name not in EXCLUDED_NOTEBOOKS: nb = nbformat.read(notebookfile, as_version=nbformat.NO_CONVERT) found = False device_found = False for cell in nb["cells"]: if skip_ov_install and "%pip" in cell["source"]: remove_ov_install(cell) if test_device and DEVICE_WIDGET in cell["source"]: device_found = True cell["source"] = re.sub(r"value=.*,", f"value='{test_device.upper()}',", cell["source"]) cell["source"] = re.sub( r"options=", f"options=['{test_device.upper()}'] + ", cell["source"], ) print(f"Replaced testing device to {test_device}") replace_dict = cell.get("metadata", {}).get("test_replace") if replace_dict is not None: found = True for source_value, target_value in replace_dict.items(): if source_value not in cell["source"]: raise ValueError(f"Processing {notebookfile} failed: {source_value} does not exist in cell") cell["source"] = cell["source"].replace(source_value, target_value) cell["source"] = "# Modified for testing\n" + cell["source"] print(f"Processed {notebookfile}: {source_value} -> {target_value}") if test_device and not device_found: print(f"No device replacement found for {notebookfile}") if not found: print(f"No replacements found for {notebookfile}") disable_gradio_debug(nb, notebookfile) disable_skip_ext(nb, notebookfile, args.test_device) nb_without_out, _ = output_remover.from_notebook_node(nb) with notebookfile.with_name(f"test_{notebookfile.name}").open("w", encoding="utf-8") as out_file: out_file.write(nb_without_out) if __name__ == "__main__": parser = argparse.ArgumentParser("Notebook patcher") parser.add_argument("notebooks_dir", default=".") parser.add_argument("-td", "--test_device", default="") parser.add_argument("--skip_ov_install", action="store_true") args = parser.parse_args() if not Path(args.notebooks_dir).is_dir(): raise ValueError(f"'{args.notebooks_dir}' is not an existing directory") patch_notebooks(args.notebooks_dir, args.test_device, args.skip_ov_install)