from __future__ import annotations import argparse import datetime import json import os from operator import itemgetter from pathlib import Path from typing import Callable import requests import pynvml import PIL.Image import torch from pydantic import BaseSettings, BaseModel class Args(BaseSettings): @classmethod def from_args(cls): parser = argparse.ArgumentParser() for field in cls.__fields__.values(): if issubclass(field.type_, BaseModel): prefix = field.type_.__name__.lower() for subfield in field.type_.__fields__.values(): short = "".join([x[0] for x in subfield.name.split("_")]) parser.add_argument(f"--{prefix}.{subfield.name}", default=subfield.default, required=subfield.required) else: short = "".join([x[0] for x in field.name.split("_")]) parser.add_argument(f"-{short}", f"--{field.name}", default=field.default, required=field.required) args = vars(parser.parse_known_args()[0]) to_delete = set() for field in cls.__fields__.values(): if issubclass(field.type_, BaseModel): prefix = field.type_.__name__.lower() sub_args = {} for k, v in args.items(): if k.startswith(prefix): to_delete.add(k) sub_args[k.replace(f"{prefix}.", "")] = v args[field.name] = sub_args args = {k: v for k, v in args.items() if k not in to_delete} return cls(**args) class Config: env_file = ".env" env_file_encoding = "utf-8" env_prefix = "ARG_" def inject_args(func: Callable) -> Callable: """Decorates a function to inject the arguments.""" injected = None for type_ in func.__annotations__.values(): if issubclass(type_, Args): injected = type_.from_args() break if injected is None: raise ValueError(f"Function {func.__name__} is not annotated with an Args subclass.") def wrapper(*args, **kwargs): return func(injected, *args, **kwargs) return wrapper def get_free_gpu() -> int: pynvml.nvmlInit() total = torch.cuda.device_count() gpus = [] for i in range(total): handle = pynvml.nvmlDeviceGetHandleByIndex(i) info = pynvml.nvmlDeviceGetMemoryInfo(handle) gpus.append((i, info.free)) gpus = sorted(gpus, key=itemgetter(1), reverse=True) return gpus[0][0] def get_user_name() -> str: return Path(os.environ["HOME"]).stem def get_storage_dir() -> Path: return Path(f"/fsx/{get_user_name()}") def get_checkpoints_dir(*, timestamp: bool) -> Path: base_dir = get_storage_dir()/"checkpoints" return Path(f"{base_dir}/{now()}") if timestamp else base_dir def now() -> str: return datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S") def read_jsonl(path: Path) -> list: return [json.loads(x) for x in Path(path).read_text().split("\n") if x] def write_jsonl(path: Path, data: list): with Path(path).open("w") as f: for x in data: f.write(json.dumps(x) + "\n") def get_image(url: str, filename: Path | None = None): if filename is None: filename = Path(f"{url.split('/')[-1]}.jpg") filename = Path(filename) if filename.exists(): return filename PIL.Image.open(requests.get(url, stream=True).raw).save(filename) return filename def find_latest_checkpoint(dirname: Path) -> Path: dirname = Path(dirname) checkpoints = list(dirname.glob("*.ckpt")) if not checkpoints: return None latest = max(checkpoints, key=lambda path: path.stat().st_mtime) return latest def list_files(dirname: Path, exts: list[str] | None = None) -> list: files = Path(dirname).iterdir() if not exts: return list(files) return [fn for fn in files for ext in exts if fn.match(f"*.{ext}")]