|
from __future__ import annotations |
|
import argparse |
|
import datetime |
|
import json |
|
import os |
|
from operator import itemgetter |
|
from pathlib import Path |
|
from typing import Callable |
|
|
|
import requests |
|
import pynvml |
|
import PIL.Image |
|
import torch |
|
from pydantic import BaseSettings, BaseModel |
|
|
|
|
|
class Args(BaseSettings): |
|
|
|
@classmethod |
|
def from_args(cls): |
|
parser = argparse.ArgumentParser() |
|
for field in cls.__fields__.values(): |
|
if issubclass(field.type_, BaseModel): |
|
prefix = field.type_.__name__.lower() |
|
for subfield in field.type_.__fields__.values(): |
|
short = "".join([x[0] for x in subfield.name.split("_")]) |
|
parser.add_argument(f"--{prefix}.{subfield.name}", default=subfield.default, required=subfield.required) |
|
else: |
|
short = "".join([x[0] for x in field.name.split("_")]) |
|
parser.add_argument(f"-{short}", f"--{field.name}", default=field.default, required=field.required) |
|
args = vars(parser.parse_known_args()[0]) |
|
to_delete = set() |
|
for field in cls.__fields__.values(): |
|
if issubclass(field.type_, BaseModel): |
|
prefix = field.type_.__name__.lower() |
|
sub_args = {} |
|
for k, v in args.items(): |
|
if k.startswith(prefix): |
|
to_delete.add(k) |
|
sub_args[k.replace(f"{prefix}.", "")] = v |
|
args[field.name] = sub_args |
|
args = {k: v for k, v in args.items() if k not in to_delete} |
|
return cls(**args) |
|
|
|
class Config: |
|
env_file = ".env" |
|
env_file_encoding = "utf-8" |
|
env_prefix = "ARG_" |
|
|
|
|
|
def inject_args(func: Callable) -> Callable: |
|
"""Decorates a function to inject the arguments.""" |
|
|
|
injected = None |
|
for type_ in func.__annotations__.values(): |
|
if issubclass(type_, Args): |
|
injected = type_.from_args() |
|
break |
|
|
|
if injected is None: |
|
raise ValueError(f"Function {func.__name__} is not annotated with an Args subclass.") |
|
|
|
def wrapper(*args, **kwargs): |
|
return func(injected, *args, **kwargs) |
|
|
|
return wrapper |
|
|
|
|
|
def get_free_gpu() -> int: |
|
pynvml.nvmlInit() |
|
total = torch.cuda.device_count() |
|
gpus = [] |
|
for i in range(total): |
|
handle = pynvml.nvmlDeviceGetHandleByIndex(i) |
|
info = pynvml.nvmlDeviceGetMemoryInfo(handle) |
|
gpus.append((i, info.free)) |
|
gpus = sorted(gpus, key=itemgetter(1), reverse=True) |
|
return gpus[0][0] |
|
|
|
|
|
def get_user_name() -> str: |
|
return Path(os.environ["HOME"]).stem |
|
|
|
|
|
def get_storage_dir() -> Path: |
|
return Path(f"/fsx/{get_user_name()}") |
|
|
|
|
|
def get_checkpoints_dir(*, timestamp: bool) -> Path: |
|
base_dir = get_storage_dir()/"checkpoints" |
|
return Path(f"{base_dir}/{now()}") if timestamp else base_dir |
|
|
|
|
|
def now() -> str: |
|
return datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
def read_jsonl(path: Path) -> list: |
|
return [json.loads(x) for x in Path(path).read_text().split("\n") if x] |
|
|
|
|
|
def write_jsonl(path: Path, data: list): |
|
with Path(path).open("w") as f: |
|
for x in data: |
|
f.write(json.dumps(x) + "\n") |
|
|
|
|
|
def get_image(url: str, filename: Path | None = None): |
|
if filename is None: filename = Path(f"{url.split('/')[-1]}.jpg") |
|
filename = Path(filename) |
|
if filename.exists(): return filename |
|
PIL.Image.open(requests.get(url, stream=True).raw).save(filename) |
|
return filename |
|
|
|
|
|
def find_latest_checkpoint(dirname: Path) -> Path: |
|
dirname = Path(dirname) |
|
checkpoints = list(dirname.glob("*.ckpt")) |
|
if not checkpoints: |
|
return None |
|
latest = max(checkpoints, key=lambda path: path.stat().st_mtime) |
|
return latest |
|
|
|
|
|
def list_files(dirname: Path, exts: list[str] | None = None) -> list: |
|
files = Path(dirname).iterdir() |
|
if not exts: |
|
return list(files) |
|
return [fn for fn in files for ext in exts if fn.match(f"*.{ext}")] |
|
|