|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
import re |
|
|
|
|
|
|
|
|
|
PATH_TO_DIFFUSERS = "src/diffusers" |
|
|
|
|
|
_re_backend = re.compile(r"is\_([a-z_]*)_available\(\)") |
|
|
|
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") |
|
|
|
|
|
DUMMY_CONSTANT = """ |
|
{0} = None |
|
""" |
|
|
|
DUMMY_CLASS = """ |
|
class {0}(metaclass=DummyObject): |
|
_backends = {1} |
|
|
|
def __init__(self, *args, **kwargs): |
|
requires_backends(self, {1}) |
|
|
|
@classmethod |
|
def from_config(cls, *args, **kwargs): |
|
requires_backends(cls, {1}) |
|
|
|
@classmethod |
|
def from_pretrained(cls, *args, **kwargs): |
|
requires_backends(cls, {1}) |
|
""" |
|
|
|
|
|
DUMMY_FUNCTION = """ |
|
def {0}(*args, **kwargs): |
|
requires_backends({0}, {1}) |
|
""" |
|
|
|
|
|
def find_backend(line): |
|
"""Find one (or multiple) backend in a code line of the init.""" |
|
backends = _re_backend.findall(line) |
|
if len(backends) == 0: |
|
return None |
|
|
|
return "_and_".join(backends) |
|
|
|
|
|
def read_init(): |
|
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects.""" |
|
with open(os.path.join(PATH_TO_DIFFUSERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f: |
|
lines = f.readlines() |
|
|
|
|
|
line_index = 0 |
|
while not lines[line_index].startswith("if TYPE_CHECKING"): |
|
line_index += 1 |
|
|
|
backend_specific_objects = {} |
|
|
|
while line_index < len(lines): |
|
|
|
backend = find_backend(lines[line_index]) |
|
if backend is not None: |
|
while not lines[line_index].startswith(" else:"): |
|
line_index += 1 |
|
line_index += 1 |
|
objects = [] |
|
|
|
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8): |
|
line = lines[line_index] |
|
single_line_import_search = _re_single_line_import.search(line) |
|
if single_line_import_search is not None: |
|
objects.extend(single_line_import_search.groups()[0].split(", ")) |
|
elif line.startswith(" " * 12): |
|
objects.append(line[12:-2]) |
|
line_index += 1 |
|
|
|
if len(objects) > 0: |
|
backend_specific_objects[backend] = objects |
|
else: |
|
line_index += 1 |
|
|
|
return backend_specific_objects |
|
|
|
|
|
def create_dummy_object(name, backend_name): |
|
"""Create the code for the dummy object corresponding to `name`.""" |
|
if name.isupper(): |
|
return DUMMY_CONSTANT.format(name) |
|
elif name.islower(): |
|
return DUMMY_FUNCTION.format(name, backend_name) |
|
else: |
|
return DUMMY_CLASS.format(name, backend_name) |
|
|
|
|
|
def create_dummy_files(backend_specific_objects=None): |
|
"""Create the content of the dummy files.""" |
|
if backend_specific_objects is None: |
|
backend_specific_objects = read_init() |
|
|
|
dummy_files = {} |
|
|
|
for backend, objects in backend_specific_objects.items(): |
|
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]" |
|
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" |
|
dummy_file += "from ..utils import DummyObject, requires_backends\n\n" |
|
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects]) |
|
dummy_files[backend] = dummy_file |
|
|
|
return dummy_files |
|
|
|
|
|
def check_dummies(overwrite=False): |
|
"""Check if the dummy files are up to date and maybe `overwrite` with the right content.""" |
|
dummy_files = create_dummy_files() |
|
|
|
short_names = {"torch": "pt"} |
|
|
|
|
|
path = os.path.join(PATH_TO_DIFFUSERS, "utils") |
|
dummy_file_paths = { |
|
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py") |
|
for backend in dummy_files.keys() |
|
} |
|
|
|
actual_dummies = {} |
|
for backend, file_path in dummy_file_paths.items(): |
|
if os.path.isfile(file_path): |
|
with open(file_path, "r", encoding="utf-8", newline="\n") as f: |
|
actual_dummies[backend] = f.read() |
|
else: |
|
actual_dummies[backend] = "" |
|
|
|
for backend in dummy_files.keys(): |
|
if dummy_files[backend] != actual_dummies[backend]: |
|
if overwrite: |
|
print( |
|
f"Updating diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main " |
|
"__init__ has new objects." |
|
) |
|
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f: |
|
f.write(dummy_files[backend]) |
|
else: |
|
raise ValueError( |
|
"The main __init__ has objects that are not present in " |
|
f"diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` " |
|
"to fix this." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") |
|
args = parser.parse_args() |
|
|
|
check_dummies(args.fix_and_overwrite) |
|
|