| import argparse |
| import os |
| import re |
| import subprocess |
| from datetime import date, datetime |
| from urllib.error import HTTPError |
| from urllib.request import Request, urlopen |
|
|
| from huggingface_hub import paper_info |
|
|
|
|
| ROOT = os.getcwd().split("utils")[0] |
| DOCS_PATH = os.path.join(ROOT, "docs/source/en/model_doc") |
| MODELS_PATH = os.path.join(ROOT, "src/transformers/models") |
| GITHUB_REPO_URL = "https://github.com/huggingface/transformers" |
| GITHUB_RAW_URL = "https://raw.githubusercontent.com/huggingface/transformers/main" |
|
|
| COPYRIGHT_DISCLAIMER = """<!--Copyright 2025 The HuggingFace Team. All rights reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
| the License. You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
| specific language governing permissions and limitations under the License. |
| |
| ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be |
| rendered properly in your Markdown viewer. |
| |
| -->""" |
|
|
| ARXIV_PAPERS_NOT_IN_HF_PAPERS = { |
| "gemma3n.md": "2506.06644", |
| "xmod.md": "2205.06266", |
| } |
|
|
|
|
| def check_file_exists_on_github(file_path: str) -> bool: |
| """Check if a file exists on the main branch of the GitHub repository. |
| |
| Args: |
| file_path: Relative path from repository root |
| |
| Returns: |
| True if file exists on GitHub main branch (or if check failed), False only if confirmed 404 |
| |
| Note: |
| On network errors or other issues, returns True (assumes file exists) with a warning. |
| This prevents the script from failing due to temporary network issues. |
| """ |
| |
| if file_path.startswith(ROOT): |
| file_path = file_path[len(ROOT) :].lstrip("/") |
|
|
| |
| url = f"{GITHUB_RAW_URL}/{file_path}" |
|
|
| try: |
| |
| request = Request(url, method="HEAD") |
| request.add_header("User-Agent", "transformers-add-dates-script") |
|
|
| with urlopen(request, timeout=10) as response: |
| return response.status == 200 |
| except HTTPError as e: |
| if e.code == 404: |
| |
| return False |
| |
| return True |
| except Exception: |
| |
| return True |
|
|
|
|
| def get_modified_cards() -> list[str]: |
| """Get the list of model names from modified files in docs/source/en/model_doc/""" |
|
|
| current_branch = subprocess.check_output(["git", "branch", "--show-current"], text=True).strip() |
| if current_branch == "main": |
| |
| result = subprocess.check_output(["git", "diff", "--name-only", "HEAD"], text=True) |
| else: |
| fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8") |
| result = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8") |
|
|
| model_names = [] |
| for line in result.strip().split("\n"): |
| if line: |
| |
| if line.startswith("docs/source/en/model_doc/") and line.endswith(".md"): |
| file_path = os.path.join(ROOT, line) |
| if os.path.exists(file_path): |
| model_name = os.path.splitext(os.path.basename(line))[0] |
| if model_name not in ["auto", "timm_wrapper"]: |
| model_names.append(model_name) |
|
|
| return model_names |
|
|
|
|
| def get_paper_link(model_card: str | None, path: str | None) -> str: |
| """Get the first paper link from the model card content.""" |
|
|
| if model_card is not None and not model_card.endswith(".md"): |
| model_card = f"{model_card}.md" |
| file_path = path or os.path.join(DOCS_PATH, f"{model_card}") |
| model_card = os.path.basename(file_path) |
| with open(file_path, "r", encoding="utf-8") as f: |
| content = f.read() |
|
|
| |
| paper_ids = re.findall(r"https://huggingface\.co/papers/\d+\.\d+", content) |
| paper_ids += re.findall(r"https://arxiv\.org/abs/\d+\.\d+", content) |
| paper_ids += re.findall(r"https://arxiv\.org/pdf/\d+\.\d+", content) |
|
|
| if len(paper_ids) == 0: |
| return "No_paper" |
|
|
| return paper_ids[0] |
|
|
|
|
| def get_first_commit_date(model_name: str | None) -> str: |
| """Get the first commit date of the model's init file or model.md. This date is considered as the date the model was added to HF transformers""" |
|
|
| if model_name.endswith(".md"): |
| model_name = f"{model_name[:-3]}" |
|
|
| model_name_src = model_name |
| if "-" in model_name: |
| model_name_src = model_name.replace("-", "_") |
| file_path = os.path.join(MODELS_PATH, model_name_src, "__init__.py") |
|
|
| |
| if not os.path.exists(file_path): |
| file_path = os.path.join(DOCS_PATH, f"{model_name}.md") |
|
|
| |
| file_exists_on_github = check_file_exists_on_github(file_path) |
|
|
| if not file_exists_on_github: |
| |
| final_date = date.today().isoformat() |
| else: |
| |
| final_date = subprocess.check_output( |
| ["git", "log", "--reverse", "--pretty=format:%ad", "--date=iso", file_path], text=True |
| ) |
| return final_date.strip().split("\n")[0][:10] |
|
|
|
|
| def get_release_date(link: str) -> str: |
| if link.startswith("https://huggingface.co/papers/"): |
| link = link.replace("https://huggingface.co/papers/", "") |
|
|
| try: |
| info = paper_info(link) |
| return info.published_at.date().isoformat() |
| except Exception: |
| |
| pass |
|
|
| elif link.startswith("https://arxiv.org/abs/") or link.startswith("https://arxiv.org/pdf/"): |
| return r"{release_date}" |
|
|
|
|
| def replace_paper_links(file_path: str) -> bool: |
| """Replace arxiv links with huggingface links if valid, and replace hf.co with huggingface.co""" |
|
|
| with open(file_path, "r", encoding="utf-8") as f: |
| content = f.read() |
|
|
| original_content = content |
|
|
| |
| content = content.replace("https://hf.co/", "https://huggingface.co/") |
|
|
| |
| arxiv_links = re.findall(r"https://arxiv\.org/abs/(\d+\.\d+)", content) |
| arxiv_links += re.findall(r"https://arxiv\.org/pdf/(\d+\.\d+)", content) |
|
|
| for paper_id in arxiv_links: |
| try: |
| |
| paper_info(paper_id) |
| |
| old_link = f"https://arxiv.org/abs/{paper_id}" |
| if old_link not in content: |
| old_link = f"https://arxiv.org/pdf/{paper_id}" |
| new_link = f"https://huggingface.co/papers/{paper_id}" |
| content = content.replace(old_link, new_link) |
|
|
| except Exception: |
| |
| continue |
|
|
| |
| if content != original_content: |
| with open(file_path, "w", encoding="utf-8") as f: |
| f.write(content) |
| return True |
| return False |
|
|
|
|
| def _normalize_model_card_name(model_card: str) -> str: |
| """Ensure model card has .md extension""" |
| return model_card if model_card.endswith(".md") else f"{model_card}.md" |
|
|
|
|
| def _should_skip_model_card(model_card: str) -> bool: |
| """Check if model card should be skipped""" |
| return model_card in ("auto.md", "timm_wrapper.md") |
|
|
|
|
| def _read_model_card_content(model_card: str) -> str: |
| """Read and return the content of a model card""" |
| file_path = os.path.join(DOCS_PATH, model_card) |
| with open(file_path, "r", encoding="utf-8") as f: |
| return f.read() |
|
|
|
|
| def _get_dates_pattern_match(content: str): |
| """Search for the dates pattern in content and return match object""" |
| pattern = r"\n\*This model was released on (.*) and added to Hugging Face Transformers on (\d{4}-\d{2}-\d{2})\.\*" |
| return re.search(pattern, content) |
|
|
|
|
| def _dates_differ_significantly(date1: str, date2: str) -> bool: |
| """Check if two dates differ by more than 1 day""" |
| try: |
| d1 = datetime.strptime(date1, "%Y-%m-%d") |
| d2 = datetime.strptime(date2, "%Y-%m-%d") |
| return abs((d1 - d2).days) > 1 |
| except Exception: |
| return True |
|
|
|
|
| def check_missing_dates(model_card_list: list[str]) -> list[str]: |
| """Check which model cards are missing release dates and return their names""" |
| missing_dates = [] |
|
|
| for model_card in model_card_list: |
| model_card = _normalize_model_card_name(model_card) |
| if _should_skip_model_card(model_card): |
| continue |
|
|
| content = _read_model_card_content(model_card) |
| if not _get_dates_pattern_match(content): |
| missing_dates.append(model_card) |
|
|
| return missing_dates |
|
|
|
|
| def check_incorrect_dates(model_card_list: list[str]) -> list[str]: |
| """Check which model cards have incorrect HF commit dates and return their names""" |
| incorrect_dates = [] |
|
|
| for model_card in model_card_list: |
| model_card = _normalize_model_card_name(model_card) |
| if _should_skip_model_card(model_card): |
| continue |
|
|
| content = _read_model_card_content(model_card) |
| match = _get_dates_pattern_match(content) |
|
|
| if match: |
| existing_hf_date = match.group(2) |
| actual_hf_date = get_first_commit_date(model_name=model_card) |
|
|
| if _dates_differ_significantly(existing_hf_date, actual_hf_date): |
| incorrect_dates.append(model_card) |
|
|
| return incorrect_dates |
|
|
|
|
| def insert_dates(model_card_list: list[str]): |
| """Insert or update release and commit dates in model cards""" |
| for model_card in model_card_list: |
| model_card = _normalize_model_card_name(model_card) |
| if _should_skip_model_card(model_card): |
| continue |
|
|
| file_path = os.path.join(DOCS_PATH, model_card) |
|
|
| |
| replace_paper_links(file_path) |
|
|
| |
| content = _read_model_card_content(model_card) |
| markers = list(re.finditer(r"-->", content)) |
|
|
| if len(markers) == 0: |
| |
| content = COPYRIGHT_DISCLAIMER + "\n\n" + content |
| with open(file_path, "w", encoding="utf-8") as f: |
| f.write(content) |
| markers = list(re.finditer(r"-->", content)) |
|
|
| |
| hf_commit_date = get_first_commit_date(model_name=model_card) |
| paper_link = get_paper_link(model_card=model_card, path=file_path) |
|
|
| if paper_link in ("No_paper", "blog"): |
| release_date = r"{release_date}" |
| else: |
| release_date = get_release_date(paper_link) |
|
|
| match = _get_dates_pattern_match(content) |
|
|
| |
| if match: |
| |
| existing_release_date = match.group(1) |
| existing_hf_date = match.group(2) |
|
|
| if existing_release_date not in (r"{release_date}", "None"): |
| release_date = existing_release_date |
|
|
| if _dates_differ_significantly(existing_hf_date, hf_commit_date) or existing_release_date != release_date: |
| old_line = match.group(0) |
| new_line = f"\n*This model was released on {release_date} and added to Hugging Face Transformers on {hf_commit_date}.*" |
| content = content.replace(old_line, new_line) |
| with open(file_path, "w", encoding="utf-8") as f: |
| f.write(content) |
| else: |
| |
| insert_index = markers[0].end() |
| date_info = f"\n*This model was released on {release_date} and added to Hugging Face Transformers on {hf_commit_date}.*" |
| content = content[:insert_index] + date_info + content[insert_index:] |
| with open(file_path, "w", encoding="utf-8") as f: |
| f.write(content) |
|
|
|
|
| def get_all_model_cards(): |
| """Get all model cards from the docs path""" |
|
|
| all_files = os.listdir(DOCS_PATH) |
| model_cards = [] |
| for file in all_files: |
| if file.endswith(".md"): |
| model_name = os.path.splitext(file)[0] |
| if model_name not in ["auto", "timm_wrapper"]: |
| model_cards.append(model_name) |
| return sorted(model_cards) |
|
|
|
|
| def main(all=False, models=None, check_only=False): |
| if check_only: |
| |
| all_model_cards = get_all_model_cards() |
| missing_dates = check_missing_dates(all_model_cards) |
|
|
| |
| modified_cards = get_modified_cards() |
| incorrect_dates = check_incorrect_dates(modified_cards) |
|
|
| if missing_dates or incorrect_dates: |
| problematic_cards = missing_dates + incorrect_dates |
| model_names = [card.replace(".md", "") for card in problematic_cards] |
| raise ValueError( |
| f"Missing or incorrect dates in the following model cards: {' '.join(problematic_cards)}\n" |
| f"Run `python utils/add_dates.py --models {' '.join(model_names)}` to fix them." |
| ) |
| return |
|
|
| |
| if all: |
| model_cards = get_all_model_cards() |
| elif models: |
| model_cards = models |
| else: |
| model_cards = get_modified_cards() |
| if not model_cards: |
| return |
|
|
| insert_dates(model_cards) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Add release and commit dates to model cards") |
| group = parser.add_mutually_exclusive_group(required=False) |
| group.add_argument("--models", nargs="+", help="Specify model cards to process (without .md extension)") |
| group.add_argument("--all", action="store_true", help="Process all model cards in the docs directory") |
| group.add_argument("--check-only", action="store_true", help="Check if the dates are already present") |
|
|
| args = parser.parse_args() |
| try: |
| main(args.all, args.models, args.check_only) |
| except subprocess.CalledProcessError as e: |
| print( |
| f"An error occurred while executing git commands but it can be ignored (git issue) most probably local: {e}" |
| ) |
|
|