#!/usr/bin/env python3 import os import re from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple from update_readme import generate_url, get_all_files class Wheel: def __init__(self, full_name: str, url: str): """ Args: full_name: Example: k2-1.23.4.dev20230224+cuda10.1.torch1.6.0-cp36-cp36m-linux_x86_64.whl k2-1.24.4.dev20240301+cuda12.1.torch2.3.0.dev20240229-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl """ self.full_name = full_name pattern = r"k2-(\d)\.(\d)+((\.)(\d))?\.dev(\d{8})\+cuda(\d+)\.(\d+)\.torch(\d\.\d+\.\d(\.dev\d{8})?)-cp(\d+)" m = re.search(pattern, full_name) self.k2_major = int(m.group(1)) self.k2_minor = int(m.group(2)) self.k2_patch = int(m.group(5)) self.k2_date = int(m.group(6)) self.cuda_major_version = int(m.group(7)) self.cuda_minor_version = int(m.group(8)) self.torch_version = m.group(9) self.py_version = int(m.group(11)) self.url = url def __str__(self): return self.url def __repr__(self): return self.url def generate_index(filename: str, torch_versions) -> str: b = [] for i in torch_versions: b.append(f" ./{i}.rst") b = "\n".join(b) s = f"""\ Pre-compiled CUDA wheels (Linux) ================================ This page describes pre-compiled ``CUDA`` wheels for `k2`_ on Linux. .. toctree:: :maxdepth: 2 {b} """ with open(filename, "w") as f: f.write(s) def sort_by_wheel(x: Wheel): return ( x.k2_major, x.k2_minor, x.k2_patch, x.k2_date, x.cuda_major_version, x.cuda_minor_version, x.py_version, ) def sort_by_torch(x): major, minor, patch = x.split(".") return int(major), int(minor), int(patch) def get_all_torch_versions(wheels: List[Wheel]) -> List[str]: ans = set() for w in wheels: ans.add(".".join(w.torch_version.split(".")[:3])) # sort torch version from high to low ans = list(ans) ans.sort(reverse=True, key=sort_by_torch) return ans def get_doc_dir(): k2_dir = os.getenv("K2_DIR") if k2_dir is None: raise ValueError("Please set the environment variable k2_dir") cuda = Path(k2_dir) / "docs/source/installation/pre-compiled-cuda-wheels-linux" if not Path(cuda).is_dir(): raise ValueError(f"{cuda} does not exist") print(f"k2 doc cuda: {cuda}") return cuda def remove_all_files(d: str): files = get_all_files(d, "*.rst") for f in files: print(f"removing {f}") os.remove(f) def get_all_cuda_wheels(): cuda = get_all_files("cuda", suffix="*.whl") cuda += get_all_files("ubuntu-cuda", suffix="*.whl") cuda_wheels = generate_url(cuda) return cuda_wheels def generate_file(d: str, torch_version: str, wheels: List[Wheel]) -> str: s = f"torch {torch_version}\n" s += "=" * len(f"torch {torch_version}") s += "\n" * 3 wheels = filter( lambda w: ".".join(w.torch_version.split(".")[:3]) == torch_version, wheels ) wheels = list(wheels) wheels.sort(reverse=True, key=sort_by_wheel) for w in wheels: s += f"- `{w.full_name} <{w.url}>`_\n" with open(f"{d}/{torch_version}.rst", "w") as f: f.write(s) def main(): d = get_doc_dir() remove_all_files(d) urls = get_all_cuda_wheels() wheels = [] for url in urls: full_name = url.rsplit("/", maxsplit=1)[1] wheels.append(Wheel(full_name, url)) torch_versions = get_all_torch_versions(wheels) content = [] for t in torch_versions: s = generate_file(d, t, wheels) generate_index(f"{d}/index.rst", torch_versions) if __name__ == "__main__": main()