# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from collections import defaultdict import yaml PATH_TO_TOC = "docs/source/en/_toctree.yml" def clean_doc_toc(doc_list): """ Cleans the table of content of the model documentation by removing duplicates and sorting models alphabetically. """ counts = defaultdict(int) overview_doc = [] new_doc_list = [] for doc in doc_list: if "local" in doc: counts[doc["local"]] += 1 if doc["title"].lower() == "overview": overview_doc.append({"local": doc["local"], "title": doc["title"]}) else: new_doc_list.append(doc) doc_list = new_doc_list duplicates = [key for key, value in counts.items() if value > 1] new_doc = [] for duplicate_key in duplicates: titles = list({doc["title"] for doc in doc_list if doc["local"] == duplicate_key}) if len(titles) > 1: raise ValueError( f"{duplicate_key} is present several times in the documentation table of content at " "`docs/source/en/_toctree.yml` with different *Title* values. Choose one of those and remove the " "others." ) # Only add this once new_doc.append({"local": duplicate_key, "title": titles[0]}) # Add none duplicate-keys new_doc.extend([doc for doc in doc_list if "local" not in counts or counts[doc["local"]] == 1]) new_doc = sorted(new_doc, key=lambda s: s["title"].lower()) # "overview" gets special treatment and is always first if len(overview_doc) > 1: raise ValueError("{doc_list} has two 'overview' docs which is not allowed.") overview_doc.extend(new_doc) # Sort return overview_doc def check_scheduler_doc(overwrite=False): with open(PATH_TO_TOC, encoding="utf-8") as f: content = yaml.safe_load(f.read()) # Get to the API doc api_idx = 0 while content[api_idx]["title"] != "API": api_idx += 1 api_doc = content[api_idx]["sections"] # Then to the model doc scheduler_idx = 0 while api_doc[scheduler_idx]["title"] != "Schedulers": scheduler_idx += 1 scheduler_doc = api_doc[scheduler_idx]["sections"] new_scheduler_doc = clean_doc_toc(scheduler_doc) diff = False if new_scheduler_doc != scheduler_doc: diff = True if overwrite: api_doc[scheduler_idx]["sections"] = new_scheduler_doc if diff: if overwrite: content[api_idx]["sections"] = api_doc with open(PATH_TO_TOC, "w", encoding="utf-8") as f: f.write(yaml.dump(content, allow_unicode=True)) else: raise ValueError( "The model doc part of the table of content is not properly sorted, run `make style` to fix this." ) def check_pipeline_doc(overwrite=False): with open(PATH_TO_TOC, encoding="utf-8") as f: content = yaml.safe_load(f.read()) # Get to the API doc api_idx = 0 while content[api_idx]["title"] != "API": api_idx += 1 api_doc = content[api_idx]["sections"] # Then to the model doc pipeline_idx = 0 while api_doc[pipeline_idx]["title"] != "Pipelines": pipeline_idx += 1 diff = False pipeline_docs = api_doc[pipeline_idx]["sections"] new_pipeline_docs = [] # sort sub pipeline docs for pipeline_doc in pipeline_docs: if "section" in pipeline_doc: sub_pipeline_doc = pipeline_doc["section"] new_sub_pipeline_doc = clean_doc_toc(sub_pipeline_doc) if overwrite: pipeline_doc["section"] = new_sub_pipeline_doc new_pipeline_docs.append(pipeline_doc) # sort overall pipeline doc new_pipeline_docs = clean_doc_toc(new_pipeline_docs) if new_pipeline_docs != pipeline_docs: diff = True if overwrite: api_doc[pipeline_idx]["sections"] = new_pipeline_docs if diff: if overwrite: content[api_idx]["sections"] = api_doc with open(PATH_TO_TOC, "w", encoding="utf-8") as f: f.write(yaml.dump(content, allow_unicode=True)) else: raise ValueError( "The model doc part of the table of content is not properly sorted, run `make style` to fix this." ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") args = parser.parse_args() check_scheduler_doc(args.fix_and_overwrite) check_pipeline_doc(args.fix_and_overwrite)